sqlmesh.utils.concurrency
1import typing as t 2from concurrent.futures import Executor, Future, ThreadPoolExecutor 3from threading import Lock 4 5from sqlmesh.core.snapshot import SnapshotId, SnapshotInfoLike 6from sqlmesh.utils.dag import DAG 7from sqlmesh.utils.errors import ConfigError, SQLMeshError 8 9H = t.TypeVar("H", bound=t.Hashable) 10S = t.TypeVar("S", bound=SnapshotInfoLike) 11A = t.TypeVar("A") 12R = t.TypeVar("R") 13 14 15class NodeExecutionFailedError(t.Generic[H], SQLMeshError): 16 def __init__(self, node: H): 17 self.node = node 18 super().__init__(f"Execution failed for node {node}") 19 20 21class ConcurrentDAGExecutor(t.Generic[H]): 22 """Concurrently traverses the given DAG in topological order while applying a function to each node. 23 24 If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes. 25 26 Args: 27 dag: The target DAG. 28 fn: The function that will be applied concurrently to each snapshot. 29 tasks_num: The number of concurrent tasks. 30 raise_on_error: If set to True raises an exception on a first encountered error, 31 otherwises returns a tuple which contains a list of failed nodes and a list of 32 skipped nodes. 33 """ 34 35 def __init__( 36 self, 37 dag: DAG[H], 38 fn: t.Callable[[H], None], 39 tasks_num: int, 40 raise_on_error: bool, 41 ): 42 self.dag = dag 43 self.fn = fn 44 self.tasks_num = tasks_num 45 self.raise_on_error = raise_on_error 46 47 self._init_state() 48 49 def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 50 """Runs the executor. 51 52 Raises: 53 NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot. 54 55 Returns: 56 A pair which contains a list of node errors and a list of skipped nodes. 57 """ 58 if self._finished_future.done(): 59 self._init_state() 60 61 with ThreadPoolExecutor(max_workers=self.tasks_num) as pool: 62 with self._unprocessed_nodes_lock: 63 self._submit_next_nodes(pool) 64 self._finished_future.result() 65 return self._node_errors, self._skipped_nodes 66 67 def _process_node(self, node: H, executor: Executor) -> None: 68 try: 69 self.fn(node) 70 71 with self._unprocessed_nodes_lock: 72 self._unprocessed_nodes_num -= 1 73 self._submit_next_nodes(executor, node) 74 except Exception as ex: 75 error = NodeExecutionFailedError(node) 76 error.__cause__ = ex 77 78 if self.raise_on_error: 79 self._finished_future.set_exception(error) 80 return 81 82 with self._unprocessed_nodes_lock: 83 self._unprocessed_nodes_num -= 1 84 self._node_errors.append(error) 85 self._skip_next_nodes(node) 86 87 def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None: 88 if not self._unprocessed_nodes_num: 89 self._finished_future.set_result(None) 90 return 91 92 submitted_nodes = [] 93 for next_node, deps in self._unprocessed_nodes.items(): 94 if processed_node: 95 deps.discard(processed_node) 96 if not deps: 97 submitted_nodes.append(next_node) 98 99 for submitted_node in submitted_nodes: 100 self._unprocessed_nodes.pop(submitted_node) 101 executor.submit(self._process_node, submitted_node, executor) 102 103 def _skip_next_nodes(self, parent: H) -> None: 104 if not self._unprocessed_nodes_num: 105 self._finished_future.set_result(None) 106 return 107 108 skipped_nodes = {node for node, deps in self._unprocessed_nodes.items() if parent in deps} 109 110 while skipped_nodes: 111 self._skipped_nodes.extend(skipped_nodes) 112 113 for skipped_node in skipped_nodes: 114 self._unprocessed_nodes_num -= 1 115 self._unprocessed_nodes.pop(skipped_node) 116 117 skipped_nodes = { 118 node 119 for node, deps in self._unprocessed_nodes.items() 120 if skipped_nodes.intersection(deps) 121 } 122 123 if not self._unprocessed_nodes_num: 124 self._finished_future.set_result(None) 125 126 def _init_state(self) -> None: 127 self._unprocessed_nodes = self.dag.graph 128 self._unprocessed_nodes_num = len(self._unprocessed_nodes) 129 self._unprocessed_nodes_lock = Lock() 130 self._finished_future = Future() # type: ignore 131 132 self._node_errors: t.List[NodeExecutionFailedError[H]] = [] 133 self._skipped_nodes: t.List[H] = [] 134 135 136def concurrent_apply_to_snapshots( 137 snapshots: t.Iterable[S], 138 fn: t.Callable[[S], None], 139 tasks_num: int, 140 reverse_order: bool = False, 141 raise_on_error: bool = True, 142) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]: 143 """Applies a function to the given collection of snapshots concurrently while 144 preserving the topological order between snapshots. 145 146 Args: 147 snapshots: Target snapshots. 148 fn: The function that will be applied concurrently to each snapshot. 149 tasks_num: The number of concurrent tasks. 150 reverse_order: Whether the order should be reversed. Default: False. 151 raise_on_error: If set to True raises an exception on a first encountered error, 152 otherwises returns a tuple which contains a list of failed nodes and a list of 153 skipped nodes. 154 155 Raises: 156 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 157 158 Returns: 159 A pair which contains a list of errors and a list of skipped snapshot IDs. 160 """ 161 snapshots_by_id = {s.snapshot_id: s for s in snapshots} 162 163 dag: DAG[SnapshotId] = DAG[SnapshotId]() 164 for snapshot in snapshots: 165 dag.add( 166 snapshot.snapshot_id, 167 [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id], 168 ) 169 170 return concurrent_apply_to_dag( 171 dag if not reverse_order else dag.reversed, 172 lambda s_id: fn(snapshots_by_id[s_id]), 173 tasks_num, 174 raise_on_error=raise_on_error, 175 ) 176 177 178def concurrent_apply_to_dag( 179 dag: DAG[H], 180 fn: t.Callable[[H], None], 181 tasks_num: int, 182 raise_on_error: bool = True, 183) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 184 """Applies a function to the given DAG concurrently while preserving the topological 185 order between snapshots. 186 187 Args: 188 dag: The target DAG. 189 fn: The function that will be applied concurrently to each snapshot. 190 tasks_num: The number of concurrent tasks. 191 raise_on_error: If set to True raises an exception on a first encountered error, 192 otherwises returns a tuple which contains a list of failed nodes and a list of 193 skipped nodes. 194 195 Raises: 196 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 197 198 Returns: 199 A pair which contains a list of node errors and a list of skipped nodes. 200 """ 201 if tasks_num <= 0: 202 raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}") 203 204 if tasks_num == 1: 205 return sequential_apply_to_dag(dag, fn, raise_on_error) 206 207 return ConcurrentDAGExecutor( 208 dag, 209 fn, 210 tasks_num, 211 raise_on_error, 212 ).run() 213 214 215def sequential_apply_to_dag( 216 dag: DAG[H], 217 fn: t.Callable[[H], None], 218 raise_on_error: bool = True, 219) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 220 dependencies = dag.graph 221 222 node_errors: t.List[NodeExecutionFailedError[H]] = [] 223 skipped_nodes: t.List[H] = [] 224 225 failed_or_skipped_nodes: t.Set[H] = set() 226 227 for node in dag.sorted: 228 if not failed_or_skipped_nodes.isdisjoint(dependencies[node]): 229 skipped_nodes.append(node) 230 failed_or_skipped_nodes.add(node) 231 continue 232 233 try: 234 fn(node) 235 except Exception as ex: 236 error = NodeExecutionFailedError(node) 237 error.__cause__ = ex 238 239 if raise_on_error: 240 raise error 241 242 node_errors.append(error) 243 failed_or_skipped_nodes.add(node) 244 245 return node_errors, skipped_nodes 246 247 248def concurrent_apply_to_values( 249 values: t.Sequence[A], 250 fn: t.Callable[[A], R], 251 tasks_num: int, 252) -> t.List[R]: 253 """Applies a function to the given collection of values concurrently. 254 255 Args: 256 values: Target values. 257 fn: The function that will be applied concurrently to each value. 258 tasks_num: The number of concurrent tasks. 259 260 Returns: 261 A list of results. 262 """ 263 if tasks_num == 1: 264 return [fn(value) for value in values] 265 266 futures: t.List[Future] = [Future() for _ in values] 267 268 def _process_value(value: A, index: int) -> None: 269 try: 270 futures[index].set_result(fn(value)) 271 except Exception as ex: 272 futures[index].set_exception(ex) 273 274 with ThreadPoolExecutor(max_workers=tasks_num) as pool: 275 for index, value in enumerate(values): 276 pool.submit(_process_value, value, index) 277 278 return [f.result() for f in futures]
16class NodeExecutionFailedError(t.Generic[H], SQLMeshError): 17 def __init__(self, node: H): 18 self.node = node 19 super().__init__(f"Execution failed for node {node}")
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
Inherited Members
- builtins.BaseException
- with_traceback
- args
22class ConcurrentDAGExecutor(t.Generic[H]): 23 """Concurrently traverses the given DAG in topological order while applying a function to each node. 24 25 If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes. 26 27 Args: 28 dag: The target DAG. 29 fn: The function that will be applied concurrently to each snapshot. 30 tasks_num: The number of concurrent tasks. 31 raise_on_error: If set to True raises an exception on a first encountered error, 32 otherwises returns a tuple which contains a list of failed nodes and a list of 33 skipped nodes. 34 """ 35 36 def __init__( 37 self, 38 dag: DAG[H], 39 fn: t.Callable[[H], None], 40 tasks_num: int, 41 raise_on_error: bool, 42 ): 43 self.dag = dag 44 self.fn = fn 45 self.tasks_num = tasks_num 46 self.raise_on_error = raise_on_error 47 48 self._init_state() 49 50 def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 51 """Runs the executor. 52 53 Raises: 54 NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot. 55 56 Returns: 57 A pair which contains a list of node errors and a list of skipped nodes. 58 """ 59 if self._finished_future.done(): 60 self._init_state() 61 62 with ThreadPoolExecutor(max_workers=self.tasks_num) as pool: 63 with self._unprocessed_nodes_lock: 64 self._submit_next_nodes(pool) 65 self._finished_future.result() 66 return self._node_errors, self._skipped_nodes 67 68 def _process_node(self, node: H, executor: Executor) -> None: 69 try: 70 self.fn(node) 71 72 with self._unprocessed_nodes_lock: 73 self._unprocessed_nodes_num -= 1 74 self._submit_next_nodes(executor, node) 75 except Exception as ex: 76 error = NodeExecutionFailedError(node) 77 error.__cause__ = ex 78 79 if self.raise_on_error: 80 self._finished_future.set_exception(error) 81 return 82 83 with self._unprocessed_nodes_lock: 84 self._unprocessed_nodes_num -= 1 85 self._node_errors.append(error) 86 self._skip_next_nodes(node) 87 88 def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None: 89 if not self._unprocessed_nodes_num: 90 self._finished_future.set_result(None) 91 return 92 93 submitted_nodes = [] 94 for next_node, deps in self._unprocessed_nodes.items(): 95 if processed_node: 96 deps.discard(processed_node) 97 if not deps: 98 submitted_nodes.append(next_node) 99 100 for submitted_node in submitted_nodes: 101 self._unprocessed_nodes.pop(submitted_node) 102 executor.submit(self._process_node, submitted_node, executor) 103 104 def _skip_next_nodes(self, parent: H) -> None: 105 if not self._unprocessed_nodes_num: 106 self._finished_future.set_result(None) 107 return 108 109 skipped_nodes = {node for node, deps in self._unprocessed_nodes.items() if parent in deps} 110 111 while skipped_nodes: 112 self._skipped_nodes.extend(skipped_nodes) 113 114 for skipped_node in skipped_nodes: 115 self._unprocessed_nodes_num -= 1 116 self._unprocessed_nodes.pop(skipped_node) 117 118 skipped_nodes = { 119 node 120 for node, deps in self._unprocessed_nodes.items() 121 if skipped_nodes.intersection(deps) 122 } 123 124 if not self._unprocessed_nodes_num: 125 self._finished_future.set_result(None) 126 127 def _init_state(self) -> None: 128 self._unprocessed_nodes = self.dag.graph 129 self._unprocessed_nodes_num = len(self._unprocessed_nodes) 130 self._unprocessed_nodes_lock = Lock() 131 self._finished_future = Future() # type: ignore 132 133 self._node_errors: t.List[NodeExecutionFailedError[H]] = [] 134 self._skipped_nodes: t.List[H] = []
Concurrently traverses the given DAG in topological order while applying a function to each node.
If raise_on_error is set to False maintains a state of execution errors as well as of skipped nodes.
Arguments:
- dag: The target DAG.
- fn: The function that will be applied concurrently to each snapshot.
- tasks_num: The number of concurrent tasks.
- raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
50 def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 51 """Runs the executor. 52 53 Raises: 54 NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot. 55 56 Returns: 57 A pair which contains a list of node errors and a list of skipped nodes. 58 """ 59 if self._finished_future.done(): 60 self._init_state() 61 62 with ThreadPoolExecutor(max_workers=self.tasks_num) as pool: 63 with self._unprocessed_nodes_lock: 64 self._submit_next_nodes(pool) 65 self._finished_future.result() 66 return self._node_errors, self._skipped_nodes
Runs the executor.
Raises:
- NodeExecutionFailedError if
raise_on_errorwas set to True and execution fails for any snapshot.
Returns:
A pair which contains a list of node errors and a list of skipped nodes.
137def concurrent_apply_to_snapshots( 138 snapshots: t.Iterable[S], 139 fn: t.Callable[[S], None], 140 tasks_num: int, 141 reverse_order: bool = False, 142 raise_on_error: bool = True, 143) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]: 144 """Applies a function to the given collection of snapshots concurrently while 145 preserving the topological order between snapshots. 146 147 Args: 148 snapshots: Target snapshots. 149 fn: The function that will be applied concurrently to each snapshot. 150 tasks_num: The number of concurrent tasks. 151 reverse_order: Whether the order should be reversed. Default: False. 152 raise_on_error: If set to True raises an exception on a first encountered error, 153 otherwises returns a tuple which contains a list of failed nodes and a list of 154 skipped nodes. 155 156 Raises: 157 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 158 159 Returns: 160 A pair which contains a list of errors and a list of skipped snapshot IDs. 161 """ 162 snapshots_by_id = {s.snapshot_id: s for s in snapshots} 163 164 dag: DAG[SnapshotId] = DAG[SnapshotId]() 165 for snapshot in snapshots: 166 dag.add( 167 snapshot.snapshot_id, 168 [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id], 169 ) 170 171 return concurrent_apply_to_dag( 172 dag if not reverse_order else dag.reversed, 173 lambda s_id: fn(snapshots_by_id[s_id]), 174 tasks_num, 175 raise_on_error=raise_on_error, 176 )
Applies a function to the given collection of snapshots concurrently while preserving the topological order between snapshots.
Arguments:
- snapshots: Target snapshots.
- fn: The function that will be applied concurrently to each snapshot.
- tasks_num: The number of concurrent tasks.
- reverse_order: Whether the order should be reversed. Default: False.
- raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
- NodeExecutionFailedError if
raise_on_erroris set to True and execution fails for any snapshot.
Returns:
A pair which contains a list of errors and a list of skipped snapshot IDs.
179def concurrent_apply_to_dag( 180 dag: DAG[H], 181 fn: t.Callable[[H], None], 182 tasks_num: int, 183 raise_on_error: bool = True, 184) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 185 """Applies a function to the given DAG concurrently while preserving the topological 186 order between snapshots. 187 188 Args: 189 dag: The target DAG. 190 fn: The function that will be applied concurrently to each snapshot. 191 tasks_num: The number of concurrent tasks. 192 raise_on_error: If set to True raises an exception on a first encountered error, 193 otherwises returns a tuple which contains a list of failed nodes and a list of 194 skipped nodes. 195 196 Raises: 197 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 198 199 Returns: 200 A pair which contains a list of node errors and a list of skipped nodes. 201 """ 202 if tasks_num <= 0: 203 raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}") 204 205 if tasks_num == 1: 206 return sequential_apply_to_dag(dag, fn, raise_on_error) 207 208 return ConcurrentDAGExecutor( 209 dag, 210 fn, 211 tasks_num, 212 raise_on_error, 213 ).run()
Applies a function to the given DAG concurrently while preserving the topological order between snapshots.
Arguments:
- dag: The target DAG.
- fn: The function that will be applied concurrently to each snapshot.
- tasks_num: The number of concurrent tasks.
- raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
- NodeExecutionFailedError if
raise_on_erroris set to True and execution fails for any snapshot.
Returns:
A pair which contains a list of node errors and a list of skipped nodes.
216def sequential_apply_to_dag( 217 dag: DAG[H], 218 fn: t.Callable[[H], None], 219 raise_on_error: bool = True, 220) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 221 dependencies = dag.graph 222 223 node_errors: t.List[NodeExecutionFailedError[H]] = [] 224 skipped_nodes: t.List[H] = [] 225 226 failed_or_skipped_nodes: t.Set[H] = set() 227 228 for node in dag.sorted: 229 if not failed_or_skipped_nodes.isdisjoint(dependencies[node]): 230 skipped_nodes.append(node) 231 failed_or_skipped_nodes.add(node) 232 continue 233 234 try: 235 fn(node) 236 except Exception as ex: 237 error = NodeExecutionFailedError(node) 238 error.__cause__ = ex 239 240 if raise_on_error: 241 raise error 242 243 node_errors.append(error) 244 failed_or_skipped_nodes.add(node) 245 246 return node_errors, skipped_nodes
249def concurrent_apply_to_values( 250 values: t.Sequence[A], 251 fn: t.Callable[[A], R], 252 tasks_num: int, 253) -> t.List[R]: 254 """Applies a function to the given collection of values concurrently. 255 256 Args: 257 values: Target values. 258 fn: The function that will be applied concurrently to each value. 259 tasks_num: The number of concurrent tasks. 260 261 Returns: 262 A list of results. 263 """ 264 if tasks_num == 1: 265 return [fn(value) for value in values] 266 267 futures: t.List[Future] = [Future() for _ in values] 268 269 def _process_value(value: A, index: int) -> None: 270 try: 271 futures[index].set_result(fn(value)) 272 except Exception as ex: 273 futures[index].set_exception(ex) 274 275 with ThreadPoolExecutor(max_workers=tasks_num) as pool: 276 for index, value in enumerate(values): 277 pool.submit(_process_value, value, index) 278 279 return [f.result() for f in futures]
Applies a function to the given collection of values concurrently.
Arguments:
- values: Target values.
- fn: The function that will be applied concurrently to each value.
- tasks_num: The number of concurrent tasks.
Returns:
A list of results.