sqlmesh.utils.concurrency
1import typing as t 2from concurrent.futures import Executor, Future, ThreadPoolExecutor 3from threading import Lock 4 5from sqlmesh.core.snapshot import SnapshotId, SnapshotInfoLike 6from sqlmesh.utils.dag import DAG 7from sqlmesh.utils.errors import ConfigError, SQLMeshError 8 9H = t.TypeVar("H", bound=t.Hashable) 10S = t.TypeVar("S", bound=SnapshotInfoLike) 11A = t.TypeVar("A") 12R = t.TypeVar("R") 13 14 15class NodeExecutionFailedError(t.Generic[H], SQLMeshError): 16 def __init__(self, node: H): 17 self.node = node 18 super().__init__(f"Execution failed for node {node}") 19 20 21class ConcurrentDAGExecutor(t.Generic[H]): 22 """Concurrently traverses the given DAG in topological order while applying a function to each node. 23 24 If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes. 25 26 Args: 27 dag: The target DAG. 28 fn: The function that will be applied concurrently to each snapshot. 29 tasks_num: The number of concurrent tasks. 30 raise_on_error: If set to True raises an exception on a first encountered error, 31 otherwises returns a tuple which contains a list of failed nodes and a list of 32 skipped nodes. 33 """ 34 35 def __init__( 36 self, 37 dag: DAG[H], 38 fn: t.Callable[[H], None], 39 tasks_num: int, 40 raise_on_error: bool, 41 ): 42 self.dag = dag 43 self.fn = fn 44 self.tasks_num = tasks_num 45 self.raise_on_error = raise_on_error 46 47 self._init_state() 48 49 def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 50 """Runs the executor. 51 52 Raises: 53 NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot. 54 55 Returns: 56 A pair which contains a list of node errors and a list of skipped nodes. 57 """ 58 if self._finished_future.done(): 59 self._init_state() 60 61 with ThreadPoolExecutor(max_workers=self.tasks_num) as pool: 62 with self._unprocessed_nodes_lock: 63 self._submit_next_nodes(pool) 64 self._finished_future.result() 65 return self._node_errors, self._skipped_nodes 66 67 def _process_node(self, node: H, executor: Executor) -> None: 68 try: 69 self.fn(node) 70 71 with self._unprocessed_nodes_lock: 72 self._unprocessed_nodes_num -= 1 73 self._submit_next_nodes(executor, node) 74 except Exception as ex: 75 error = NodeExecutionFailedError(node) 76 error.__cause__ = ex 77 78 if self.raise_on_error: 79 self._finished_future.set_exception(error) 80 return 81 82 with self._unprocessed_nodes_lock: 83 self._unprocessed_nodes_num -= 1 84 self._node_errors.append(error) 85 self._skip_next_nodes(node) 86 87 def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None: 88 if not self._unprocessed_nodes_num: 89 self._finished_future.set_result(None) 90 return 91 92 submitted_nodes = [] 93 for next_node, deps in self._unprocessed_nodes.items(): 94 if processed_node: 95 deps.discard(processed_node) 96 if not deps: 97 submitted_nodes.append(next_node) 98 99 for submitted_node in submitted_nodes: 100 self._unprocessed_nodes.pop(submitted_node) 101 executor.submit(self._process_node, submitted_node, executor) 102 103 def _skip_next_nodes(self, parent: H) -> None: 104 if not self._unprocessed_nodes_num: 105 self._finished_future.set_result(None) 106 return 107 108 skipped_nodes = [node for node, deps in self._unprocessed_nodes.items() if parent in deps] 109 110 self._skipped_nodes.extend(skipped_nodes) 111 112 for skipped_node in skipped_nodes: 113 self._unprocessed_nodes_num -= 1 114 self._unprocessed_nodes.pop(skipped_node) 115 116 for skipped_node in skipped_nodes: 117 self._skip_next_nodes(skipped_node) 118 119 def _init_state(self) -> None: 120 self._unprocessed_nodes = self.dag.graph 121 self._unprocessed_nodes_num = len(self._unprocessed_nodes) 122 self._unprocessed_nodes_lock = Lock() 123 self._finished_future = Future() # type: ignore 124 125 self._node_errors: t.List[NodeExecutionFailedError[H]] = [] 126 self._skipped_nodes: t.List[H] = [] 127 128 129def concurrent_apply_to_snapshots( 130 snapshots: t.Iterable[S], 131 fn: t.Callable[[S], None], 132 tasks_num: int, 133 reverse_order: bool = False, 134 raise_on_error: bool = True, 135) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]: 136 """Applies a function to the given collection of snapshots concurrently while 137 preserving the topological order between snapshots. 138 139 Args: 140 snapshots: Target snapshots. 141 fn: The function that will be applied concurrently to each snapshot. 142 tasks_num: The number of concurrent tasks. 143 reverse_order: Whether the order should be reversed. Default: False. 144 raise_on_error: If set to True raises an exception on a first encountered error, 145 otherwises returns a tuple which contains a list of failed nodes and a list of 146 skipped nodes. 147 148 Raises: 149 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 150 151 Returns: 152 A pair which contains a list of errors and a list of skipped snapshot IDs. 153 """ 154 snapshots_by_id = {s.snapshot_id: s for s in snapshots} 155 156 dag: DAG[SnapshotId] = DAG[SnapshotId]() 157 for snapshot in snapshots: 158 dag.add( 159 snapshot.snapshot_id, 160 [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id], 161 ) 162 163 return concurrent_apply_to_dag( 164 dag if not reverse_order else dag.reversed, 165 lambda s_id: fn(snapshots_by_id[s_id]), 166 tasks_num, 167 raise_on_error=raise_on_error, 168 ) 169 170 171def concurrent_apply_to_dag( 172 dag: DAG[H], 173 fn: t.Callable[[H], None], 174 tasks_num: int, 175 raise_on_error: bool = True, 176) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 177 """Applies a function to the given DAG concurrently while preserving the topological 178 order between snapshots. 179 180 Args: 181 dag: The target DAG. 182 fn: The function that will be applied concurrently to each snapshot. 183 tasks_num: The number of concurrent tasks. 184 raise_on_error: If set to True raises an exception on a first encountered error, 185 otherwises returns a tuple which contains a list of failed nodes and a list of 186 skipped nodes. 187 188 Raises: 189 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 190 191 Returns: 192 A pair which contains a list of node errors and a list of skipped nodes. 193 """ 194 if tasks_num <= 0: 195 raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}") 196 197 if tasks_num == 1: 198 return sequential_apply_to_dag(dag, fn, raise_on_error) 199 200 return ConcurrentDAGExecutor( 201 dag, 202 fn, 203 tasks_num, 204 raise_on_error, 205 ).run() 206 207 208def sequential_apply_to_dag( 209 dag: DAG[H], 210 fn: t.Callable[[H], None], 211 raise_on_error: bool = True, 212) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 213 dependencies = dag.graph 214 215 node_errors: t.List[NodeExecutionFailedError[H]] = [] 216 skipped_nodes: t.List[H] = [] 217 218 failed_or_skipped_nodes: t.Set[H] = set() 219 220 for node in dag.sorted: 221 if not failed_or_skipped_nodes.isdisjoint(dependencies[node]): 222 skipped_nodes.append(node) 223 failed_or_skipped_nodes.add(node) 224 continue 225 226 try: 227 fn(node) 228 except Exception as ex: 229 if raise_on_error: 230 raise NodeExecutionFailedError(node) from ex 231 232 error = NodeExecutionFailedError(node) 233 error.__cause__ = ex 234 235 node_errors.append(error) 236 failed_or_skipped_nodes.add(node) 237 238 return node_errors, skipped_nodes 239 240 241def concurrent_apply_to_values( 242 values: t.Sequence[A], 243 fn: t.Callable[[A], R], 244 tasks_num: int, 245) -> t.List[R]: 246 """Applies a function to the given collection of values concurrently. 247 248 Args: 249 values: Target values. 250 fn: The function that will be applied concurrently to each value. 251 tasks_num: The number of concurrent tasks. 252 253 Returns: 254 A list of results. 255 """ 256 if tasks_num == 1: 257 return [fn(value) for value in values] 258 259 futures: t.List[Future] = [Future() for _ in values] 260 261 def _process_value(value: A, index: int) -> None: 262 try: 263 futures[index].set_result(fn(value)) 264 except Exception as ex: 265 futures[index].set_exception(ex) 266 267 with ThreadPoolExecutor(max_workers=tasks_num) as pool: 268 for index, value in enumerate(values): 269 pool.submit(_process_value, value, index) 270 271 return [f.result() for f in futures]
16class NodeExecutionFailedError(t.Generic[H], SQLMeshError): 17 def __init__(self, node: H): 18 self.node = node 19 super().__init__(f"Execution failed for node {node}")
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
Inherited Members
- builtins.BaseException
- with_traceback
22class ConcurrentDAGExecutor(t.Generic[H]): 23 """Concurrently traverses the given DAG in topological order while applying a function to each node. 24 25 If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes. 26 27 Args: 28 dag: The target DAG. 29 fn: The function that will be applied concurrently to each snapshot. 30 tasks_num: The number of concurrent tasks. 31 raise_on_error: If set to True raises an exception on a first encountered error, 32 otherwises returns a tuple which contains a list of failed nodes and a list of 33 skipped nodes. 34 """ 35 36 def __init__( 37 self, 38 dag: DAG[H], 39 fn: t.Callable[[H], None], 40 tasks_num: int, 41 raise_on_error: bool, 42 ): 43 self.dag = dag 44 self.fn = fn 45 self.tasks_num = tasks_num 46 self.raise_on_error = raise_on_error 47 48 self._init_state() 49 50 def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 51 """Runs the executor. 52 53 Raises: 54 NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot. 55 56 Returns: 57 A pair which contains a list of node errors and a list of skipped nodes. 58 """ 59 if self._finished_future.done(): 60 self._init_state() 61 62 with ThreadPoolExecutor(max_workers=self.tasks_num) as pool: 63 with self._unprocessed_nodes_lock: 64 self._submit_next_nodes(pool) 65 self._finished_future.result() 66 return self._node_errors, self._skipped_nodes 67 68 def _process_node(self, node: H, executor: Executor) -> None: 69 try: 70 self.fn(node) 71 72 with self._unprocessed_nodes_lock: 73 self._unprocessed_nodes_num -= 1 74 self._submit_next_nodes(executor, node) 75 except Exception as ex: 76 error = NodeExecutionFailedError(node) 77 error.__cause__ = ex 78 79 if self.raise_on_error: 80 self._finished_future.set_exception(error) 81 return 82 83 with self._unprocessed_nodes_lock: 84 self._unprocessed_nodes_num -= 1 85 self._node_errors.append(error) 86 self._skip_next_nodes(node) 87 88 def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None: 89 if not self._unprocessed_nodes_num: 90 self._finished_future.set_result(None) 91 return 92 93 submitted_nodes = [] 94 for next_node, deps in self._unprocessed_nodes.items(): 95 if processed_node: 96 deps.discard(processed_node) 97 if not deps: 98 submitted_nodes.append(next_node) 99 100 for submitted_node in submitted_nodes: 101 self._unprocessed_nodes.pop(submitted_node) 102 executor.submit(self._process_node, submitted_node, executor) 103 104 def _skip_next_nodes(self, parent: H) -> None: 105 if not self._unprocessed_nodes_num: 106 self._finished_future.set_result(None) 107 return 108 109 skipped_nodes = [node for node, deps in self._unprocessed_nodes.items() if parent in deps] 110 111 self._skipped_nodes.extend(skipped_nodes) 112 113 for skipped_node in skipped_nodes: 114 self._unprocessed_nodes_num -= 1 115 self._unprocessed_nodes.pop(skipped_node) 116 117 for skipped_node in skipped_nodes: 118 self._skip_next_nodes(skipped_node) 119 120 def _init_state(self) -> None: 121 self._unprocessed_nodes = self.dag.graph 122 self._unprocessed_nodes_num = len(self._unprocessed_nodes) 123 self._unprocessed_nodes_lock = Lock() 124 self._finished_future = Future() # type: ignore 125 126 self._node_errors: t.List[NodeExecutionFailedError[H]] = [] 127 self._skipped_nodes: t.List[H] = []
Concurrently traverses the given DAG in topological order while applying a function to each node.
If raise_on_error
is set to False maintains a state of execution errors as well as of skipped nodes.
Arguments:
- dag: The target DAG.
- fn: The function that will be applied concurrently to each snapshot.
- tasks_num: The number of concurrent tasks.
- raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
50 def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 51 """Runs the executor. 52 53 Raises: 54 NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot. 55 56 Returns: 57 A pair which contains a list of node errors and a list of skipped nodes. 58 """ 59 if self._finished_future.done(): 60 self._init_state() 61 62 with ThreadPoolExecutor(max_workers=self.tasks_num) as pool: 63 with self._unprocessed_nodes_lock: 64 self._submit_next_nodes(pool) 65 self._finished_future.result() 66 return self._node_errors, self._skipped_nodes
Runs the executor.
Raises:
- NodeExecutionFailedError if
raise_on_error
was set to True and execution fails for any snapshot.
Returns:
A pair which contains a list of node errors and a list of skipped nodes.
130def concurrent_apply_to_snapshots( 131 snapshots: t.Iterable[S], 132 fn: t.Callable[[S], None], 133 tasks_num: int, 134 reverse_order: bool = False, 135 raise_on_error: bool = True, 136) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]: 137 """Applies a function to the given collection of snapshots concurrently while 138 preserving the topological order between snapshots. 139 140 Args: 141 snapshots: Target snapshots. 142 fn: The function that will be applied concurrently to each snapshot. 143 tasks_num: The number of concurrent tasks. 144 reverse_order: Whether the order should be reversed. Default: False. 145 raise_on_error: If set to True raises an exception on a first encountered error, 146 otherwises returns a tuple which contains a list of failed nodes and a list of 147 skipped nodes. 148 149 Raises: 150 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 151 152 Returns: 153 A pair which contains a list of errors and a list of skipped snapshot IDs. 154 """ 155 snapshots_by_id = {s.snapshot_id: s for s in snapshots} 156 157 dag: DAG[SnapshotId] = DAG[SnapshotId]() 158 for snapshot in snapshots: 159 dag.add( 160 snapshot.snapshot_id, 161 [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id], 162 ) 163 164 return concurrent_apply_to_dag( 165 dag if not reverse_order else dag.reversed, 166 lambda s_id: fn(snapshots_by_id[s_id]), 167 tasks_num, 168 raise_on_error=raise_on_error, 169 )
Applies a function to the given collection of snapshots concurrently while preserving the topological order between snapshots.
Arguments:
- snapshots: Target snapshots.
- fn: The function that will be applied concurrently to each snapshot.
- tasks_num: The number of concurrent tasks.
- reverse_order: Whether the order should be reversed. Default: False.
- raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
- NodeExecutionFailedError if
raise_on_error
is set to True and execution fails for any snapshot.
Returns:
A pair which contains a list of errors and a list of skipped snapshot IDs.
172def concurrent_apply_to_dag( 173 dag: DAG[H], 174 fn: t.Callable[[H], None], 175 tasks_num: int, 176 raise_on_error: bool = True, 177) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 178 """Applies a function to the given DAG concurrently while preserving the topological 179 order between snapshots. 180 181 Args: 182 dag: The target DAG. 183 fn: The function that will be applied concurrently to each snapshot. 184 tasks_num: The number of concurrent tasks. 185 raise_on_error: If set to True raises an exception on a first encountered error, 186 otherwises returns a tuple which contains a list of failed nodes and a list of 187 skipped nodes. 188 189 Raises: 190 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 191 192 Returns: 193 A pair which contains a list of node errors and a list of skipped nodes. 194 """ 195 if tasks_num <= 0: 196 raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}") 197 198 if tasks_num == 1: 199 return sequential_apply_to_dag(dag, fn, raise_on_error) 200 201 return ConcurrentDAGExecutor( 202 dag, 203 fn, 204 tasks_num, 205 raise_on_error, 206 ).run()
Applies a function to the given DAG concurrently while preserving the topological order between snapshots.
Arguments:
- dag: The target DAG.
- fn: The function that will be applied concurrently to each snapshot.
- tasks_num: The number of concurrent tasks.
- raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
- NodeExecutionFailedError if
raise_on_error
is set to True and execution fails for any snapshot.
Returns:
A pair which contains a list of node errors and a list of skipped nodes.
209def sequential_apply_to_dag( 210 dag: DAG[H], 211 fn: t.Callable[[H], None], 212 raise_on_error: bool = True, 213) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 214 dependencies = dag.graph 215 216 node_errors: t.List[NodeExecutionFailedError[H]] = [] 217 skipped_nodes: t.List[H] = [] 218 219 failed_or_skipped_nodes: t.Set[H] = set() 220 221 for node in dag.sorted: 222 if not failed_or_skipped_nodes.isdisjoint(dependencies[node]): 223 skipped_nodes.append(node) 224 failed_or_skipped_nodes.add(node) 225 continue 226 227 try: 228 fn(node) 229 except Exception as ex: 230 if raise_on_error: 231 raise NodeExecutionFailedError(node) from ex 232 233 error = NodeExecutionFailedError(node) 234 error.__cause__ = ex 235 236 node_errors.append(error) 237 failed_or_skipped_nodes.add(node) 238 239 return node_errors, skipped_nodes
242def concurrent_apply_to_values( 243 values: t.Sequence[A], 244 fn: t.Callable[[A], R], 245 tasks_num: int, 246) -> t.List[R]: 247 """Applies a function to the given collection of values concurrently. 248 249 Args: 250 values: Target values. 251 fn: The function that will be applied concurrently to each value. 252 tasks_num: The number of concurrent tasks. 253 254 Returns: 255 A list of results. 256 """ 257 if tasks_num == 1: 258 return [fn(value) for value in values] 259 260 futures: t.List[Future] = [Future() for _ in values] 261 262 def _process_value(value: A, index: int) -> None: 263 try: 264 futures[index].set_result(fn(value)) 265 except Exception as ex: 266 futures[index].set_exception(ex) 267 268 with ThreadPoolExecutor(max_workers=tasks_num) as pool: 269 for index, value in enumerate(values): 270 pool.submit(_process_value, value, index) 271 272 return [f.result() for f in futures]
Applies a function to the given collection of values concurrently.
Arguments:
- values: Target values.
- fn: The function that will be applied concurrently to each value.
- tasks_num: The number of concurrent tasks.
Returns:
A list of results.