Edit on GitHub

sqlmesh.core.reference

  1from __future__ import annotations
  2
  3import typing as t
  4from collections import deque
  5
  6from sqlglot import exp
  7
  8from sqlmesh.utils.errors import ConfigError, SQLMeshError
  9from sqlmesh.utils.pydantic import PydanticModel
 10
 11if t.TYPE_CHECKING:
 12    from sqlmesh.core.model import Model
 13
 14
 15class Reference(PydanticModel, frozen=True):
 16    model_name: str
 17    expression: exp.Expression
 18    unique: bool = False
 19    _name: str = ""
 20
 21    @property
 22    def columns(self) -> t.List[str]:
 23        expression = self.expression
 24        if isinstance(expression, exp.Alias):
 25            expression = expression.this
 26        expression = expression.unnest()
 27        if isinstance(expression, (exp.Array, exp.Tuple)):
 28            return [e.output_name for e in expression.expressions]
 29        return [expression.output_name]
 30
 31    @property
 32    def name(self) -> str:
 33        if not self._name:
 34            keys = []
 35
 36            if isinstance(self.expression, (exp.Tuple, exp.Array)):
 37                for e in self.expression.expressions:
 38                    if not e.output_name:
 39                        raise ConfigError(
 40                            f"Reference '{e}' must have an inferrable name or explicit alias."
 41                        )
 42                    keys.append(e.output_name)
 43            elif self.expression.output_name:
 44                keys.append(self.expression.output_name)
 45            else:
 46                raise ConfigError(
 47                    f"Reference '{self.expression}' must have an inferrable name or explicit alias."
 48                )
 49
 50            self._name = "__".join(keys)
 51        return self._name
 52
 53
 54class ReferenceGraph:
 55    def __init__(self, models: t.Iterable[Model]):
 56        self._model_refs: t.Dict[str, t.Dict[str, Reference]] = {}
 57        self._ref_models: t.Dict[str, t.Set[str]] = {}
 58        self._dim_models: t.Dict[str, t.Set[str]] = {}
 59
 60        for model in models:
 61            self.add_model(model)
 62
 63    def add_model(self, model: Model) -> None:
 64        """Add a model and its references to the graph.
 65
 66        Args:
 67            model: the model to add.
 68        """
 69        for column in model.columns_to_types or {}:
 70            self._dim_models.setdefault(column, set())
 71            self._dim_models[column].add(model.name)
 72
 73        for ref in model.all_references:
 74            self._model_refs.setdefault(model.name, {})
 75            self._model_refs[model.name][ref.name] = ref
 76            self._ref_models.setdefault(ref.name, set())
 77            self._ref_models[ref.name].add(model.name)
 78
 79    def models_for_column(self, source: str, column: str, max_depth: int = 3) -> t.List[str]:
 80        """Find all the models with a column that join to a source within max_depth.
 81
 82        Args:
 83            source: The source model.
 84            column: The column to look for.
 85            max_depth: The maximum number of models to join to find a path.
 86
 87        Returns:
 88            The list of models that fit the criteria of the search.
 89        """
 90        models = []
 91
 92        for model in self._dim_models[column]:
 93            try:
 94                if model != source:
 95                    self.find_path(source, model)
 96                models.append(model)
 97            except SQLMeshError:
 98                pass
 99
100        return sorted(models)
101
102    def find_path(self, source: str, target: str, max_depth: int = 3) -> t.List[Reference]:
103        """Find a path from source model to target model with max depth.
104
105        Args:
106            source: The source model.
107            target: The target model.
108            max_depth: The maximum number of models to join to find a path.
109
110        Returns:
111            The list of references representing the join path of source to target.
112        """
113        if source not in self._model_refs:
114            return []
115
116        queue = deque(([ref] for ref in self._model_refs[source].values()))
117
118        while queue:
119            path = queue.popleft()
120            visited = set()
121            many = False
122
123            for ref in path:
124                visited.add(ref.model_name)
125                many = many or not ref.unique
126
127            ref_name = path[-1].name
128
129            for model_name in sorted(self._ref_models[ref_name]):
130                for ref in self._model_refs[model_name].values():
131                    # paths cannot have loops or contain many to many refs
132                    if model_name in visited or (many and not ref.unique):
133                        continue
134
135                    new_path = path + [ref]
136
137                    if model_name == target:
138                        return new_path
139
140                    if len(new_path) < max_depth:
141                        queue.append(new_path)
142
143        raise SQLMeshError(
144            f"Cannot find path between '{source}' and '{target}'. Make sure that references/grains are configured and that a many to many join is not occurring."
145        )
class Reference(sqlmesh.utils.pydantic.PydanticModel):
16class Reference(PydanticModel, frozen=True):
17    model_name: str
18    expression: exp.Expression
19    unique: bool = False
20    _name: str = ""
21
22    @property
23    def columns(self) -> t.List[str]:
24        expression = self.expression
25        if isinstance(expression, exp.Alias):
26            expression = expression.this
27        expression = expression.unnest()
28        if isinstance(expression, (exp.Array, exp.Tuple)):
29            return [e.output_name for e in expression.expressions]
30        return [expression.output_name]
31
32    @property
33    def name(self) -> str:
34        if not self._name:
35            keys = []
36
37            if isinstance(self.expression, (exp.Tuple, exp.Array)):
38                for e in self.expression.expressions:
39                    if not e.output_name:
40                        raise ConfigError(
41                            f"Reference '{e}' must have an inferrable name or explicit alias."
42                        )
43                    keys.append(e.output_name)
44            elif self.expression.output_name:
45                keys.append(self.expression.output_name)
46            else:
47                raise ConfigError(
48                    f"Reference '{self.expression}' must have an inferrable name or explicit alias."
49                )
50
51            self._name = "__".join(keys)
52        return self._name

Usage docs: https://docs.pydantic.dev/2.7/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of classvars defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The signature for instantiating the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a RootModel.
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_extra__: An instance attribute with the values of extra fields from validation when model_config['extra'] == 'allow'.
  • __pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
  • __pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
def model_post_init(self: pydantic.main.BaseModel, _ModelMetaclass__context: Any) -> None:
102                    def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
103                        """We need to both initialize private attributes and call the user-defined model_post_init
104                        method.
105                        """
106                        init_private_attributes(self, __context)
107                        original_model_post_init(self, __context)

Override this method to perform additional initialization after __init__ and model_construct. This is useful if you want to do some validation that requires the entire model to be initialized.

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class ReferenceGraph:
 55class ReferenceGraph:
 56    def __init__(self, models: t.Iterable[Model]):
 57        self._model_refs: t.Dict[str, t.Dict[str, Reference]] = {}
 58        self._ref_models: t.Dict[str, t.Set[str]] = {}
 59        self._dim_models: t.Dict[str, t.Set[str]] = {}
 60
 61        for model in models:
 62            self.add_model(model)
 63
 64    def add_model(self, model: Model) -> None:
 65        """Add a model and its references to the graph.
 66
 67        Args:
 68            model: the model to add.
 69        """
 70        for column in model.columns_to_types or {}:
 71            self._dim_models.setdefault(column, set())
 72            self._dim_models[column].add(model.name)
 73
 74        for ref in model.all_references:
 75            self._model_refs.setdefault(model.name, {})
 76            self._model_refs[model.name][ref.name] = ref
 77            self._ref_models.setdefault(ref.name, set())
 78            self._ref_models[ref.name].add(model.name)
 79
 80    def models_for_column(self, source: str, column: str, max_depth: int = 3) -> t.List[str]:
 81        """Find all the models with a column that join to a source within max_depth.
 82
 83        Args:
 84            source: The source model.
 85            column: The column to look for.
 86            max_depth: The maximum number of models to join to find a path.
 87
 88        Returns:
 89            The list of models that fit the criteria of the search.
 90        """
 91        models = []
 92
 93        for model in self._dim_models[column]:
 94            try:
 95                if model != source:
 96                    self.find_path(source, model)
 97                models.append(model)
 98            except SQLMeshError:
 99                pass
100
101        return sorted(models)
102
103    def find_path(self, source: str, target: str, max_depth: int = 3) -> t.List[Reference]:
104        """Find a path from source model to target model with max depth.
105
106        Args:
107            source: The source model.
108            target: The target model.
109            max_depth: The maximum number of models to join to find a path.
110
111        Returns:
112            The list of references representing the join path of source to target.
113        """
114        if source not in self._model_refs:
115            return []
116
117        queue = deque(([ref] for ref in self._model_refs[source].values()))
118
119        while queue:
120            path = queue.popleft()
121            visited = set()
122            many = False
123
124            for ref in path:
125                visited.add(ref.model_name)
126                many = many or not ref.unique
127
128            ref_name = path[-1].name
129
130            for model_name in sorted(self._ref_models[ref_name]):
131                for ref in self._model_refs[model_name].values():
132                    # paths cannot have loops or contain many to many refs
133                    if model_name in visited or (many and not ref.unique):
134                        continue
135
136                    new_path = path + [ref]
137
138                    if model_name == target:
139                        return new_path
140
141                    if len(new_path) < max_depth:
142                        queue.append(new_path)
143
144        raise SQLMeshError(
145            f"Cannot find path between '{source}' and '{target}'. Make sure that references/grains are configured and that a many to many join is not occurring."
146        )
56    def __init__(self, models: t.Iterable[Model]):
57        self._model_refs: t.Dict[str, t.Dict[str, Reference]] = {}
58        self._ref_models: t.Dict[str, t.Set[str]] = {}
59        self._dim_models: t.Dict[str, t.Set[str]] = {}
60
61        for model in models:
62            self.add_model(model)
64    def add_model(self, model: Model) -> None:
65        """Add a model and its references to the graph.
66
67        Args:
68            model: the model to add.
69        """
70        for column in model.columns_to_types or {}:
71            self._dim_models.setdefault(column, set())
72            self._dim_models[column].add(model.name)
73
74        for ref in model.all_references:
75            self._model_refs.setdefault(model.name, {})
76            self._model_refs[model.name][ref.name] = ref
77            self._ref_models.setdefault(ref.name, set())
78            self._ref_models[ref.name].add(model.name)

Add a model and its references to the graph.

Arguments:
  • model: the model to add.
def models_for_column(self, source: str, column: str, max_depth: int = 3) -> List[str]:
 80    def models_for_column(self, source: str, column: str, max_depth: int = 3) -> t.List[str]:
 81        """Find all the models with a column that join to a source within max_depth.
 82
 83        Args:
 84            source: The source model.
 85            column: The column to look for.
 86            max_depth: The maximum number of models to join to find a path.
 87
 88        Returns:
 89            The list of models that fit the criteria of the search.
 90        """
 91        models = []
 92
 93        for model in self._dim_models[column]:
 94            try:
 95                if model != source:
 96                    self.find_path(source, model)
 97                models.append(model)
 98            except SQLMeshError:
 99                pass
100
101        return sorted(models)

Find all the models with a column that join to a source within max_depth.

Arguments:
  • source: The source model.
  • column: The column to look for.
  • max_depth: The maximum number of models to join to find a path.
Returns:

The list of models that fit the criteria of the search.

def find_path( self, source: str, target: str, max_depth: int = 3) -> List[sqlmesh.core.reference.Reference]:
103    def find_path(self, source: str, target: str, max_depth: int = 3) -> t.List[Reference]:
104        """Find a path from source model to target model with max depth.
105
106        Args:
107            source: The source model.
108            target: The target model.
109            max_depth: The maximum number of models to join to find a path.
110
111        Returns:
112            The list of references representing the join path of source to target.
113        """
114        if source not in self._model_refs:
115            return []
116
117        queue = deque(([ref] for ref in self._model_refs[source].values()))
118
119        while queue:
120            path = queue.popleft()
121            visited = set()
122            many = False
123
124            for ref in path:
125                visited.add(ref.model_name)
126                many = many or not ref.unique
127
128            ref_name = path[-1].name
129
130            for model_name in sorted(self._ref_models[ref_name]):
131                for ref in self._model_refs[model_name].values():
132                    # paths cannot have loops or contain many to many refs
133                    if model_name in visited or (many and not ref.unique):
134                        continue
135
136                    new_path = path + [ref]
137
138                    if model_name == target:
139                        return new_path
140
141                    if len(new_path) < max_depth:
142                        queue.append(new_path)
143
144        raise SQLMeshError(
145            f"Cannot find path between '{source}' and '{target}'. Make sure that references/grains are configured and that a many to many join is not occurring."
146        )

Find a path from source model to target model with max depth.

Arguments:
  • source: The source model.
  • target: The target model.
  • max_depth: The maximum number of models to join to find a path.
Returns:

The list of references representing the join path of source to target.