Edit on GitHub

sqlmesh.utils.concurrency

  1import typing as t
  2from concurrent.futures import Executor, Future, ThreadPoolExecutor
  3from threading import Lock
  4
  5from sqlmesh.core.snapshot import SnapshotId, SnapshotInfoLike
  6from sqlmesh.utils.dag import DAG
  7from sqlmesh.utils.errors import ConfigError, SQLMeshError
  8
  9H = t.TypeVar("H", bound=t.Hashable)
 10S = t.TypeVar("S", bound=SnapshotInfoLike)
 11A = t.TypeVar("A")
 12R = t.TypeVar("R")
 13
 14
 15class NodeExecutionFailedError(t.Generic[H], SQLMeshError):
 16    def __init__(self, node: H):
 17        self.node = node
 18        super().__init__(f"Execution failed for node {node}")
 19
 20
 21class ConcurrentDAGExecutor(t.Generic[H]):
 22    """Concurrently traverses the given DAG in topological order while applying a function to each node.
 23
 24    If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes.
 25
 26    Args:
 27        dag: The target DAG.
 28        fn: The function that will be applied concurrently to each snapshot.
 29        tasks_num: The number of concurrent tasks.
 30        raise_on_error: If set to True raises an exception on a first encountered error,
 31            otherwises returns a tuple which contains a list of failed nodes and a list of
 32            skipped nodes.
 33    """
 34
 35    def __init__(
 36        self,
 37        dag: DAG[H],
 38        fn: t.Callable[[H], None],
 39        tasks_num: int,
 40        raise_on_error: bool,
 41    ):
 42        self.dag = dag
 43        self.fn = fn
 44        self.tasks_num = tasks_num
 45        self.raise_on_error = raise_on_error
 46
 47        self._init_state()
 48
 49    def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
 50        """Runs the executor.
 51
 52        Raises:
 53            NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot.
 54
 55        Returns:
 56            A pair which contains a list of node errors and a list of skipped nodes.
 57        """
 58        if self._finished_future.done():
 59            self._init_state()
 60
 61        with ThreadPoolExecutor(max_workers=self.tasks_num) as pool:
 62            with self._unprocessed_nodes_lock:
 63                self._submit_next_nodes(pool)
 64            self._finished_future.result()
 65        return self._node_errors, self._skipped_nodes
 66
 67    def _process_node(self, node: H, executor: Executor) -> None:
 68        try:
 69            self.fn(node)
 70
 71            with self._unprocessed_nodes_lock:
 72                self._unprocessed_nodes_num -= 1
 73                self._submit_next_nodes(executor, node)
 74        except Exception as ex:
 75            error = NodeExecutionFailedError(node)
 76            error.__cause__ = ex
 77
 78            if self.raise_on_error:
 79                self._finished_future.set_exception(error)
 80                return
 81
 82            with self._unprocessed_nodes_lock:
 83                self._unprocessed_nodes_num -= 1
 84                self._node_errors.append(error)
 85                self._skip_next_nodes(node)
 86
 87    def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None:
 88        if not self._unprocessed_nodes_num:
 89            self._finished_future.set_result(None)
 90            return
 91
 92        submitted_nodes = []
 93        for next_node, deps in self._unprocessed_nodes.items():
 94            if processed_node:
 95                deps.discard(processed_node)
 96            if not deps:
 97                submitted_nodes.append(next_node)
 98
 99        for submitted_node in submitted_nodes:
100            self._unprocessed_nodes.pop(submitted_node)
101            executor.submit(self._process_node, submitted_node, executor)
102
103    def _skip_next_nodes(self, parent: H) -> None:
104        if not self._unprocessed_nodes_num:
105            self._finished_future.set_result(None)
106            return
107
108        skipped_nodes = [node for node, deps in self._unprocessed_nodes.items() if parent in deps]
109
110        self._skipped_nodes.extend(skipped_nodes)
111
112        for skipped_node in skipped_nodes:
113            self._unprocessed_nodes_num -= 1
114            self._unprocessed_nodes.pop(skipped_node)
115
116        for skipped_node in skipped_nodes:
117            self._skip_next_nodes(skipped_node)
118
119    def _init_state(self) -> None:
120        self._unprocessed_nodes = self.dag.graph
121        self._unprocessed_nodes_num = len(self._unprocessed_nodes)
122        self._unprocessed_nodes_lock = Lock()
123        self._finished_future = Future()  # type: ignore
124
125        self._node_errors: t.List[NodeExecutionFailedError[H]] = []
126        self._skipped_nodes: t.List[H] = []
127
128
129def concurrent_apply_to_snapshots(
130    snapshots: t.Iterable[S],
131    fn: t.Callable[[S], None],
132    tasks_num: int,
133    reverse_order: bool = False,
134    raise_on_error: bool = True,
135) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]:
136    """Applies a function to the given collection of snapshots concurrently while
137    preserving the topological order between snapshots.
138
139    Args:
140        snapshots: Target snapshots.
141        fn: The function that will be applied concurrently to each snapshot.
142        tasks_num: The number of concurrent tasks.
143        reverse_order: Whether the order should be reversed. Default: False.
144        raise_on_error: If set to True raises an exception on a first encountered error,
145            otherwises returns a tuple which contains a list of failed nodes and a list of
146            skipped nodes.
147
148    Raises:
149        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
150
151    Returns:
152        A pair which contains a list of errors and a list of skipped snapshot IDs.
153    """
154    snapshots_by_id = {s.snapshot_id: s for s in snapshots}
155
156    dag: DAG[SnapshotId] = DAG[SnapshotId]()
157    for snapshot in snapshots:
158        dag.add(
159            snapshot.snapshot_id,
160            [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id],
161        )
162
163    return concurrent_apply_to_dag(
164        dag if not reverse_order else dag.reversed,
165        lambda s_id: fn(snapshots_by_id[s_id]),
166        tasks_num,
167        raise_on_error=raise_on_error,
168    )
169
170
171def concurrent_apply_to_dag(
172    dag: DAG[H],
173    fn: t.Callable[[H], None],
174    tasks_num: int,
175    raise_on_error: bool = True,
176) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
177    """Applies a function to the given DAG concurrently while preserving the topological
178    order between snapshots.
179
180    Args:
181        dag: The target DAG.
182        fn: The function that will be applied concurrently to each snapshot.
183        tasks_num: The number of concurrent tasks.
184        raise_on_error: If set to True raises an exception on a first encountered error,
185            otherwises returns a tuple which contains a list of failed nodes and a list of
186            skipped nodes.
187
188    Raises:
189        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
190
191    Returns:
192        A pair which contains a list of node errors and a list of skipped nodes.
193    """
194    if tasks_num <= 0:
195        raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}")
196
197    if tasks_num == 1:
198        return sequential_apply_to_dag(dag, fn, raise_on_error)
199
200    return ConcurrentDAGExecutor(
201        dag,
202        fn,
203        tasks_num,
204        raise_on_error,
205    ).run()
206
207
208def sequential_apply_to_dag(
209    dag: DAG[H],
210    fn: t.Callable[[H], None],
211    raise_on_error: bool = True,
212) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
213    dependencies = dag.graph
214
215    node_errors: t.List[NodeExecutionFailedError[H]] = []
216    skipped_nodes: t.List[H] = []
217
218    failed_or_skipped_nodes: t.Set[H] = set()
219
220    for node in dag.sorted:
221        if not failed_or_skipped_nodes.isdisjoint(dependencies[node]):
222            skipped_nodes.append(node)
223            failed_or_skipped_nodes.add(node)
224            continue
225
226        try:
227            fn(node)
228        except Exception as ex:
229            if raise_on_error:
230                raise NodeExecutionFailedError(node) from ex
231
232            error = NodeExecutionFailedError(node)
233            error.__cause__ = ex
234
235            node_errors.append(error)
236            failed_or_skipped_nodes.add(node)
237
238    return node_errors, skipped_nodes
239
240
241def concurrent_apply_to_values(
242    values: t.Sequence[A],
243    fn: t.Callable[[A], R],
244    tasks_num: int,
245) -> t.List[R]:
246    """Applies a function to the given collection of values concurrently.
247
248    Args:
249        values: Target values.
250        fn: The function that will be applied concurrently to each value.
251        tasks_num: The number of concurrent tasks.
252
253    Returns:
254        A list of results.
255    """
256    if tasks_num == 1:
257        return [fn(value) for value in values]
258
259    futures: t.List[Future] = [Future() for _ in values]
260
261    def _process_value(value: A, index: int) -> None:
262        try:
263            futures[index].set_result(fn(value))
264        except Exception as ex:
265            futures[index].set_exception(ex)
266
267    with ThreadPoolExecutor(max_workers=tasks_num) as pool:
268        for index, value in enumerate(values):
269            pool.submit(_process_value, value, index)
270
271    return [f.result() for f in futures]
class NodeExecutionFailedError(typing.Generic[~H], sqlmesh.utils.errors.SQLMeshError):
16class NodeExecutionFailedError(t.Generic[H], SQLMeshError):
17    def __init__(self, node: H):
18        self.node = node
19        super().__init__(f"Execution failed for node {node}")

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

NodeExecutionFailedError(node: ~H)
17    def __init__(self, node: H):
18        self.node = node
19        super().__init__(f"Execution failed for node {node}")
Inherited Members
builtins.BaseException
with_traceback
class ConcurrentDAGExecutor(typing.Generic[~H]):
 22class ConcurrentDAGExecutor(t.Generic[H]):
 23    """Concurrently traverses the given DAG in topological order while applying a function to each node.
 24
 25    If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes.
 26
 27    Args:
 28        dag: The target DAG.
 29        fn: The function that will be applied concurrently to each snapshot.
 30        tasks_num: The number of concurrent tasks.
 31        raise_on_error: If set to True raises an exception on a first encountered error,
 32            otherwises returns a tuple which contains a list of failed nodes and a list of
 33            skipped nodes.
 34    """
 35
 36    def __init__(
 37        self,
 38        dag: DAG[H],
 39        fn: t.Callable[[H], None],
 40        tasks_num: int,
 41        raise_on_error: bool,
 42    ):
 43        self.dag = dag
 44        self.fn = fn
 45        self.tasks_num = tasks_num
 46        self.raise_on_error = raise_on_error
 47
 48        self._init_state()
 49
 50    def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
 51        """Runs the executor.
 52
 53        Raises:
 54            NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot.
 55
 56        Returns:
 57            A pair which contains a list of node errors and a list of skipped nodes.
 58        """
 59        if self._finished_future.done():
 60            self._init_state()
 61
 62        with ThreadPoolExecutor(max_workers=self.tasks_num) as pool:
 63            with self._unprocessed_nodes_lock:
 64                self._submit_next_nodes(pool)
 65            self._finished_future.result()
 66        return self._node_errors, self._skipped_nodes
 67
 68    def _process_node(self, node: H, executor: Executor) -> None:
 69        try:
 70            self.fn(node)
 71
 72            with self._unprocessed_nodes_lock:
 73                self._unprocessed_nodes_num -= 1
 74                self._submit_next_nodes(executor, node)
 75        except Exception as ex:
 76            error = NodeExecutionFailedError(node)
 77            error.__cause__ = ex
 78
 79            if self.raise_on_error:
 80                self._finished_future.set_exception(error)
 81                return
 82
 83            with self._unprocessed_nodes_lock:
 84                self._unprocessed_nodes_num -= 1
 85                self._node_errors.append(error)
 86                self._skip_next_nodes(node)
 87
 88    def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None:
 89        if not self._unprocessed_nodes_num:
 90            self._finished_future.set_result(None)
 91            return
 92
 93        submitted_nodes = []
 94        for next_node, deps in self._unprocessed_nodes.items():
 95            if processed_node:
 96                deps.discard(processed_node)
 97            if not deps:
 98                submitted_nodes.append(next_node)
 99
100        for submitted_node in submitted_nodes:
101            self._unprocessed_nodes.pop(submitted_node)
102            executor.submit(self._process_node, submitted_node, executor)
103
104    def _skip_next_nodes(self, parent: H) -> None:
105        if not self._unprocessed_nodes_num:
106            self._finished_future.set_result(None)
107            return
108
109        skipped_nodes = [node for node, deps in self._unprocessed_nodes.items() if parent in deps]
110
111        self._skipped_nodes.extend(skipped_nodes)
112
113        for skipped_node in skipped_nodes:
114            self._unprocessed_nodes_num -= 1
115            self._unprocessed_nodes.pop(skipped_node)
116
117        for skipped_node in skipped_nodes:
118            self._skip_next_nodes(skipped_node)
119
120    def _init_state(self) -> None:
121        self._unprocessed_nodes = self.dag.graph
122        self._unprocessed_nodes_num = len(self._unprocessed_nodes)
123        self._unprocessed_nodes_lock = Lock()
124        self._finished_future = Future()  # type: ignore
125
126        self._node_errors: t.List[NodeExecutionFailedError[H]] = []
127        self._skipped_nodes: t.List[H] = []

Concurrently traverses the given DAG in topological order while applying a function to each node.

If raise_on_error is set to False maintains a state of execution errors as well as of skipped nodes.

Arguments:
  • dag: The target DAG.
  • fn: The function that will be applied concurrently to each snapshot.
  • tasks_num: The number of concurrent tasks.
  • raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
ConcurrentDAGExecutor( dag: sqlmesh.utils.dag.DAG[~H], fn: Callable[[~H], NoneType], tasks_num: int, raise_on_error: bool)
36    def __init__(
37        self,
38        dag: DAG[H],
39        fn: t.Callable[[H], None],
40        tasks_num: int,
41        raise_on_error: bool,
42    ):
43        self.dag = dag
44        self.fn = fn
45        self.tasks_num = tasks_num
46        self.raise_on_error = raise_on_error
47
48        self._init_state()
def run( self) -> Tuple[List[sqlmesh.utils.concurrency.NodeExecutionFailedError[~H]], List[~H]]:
50    def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
51        """Runs the executor.
52
53        Raises:
54            NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot.
55
56        Returns:
57            A pair which contains a list of node errors and a list of skipped nodes.
58        """
59        if self._finished_future.done():
60            self._init_state()
61
62        with ThreadPoolExecutor(max_workers=self.tasks_num) as pool:
63            with self._unprocessed_nodes_lock:
64                self._submit_next_nodes(pool)
65            self._finished_future.result()
66        return self._node_errors, self._skipped_nodes

Runs the executor.

Raises:
  • NodeExecutionFailedError if raise_on_error was set to True and execution fails for any snapshot.
Returns:

A pair which contains a list of node errors and a list of skipped nodes.

def concurrent_apply_to_snapshots( snapshots: Iterable[~S], fn: Callable[[~S], NoneType], tasks_num: int, reverse_order: bool = False, raise_on_error: bool = True) -> Tuple[List[sqlmesh.utils.concurrency.NodeExecutionFailedError[sqlmesh.core.snapshot.definition.SnapshotId]], List[sqlmesh.core.snapshot.definition.SnapshotId]]:
130def concurrent_apply_to_snapshots(
131    snapshots: t.Iterable[S],
132    fn: t.Callable[[S], None],
133    tasks_num: int,
134    reverse_order: bool = False,
135    raise_on_error: bool = True,
136) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]:
137    """Applies a function to the given collection of snapshots concurrently while
138    preserving the topological order between snapshots.
139
140    Args:
141        snapshots: Target snapshots.
142        fn: The function that will be applied concurrently to each snapshot.
143        tasks_num: The number of concurrent tasks.
144        reverse_order: Whether the order should be reversed. Default: False.
145        raise_on_error: If set to True raises an exception on a first encountered error,
146            otherwises returns a tuple which contains a list of failed nodes and a list of
147            skipped nodes.
148
149    Raises:
150        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
151
152    Returns:
153        A pair which contains a list of errors and a list of skipped snapshot IDs.
154    """
155    snapshots_by_id = {s.snapshot_id: s for s in snapshots}
156
157    dag: DAG[SnapshotId] = DAG[SnapshotId]()
158    for snapshot in snapshots:
159        dag.add(
160            snapshot.snapshot_id,
161            [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id],
162        )
163
164    return concurrent_apply_to_dag(
165        dag if not reverse_order else dag.reversed,
166        lambda s_id: fn(snapshots_by_id[s_id]),
167        tasks_num,
168        raise_on_error=raise_on_error,
169    )

Applies a function to the given collection of snapshots concurrently while preserving the topological order between snapshots.

Arguments:
  • snapshots: Target snapshots.
  • fn: The function that will be applied concurrently to each snapshot.
  • tasks_num: The number of concurrent tasks.
  • reverse_order: Whether the order should be reversed. Default: False.
  • raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
  • NodeExecutionFailedError if raise_on_error is set to True and execution fails for any snapshot.
Returns:

A pair which contains a list of errors and a list of skipped snapshot IDs.

def concurrent_apply_to_dag( dag: sqlmesh.utils.dag.DAG[~H], fn: Callable[[~H], NoneType], tasks_num: int, raise_on_error: bool = True) -> Tuple[List[sqlmesh.utils.concurrency.NodeExecutionFailedError[~H]], List[~H]]:
172def concurrent_apply_to_dag(
173    dag: DAG[H],
174    fn: t.Callable[[H], None],
175    tasks_num: int,
176    raise_on_error: bool = True,
177) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
178    """Applies a function to the given DAG concurrently while preserving the topological
179    order between snapshots.
180
181    Args:
182        dag: The target DAG.
183        fn: The function that will be applied concurrently to each snapshot.
184        tasks_num: The number of concurrent tasks.
185        raise_on_error: If set to True raises an exception on a first encountered error,
186            otherwises returns a tuple which contains a list of failed nodes and a list of
187            skipped nodes.
188
189    Raises:
190        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
191
192    Returns:
193        A pair which contains a list of node errors and a list of skipped nodes.
194    """
195    if tasks_num <= 0:
196        raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}")
197
198    if tasks_num == 1:
199        return sequential_apply_to_dag(dag, fn, raise_on_error)
200
201    return ConcurrentDAGExecutor(
202        dag,
203        fn,
204        tasks_num,
205        raise_on_error,
206    ).run()

Applies a function to the given DAG concurrently while preserving the topological order between snapshots.

Arguments:
  • dag: The target DAG.
  • fn: The function that will be applied concurrently to each snapshot.
  • tasks_num: The number of concurrent tasks.
  • raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
  • NodeExecutionFailedError if raise_on_error is set to True and execution fails for any snapshot.
Returns:

A pair which contains a list of node errors and a list of skipped nodes.

def sequential_apply_to_dag( dag: sqlmesh.utils.dag.DAG[~H], fn: Callable[[~H], NoneType], raise_on_error: bool = True) -> Tuple[List[sqlmesh.utils.concurrency.NodeExecutionFailedError[~H]], List[~H]]:
209def sequential_apply_to_dag(
210    dag: DAG[H],
211    fn: t.Callable[[H], None],
212    raise_on_error: bool = True,
213) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
214    dependencies = dag.graph
215
216    node_errors: t.List[NodeExecutionFailedError[H]] = []
217    skipped_nodes: t.List[H] = []
218
219    failed_or_skipped_nodes: t.Set[H] = set()
220
221    for node in dag.sorted:
222        if not failed_or_skipped_nodes.isdisjoint(dependencies[node]):
223            skipped_nodes.append(node)
224            failed_or_skipped_nodes.add(node)
225            continue
226
227        try:
228            fn(node)
229        except Exception as ex:
230            if raise_on_error:
231                raise NodeExecutionFailedError(node) from ex
232
233            error = NodeExecutionFailedError(node)
234            error.__cause__ = ex
235
236            node_errors.append(error)
237            failed_or_skipped_nodes.add(node)
238
239    return node_errors, skipped_nodes
def concurrent_apply_to_values(values: Sequence[~A], fn: Callable[[~A], ~R], tasks_num: int) -> List[~R]:
242def concurrent_apply_to_values(
243    values: t.Sequence[A],
244    fn: t.Callable[[A], R],
245    tasks_num: int,
246) -> t.List[R]:
247    """Applies a function to the given collection of values concurrently.
248
249    Args:
250        values: Target values.
251        fn: The function that will be applied concurrently to each value.
252        tasks_num: The number of concurrent tasks.
253
254    Returns:
255        A list of results.
256    """
257    if tasks_num == 1:
258        return [fn(value) for value in values]
259
260    futures: t.List[Future] = [Future() for _ in values]
261
262    def _process_value(value: A, index: int) -> None:
263        try:
264            futures[index].set_result(fn(value))
265        except Exception as ex:
266            futures[index].set_exception(ex)
267
268    with ThreadPoolExecutor(max_workers=tasks_num) as pool:
269        for index, value in enumerate(values):
270            pool.submit(_process_value, value, index)
271
272    return [f.result() for f in futures]

Applies a function to the given collection of values concurrently.

Arguments:
  • values: Target values.
  • fn: The function that will be applied concurrently to each value.
  • tasks_num: The number of concurrent tasks.
Returns:

A list of results.