Edit on GitHub

DAG

A DAG, or directed acyclic graph, is a graph where the edges are directional and there are no cycles with all the edges pointing in the same direction. SQLMesh uses a DAG to keep track of a project's models. This allows SQLMesh to easily determine a model's lineage and to identify upstream and downstream dependencies.

  1"""
  2# DAG
  3
  4A DAG, or directed acyclic graph, is a graph where the edges are directional and there are no cycles with
  5all the edges pointing in the same direction. SQLMesh uses a DAG to keep track of a project's models. This
  6allows SQLMesh to easily determine a model's lineage and to identify upstream and downstream dependencies.
  7"""
  8
  9from __future__ import annotations
 10
 11import typing as t
 12
 13from sqlmesh.utils.errors import SQLMeshError
 14
 15T = t.TypeVar("T", bound=t.Hashable)
 16
 17
 18class DAG(t.Generic[T]):
 19    def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None):
 20        self._dag: t.Dict[T, t.Set[T]] = {}
 21        self._sorted: t.Optional[t.List[T]] = None
 22
 23        for node, dependencies in (graph or {}).items():
 24            self.add(node, dependencies)
 25
 26    def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None:
 27        """Add a node to the graph with an optional upstream dependency.
 28
 29        Args:
 30            node: The node to add.
 31            dependencies: Optional dependencies to add to the node.
 32        """
 33        self._sorted = None
 34        if node not in self._dag:
 35            self._dag[node] = set()
 36        if dependencies:
 37            self._dag[node].update(dependencies)
 38            for d in dependencies:
 39                self.add(d)
 40
 41    @property
 42    def reversed(self) -> DAG[T]:
 43        """Returns a copy of this DAG with all its edges reversed."""
 44        result = DAG[T]()
 45
 46        for node, deps in self._dag.items():
 47            result.add(node)
 48            for dep in deps:
 49                result.add(dep, [node])
 50
 51        return result
 52
 53    def subdag(self, *nodes: T) -> DAG[T]:
 54        """Create a new subdag given node(s).
 55
 56        Args:
 57            nodes: The nodes of the new subdag.
 58
 59        Returns:
 60            A new dag consisting of the specified nodes and upstream.
 61        """
 62        queue = set(nodes)
 63        graph = {}
 64
 65        while queue:
 66            node = queue.pop()
 67            deps = self._dag.get(node, set())
 68            graph[node] = deps
 69            queue.update(deps)
 70
 71        return DAG(graph)
 72
 73    def prune(self, *nodes: T) -> DAG[T]:
 74        """Create a dag keeping only the included nodes.
 75
 76        Args:
 77            nodes: The nodes of the new pruned dag.
 78
 79        Returns:
 80            A new dag consisting of the specified nodes.
 81        """
 82        graph = {}
 83
 84        for node, deps in self._dag.items():
 85            if node in nodes:
 86                graph[node] = {dep for dep in deps if dep in nodes}
 87
 88        return DAG(graph)
 89
 90    def upstream(self, node: T) -> t.List[T]:
 91        """Returns all upstream dependencies in topologically sorted order."""
 92        return self.subdag(node).sorted[:-1]
 93
 94    @property
 95    def roots(self) -> t.Set[T]:
 96        """Returns all nodes in the graph without any upstream dependencies."""
 97        return {node for node, deps in self._dag.items() if not deps}
 98
 99    @property
100    def graph(self) -> t.Dict[T, t.Set[T]]:
101        graph = {}
102        for node, deps in self._dag.items():
103            graph[node] = deps.copy()
104        return graph
105
106    @property
107    def sorted(self) -> t.List[T]:
108        """Returns a list of nodes sorted in topological order."""
109        if self._sorted is None:
110            self._sorted = []
111
112            unprocessed_nodes = self.graph
113
114            last_processed_nodes: t.Set[T] = set()
115            cycle_candidates: t.Collection = unprocessed_nodes
116
117            while unprocessed_nodes:
118                next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps}
119
120                if not next_nodes:
121                    # Sort cycle candidates to make the order deterministic
122                    cycle_candidates_msg = (
123                        "\nPossible candidates to check for circular references: "
124                        + ", ".join(str(node) for node in sorted(cycle_candidates))
125                    )
126
127                    if last_processed_nodes:
128                        last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join(
129                            str(node) for node in last_processed_nodes
130                        )
131                    else:
132                        last_processed_msg = ""
133
134                    raise SQLMeshError(
135                        "Detected a cycle in the DAG. "
136                        "Please make sure there are no circular references between nodes."
137                        f"{last_processed_msg}{cycle_candidates_msg}"
138                    )
139
140                for node in next_nodes:
141                    unprocessed_nodes.pop(node)
142
143                nodes_with_unaffected_deps: t.Set[T] = set()
144                for node, deps in unprocessed_nodes.items():
145                    deps_before_subtraction = deps
146
147                    deps -= next_nodes
148                    if deps_before_subtraction == deps:
149                        nodes_with_unaffected_deps.add(node)
150
151                cycle_candidates = nodes_with_unaffected_deps or unprocessed_nodes
152
153                # Sort to make the order deterministic
154                # TODO: Make protocol that makes the type var both hashable and sortable once we are on Python 3.8+
155                last_processed_nodes = sorted(next_nodes)  # type: ignore
156                self._sorted.extend(last_processed_nodes)
157
158        return self._sorted
159
160    def downstream(self, node: T) -> t.List[T]:
161        """Get all nodes that have the input node as an upstream dependency.
162
163        Args:
164            node: The ancestor node.
165
166        Returns:
167            A list of descendant nodes sorted in topological order.
168        """
169        sorted_nodes = self.sorted
170        try:
171            node_index = sorted_nodes.index(node)
172        except ValueError:
173            return []
174
175        def visit() -> t.Iterator[T]:
176            """Visit topologically sorted nodes after input node and yield downstream dependants."""
177            downstream = {node}
178            for current_node in sorted_nodes[node_index + 1 :]:
179                upstream = self._dag.get(current_node, set())
180                if not upstream.isdisjoint(downstream):
181                    downstream.add(current_node)
182                    yield current_node
183
184        return list(visit())
185
186    def lineage(self, node: T) -> DAG[T]:
187        """Get a dag of the node and its upstream dependencies and downstream dependents.
188
189        Args:
190            node: The node used to determine lineage.
191
192        Returns:
193            A new dag consisting of the dependent and descendant nodes.
194        """
195        return self.subdag(node, *self.downstream(node))
196
197    def __contains__(self, item: T) -> bool:
198        return item in self.graph
199
200    def __iter__(self) -> t.Iterator[T]:
201        for node in self.sorted:
202            yield node
class DAG(typing.Generic[~T]):
 19class DAG(t.Generic[T]):
 20    def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None):
 21        self._dag: t.Dict[T, t.Set[T]] = {}
 22        self._sorted: t.Optional[t.List[T]] = None
 23
 24        for node, dependencies in (graph or {}).items():
 25            self.add(node, dependencies)
 26
 27    def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None:
 28        """Add a node to the graph with an optional upstream dependency.
 29
 30        Args:
 31            node: The node to add.
 32            dependencies: Optional dependencies to add to the node.
 33        """
 34        self._sorted = None
 35        if node not in self._dag:
 36            self._dag[node] = set()
 37        if dependencies:
 38            self._dag[node].update(dependencies)
 39            for d in dependencies:
 40                self.add(d)
 41
 42    @property
 43    def reversed(self) -> DAG[T]:
 44        """Returns a copy of this DAG with all its edges reversed."""
 45        result = DAG[T]()
 46
 47        for node, deps in self._dag.items():
 48            result.add(node)
 49            for dep in deps:
 50                result.add(dep, [node])
 51
 52        return result
 53
 54    def subdag(self, *nodes: T) -> DAG[T]:
 55        """Create a new subdag given node(s).
 56
 57        Args:
 58            nodes: The nodes of the new subdag.
 59
 60        Returns:
 61            A new dag consisting of the specified nodes and upstream.
 62        """
 63        queue = set(nodes)
 64        graph = {}
 65
 66        while queue:
 67            node = queue.pop()
 68            deps = self._dag.get(node, set())
 69            graph[node] = deps
 70            queue.update(deps)
 71
 72        return DAG(graph)
 73
 74    def prune(self, *nodes: T) -> DAG[T]:
 75        """Create a dag keeping only the included nodes.
 76
 77        Args:
 78            nodes: The nodes of the new pruned dag.
 79
 80        Returns:
 81            A new dag consisting of the specified nodes.
 82        """
 83        graph = {}
 84
 85        for node, deps in self._dag.items():
 86            if node in nodes:
 87                graph[node] = {dep for dep in deps if dep in nodes}
 88
 89        return DAG(graph)
 90
 91    def upstream(self, node: T) -> t.List[T]:
 92        """Returns all upstream dependencies in topologically sorted order."""
 93        return self.subdag(node).sorted[:-1]
 94
 95    @property
 96    def roots(self) -> t.Set[T]:
 97        """Returns all nodes in the graph without any upstream dependencies."""
 98        return {node for node, deps in self._dag.items() if not deps}
 99
100    @property
101    def graph(self) -> t.Dict[T, t.Set[T]]:
102        graph = {}
103        for node, deps in self._dag.items():
104            graph[node] = deps.copy()
105        return graph
106
107    @property
108    def sorted(self) -> t.List[T]:
109        """Returns a list of nodes sorted in topological order."""
110        if self._sorted is None:
111            self._sorted = []
112
113            unprocessed_nodes = self.graph
114
115            last_processed_nodes: t.Set[T] = set()
116            cycle_candidates: t.Collection = unprocessed_nodes
117
118            while unprocessed_nodes:
119                next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps}
120
121                if not next_nodes:
122                    # Sort cycle candidates to make the order deterministic
123                    cycle_candidates_msg = (
124                        "\nPossible candidates to check for circular references: "
125                        + ", ".join(str(node) for node in sorted(cycle_candidates))
126                    )
127
128                    if last_processed_nodes:
129                        last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join(
130                            str(node) for node in last_processed_nodes
131                        )
132                    else:
133                        last_processed_msg = ""
134
135                    raise SQLMeshError(
136                        "Detected a cycle in the DAG. "
137                        "Please make sure there are no circular references between nodes."
138                        f"{last_processed_msg}{cycle_candidates_msg}"
139                    )
140
141                for node in next_nodes:
142                    unprocessed_nodes.pop(node)
143
144                nodes_with_unaffected_deps: t.Set[T] = set()
145                for node, deps in unprocessed_nodes.items():
146                    deps_before_subtraction = deps
147
148                    deps -= next_nodes
149                    if deps_before_subtraction == deps:
150                        nodes_with_unaffected_deps.add(node)
151
152                cycle_candidates = nodes_with_unaffected_deps or unprocessed_nodes
153
154                # Sort to make the order deterministic
155                # TODO: Make protocol that makes the type var both hashable and sortable once we are on Python 3.8+
156                last_processed_nodes = sorted(next_nodes)  # type: ignore
157                self._sorted.extend(last_processed_nodes)
158
159        return self._sorted
160
161    def downstream(self, node: T) -> t.List[T]:
162        """Get all nodes that have the input node as an upstream dependency.
163
164        Args:
165            node: The ancestor node.
166
167        Returns:
168            A list of descendant nodes sorted in topological order.
169        """
170        sorted_nodes = self.sorted
171        try:
172            node_index = sorted_nodes.index(node)
173        except ValueError:
174            return []
175
176        def visit() -> t.Iterator[T]:
177            """Visit topologically sorted nodes after input node and yield downstream dependants."""
178            downstream = {node}
179            for current_node in sorted_nodes[node_index + 1 :]:
180                upstream = self._dag.get(current_node, set())
181                if not upstream.isdisjoint(downstream):
182                    downstream.add(current_node)
183                    yield current_node
184
185        return list(visit())
186
187    def lineage(self, node: T) -> DAG[T]:
188        """Get a dag of the node and its upstream dependencies and downstream dependents.
189
190        Args:
191            node: The node used to determine lineage.
192
193        Returns:
194            A new dag consisting of the dependent and descendant nodes.
195        """
196        return self.subdag(node, *self.downstream(node))
197
198    def __contains__(self, item: T) -> bool:
199        return item in self.graph
200
201    def __iter__(self) -> t.Iterator[T]:
202        for node in self.sorted:
203            yield node

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

DAG(graph: Union[Dict[~T, Set[~T]], NoneType] = None)
20    def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None):
21        self._dag: t.Dict[T, t.Set[T]] = {}
22        self._sorted: t.Optional[t.List[T]] = None
23
24        for node, dependencies in (graph or {}).items():
25            self.add(node, dependencies)
def add( self, node: ~T, dependencies: Union[Iterable[~T], NoneType] = None) -> None:
27    def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None:
28        """Add a node to the graph with an optional upstream dependency.
29
30        Args:
31            node: The node to add.
32            dependencies: Optional dependencies to add to the node.
33        """
34        self._sorted = None
35        if node not in self._dag:
36            self._dag[node] = set()
37        if dependencies:
38            self._dag[node].update(dependencies)
39            for d in dependencies:
40                self.add(d)

Add a node to the graph with an optional upstream dependency.

Arguments:
  • node: The node to add.
  • dependencies: Optional dependencies to add to the node.
reversed: sqlmesh.utils.dag.DAG[~T]

Returns a copy of this DAG with all its edges reversed.

def subdag(self, *nodes: ~T) -> sqlmesh.utils.dag.DAG[~T]:
54    def subdag(self, *nodes: T) -> DAG[T]:
55        """Create a new subdag given node(s).
56
57        Args:
58            nodes: The nodes of the new subdag.
59
60        Returns:
61            A new dag consisting of the specified nodes and upstream.
62        """
63        queue = set(nodes)
64        graph = {}
65
66        while queue:
67            node = queue.pop()
68            deps = self._dag.get(node, set())
69            graph[node] = deps
70            queue.update(deps)
71
72        return DAG(graph)

Create a new subdag given node(s).

Arguments:
  • nodes: The nodes of the new subdag.
Returns:

A new dag consisting of the specified nodes and upstream.

def prune(self, *nodes: ~T) -> sqlmesh.utils.dag.DAG[~T]:
74    def prune(self, *nodes: T) -> DAG[T]:
75        """Create a dag keeping only the included nodes.
76
77        Args:
78            nodes: The nodes of the new pruned dag.
79
80        Returns:
81            A new dag consisting of the specified nodes.
82        """
83        graph = {}
84
85        for node, deps in self._dag.items():
86            if node in nodes:
87                graph[node] = {dep for dep in deps if dep in nodes}
88
89        return DAG(graph)

Create a dag keeping only the included nodes.

Arguments:
  • nodes: The nodes of the new pruned dag.
Returns:

A new dag consisting of the specified nodes.

def upstream(self, node: ~T) -> List[~T]:
91    def upstream(self, node: T) -> t.List[T]:
92        """Returns all upstream dependencies in topologically sorted order."""
93        return self.subdag(node).sorted[:-1]

Returns all upstream dependencies in topologically sorted order.

roots: Set[~T]

Returns all nodes in the graph without any upstream dependencies.

sorted: List[~T]

Returns a list of nodes sorted in topological order.

def downstream(self, node: ~T) -> List[~T]:
161    def downstream(self, node: T) -> t.List[T]:
162        """Get all nodes that have the input node as an upstream dependency.
163
164        Args:
165            node: The ancestor node.
166
167        Returns:
168            A list of descendant nodes sorted in topological order.
169        """
170        sorted_nodes = self.sorted
171        try:
172            node_index = sorted_nodes.index(node)
173        except ValueError:
174            return []
175
176        def visit() -> t.Iterator[T]:
177            """Visit topologically sorted nodes after input node and yield downstream dependants."""
178            downstream = {node}
179            for current_node in sorted_nodes[node_index + 1 :]:
180                upstream = self._dag.get(current_node, set())
181                if not upstream.isdisjoint(downstream):
182                    downstream.add(current_node)
183                    yield current_node
184
185        return list(visit())

Get all nodes that have the input node as an upstream dependency.

Arguments:
  • node: The ancestor node.
Returns:

A list of descendant nodes sorted in topological order.

def lineage(self, node: ~T) -> sqlmesh.utils.dag.DAG[~T]:
187    def lineage(self, node: T) -> DAG[T]:
188        """Get a dag of the node and its upstream dependencies and downstream dependents.
189
190        Args:
191            node: The node used to determine lineage.
192
193        Returns:
194            A new dag consisting of the dependent and descendant nodes.
195        """
196        return self.subdag(node, *self.downstream(node))

Get a dag of the node and its upstream dependencies and downstream dependents.

Arguments:
  • node: The node used to determine lineage.
Returns:

A new dag consisting of the dependent and descendant nodes.