Edit on GitHub

sqlmesh.core.reference

  1from __future__ import annotations
  2
  3import typing as t
  4from collections import deque
  5
  6from sqlglot import exp
  7
  8from sqlmesh.utils.errors import ConfigError, SQLMeshError
  9from sqlmesh.utils.pydantic import PydanticModel
 10
 11if t.TYPE_CHECKING:
 12    from sqlmesh.core.model import Model
 13
 14
 15class Reference(PydanticModel, frozen=True):
 16    model_name: str
 17    expression: exp.Expr
 18    unique: bool = False
 19    _name: str = ""
 20
 21    @property
 22    def columns(self) -> t.List[str]:
 23        expression = self.expression
 24        if isinstance(expression, exp.Alias):
 25            expression = expression.this
 26        expression = expression.unnest()
 27        if isinstance(expression, (exp.Array, exp.Tuple)):
 28            return [e.output_name for e in expression.expressions]
 29        return [expression.output_name]
 30
 31    @property
 32    def name(self) -> str:
 33        if not self._name:
 34            keys = []
 35
 36            if isinstance(self.expression, (exp.Tuple, exp.Array)):
 37                for e in self.expression.expressions:
 38                    if not e.output_name:
 39                        raise ConfigError(
 40                            f"Reference '{e}' must have an inferrable name or explicit alias."
 41                        )
 42                    keys.append(e.output_name)
 43            elif self.expression.output_name:
 44                keys.append(self.expression.output_name)
 45            else:
 46                raise ConfigError(
 47                    f"Reference '{self.expression}' must have an inferrable name or explicit alias."
 48                )
 49
 50            self._name = "__".join(keys)
 51        return self._name
 52
 53
 54class ReferenceGraph:
 55    def __init__(self, models: t.Iterable[Model]):
 56        self._model_refs: t.Dict[str, t.Dict[str, Reference]] = {}
 57        self._ref_models: t.Dict[str, t.Set[str]] = {}
 58        self._dim_models: t.Dict[str, t.Set[str]] = {}
 59
 60        for model in models:
 61            self.add_model(model)
 62
 63    def add_model(self, model: Model) -> None:
 64        """Add a model and its references to the graph.
 65
 66        Args:
 67            model: the model to add.
 68        """
 69        for column in model.columns_to_types or {}:
 70            self._dim_models.setdefault(column, set())
 71            self._dim_models[column].add(model.name)
 72
 73        for ref in model.all_references:
 74            self._model_refs.setdefault(model.name, {})
 75            self._model_refs[model.name][ref.name] = ref
 76            self._ref_models.setdefault(ref.name, set())
 77            self._ref_models[ref.name].add(model.name)
 78
 79    def models_for_column(self, source: str, column: str, max_depth: int = 3) -> t.List[str]:
 80        """Find all the models with a column that join to a source within max_depth.
 81
 82        Args:
 83            source: The source model.
 84            column: The column to look for.
 85            max_depth: The maximum number of models to join to find a path.
 86
 87        Returns:
 88            The list of models that fit the criteria of the search.
 89        """
 90        models = []
 91
 92        for model in self._dim_models[column]:
 93            try:
 94                if model != source:
 95                    self.find_path(source, model)
 96                models.append(model)
 97            except SQLMeshError:
 98                pass
 99
100        return sorted(models)
101
102    def find_path(self, source: str, target: str, max_depth: int = 3) -> t.List[Reference]:
103        """Find a path from source model to target model with max depth.
104
105        Args:
106            source: The source model.
107            target: The target model.
108            max_depth: The maximum number of models to join to find a path.
109
110        Returns:
111            The list of references representing the join path of source to target.
112        """
113        if source not in self._model_refs:
114            return []
115
116        queue = deque(([ref] for ref in self._model_refs[source].values()))
117
118        while queue:
119            path = queue.popleft()
120            visited = set()
121            many = False
122
123            for ref in path:
124                visited.add(ref.model_name)
125                many = many or not ref.unique
126
127            ref_name = path[-1].name
128
129            for model_name in sorted(self._ref_models[ref_name]):
130                for ref in self._model_refs[model_name].values():
131                    # paths cannot have loops or contain many to many refs
132                    if model_name in visited or (many and not ref.unique):
133                        continue
134
135                    new_path = path + [ref]
136
137                    if model_name == target:
138                        return new_path
139
140                    if len(new_path) < max_depth:
141                        queue.append(new_path)
142
143        raise SQLMeshError(
144            f"Cannot find path between '{source}' and '{target}'. Make sure that references/grains are configured and that a many to many join is not occurring."
145        )
class Reference(sqlmesh.utils.pydantic.PydanticModel):
16class Reference(PydanticModel, frozen=True):
17    model_name: str
18    expression: exp.Expr
19    unique: bool = False
20    _name: str = ""
21
22    @property
23    def columns(self) -> t.List[str]:
24        expression = self.expression
25        if isinstance(expression, exp.Alias):
26            expression = expression.this
27        expression = expression.unnest()
28        if isinstance(expression, (exp.Array, exp.Tuple)):
29            return [e.output_name for e in expression.expressions]
30        return [expression.output_name]
31
32    @property
33    def name(self) -> str:
34        if not self._name:
35            keys = []
36
37            if isinstance(self.expression, (exp.Tuple, exp.Array)):
38                for e in self.expression.expressions:
39                    if not e.output_name:
40                        raise ConfigError(
41                            f"Reference '{e}' must have an inferrable name or explicit alias."
42                        )
43                    keys.append(e.output_name)
44            elif self.expression.output_name:
45                keys.append(self.expression.output_name)
46            else:
47                raise ConfigError(
48                    f"Reference '{self.expression}' must have an inferrable name or explicit alias."
49                )
50
51            self._name = "__".join(keys)
52        return self._name

!!! abstract "Usage Documentation" Models

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
model_name: str
expression: sqlglot.expressions.core.Expr
unique: bool
columns: List[str]
22    @property
23    def columns(self) -> t.List[str]:
24        expression = self.expression
25        if isinstance(expression, exp.Alias):
26            expression = expression.this
27        expression = expression.unnest()
28        if isinstance(expression, (exp.Array, exp.Tuple)):
29            return [e.output_name for e in expression.expressions]
30        return [expression.output_name]
name: str
32    @property
33    def name(self) -> str:
34        if not self._name:
35            keys = []
36
37            if isinstance(self.expression, (exp.Tuple, exp.Array)):
38                for e in self.expression.expressions:
39                    if not e.output_name:
40                        raise ConfigError(
41                            f"Reference '{e}' must have an inferrable name or explicit alias."
42                        )
43                    keys.append(e.output_name)
44            elif self.expression.output_name:
45                keys.append(self.expression.output_name)
46            else:
47                raise ConfigError(
48                    f"Reference '{self.expression}' must have an inferrable name or explicit alias."
49                )
50
51            self._name = "__".join(keys)
52        return self._name
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': (), 'frozen': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
358def init_private_attributes(self: BaseModel, context: Any, /) -> None:
359    """This function is meant to behave like a BaseModel method to initialise private attributes.
360
361    It takes context as an argument since that's what pydantic-core passes when calling it.
362
363    Args:
364        self: The BaseModel instance.
365        context: The context.
366    """
367    if getattr(self, '__pydantic_private__', None) is None:
368        pydantic_private = {}
369        for name, private_attr in self.__private_attributes__.items():
370            default = private_attr.get_default()
371            if default is not PydanticUndefined:
372                pydantic_private[name] = default
373        object_setattr(self, '__pydantic_private__', pydantic_private)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Arguments:
  • self: The BaseModel instance.
  • context: The context.
Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class ReferenceGraph:
 55class ReferenceGraph:
 56    def __init__(self, models: t.Iterable[Model]):
 57        self._model_refs: t.Dict[str, t.Dict[str, Reference]] = {}
 58        self._ref_models: t.Dict[str, t.Set[str]] = {}
 59        self._dim_models: t.Dict[str, t.Set[str]] = {}
 60
 61        for model in models:
 62            self.add_model(model)
 63
 64    def add_model(self, model: Model) -> None:
 65        """Add a model and its references to the graph.
 66
 67        Args:
 68            model: the model to add.
 69        """
 70        for column in model.columns_to_types or {}:
 71            self._dim_models.setdefault(column, set())
 72            self._dim_models[column].add(model.name)
 73
 74        for ref in model.all_references:
 75            self._model_refs.setdefault(model.name, {})
 76            self._model_refs[model.name][ref.name] = ref
 77            self._ref_models.setdefault(ref.name, set())
 78            self._ref_models[ref.name].add(model.name)
 79
 80    def models_for_column(self, source: str, column: str, max_depth: int = 3) -> t.List[str]:
 81        """Find all the models with a column that join to a source within max_depth.
 82
 83        Args:
 84            source: The source model.
 85            column: The column to look for.
 86            max_depth: The maximum number of models to join to find a path.
 87
 88        Returns:
 89            The list of models that fit the criteria of the search.
 90        """
 91        models = []
 92
 93        for model in self._dim_models[column]:
 94            try:
 95                if model != source:
 96                    self.find_path(source, model)
 97                models.append(model)
 98            except SQLMeshError:
 99                pass
100
101        return sorted(models)
102
103    def find_path(self, source: str, target: str, max_depth: int = 3) -> t.List[Reference]:
104        """Find a path from source model to target model with max depth.
105
106        Args:
107            source: The source model.
108            target: The target model.
109            max_depth: The maximum number of models to join to find a path.
110
111        Returns:
112            The list of references representing the join path of source to target.
113        """
114        if source not in self._model_refs:
115            return []
116
117        queue = deque(([ref] for ref in self._model_refs[source].values()))
118
119        while queue:
120            path = queue.popleft()
121            visited = set()
122            many = False
123
124            for ref in path:
125                visited.add(ref.model_name)
126                many = many or not ref.unique
127
128            ref_name = path[-1].name
129
130            for model_name in sorted(self._ref_models[ref_name]):
131                for ref in self._model_refs[model_name].values():
132                    # paths cannot have loops or contain many to many refs
133                    if model_name in visited or (many and not ref.unique):
134                        continue
135
136                    new_path = path + [ref]
137
138                    if model_name == target:
139                        return new_path
140
141                    if len(new_path) < max_depth:
142                        queue.append(new_path)
143
144        raise SQLMeshError(
145            f"Cannot find path between '{source}' and '{target}'. Make sure that references/grains are configured and that a many to many join is not occurring."
146        )
56    def __init__(self, models: t.Iterable[Model]):
57        self._model_refs: t.Dict[str, t.Dict[str, Reference]] = {}
58        self._ref_models: t.Dict[str, t.Set[str]] = {}
59        self._dim_models: t.Dict[str, t.Set[str]] = {}
60
61        for model in models:
62            self.add_model(model)
64    def add_model(self, model: Model) -> None:
65        """Add a model and its references to the graph.
66
67        Args:
68            model: the model to add.
69        """
70        for column in model.columns_to_types or {}:
71            self._dim_models.setdefault(column, set())
72            self._dim_models[column].add(model.name)
73
74        for ref in model.all_references:
75            self._model_refs.setdefault(model.name, {})
76            self._model_refs[model.name][ref.name] = ref
77            self._ref_models.setdefault(ref.name, set())
78            self._ref_models[ref.name].add(model.name)

Add a model and its references to the graph.

Arguments:
  • model: the model to add.
def models_for_column(self, source: str, column: str, max_depth: int = 3) -> List[str]:
 80    def models_for_column(self, source: str, column: str, max_depth: int = 3) -> t.List[str]:
 81        """Find all the models with a column that join to a source within max_depth.
 82
 83        Args:
 84            source: The source model.
 85            column: The column to look for.
 86            max_depth: The maximum number of models to join to find a path.
 87
 88        Returns:
 89            The list of models that fit the criteria of the search.
 90        """
 91        models = []
 92
 93        for model in self._dim_models[column]:
 94            try:
 95                if model != source:
 96                    self.find_path(source, model)
 97                models.append(model)
 98            except SQLMeshError:
 99                pass
100
101        return sorted(models)

Find all the models with a column that join to a source within max_depth.

Arguments:
  • source: The source model.
  • column: The column to look for.
  • max_depth: The maximum number of models to join to find a path.
Returns:

The list of models that fit the criteria of the search.

def find_path( self, source: str, target: str, max_depth: int = 3) -> List[Reference]:
103    def find_path(self, source: str, target: str, max_depth: int = 3) -> t.List[Reference]:
104        """Find a path from source model to target model with max depth.
105
106        Args:
107            source: The source model.
108            target: The target model.
109            max_depth: The maximum number of models to join to find a path.
110
111        Returns:
112            The list of references representing the join path of source to target.
113        """
114        if source not in self._model_refs:
115            return []
116
117        queue = deque(([ref] for ref in self._model_refs[source].values()))
118
119        while queue:
120            path = queue.popleft()
121            visited = set()
122            many = False
123
124            for ref in path:
125                visited.add(ref.model_name)
126                many = many or not ref.unique
127
128            ref_name = path[-1].name
129
130            for model_name in sorted(self._ref_models[ref_name]):
131                for ref in self._model_refs[model_name].values():
132                    # paths cannot have loops or contain many to many refs
133                    if model_name in visited or (many and not ref.unique):
134                        continue
135
136                    new_path = path + [ref]
137
138                    if model_name == target:
139                        return new_path
140
141                    if len(new_path) < max_depth:
142                        queue.append(new_path)
143
144        raise SQLMeshError(
145            f"Cannot find path between '{source}' and '{target}'. Make sure that references/grains are configured and that a many to many join is not occurring."
146        )

Find a path from source model to target model with max depth.

Arguments:
  • source: The source model.
  • target: The target model.
  • max_depth: The maximum number of models to join to find a path.
Returns:

The list of references representing the join path of source to target.