DAG
A DAG, or directed acyclic graph, is a graph where the edges are directional and there are no cycles with all the edges pointing in the same direction. SQLMesh uses a DAG to keep track of a project's models. This allows SQLMesh to easily determine a model's lineage and to identify upstream and downstream dependencies.
1""" 2# DAG 3 4A DAG, or directed acyclic graph, is a graph where the edges are directional and there are no cycles with 5all the edges pointing in the same direction. SQLMesh uses a DAG to keep track of a project's models. This 6allows SQLMesh to easily determine a model's lineage and to identify upstream and downstream dependencies. 7""" 8 9from __future__ import annotations 10 11import typing as t 12 13from sqlmesh.utils.errors import SQLMeshError 14 15T = t.TypeVar("T", bound=t.Hashable) 16 17 18class DAG(t.Generic[T]): 19 def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None): 20 self._dag: t.Dict[T, t.Set[T]] = {} 21 self._sorted: t.Optional[t.List[T]] = None 22 self._upstream: t.Dict[T, t.Set[T]] = {} 23 24 for node, dependencies in (graph or {}).items(): 25 self.add(node, dependencies) 26 27 def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None: 28 """Add a node to the graph with an optional upstream dependency. 29 30 Args: 31 node: The node to add. 32 dependencies: Optional dependencies to add to the node. 33 """ 34 self._sorted = None 35 self._upstream.clear() 36 if node not in self._dag: 37 self._dag[node] = set() 38 if dependencies: 39 self._dag[node].update(dependencies) 40 for d in dependencies: 41 self.add(d) 42 43 @property 44 def reversed(self) -> DAG[T]: 45 """Returns a copy of this DAG with all its edges reversed.""" 46 result = DAG[T]() 47 48 for node, deps in self._dag.items(): 49 result.add(node) 50 for dep in deps: 51 result.add(dep, [node]) 52 53 return result 54 55 def subdag(self, *nodes: T) -> DAG[T]: 56 """Create a new subdag given node(s). 57 58 Args: 59 nodes: The nodes of the new subdag. 60 61 Returns: 62 A new dag consisting of the specified nodes and upstream. 63 """ 64 queue = set(nodes) 65 dag: DAG[T] = DAG() 66 67 while queue: 68 node = queue.pop() 69 deps = self._dag.get(node, set()) 70 dag.add(node, deps) 71 queue.update(deps) 72 73 return dag 74 75 def prune(self, *nodes: T) -> DAG[T]: 76 """Create a dag keeping only the included nodes. 77 78 Args: 79 nodes: The nodes of the new pruned dag. 80 81 Returns: 82 A new dag consisting of the specified nodes. 83 """ 84 dag: DAG[T] = DAG() 85 86 for node, deps in self._dag.items(): 87 if node in nodes: 88 dag.add(node, (dep for dep in deps if dep in nodes)) 89 90 return dag 91 92 def upstream(self, node: T) -> t.Set[T]: 93 """Returns all upstream dependencies.""" 94 if node not in self._upstream: 95 deps = self._dag.get(node, set()) 96 self._upstream[node] = { 97 upstream for dep in deps for upstream in self.upstream(dep) 98 } | deps 99 100 return self._upstream[node] 101 102 def _find_cycle_path(self, nodes_in_cycle: t.Dict[T, t.Set[T]]) -> t.Optional[t.List[T]]: 103 """Find the exact cycle path using DFS when a cycle is detected. 104 105 Args: 106 nodes_in_cycle: Dictionary of nodes that are part of the cycle and their dependencies 107 108 Returns: 109 List of nodes forming the cycle path, or None if no cycle found 110 """ 111 if not nodes_in_cycle: 112 return None 113 114 # Use DFS to find a cycle path 115 visited: t.Set[T] = set() 116 path: t.List[T] = [] 117 118 def dfs(node: T) -> t.Optional[t.List[T]]: 119 if node in path: 120 # Found a cycle - extract the cycle path 121 cycle_start = path.index(node) 122 return path[cycle_start:] + [node] 123 124 if node in visited: 125 return None 126 127 visited.add(node) 128 path.append(node) 129 130 # Only follow edges to nodes that are still in the unprocessed set 131 for neighbor in nodes_in_cycle.get(node, set()): 132 if neighbor in nodes_in_cycle: 133 cycle = dfs(neighbor) 134 if cycle: 135 return cycle 136 137 path.pop() 138 return None 139 140 # Try starting DFS from each unvisited node 141 for start_node in nodes_in_cycle: 142 if start_node not in visited: 143 cycle = dfs(start_node) 144 if cycle: 145 return cycle[:-1] # Remove the duplicate node at the end 146 147 return None 148 149 @property 150 def roots(self) -> t.Set[T]: 151 """Returns all nodes in the graph without any upstream dependencies.""" 152 return {node for node, deps in self._dag.items() if not deps} 153 154 @property 155 def graph(self) -> t.Dict[T, t.Set[T]]: 156 graph = {} 157 for node, deps in self._dag.items(): 158 graph[node] = deps.copy() 159 return graph 160 161 @property 162 def sorted(self) -> t.List[T]: 163 """Returns a list of nodes sorted in topological order.""" 164 if self._sorted is None: 165 self._sorted = [] 166 unprocessed_nodes = self.graph 167 168 last_processed_nodes: t.Set[T] = set() 169 cycle_candidates: t.Collection = unprocessed_nodes 170 171 while unprocessed_nodes: 172 next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps} 173 174 if not next_nodes: 175 # A cycle was detected - find the exact cycle path 176 cycle_path = self._find_cycle_path(unprocessed_nodes) 177 178 last_processed_msg = "" 179 if cycle_path: 180 node_output = " ->\n".join( 181 str(node) for node in (cycle_path + [cycle_path[0]]) 182 ) 183 cycle_msg = f"\nCycle:\n{node_output}" 184 else: 185 # Fallback message in case a cycle can't be found 186 cycle_candidates_msg = ( 187 "\nPossible candidates to check for circular references: " 188 + ", ".join(str(node) for node in sorted(cycle_candidates)) 189 ) 190 cycle_msg = cycle_candidates_msg 191 if last_processed_nodes: 192 last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join( 193 str(node) for node in last_processed_nodes 194 ) 195 196 raise SQLMeshError( 197 "Detected a cycle in the DAG. " 198 "Please make sure there are no circular references between nodes." 199 f"{last_processed_msg}{cycle_msg}" 200 ) 201 202 for node in next_nodes: 203 unprocessed_nodes.pop(node) 204 205 nodes_with_unaffected_deps: t.Set[T] = set() 206 for node, deps in unprocessed_nodes.items(): 207 deps_before_subtraction = deps 208 209 deps -= next_nodes 210 if deps_before_subtraction == deps: 211 nodes_with_unaffected_deps.add(node) 212 213 cycle_candidates = nodes_with_unaffected_deps or unprocessed_nodes 214 215 # Sort to make the order deterministic 216 # TODO: Make protocol that makes the type var both hashable and sortable once we are on Python 3.8+ 217 last_processed_nodes = sorted(next_nodes) # type: ignore 218 self._sorted.extend(last_processed_nodes) 219 220 return self._sorted 221 222 def downstream(self, node: T) -> t.List[T]: 223 """Get all nodes that have the input node as an upstream dependency. 224 225 Args: 226 node: The ancestor node. 227 228 Returns: 229 A list of descendant nodes sorted in topological order. 230 """ 231 sorted_nodes = self.sorted 232 try: 233 node_index = sorted_nodes.index(node) 234 except ValueError: 235 return [] 236 237 def visit() -> t.Iterator[T]: 238 """Visit topologically sorted nodes after input node and yield downstream dependants.""" 239 downstream = {node} 240 for current_node in sorted_nodes[node_index + 1 :]: 241 upstream = self._dag.get(current_node, set()) 242 if not upstream.isdisjoint(downstream): 243 downstream.add(current_node) 244 yield current_node 245 246 return list(visit()) 247 248 def lineage(self, node: T) -> DAG[T]: 249 """Get a dag of the node and its upstream dependencies and downstream dependents. 250 251 Args: 252 node: The node used to determine lineage. 253 254 Returns: 255 A new dag consisting of the dependent and descendant nodes. 256 """ 257 return self.subdag(node, *self.downstream(node)) 258 259 def __contains__(self, item: T) -> bool: 260 return item in self.graph 261 262 def __iter__(self) -> t.Iterator[T]: 263 for node in self.sorted: 264 yield node
19class DAG(t.Generic[T]): 20 def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None): 21 self._dag: t.Dict[T, t.Set[T]] = {} 22 self._sorted: t.Optional[t.List[T]] = None 23 self._upstream: t.Dict[T, t.Set[T]] = {} 24 25 for node, dependencies in (graph or {}).items(): 26 self.add(node, dependencies) 27 28 def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None: 29 """Add a node to the graph with an optional upstream dependency. 30 31 Args: 32 node: The node to add. 33 dependencies: Optional dependencies to add to the node. 34 """ 35 self._sorted = None 36 self._upstream.clear() 37 if node not in self._dag: 38 self._dag[node] = set() 39 if dependencies: 40 self._dag[node].update(dependencies) 41 for d in dependencies: 42 self.add(d) 43 44 @property 45 def reversed(self) -> DAG[T]: 46 """Returns a copy of this DAG with all its edges reversed.""" 47 result = DAG[T]() 48 49 for node, deps in self._dag.items(): 50 result.add(node) 51 for dep in deps: 52 result.add(dep, [node]) 53 54 return result 55 56 def subdag(self, *nodes: T) -> DAG[T]: 57 """Create a new subdag given node(s). 58 59 Args: 60 nodes: The nodes of the new subdag. 61 62 Returns: 63 A new dag consisting of the specified nodes and upstream. 64 """ 65 queue = set(nodes) 66 dag: DAG[T] = DAG() 67 68 while queue: 69 node = queue.pop() 70 deps = self._dag.get(node, set()) 71 dag.add(node, deps) 72 queue.update(deps) 73 74 return dag 75 76 def prune(self, *nodes: T) -> DAG[T]: 77 """Create a dag keeping only the included nodes. 78 79 Args: 80 nodes: The nodes of the new pruned dag. 81 82 Returns: 83 A new dag consisting of the specified nodes. 84 """ 85 dag: DAG[T] = DAG() 86 87 for node, deps in self._dag.items(): 88 if node in nodes: 89 dag.add(node, (dep for dep in deps if dep in nodes)) 90 91 return dag 92 93 def upstream(self, node: T) -> t.Set[T]: 94 """Returns all upstream dependencies.""" 95 if node not in self._upstream: 96 deps = self._dag.get(node, set()) 97 self._upstream[node] = { 98 upstream for dep in deps for upstream in self.upstream(dep) 99 } | deps 100 101 return self._upstream[node] 102 103 def _find_cycle_path(self, nodes_in_cycle: t.Dict[T, t.Set[T]]) -> t.Optional[t.List[T]]: 104 """Find the exact cycle path using DFS when a cycle is detected. 105 106 Args: 107 nodes_in_cycle: Dictionary of nodes that are part of the cycle and their dependencies 108 109 Returns: 110 List of nodes forming the cycle path, or None if no cycle found 111 """ 112 if not nodes_in_cycle: 113 return None 114 115 # Use DFS to find a cycle path 116 visited: t.Set[T] = set() 117 path: t.List[T] = [] 118 119 def dfs(node: T) -> t.Optional[t.List[T]]: 120 if node in path: 121 # Found a cycle - extract the cycle path 122 cycle_start = path.index(node) 123 return path[cycle_start:] + [node] 124 125 if node in visited: 126 return None 127 128 visited.add(node) 129 path.append(node) 130 131 # Only follow edges to nodes that are still in the unprocessed set 132 for neighbor in nodes_in_cycle.get(node, set()): 133 if neighbor in nodes_in_cycle: 134 cycle = dfs(neighbor) 135 if cycle: 136 return cycle 137 138 path.pop() 139 return None 140 141 # Try starting DFS from each unvisited node 142 for start_node in nodes_in_cycle: 143 if start_node not in visited: 144 cycle = dfs(start_node) 145 if cycle: 146 return cycle[:-1] # Remove the duplicate node at the end 147 148 return None 149 150 @property 151 def roots(self) -> t.Set[T]: 152 """Returns all nodes in the graph without any upstream dependencies.""" 153 return {node for node, deps in self._dag.items() if not deps} 154 155 @property 156 def graph(self) -> t.Dict[T, t.Set[T]]: 157 graph = {} 158 for node, deps in self._dag.items(): 159 graph[node] = deps.copy() 160 return graph 161 162 @property 163 def sorted(self) -> t.List[T]: 164 """Returns a list of nodes sorted in topological order.""" 165 if self._sorted is None: 166 self._sorted = [] 167 unprocessed_nodes = self.graph 168 169 last_processed_nodes: t.Set[T] = set() 170 cycle_candidates: t.Collection = unprocessed_nodes 171 172 while unprocessed_nodes: 173 next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps} 174 175 if not next_nodes: 176 # A cycle was detected - find the exact cycle path 177 cycle_path = self._find_cycle_path(unprocessed_nodes) 178 179 last_processed_msg = "" 180 if cycle_path: 181 node_output = " ->\n".join( 182 str(node) for node in (cycle_path + [cycle_path[0]]) 183 ) 184 cycle_msg = f"\nCycle:\n{node_output}" 185 else: 186 # Fallback message in case a cycle can't be found 187 cycle_candidates_msg = ( 188 "\nPossible candidates to check for circular references: " 189 + ", ".join(str(node) for node in sorted(cycle_candidates)) 190 ) 191 cycle_msg = cycle_candidates_msg 192 if last_processed_nodes: 193 last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join( 194 str(node) for node in last_processed_nodes 195 ) 196 197 raise SQLMeshError( 198 "Detected a cycle in the DAG. " 199 "Please make sure there are no circular references between nodes." 200 f"{last_processed_msg}{cycle_msg}" 201 ) 202 203 for node in next_nodes: 204 unprocessed_nodes.pop(node) 205 206 nodes_with_unaffected_deps: t.Set[T] = set() 207 for node, deps in unprocessed_nodes.items(): 208 deps_before_subtraction = deps 209 210 deps -= next_nodes 211 if deps_before_subtraction == deps: 212 nodes_with_unaffected_deps.add(node) 213 214 cycle_candidates = nodes_with_unaffected_deps or unprocessed_nodes 215 216 # Sort to make the order deterministic 217 # TODO: Make protocol that makes the type var both hashable and sortable once we are on Python 3.8+ 218 last_processed_nodes = sorted(next_nodes) # type: ignore 219 self._sorted.extend(last_processed_nodes) 220 221 return self._sorted 222 223 def downstream(self, node: T) -> t.List[T]: 224 """Get all nodes that have the input node as an upstream dependency. 225 226 Args: 227 node: The ancestor node. 228 229 Returns: 230 A list of descendant nodes sorted in topological order. 231 """ 232 sorted_nodes = self.sorted 233 try: 234 node_index = sorted_nodes.index(node) 235 except ValueError: 236 return [] 237 238 def visit() -> t.Iterator[T]: 239 """Visit topologically sorted nodes after input node and yield downstream dependants.""" 240 downstream = {node} 241 for current_node in sorted_nodes[node_index + 1 :]: 242 upstream = self._dag.get(current_node, set()) 243 if not upstream.isdisjoint(downstream): 244 downstream.add(current_node) 245 yield current_node 246 247 return list(visit()) 248 249 def lineage(self, node: T) -> DAG[T]: 250 """Get a dag of the node and its upstream dependencies and downstream dependents. 251 252 Args: 253 node: The node used to determine lineage. 254 255 Returns: 256 A new dag consisting of the dependent and descendant nodes. 257 """ 258 return self.subdag(node, *self.downstream(node)) 259 260 def __contains__(self, item: T) -> bool: 261 return item in self.graph 262 263 def __iter__(self) -> t.Iterator[T]: 264 for node in self.sorted: 265 yield node
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
28 def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None: 29 """Add a node to the graph with an optional upstream dependency. 30 31 Args: 32 node: The node to add. 33 dependencies: Optional dependencies to add to the node. 34 """ 35 self._sorted = None 36 self._upstream.clear() 37 if node not in self._dag: 38 self._dag[node] = set() 39 if dependencies: 40 self._dag[node].update(dependencies) 41 for d in dependencies: 42 self.add(d)
Add a node to the graph with an optional upstream dependency.
Arguments:
- node: The node to add.
- dependencies: Optional dependencies to add to the node.
44 @property 45 def reversed(self) -> DAG[T]: 46 """Returns a copy of this DAG with all its edges reversed.""" 47 result = DAG[T]() 48 49 for node, deps in self._dag.items(): 50 result.add(node) 51 for dep in deps: 52 result.add(dep, [node]) 53 54 return result
Returns a copy of this DAG with all its edges reversed.
56 def subdag(self, *nodes: T) -> DAG[T]: 57 """Create a new subdag given node(s). 58 59 Args: 60 nodes: The nodes of the new subdag. 61 62 Returns: 63 A new dag consisting of the specified nodes and upstream. 64 """ 65 queue = set(nodes) 66 dag: DAG[T] = DAG() 67 68 while queue: 69 node = queue.pop() 70 deps = self._dag.get(node, set()) 71 dag.add(node, deps) 72 queue.update(deps) 73 74 return dag
Create a new subdag given node(s).
Arguments:
- nodes: The nodes of the new subdag.
Returns:
A new dag consisting of the specified nodes and upstream.
76 def prune(self, *nodes: T) -> DAG[T]: 77 """Create a dag keeping only the included nodes. 78 79 Args: 80 nodes: The nodes of the new pruned dag. 81 82 Returns: 83 A new dag consisting of the specified nodes. 84 """ 85 dag: DAG[T] = DAG() 86 87 for node, deps in self._dag.items(): 88 if node in nodes: 89 dag.add(node, (dep for dep in deps if dep in nodes)) 90 91 return dag
Create a dag keeping only the included nodes.
Arguments:
- nodes: The nodes of the new pruned dag.
Returns:
A new dag consisting of the specified nodes.
93 def upstream(self, node: T) -> t.Set[T]: 94 """Returns all upstream dependencies.""" 95 if node not in self._upstream: 96 deps = self._dag.get(node, set()) 97 self._upstream[node] = { 98 upstream for dep in deps for upstream in self.upstream(dep) 99 } | deps 100 101 return self._upstream[node]
Returns all upstream dependencies.
150 @property 151 def roots(self) -> t.Set[T]: 152 """Returns all nodes in the graph without any upstream dependencies.""" 153 return {node for node, deps in self._dag.items() if not deps}
Returns all nodes in the graph without any upstream dependencies.
162 @property 163 def sorted(self) -> t.List[T]: 164 """Returns a list of nodes sorted in topological order.""" 165 if self._sorted is None: 166 self._sorted = [] 167 unprocessed_nodes = self.graph 168 169 last_processed_nodes: t.Set[T] = set() 170 cycle_candidates: t.Collection = unprocessed_nodes 171 172 while unprocessed_nodes: 173 next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps} 174 175 if not next_nodes: 176 # A cycle was detected - find the exact cycle path 177 cycle_path = self._find_cycle_path(unprocessed_nodes) 178 179 last_processed_msg = "" 180 if cycle_path: 181 node_output = " ->\n".join( 182 str(node) for node in (cycle_path + [cycle_path[0]]) 183 ) 184 cycle_msg = f"\nCycle:\n{node_output}" 185 else: 186 # Fallback message in case a cycle can't be found 187 cycle_candidates_msg = ( 188 "\nPossible candidates to check for circular references: " 189 + ", ".join(str(node) for node in sorted(cycle_candidates)) 190 ) 191 cycle_msg = cycle_candidates_msg 192 if last_processed_nodes: 193 last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join( 194 str(node) for node in last_processed_nodes 195 ) 196 197 raise SQLMeshError( 198 "Detected a cycle in the DAG. " 199 "Please make sure there are no circular references between nodes." 200 f"{last_processed_msg}{cycle_msg}" 201 ) 202 203 for node in next_nodes: 204 unprocessed_nodes.pop(node) 205 206 nodes_with_unaffected_deps: t.Set[T] = set() 207 for node, deps in unprocessed_nodes.items(): 208 deps_before_subtraction = deps 209 210 deps -= next_nodes 211 if deps_before_subtraction == deps: 212 nodes_with_unaffected_deps.add(node) 213 214 cycle_candidates = nodes_with_unaffected_deps or unprocessed_nodes 215 216 # Sort to make the order deterministic 217 # TODO: Make protocol that makes the type var both hashable and sortable once we are on Python 3.8+ 218 last_processed_nodes = sorted(next_nodes) # type: ignore 219 self._sorted.extend(last_processed_nodes) 220 221 return self._sorted
Returns a list of nodes sorted in topological order.
223 def downstream(self, node: T) -> t.List[T]: 224 """Get all nodes that have the input node as an upstream dependency. 225 226 Args: 227 node: The ancestor node. 228 229 Returns: 230 A list of descendant nodes sorted in topological order. 231 """ 232 sorted_nodes = self.sorted 233 try: 234 node_index = sorted_nodes.index(node) 235 except ValueError: 236 return [] 237 238 def visit() -> t.Iterator[T]: 239 """Visit topologically sorted nodes after input node and yield downstream dependants.""" 240 downstream = {node} 241 for current_node in sorted_nodes[node_index + 1 :]: 242 upstream = self._dag.get(current_node, set()) 243 if not upstream.isdisjoint(downstream): 244 downstream.add(current_node) 245 yield current_node 246 247 return list(visit())
Get all nodes that have the input node as an upstream dependency.
Arguments:
- node: The ancestor node.
Returns:
A list of descendant nodes sorted in topological order.
249 def lineage(self, node: T) -> DAG[T]: 250 """Get a dag of the node and its upstream dependencies and downstream dependents. 251 252 Args: 253 node: The node used to determine lineage. 254 255 Returns: 256 A new dag consisting of the dependent and descendant nodes. 257 """ 258 return self.subdag(node, *self.downstream(node))
Get a dag of the node and its upstream dependencies and downstream dependents.
Arguments:
- node: The node used to determine lineage.
Returns:
A new dag consisting of the dependent and descendant nodes.