Edit on GitHub

sqlmesh.utils.concurrency

  1import typing as t
  2from concurrent.futures import Executor, Future, ThreadPoolExecutor
  3from threading import Lock
  4
  5from sqlmesh.core.snapshot import SnapshotId, SnapshotInfoLike
  6from sqlmesh.utils.dag import DAG
  7from sqlmesh.utils.errors import ConfigError, SQLMeshError
  8
  9H = t.TypeVar("H", bound=t.Hashable)
 10S = t.TypeVar("S", bound=SnapshotInfoLike)
 11A = t.TypeVar("A")
 12R = t.TypeVar("R")
 13
 14
 15class NodeExecutionFailedError(t.Generic[H], SQLMeshError):
 16    def __init__(self, node: H):
 17        self.node = node
 18        super().__init__(f"Execution failed for node {node}")
 19
 20
 21class ConcurrentDAGExecutor(t.Generic[H]):
 22    """Concurrently traverses the given DAG in topological order while applying a function to each node.
 23
 24    If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes.
 25
 26    Args:
 27        dag: The target DAG.
 28        fn: The function that will be applied concurrently to each snapshot.
 29        tasks_num: The number of concurrent tasks.
 30        raise_on_error: If set to True raises an exception on a first encountered error,
 31            otherwises returns a tuple which contains a list of failed nodes and a list of
 32            skipped nodes.
 33    """
 34
 35    def __init__(
 36        self,
 37        dag: DAG[H],
 38        fn: t.Callable[[H], None],
 39        tasks_num: int,
 40        raise_on_error: bool,
 41    ):
 42        self.dag = dag
 43        self.fn = fn
 44        self.tasks_num = tasks_num
 45        self.raise_on_error = raise_on_error
 46
 47        self._init_state()
 48
 49    def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
 50        """Runs the executor.
 51
 52        Raises:
 53            NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot.
 54
 55        Returns:
 56            A pair which contains a list of node errors and a list of skipped nodes.
 57        """
 58        if self._finished_future.done():
 59            self._init_state()
 60
 61        with ThreadPoolExecutor(max_workers=self.tasks_num) as pool:
 62            with self._unprocessed_nodes_lock:
 63                self._submit_next_nodes(pool)
 64            self._finished_future.result()
 65        return self._node_errors, self._skipped_nodes
 66
 67    def _process_node(self, node: H, executor: Executor) -> None:
 68        try:
 69            self.fn(node)
 70
 71            with self._unprocessed_nodes_lock:
 72                self._unprocessed_nodes_num -= 1
 73                self._submit_next_nodes(executor, node)
 74        except Exception as ex:
 75            error = NodeExecutionFailedError(node)
 76            error.__cause__ = ex
 77
 78            if self.raise_on_error:
 79                self._finished_future.set_exception(error)
 80                return
 81
 82            with self._unprocessed_nodes_lock:
 83                self._unprocessed_nodes_num -= 1
 84                self._node_errors.append(error)
 85                self._skip_next_nodes(node)
 86
 87    def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None:
 88        if not self._unprocessed_nodes_num:
 89            self._finished_future.set_result(None)
 90            return
 91
 92        submitted_nodes = []
 93        for next_node, deps in self._unprocessed_nodes.items():
 94            if processed_node:
 95                deps.discard(processed_node)
 96            if not deps:
 97                submitted_nodes.append(next_node)
 98
 99        for submitted_node in submitted_nodes:
100            self._unprocessed_nodes.pop(submitted_node)
101            executor.submit(self._process_node, submitted_node, executor)
102
103    def _skip_next_nodes(self, parent: H) -> None:
104        if not self._unprocessed_nodes_num:
105            self._finished_future.set_result(None)
106            return
107
108        skipped_nodes = {node for node, deps in self._unprocessed_nodes.items() if parent in deps}
109
110        while skipped_nodes:
111            self._skipped_nodes.extend(skipped_nodes)
112
113            for skipped_node in skipped_nodes:
114                self._unprocessed_nodes_num -= 1
115                self._unprocessed_nodes.pop(skipped_node)
116
117            skipped_nodes = {
118                node
119                for node, deps in self._unprocessed_nodes.items()
120                if skipped_nodes.intersection(deps)
121            }
122
123        if not self._unprocessed_nodes_num:
124            self._finished_future.set_result(None)
125
126    def _init_state(self) -> None:
127        self._unprocessed_nodes = self.dag.graph
128        self._unprocessed_nodes_num = len(self._unprocessed_nodes)
129        self._unprocessed_nodes_lock = Lock()
130        self._finished_future = Future()  # type: ignore
131
132        self._node_errors: t.List[NodeExecutionFailedError[H]] = []
133        self._skipped_nodes: t.List[H] = []
134
135
136def concurrent_apply_to_snapshots(
137    snapshots: t.Iterable[S],
138    fn: t.Callable[[S], None],
139    tasks_num: int,
140    reverse_order: bool = False,
141    raise_on_error: bool = True,
142) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]:
143    """Applies a function to the given collection of snapshots concurrently while
144    preserving the topological order between snapshots.
145
146    Args:
147        snapshots: Target snapshots.
148        fn: The function that will be applied concurrently to each snapshot.
149        tasks_num: The number of concurrent tasks.
150        reverse_order: Whether the order should be reversed. Default: False.
151        raise_on_error: If set to True raises an exception on a first encountered error,
152            otherwises returns a tuple which contains a list of failed nodes and a list of
153            skipped nodes.
154
155    Raises:
156        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
157
158    Returns:
159        A pair which contains a list of errors and a list of skipped snapshot IDs.
160    """
161    snapshots_by_id = {s.snapshot_id: s for s in snapshots}
162
163    dag: DAG[SnapshotId] = DAG[SnapshotId]()
164    for snapshot in snapshots:
165        dag.add(
166            snapshot.snapshot_id,
167            [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id],
168        )
169
170    return concurrent_apply_to_dag(
171        dag if not reverse_order else dag.reversed,
172        lambda s_id: fn(snapshots_by_id[s_id]),
173        tasks_num,
174        raise_on_error=raise_on_error,
175    )
176
177
178def concurrent_apply_to_dag(
179    dag: DAG[H],
180    fn: t.Callable[[H], None],
181    tasks_num: int,
182    raise_on_error: bool = True,
183) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
184    """Applies a function to the given DAG concurrently while preserving the topological
185    order between snapshots.
186
187    Args:
188        dag: The target DAG.
189        fn: The function that will be applied concurrently to each snapshot.
190        tasks_num: The number of concurrent tasks.
191        raise_on_error: If set to True raises an exception on a first encountered error,
192            otherwises returns a tuple which contains a list of failed nodes and a list of
193            skipped nodes.
194
195    Raises:
196        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
197
198    Returns:
199        A pair which contains a list of node errors and a list of skipped nodes.
200    """
201    if tasks_num <= 0:
202        raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}")
203
204    if tasks_num == 1:
205        return sequential_apply_to_dag(dag, fn, raise_on_error)
206
207    return ConcurrentDAGExecutor(
208        dag,
209        fn,
210        tasks_num,
211        raise_on_error,
212    ).run()
213
214
215def sequential_apply_to_dag(
216    dag: DAG[H],
217    fn: t.Callable[[H], None],
218    raise_on_error: bool = True,
219) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
220    dependencies = dag.graph
221
222    node_errors: t.List[NodeExecutionFailedError[H]] = []
223    skipped_nodes: t.List[H] = []
224
225    failed_or_skipped_nodes: t.Set[H] = set()
226
227    for node in dag.sorted:
228        if not failed_or_skipped_nodes.isdisjoint(dependencies[node]):
229            skipped_nodes.append(node)
230            failed_or_skipped_nodes.add(node)
231            continue
232
233        try:
234            fn(node)
235        except Exception as ex:
236            error = NodeExecutionFailedError(node)
237            error.__cause__ = ex
238
239            if raise_on_error:
240                raise error
241
242            node_errors.append(error)
243            failed_or_skipped_nodes.add(node)
244
245    return node_errors, skipped_nodes
246
247
248def concurrent_apply_to_values(
249    values: t.Sequence[A],
250    fn: t.Callable[[A], R],
251    tasks_num: int,
252) -> t.List[R]:
253    """Applies a function to the given collection of values concurrently.
254
255    Args:
256        values: Target values.
257        fn: The function that will be applied concurrently to each value.
258        tasks_num: The number of concurrent tasks.
259
260    Returns:
261        A list of results.
262    """
263    if tasks_num == 1:
264        return [fn(value) for value in values]
265
266    futures: t.List[Future] = [Future() for _ in values]
267
268    def _process_value(value: A, index: int) -> None:
269        try:
270            futures[index].set_result(fn(value))
271        except Exception as ex:
272            futures[index].set_exception(ex)
273
274    with ThreadPoolExecutor(max_workers=tasks_num) as pool:
275        for index, value in enumerate(values):
276            pool.submit(_process_value, value, index)
277
278    return [f.result() for f in futures]
class NodeExecutionFailedError(typing.Generic[~H], sqlmesh.utils.errors.SQLMeshError):
16class NodeExecutionFailedError(t.Generic[H], SQLMeshError):
17    def __init__(self, node: H):
18        self.node = node
19        super().__init__(f"Execution failed for node {node}")

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

NodeExecutionFailedError(node: ~H)
17    def __init__(self, node: H):
18        self.node = node
19        super().__init__(f"Execution failed for node {node}")
node
Inherited Members
builtins.BaseException
with_traceback
args
class ConcurrentDAGExecutor(typing.Generic[~H]):
 22class ConcurrentDAGExecutor(t.Generic[H]):
 23    """Concurrently traverses the given DAG in topological order while applying a function to each node.
 24
 25    If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes.
 26
 27    Args:
 28        dag: The target DAG.
 29        fn: The function that will be applied concurrently to each snapshot.
 30        tasks_num: The number of concurrent tasks.
 31        raise_on_error: If set to True raises an exception on a first encountered error,
 32            otherwises returns a tuple which contains a list of failed nodes and a list of
 33            skipped nodes.
 34    """
 35
 36    def __init__(
 37        self,
 38        dag: DAG[H],
 39        fn: t.Callable[[H], None],
 40        tasks_num: int,
 41        raise_on_error: bool,
 42    ):
 43        self.dag = dag
 44        self.fn = fn
 45        self.tasks_num = tasks_num
 46        self.raise_on_error = raise_on_error
 47
 48        self._init_state()
 49
 50    def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
 51        """Runs the executor.
 52
 53        Raises:
 54            NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot.
 55
 56        Returns:
 57            A pair which contains a list of node errors and a list of skipped nodes.
 58        """
 59        if self._finished_future.done():
 60            self._init_state()
 61
 62        with ThreadPoolExecutor(max_workers=self.tasks_num) as pool:
 63            with self._unprocessed_nodes_lock:
 64                self._submit_next_nodes(pool)
 65            self._finished_future.result()
 66        return self._node_errors, self._skipped_nodes
 67
 68    def _process_node(self, node: H, executor: Executor) -> None:
 69        try:
 70            self.fn(node)
 71
 72            with self._unprocessed_nodes_lock:
 73                self._unprocessed_nodes_num -= 1
 74                self._submit_next_nodes(executor, node)
 75        except Exception as ex:
 76            error = NodeExecutionFailedError(node)
 77            error.__cause__ = ex
 78
 79            if self.raise_on_error:
 80                self._finished_future.set_exception(error)
 81                return
 82
 83            with self._unprocessed_nodes_lock:
 84                self._unprocessed_nodes_num -= 1
 85                self._node_errors.append(error)
 86                self._skip_next_nodes(node)
 87
 88    def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None:
 89        if not self._unprocessed_nodes_num:
 90            self._finished_future.set_result(None)
 91            return
 92
 93        submitted_nodes = []
 94        for next_node, deps in self._unprocessed_nodes.items():
 95            if processed_node:
 96                deps.discard(processed_node)
 97            if not deps:
 98                submitted_nodes.append(next_node)
 99
100        for submitted_node in submitted_nodes:
101            self._unprocessed_nodes.pop(submitted_node)
102            executor.submit(self._process_node, submitted_node, executor)
103
104    def _skip_next_nodes(self, parent: H) -> None:
105        if not self._unprocessed_nodes_num:
106            self._finished_future.set_result(None)
107            return
108
109        skipped_nodes = {node for node, deps in self._unprocessed_nodes.items() if parent in deps}
110
111        while skipped_nodes:
112            self._skipped_nodes.extend(skipped_nodes)
113
114            for skipped_node in skipped_nodes:
115                self._unprocessed_nodes_num -= 1
116                self._unprocessed_nodes.pop(skipped_node)
117
118            skipped_nodes = {
119                node
120                for node, deps in self._unprocessed_nodes.items()
121                if skipped_nodes.intersection(deps)
122            }
123
124        if not self._unprocessed_nodes_num:
125            self._finished_future.set_result(None)
126
127    def _init_state(self) -> None:
128        self._unprocessed_nodes = self.dag.graph
129        self._unprocessed_nodes_num = len(self._unprocessed_nodes)
130        self._unprocessed_nodes_lock = Lock()
131        self._finished_future = Future()  # type: ignore
132
133        self._node_errors: t.List[NodeExecutionFailedError[H]] = []
134        self._skipped_nodes: t.List[H] = []

Concurrently traverses the given DAG in topological order while applying a function to each node.

If raise_on_error is set to False maintains a state of execution errors as well as of skipped nodes.

Arguments:
  • dag: The target DAG.
  • fn: The function that will be applied concurrently to each snapshot.
  • tasks_num: The number of concurrent tasks.
  • raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
ConcurrentDAGExecutor( dag: sqlmesh.utils.dag.DAG[~H], fn: Callable[[~H], NoneType], tasks_num: int, raise_on_error: bool)
36    def __init__(
37        self,
38        dag: DAG[H],
39        fn: t.Callable[[H], None],
40        tasks_num: int,
41        raise_on_error: bool,
42    ):
43        self.dag = dag
44        self.fn = fn
45        self.tasks_num = tasks_num
46        self.raise_on_error = raise_on_error
47
48        self._init_state()
dag
fn
tasks_num
raise_on_error
def run( self) -> Tuple[List[NodeExecutionFailedError[~H]], List[~H]]:
50    def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
51        """Runs the executor.
52
53        Raises:
54            NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot.
55
56        Returns:
57            A pair which contains a list of node errors and a list of skipped nodes.
58        """
59        if self._finished_future.done():
60            self._init_state()
61
62        with ThreadPoolExecutor(max_workers=self.tasks_num) as pool:
63            with self._unprocessed_nodes_lock:
64                self._submit_next_nodes(pool)
65            self._finished_future.result()
66        return self._node_errors, self._skipped_nodes

Runs the executor.

Raises:
  • NodeExecutionFailedError if raise_on_error was set to True and execution fails for any snapshot.
Returns:

A pair which contains a list of node errors and a list of skipped nodes.

def concurrent_apply_to_snapshots( snapshots: Iterable[~S], fn: Callable[[~S], NoneType], tasks_num: int, reverse_order: bool = False, raise_on_error: bool = True) -> Tuple[List[NodeExecutionFailedError[sqlmesh.core.snapshot.definition.SnapshotId]], List[sqlmesh.core.snapshot.definition.SnapshotId]]:
137def concurrent_apply_to_snapshots(
138    snapshots: t.Iterable[S],
139    fn: t.Callable[[S], None],
140    tasks_num: int,
141    reverse_order: bool = False,
142    raise_on_error: bool = True,
143) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]:
144    """Applies a function to the given collection of snapshots concurrently while
145    preserving the topological order between snapshots.
146
147    Args:
148        snapshots: Target snapshots.
149        fn: The function that will be applied concurrently to each snapshot.
150        tasks_num: The number of concurrent tasks.
151        reverse_order: Whether the order should be reversed. Default: False.
152        raise_on_error: If set to True raises an exception on a first encountered error,
153            otherwises returns a tuple which contains a list of failed nodes and a list of
154            skipped nodes.
155
156    Raises:
157        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
158
159    Returns:
160        A pair which contains a list of errors and a list of skipped snapshot IDs.
161    """
162    snapshots_by_id = {s.snapshot_id: s for s in snapshots}
163
164    dag: DAG[SnapshotId] = DAG[SnapshotId]()
165    for snapshot in snapshots:
166        dag.add(
167            snapshot.snapshot_id,
168            [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id],
169        )
170
171    return concurrent_apply_to_dag(
172        dag if not reverse_order else dag.reversed,
173        lambda s_id: fn(snapshots_by_id[s_id]),
174        tasks_num,
175        raise_on_error=raise_on_error,
176    )

Applies a function to the given collection of snapshots concurrently while preserving the topological order between snapshots.

Arguments:
  • snapshots: Target snapshots.
  • fn: The function that will be applied concurrently to each snapshot.
  • tasks_num: The number of concurrent tasks.
  • reverse_order: Whether the order should be reversed. Default: False.
  • raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
  • NodeExecutionFailedError if raise_on_error is set to True and execution fails for any snapshot.
Returns:

A pair which contains a list of errors and a list of skipped snapshot IDs.

def concurrent_apply_to_dag( dag: sqlmesh.utils.dag.DAG[~H], fn: Callable[[~H], NoneType], tasks_num: int, raise_on_error: bool = True) -> Tuple[List[NodeExecutionFailedError[~H]], List[~H]]:
179def concurrent_apply_to_dag(
180    dag: DAG[H],
181    fn: t.Callable[[H], None],
182    tasks_num: int,
183    raise_on_error: bool = True,
184) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
185    """Applies a function to the given DAG concurrently while preserving the topological
186    order between snapshots.
187
188    Args:
189        dag: The target DAG.
190        fn: The function that will be applied concurrently to each snapshot.
191        tasks_num: The number of concurrent tasks.
192        raise_on_error: If set to True raises an exception on a first encountered error,
193            otherwises returns a tuple which contains a list of failed nodes and a list of
194            skipped nodes.
195
196    Raises:
197        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
198
199    Returns:
200        A pair which contains a list of node errors and a list of skipped nodes.
201    """
202    if tasks_num <= 0:
203        raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}")
204
205    if tasks_num == 1:
206        return sequential_apply_to_dag(dag, fn, raise_on_error)
207
208    return ConcurrentDAGExecutor(
209        dag,
210        fn,
211        tasks_num,
212        raise_on_error,
213    ).run()

Applies a function to the given DAG concurrently while preserving the topological order between snapshots.

Arguments:
  • dag: The target DAG.
  • fn: The function that will be applied concurrently to each snapshot.
  • tasks_num: The number of concurrent tasks.
  • raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
  • NodeExecutionFailedError if raise_on_error is set to True and execution fails for any snapshot.
Returns:

A pair which contains a list of node errors and a list of skipped nodes.

def sequential_apply_to_dag( dag: sqlmesh.utils.dag.DAG[~H], fn: Callable[[~H], NoneType], raise_on_error: bool = True) -> Tuple[List[NodeExecutionFailedError[~H]], List[~H]]:
216def sequential_apply_to_dag(
217    dag: DAG[H],
218    fn: t.Callable[[H], None],
219    raise_on_error: bool = True,
220) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
221    dependencies = dag.graph
222
223    node_errors: t.List[NodeExecutionFailedError[H]] = []
224    skipped_nodes: t.List[H] = []
225
226    failed_or_skipped_nodes: t.Set[H] = set()
227
228    for node in dag.sorted:
229        if not failed_or_skipped_nodes.isdisjoint(dependencies[node]):
230            skipped_nodes.append(node)
231            failed_or_skipped_nodes.add(node)
232            continue
233
234        try:
235            fn(node)
236        except Exception as ex:
237            error = NodeExecutionFailedError(node)
238            error.__cause__ = ex
239
240            if raise_on_error:
241                raise error
242
243            node_errors.append(error)
244            failed_or_skipped_nodes.add(node)
245
246    return node_errors, skipped_nodes
def concurrent_apply_to_values(values: Sequence[~A], fn: Callable[[~A], ~R], tasks_num: int) -> List[~R]:
249def concurrent_apply_to_values(
250    values: t.Sequence[A],
251    fn: t.Callable[[A], R],
252    tasks_num: int,
253) -> t.List[R]:
254    """Applies a function to the given collection of values concurrently.
255
256    Args:
257        values: Target values.
258        fn: The function that will be applied concurrently to each value.
259        tasks_num: The number of concurrent tasks.
260
261    Returns:
262        A list of results.
263    """
264    if tasks_num == 1:
265        return [fn(value) for value in values]
266
267    futures: t.List[Future] = [Future() for _ in values]
268
269    def _process_value(value: A, index: int) -> None:
270        try:
271            futures[index].set_result(fn(value))
272        except Exception as ex:
273            futures[index].set_exception(ex)
274
275    with ThreadPoolExecutor(max_workers=tasks_num) as pool:
276        for index, value in enumerate(values):
277            pool.submit(_process_value, value, index)
278
279    return [f.result() for f in futures]

Applies a function to the given collection of values concurrently.

Arguments:
  • values: Target values.
  • fn: The function that will be applied concurrently to each value.
  • tasks_num: The number of concurrent tasks.
Returns:

A list of results.