DAG
A DAG, or directed acyclic graph, is a graph where the edges are directional and there are no cycles with all the edges pointing in the same direction. SQLMesh uses a DAG to keep track of a project's models. This allows SQLMesh to easily determine a model's lineage and to identify upstream and downstream dependencies.
1""" 2# DAG 3 4A DAG, or directed acyclic graph, is a graph where the edges are directional and there are no cycles with 5all the edges pointing in the same direction. SQLMesh uses a DAG to keep track of a project's models. This 6allows SQLMesh to easily determine a model's lineage and to identify upstream and downstream dependencies. 7""" 8 9from __future__ import annotations 10 11import typing as t 12 13from sqlmesh.utils.errors import SQLMeshError 14 15T = t.TypeVar("T", bound=t.Hashable) 16 17 18class DAG(t.Generic[T]): 19 def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None): 20 self._dag: t.Dict[T, t.Set[T]] = {} 21 self._sorted: t.Optional[t.List[T]] = None 22 23 for node, dependencies in (graph or {}).items(): 24 self.add(node, dependencies) 25 26 def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None: 27 """Add a node to the graph with an optional upstream dependency. 28 29 Args: 30 node: The node to add. 31 dependencies: Optional dependencies to add to the node. 32 """ 33 self._sorted = None 34 if node not in self._dag: 35 self._dag[node] = set() 36 if dependencies: 37 self._dag[node].update(dependencies) 38 for d in dependencies: 39 self.add(d) 40 41 @property 42 def reversed(self) -> DAG[T]: 43 """Returns a copy of this DAG with all its edges reversed.""" 44 result = DAG[T]() 45 46 for node, deps in self._dag.items(): 47 result.add(node) 48 for dep in deps: 49 result.add(dep, [node]) 50 51 return result 52 53 def subdag(self, *nodes: T) -> DAG[T]: 54 """Create a new subdag given node(s). 55 56 Args: 57 nodes: The nodes of the new subdag. 58 59 Returns: 60 A new dag consisting of the specified nodes and upstream. 61 """ 62 queue = set(nodes) 63 graph = {} 64 65 while queue: 66 node = queue.pop() 67 deps = self._dag.get(node, set()) 68 graph[node] = deps 69 queue.update(deps) 70 71 return DAG(graph) 72 73 def prune(self, *nodes: T) -> DAG[T]: 74 """Create a dag keeping only the included nodes. 75 76 Args: 77 nodes: The nodes of the new pruned dag. 78 79 Returns: 80 A new dag consisting of the specified nodes. 81 """ 82 graph = {} 83 84 for node, deps in self._dag.items(): 85 if node in nodes: 86 graph[node] = {dep for dep in deps if dep in nodes} 87 88 return DAG(graph) 89 90 def upstream(self, node: T) -> t.List[T]: 91 """Returns all upstream dependencies in topologically sorted order.""" 92 return self.subdag(node).sorted[:-1] 93 94 @property 95 def roots(self) -> t.Set[T]: 96 """Returns all nodes in the graph without any upstream dependencies.""" 97 return {node for node, deps in self._dag.items() if not deps} 98 99 @property 100 def graph(self) -> t.Dict[T, t.Set[T]]: 101 graph = {} 102 for node, deps in self._dag.items(): 103 graph[node] = deps.copy() 104 return graph 105 106 @property 107 def sorted(self) -> t.List[T]: 108 """Returns a list of nodes sorted in topological order.""" 109 if self._sorted is None: 110 self._sorted = [] 111 112 unprocessed_nodes = self.graph 113 114 last_processed_nodes: t.Set[T] = set() 115 cycle_candidates: t.Collection = unprocessed_nodes 116 117 while unprocessed_nodes: 118 next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps} 119 120 if not next_nodes: 121 # Sort cycle candidates to make the order deterministic 122 cycle_candidates_msg = ( 123 "\nPossible candidates to check for circular references: " 124 + ", ".join(str(node) for node in sorted(cycle_candidates)) 125 ) 126 127 if last_processed_nodes: 128 last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join( 129 str(node) for node in last_processed_nodes 130 ) 131 else: 132 last_processed_msg = "" 133 134 raise SQLMeshError( 135 "Detected a cycle in the DAG. " 136 "Please make sure there are no circular references between nodes." 137 f"{last_processed_msg}{cycle_candidates_msg}" 138 ) 139 140 for node in next_nodes: 141 unprocessed_nodes.pop(node) 142 143 nodes_with_unaffected_deps: t.Set[T] = set() 144 for node, deps in unprocessed_nodes.items(): 145 deps_before_subtraction = deps 146 147 deps -= next_nodes 148 if deps_before_subtraction == deps: 149 nodes_with_unaffected_deps.add(node) 150 151 cycle_candidates = nodes_with_unaffected_deps or unprocessed_nodes 152 153 # Sort to make the order deterministic 154 # TODO: Make protocol that makes the type var both hashable and sortable once we are on Python 3.8+ 155 last_processed_nodes = sorted(next_nodes) # type: ignore 156 self._sorted.extend(last_processed_nodes) 157 158 return self._sorted 159 160 def downstream(self, node: T) -> t.List[T]: 161 """Get all nodes that have the input node as an upstream dependency. 162 163 Args: 164 node: The ancestor node. 165 166 Returns: 167 A list of descendant nodes sorted in topological order. 168 """ 169 sorted_nodes = self.sorted 170 try: 171 node_index = sorted_nodes.index(node) 172 except ValueError: 173 return [] 174 175 def visit() -> t.Iterator[T]: 176 """Visit topologically sorted nodes after input node and yield downstream dependants.""" 177 downstream = {node} 178 for current_node in sorted_nodes[node_index + 1 :]: 179 upstream = self._dag.get(current_node, set()) 180 if not upstream.isdisjoint(downstream): 181 downstream.add(current_node) 182 yield current_node 183 184 return list(visit()) 185 186 def lineage(self, node: T) -> DAG[T]: 187 """Get a dag of the node and its upstream dependencies and downstream dependents. 188 189 Args: 190 node: The node used to determine lineage. 191 192 Returns: 193 A new dag consisting of the dependent and descendant nodes. 194 """ 195 return self.subdag(node, *self.downstream(node)) 196 197 def __contains__(self, item: T) -> bool: 198 return item in self.graph 199 200 def __iter__(self) -> t.Iterator[T]: 201 for node in self.sorted: 202 yield node
19class DAG(t.Generic[T]): 20 def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None): 21 self._dag: t.Dict[T, t.Set[T]] = {} 22 self._sorted: t.Optional[t.List[T]] = None 23 24 for node, dependencies in (graph or {}).items(): 25 self.add(node, dependencies) 26 27 def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None: 28 """Add a node to the graph with an optional upstream dependency. 29 30 Args: 31 node: The node to add. 32 dependencies: Optional dependencies to add to the node. 33 """ 34 self._sorted = None 35 if node not in self._dag: 36 self._dag[node] = set() 37 if dependencies: 38 self._dag[node].update(dependencies) 39 for d in dependencies: 40 self.add(d) 41 42 @property 43 def reversed(self) -> DAG[T]: 44 """Returns a copy of this DAG with all its edges reversed.""" 45 result = DAG[T]() 46 47 for node, deps in self._dag.items(): 48 result.add(node) 49 for dep in deps: 50 result.add(dep, [node]) 51 52 return result 53 54 def subdag(self, *nodes: T) -> DAG[T]: 55 """Create a new subdag given node(s). 56 57 Args: 58 nodes: The nodes of the new subdag. 59 60 Returns: 61 A new dag consisting of the specified nodes and upstream. 62 """ 63 queue = set(nodes) 64 graph = {} 65 66 while queue: 67 node = queue.pop() 68 deps = self._dag.get(node, set()) 69 graph[node] = deps 70 queue.update(deps) 71 72 return DAG(graph) 73 74 def prune(self, *nodes: T) -> DAG[T]: 75 """Create a dag keeping only the included nodes. 76 77 Args: 78 nodes: The nodes of the new pruned dag. 79 80 Returns: 81 A new dag consisting of the specified nodes. 82 """ 83 graph = {} 84 85 for node, deps in self._dag.items(): 86 if node in nodes: 87 graph[node] = {dep for dep in deps if dep in nodes} 88 89 return DAG(graph) 90 91 def upstream(self, node: T) -> t.List[T]: 92 """Returns all upstream dependencies in topologically sorted order.""" 93 return self.subdag(node).sorted[:-1] 94 95 @property 96 def roots(self) -> t.Set[T]: 97 """Returns all nodes in the graph without any upstream dependencies.""" 98 return {node for node, deps in self._dag.items() if not deps} 99 100 @property 101 def graph(self) -> t.Dict[T, t.Set[T]]: 102 graph = {} 103 for node, deps in self._dag.items(): 104 graph[node] = deps.copy() 105 return graph 106 107 @property 108 def sorted(self) -> t.List[T]: 109 """Returns a list of nodes sorted in topological order.""" 110 if self._sorted is None: 111 self._sorted = [] 112 113 unprocessed_nodes = self.graph 114 115 last_processed_nodes: t.Set[T] = set() 116 cycle_candidates: t.Collection = unprocessed_nodes 117 118 while unprocessed_nodes: 119 next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps} 120 121 if not next_nodes: 122 # Sort cycle candidates to make the order deterministic 123 cycle_candidates_msg = ( 124 "\nPossible candidates to check for circular references: " 125 + ", ".join(str(node) for node in sorted(cycle_candidates)) 126 ) 127 128 if last_processed_nodes: 129 last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join( 130 str(node) for node in last_processed_nodes 131 ) 132 else: 133 last_processed_msg = "" 134 135 raise SQLMeshError( 136 "Detected a cycle in the DAG. " 137 "Please make sure there are no circular references between nodes." 138 f"{last_processed_msg}{cycle_candidates_msg}" 139 ) 140 141 for node in next_nodes: 142 unprocessed_nodes.pop(node) 143 144 nodes_with_unaffected_deps: t.Set[T] = set() 145 for node, deps in unprocessed_nodes.items(): 146 deps_before_subtraction = deps 147 148 deps -= next_nodes 149 if deps_before_subtraction == deps: 150 nodes_with_unaffected_deps.add(node) 151 152 cycle_candidates = nodes_with_unaffected_deps or unprocessed_nodes 153 154 # Sort to make the order deterministic 155 # TODO: Make protocol that makes the type var both hashable and sortable once we are on Python 3.8+ 156 last_processed_nodes = sorted(next_nodes) # type: ignore 157 self._sorted.extend(last_processed_nodes) 158 159 return self._sorted 160 161 def downstream(self, node: T) -> t.List[T]: 162 """Get all nodes that have the input node as an upstream dependency. 163 164 Args: 165 node: The ancestor node. 166 167 Returns: 168 A list of descendant nodes sorted in topological order. 169 """ 170 sorted_nodes = self.sorted 171 try: 172 node_index = sorted_nodes.index(node) 173 except ValueError: 174 return [] 175 176 def visit() -> t.Iterator[T]: 177 """Visit topologically sorted nodes after input node and yield downstream dependants.""" 178 downstream = {node} 179 for current_node in sorted_nodes[node_index + 1 :]: 180 upstream = self._dag.get(current_node, set()) 181 if not upstream.isdisjoint(downstream): 182 downstream.add(current_node) 183 yield current_node 184 185 return list(visit()) 186 187 def lineage(self, node: T) -> DAG[T]: 188 """Get a dag of the node and its upstream dependencies and downstream dependents. 189 190 Args: 191 node: The node used to determine lineage. 192 193 Returns: 194 A new dag consisting of the dependent and descendant nodes. 195 """ 196 return self.subdag(node, *self.downstream(node)) 197 198 def __contains__(self, item: T) -> bool: 199 return item in self.graph 200 201 def __iter__(self) -> t.Iterator[T]: 202 for node in self.sorted: 203 yield node
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
27 def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None: 28 """Add a node to the graph with an optional upstream dependency. 29 30 Args: 31 node: The node to add. 32 dependencies: Optional dependencies to add to the node. 33 """ 34 self._sorted = None 35 if node not in self._dag: 36 self._dag[node] = set() 37 if dependencies: 38 self._dag[node].update(dependencies) 39 for d in dependencies: 40 self.add(d)
Add a node to the graph with an optional upstream dependency.
Arguments:
- node: The node to add.
- dependencies: Optional dependencies to add to the node.
54 def subdag(self, *nodes: T) -> DAG[T]: 55 """Create a new subdag given node(s). 56 57 Args: 58 nodes: The nodes of the new subdag. 59 60 Returns: 61 A new dag consisting of the specified nodes and upstream. 62 """ 63 queue = set(nodes) 64 graph = {} 65 66 while queue: 67 node = queue.pop() 68 deps = self._dag.get(node, set()) 69 graph[node] = deps 70 queue.update(deps) 71 72 return DAG(graph)
Create a new subdag given node(s).
Arguments:
- nodes: The nodes of the new subdag.
Returns:
A new dag consisting of the specified nodes and upstream.
74 def prune(self, *nodes: T) -> DAG[T]: 75 """Create a dag keeping only the included nodes. 76 77 Args: 78 nodes: The nodes of the new pruned dag. 79 80 Returns: 81 A new dag consisting of the specified nodes. 82 """ 83 graph = {} 84 85 for node, deps in self._dag.items(): 86 if node in nodes: 87 graph[node] = {dep for dep in deps if dep in nodes} 88 89 return DAG(graph)
Create a dag keeping only the included nodes.
Arguments:
- nodes: The nodes of the new pruned dag.
Returns:
A new dag consisting of the specified nodes.
91 def upstream(self, node: T) -> t.List[T]: 92 """Returns all upstream dependencies in topologically sorted order.""" 93 return self.subdag(node).sorted[:-1]
Returns all upstream dependencies in topologically sorted order.
161 def downstream(self, node: T) -> t.List[T]: 162 """Get all nodes that have the input node as an upstream dependency. 163 164 Args: 165 node: The ancestor node. 166 167 Returns: 168 A list of descendant nodes sorted in topological order. 169 """ 170 sorted_nodes = self.sorted 171 try: 172 node_index = sorted_nodes.index(node) 173 except ValueError: 174 return [] 175 176 def visit() -> t.Iterator[T]: 177 """Visit topologically sorted nodes after input node and yield downstream dependants.""" 178 downstream = {node} 179 for current_node in sorted_nodes[node_index + 1 :]: 180 upstream = self._dag.get(current_node, set()) 181 if not upstream.isdisjoint(downstream): 182 downstream.add(current_node) 183 yield current_node 184 185 return list(visit())
Get all nodes that have the input node as an upstream dependency.
Arguments:
- node: The ancestor node.
Returns:
A list of descendant nodes sorted in topological order.
187 def lineage(self, node: T) -> DAG[T]: 188 """Get a dag of the node and its upstream dependencies and downstream dependents. 189 190 Args: 191 node: The node used to determine lineage. 192 193 Returns: 194 A new dag consisting of the dependent and descendant nodes. 195 """ 196 return self.subdag(node, *self.downstream(node))
Get a dag of the node and its upstream dependencies and downstream dependents.
Arguments:
- node: The node used to determine lineage.
Returns:
A new dag consisting of the dependent and descendant nodes.