Edit on GitHub

sqlmesh.core.loader

  1from __future__ import annotations
  2
  3import abc
  4import linecache
  5import logging
  6import os
  7import typing as t
  8from collections import defaultdict
  9from dataclasses import dataclass
 10from pathlib import Path
 11
 12from sqlglot.errors import SchemaError, SqlglotError
 13from sqlglot.schema import MappingSchema
 14
 15from sqlmesh.core import constants as c
 16from sqlmesh.core.audit import Audit, load_multiple_audits
 17from sqlmesh.core.dialect import parse
 18from sqlmesh.core.macros import MacroRegistry, macro
 19from sqlmesh.core.metric import Metric, MetricMeta, expand_metrics, load_metric_ddl
 20from sqlmesh.core.model import (
 21    Model,
 22    ModelCache,
 23    OptimizedQueryCache,
 24    SeedModel,
 25    create_external_model,
 26    load_sql_based_model,
 27)
 28from sqlmesh.core.model import model as model_registry
 29from sqlmesh.utils import UniqueKeyDict
 30from sqlmesh.utils.dag import DAG
 31from sqlmesh.utils.errors import ConfigError
 32from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
 33from sqlmesh.utils.metaprogramming import import_python_file
 34from sqlmesh.utils.yaml import YAML
 35
 36if t.TYPE_CHECKING:
 37    from sqlmesh.core.config import Config
 38    from sqlmesh.core.context import GenericContext
 39
 40
 41logger = logging.getLogger(__name__)
 42
 43
 44# TODO: consider moving this to context
 45def update_model_schemas(
 46    dag: DAG[str],
 47    models: UniqueKeyDict[str, Model],
 48    context_path: Path,
 49) -> None:
 50    schema = MappingSchema(normalize=False)
 51    optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE)
 52
 53    for name in dag.sorted:
 54        model = models.get(name)
 55
 56        # External models don't exist in the context, so we need to skip them
 57        if not model:
 58            continue
 59
 60        try:
 61            model.update_schema(schema)
 62            optimized_query_cache.with_optimized_query(model)
 63
 64            columns_to_types = model.columns_to_types
 65            if columns_to_types is not None:
 66                schema.add_table(
 67                    model.fqn, columns_to_types, dialect=model.dialect, normalize=False
 68                )
 69        except SchemaError as e:
 70            if "nesting level:" in str(e):
 71                logger.error(
 72                    "SQLMesh requires all model names and references to have the same level of nesting."
 73                )
 74            raise
 75
 76
 77@dataclass
 78class LoadedProject:
 79    macros: MacroRegistry
 80    jinja_macros: JinjaMacroRegistry
 81    models: UniqueKeyDict[str, Model]
 82    audits: UniqueKeyDict[str, Audit]
 83    metrics: UniqueKeyDict[str, Metric]
 84    dag: DAG[str]
 85
 86
 87class Loader(abc.ABC):
 88    """Abstract base class to load macros and models for a context"""
 89
 90    def __init__(self) -> None:
 91        self._path_mtimes: t.Dict[Path, float] = {}
 92        self._dag: DAG[str] = DAG()
 93
 94    def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedProject:
 95        """
 96        Loads all macros and models in the context's path.
 97
 98        Args:
 99            context: The context to load macros and models for.
100            update_schemas: Convert star projections to explicit columns.
101        """
102        # python files are cached by the system
103        # need to manually clear here so we can reload macros
104        linecache.clearcache()
105
106        self._context = context
107        self._path_mtimes.clear()
108        self._dag = DAG()
109
110        config_mtimes: t.Dict[Path, t.List[float]] = defaultdict(list)
111        for context_path, config in self._context.configs.items():
112            for config_file in context_path.glob("config.*"):
113                self._track_file(config_file)
114                config_mtimes[context_path].append(self._path_mtimes[config_file])
115
116        for config_file in c.SQLMESH_PATH.glob("config.*"):
117            self._track_file(config_file)
118            config_mtimes[c.SQLMESH_PATH].append(self._path_mtimes[config_file])
119
120        self._config_mtimes = {path: max(mtimes) for path, mtimes in config_mtimes.items()}
121
122        macros, jinja_macros = self._load_scripts()
123        models = self._load_models(macros, jinja_macros)
124
125        for model in models.values():
126            self._add_model_to_dag(model)
127
128        if update_schemas:
129            update_model_schemas(
130                self._dag,
131                models,
132                self._context.path,
133            )
134            for model in models.values():
135                # The model definition can be validated correctly only after the schema is set.
136                model.validate_definition()
137
138        metrics = self._load_metrics()
139
140        project = LoadedProject(
141            macros=macros,
142            jinja_macros=jinja_macros,
143            models=models,
144            audits=self._load_audits(macros=macros, jinja_macros=jinja_macros),
145            metrics=expand_metrics(metrics),
146            dag=self._dag,
147        )
148        return project
149
150    def reload_needed(self) -> bool:
151        """
152        Checks for any modifications to the files the macros and models depend on
153        since the last load.
154
155        Returns:
156            True if a modification is found; False otherwise
157        """
158        return any(
159            not path.exists() or path.stat().st_mtime > initial_mtime
160            for path, initial_mtime in self._path_mtimes.items()
161        )
162
163    @abc.abstractmethod
164    def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
165        """Loads all user defined macros."""
166
167    @abc.abstractmethod
168    def _load_models(
169        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
170    ) -> UniqueKeyDict[str, Model]:
171        """Loads all models."""
172
173    @abc.abstractmethod
174    def _load_audits(
175        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
176    ) -> UniqueKeyDict[str, Audit]:
177        """Loads all audits."""
178
179    def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
180        return UniqueKeyDict("metrics")
181
182    def _load_external_models(self) -> UniqueKeyDict[str, Model]:
183        models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
184        for context_path, config in self._context.configs.items():
185            schema_path = Path(context_path / c.SCHEMA_YAML)
186            external_models_path = context_path / c.EXTERNAL_MODELS
187
188            paths_to_load = []
189            if schema_path.exists():
190                paths_to_load.append(schema_path)
191
192            if external_models_path.exists() and external_models_path.is_dir():
193                paths_to_load.extend(external_models_path.glob("*.yaml"))
194
195            for path in paths_to_load:
196                self._track_file(path)
197
198                with open(path, "r", encoding="utf-8") as file:
199                    for row in YAML().load(file.read()):
200                        model = create_external_model(
201                            **row,
202                            dialect=config.model_defaults.dialect,
203                            path=path,
204                            project=config.project,
205                            default_catalog=self._context.default_catalog,
206                        )
207                        models[model.fqn] = model
208        return models
209
210    def _add_model_to_dag(self, model: Model) -> None:
211        self._dag.add(model.fqn, model.depends_on)
212
213    def _track_file(self, path: Path) -> None:
214        """Project file to track for modifications"""
215        self._path_mtimes[path] = path.stat().st_mtime
216
217
218class SqlMeshLoader(Loader):
219    """Loads macros and models for a context using the SQLMesh file formats"""
220
221    def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
222        """Loads all user defined macros."""
223        # Store a copy of the macro registry
224        standard_macros = macro.get_registry()
225        jinja_macros = JinjaMacroRegistry()
226        extractor = MacroExtractor()
227
228        macros_max_mtime: t.Optional[float] = None
229
230        for context_path, config in self._context.configs.items():
231            for path in self._glob_paths(context_path / c.MACROS, config=config, extension=".py"):
232                if import_python_file(path, context_path):
233                    self._track_file(path)
234                    macro_file_mtime = self._path_mtimes[path]
235                    macros_max_mtime = (
236                        max(macros_max_mtime, macro_file_mtime)
237                        if macros_max_mtime
238                        else macro_file_mtime
239                    )
240
241            for path in self._glob_paths(context_path / c.MACROS, config=config, extension=".sql"):
242                self._track_file(path)
243                macro_file_mtime = self._path_mtimes[path]
244                macros_max_mtime = (
245                    max(macros_max_mtime, macro_file_mtime)
246                    if macros_max_mtime
247                    else macro_file_mtime
248                )
249                with open(path, "r", encoding="utf-8") as file:
250                    jinja_macros.add_macros(extractor.extract(file.read()))
251
252        self._macros_max_mtime = macros_max_mtime
253
254        macros = macro.get_registry()
255        macro.set_registry(standard_macros)
256
257        return macros, jinja_macros
258
259    def _load_models(
260        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
261    ) -> UniqueKeyDict[str, Model]:
262        """
263        Loads all of the models within the model directory with their associated
264        audits into a Dict and creates the dag
265        """
266        models = self._load_sql_models(macros, jinja_macros)
267        models.update(self._load_external_models())
268        models.update(self._load_python_models())
269
270        return models
271
272    def _load_sql_models(
273        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
274    ) -> UniqueKeyDict[str, Model]:
275        """Loads the sql models into a Dict"""
276        models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
277        for context_path, config in self._context.configs.items():
278            cache = SqlMeshLoader._Cache(self, context_path)
279            variables = self._variables(config)
280
281            for path in self._glob_paths(context_path / c.MODELS, config=config, extension=".sql"):
282                if not os.path.getsize(path):
283                    continue
284
285                self._track_file(path)
286
287                def _load() -> Model:
288                    with open(path, "r", encoding="utf-8") as file:
289                        try:
290                            expressions = parse(
291                                file.read(), default_dialect=config.model_defaults.dialect
292                            )
293                        except SqlglotError as ex:
294                            raise ConfigError(
295                                f"Failed to parse a model definition at '{path}': {ex}."
296                            )
297
298                    return load_sql_based_model(
299                        expressions,
300                        defaults=config.model_defaults.dict(),
301                        macros=macros,
302                        jinja_macros=jinja_macros,
303                        path=Path(path).absolute(),
304                        module_path=context_path,
305                        dialect=config.model_defaults.dialect,
306                        time_column_format=config.time_column_format,
307                        physical_schema_override=config.physical_schema_override,
308                        project=config.project,
309                        default_catalog=self._context.default_catalog,
310                        variables=variables,
311                    )
312
313                model = cache.get_or_load_model(path, _load)
314                models[model.fqn] = model
315
316                if isinstance(model, SeedModel):
317                    seed_path = model.seed_path
318                    self._track_file(seed_path)
319
320        return models
321
322    def _load_python_models(self) -> UniqueKeyDict[str, Model]:
323        """Loads the python models into a Dict"""
324        models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
325        registry = model_registry.registry()
326        registry.clear()
327        registered: t.Set[str] = set()
328
329        for context_path, config in self._context.configs.items():
330            variables = self._variables(config)
331            model_registry._dialect = config.model_defaults.dialect
332            try:
333                for path in self._glob_paths(
334                    context_path / c.MODELS, config=config, extension=".py"
335                ):
336                    if not os.path.getsize(path):
337                        continue
338
339                    self._track_file(path)
340                    import_python_file(path, context_path)
341                    new = registry.keys() - registered
342                    registered |= new
343                    for name in new:
344                        model = registry[name].model(
345                            path=path,
346                            module_path=context_path,
347                            defaults=config.model_defaults.dict(),
348                            dialect=config.model_defaults.dialect,
349                            time_column_format=config.time_column_format,
350                            physical_schema_override=config.physical_schema_override,
351                            project=config.project,
352                            default_catalog=self._context.default_catalog,
353                            variables=variables,
354                        )
355                        models[model.fqn] = model
356            finally:
357                model_registry._dialect = None
358
359        return models
360
361    def _load_audits(
362        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
363    ) -> UniqueKeyDict[str, Audit]:
364        """Loads all the model audits."""
365        audits_by_name: UniqueKeyDict[str, Audit] = UniqueKeyDict("audits")
366        for context_path, config in self._context.configs.items():
367            variables = self._variables(config)
368            for path in self._glob_paths(context_path / c.AUDITS, config=config, extension=".sql"):
369                self._track_file(path)
370                with open(path, "r", encoding="utf-8") as file:
371                    expressions = parse(file.read(), default_dialect=config.model_defaults.dialect)
372                    audits = load_multiple_audits(
373                        expressions=expressions,
374                        path=path,
375                        module_path=context_path,
376                        macros=macros,
377                        jinja_macros=jinja_macros,
378                        dialect=config.model_defaults.dialect,
379                        default_catalog=self._context.default_catalog,
380                        variables=variables,
381                    )
382                    for audit in audits:
383                        audits_by_name[audit.name] = audit
384        return audits_by_name
385
386    def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
387        """Loads all metrics."""
388        metrics: UniqueKeyDict[str, MetricMeta] = UniqueKeyDict("metrics")
389
390        for context_path, config in self._context.configs.items():
391            for path in self._glob_paths(context_path / c.METRICS, config=config, extension=".sql"):
392                if not os.path.getsize(path):
393                    continue
394                self._track_file(path)
395
396                with open(path, "r", encoding="utf-8") as file:
397                    dialect = config.model_defaults.dialect
398                    try:
399                        for expression in parse(file.read(), default_dialect=dialect):
400                            metric = load_metric_ddl(expression, path=path, dialect=dialect)
401                            metrics[metric.name] = metric
402                    except SqlglotError as ex:
403                        raise ConfigError(f"Failed to parse metric definitions at '{path}': {ex}.")
404
405        return metrics
406
407    def _glob_paths(
408        self, path: Path, config: Config, extension: str
409    ) -> t.Generator[Path, None, None]:
410        """
411        Globs the provided path for the file extension but also removes any filepaths that match an ignore
412        pattern either set in constants or provided in config
413
414        Args:
415            path: The filepath to glob
416            extension: The extension to check for in that path (checks recursively in zero or more subdirectories)
417
418        Returns:
419            Matched paths that are not ignored
420        """
421        for filepath in path.glob(f"**/*{extension}"):
422            for ignore_pattern in config.ignore_patterns:
423                if filepath.match(ignore_pattern):
424                    break
425            else:
426                yield filepath
427
428    def _variables(self, config: Config) -> t.Dict[str, t.Any]:
429        gateway_name = self._context.gateway or self._context.config.default_gateway_name
430        try:
431            gateway = config.get_gateway(gateway_name)
432        except ConfigError:
433            logger.warning("Gateway '%s' not found in project '%s'", gateway_name, config.project)
434            gateway = None
435        return {
436            **config.variables,
437            **(gateway.variables if gateway else {}),
438            c.GATEWAY: gateway_name,
439        }
440
441    class _Cache:
442        def __init__(self, loader: SqlMeshLoader, context_path: Path):
443            self._loader = loader
444            self._context_path = context_path
445            self._model_cache = ModelCache(self._context_path / c.CACHE)
446
447        def get_or_load_model(self, target_path: Path, loader: t.Callable[[], Model]) -> Model:
448            model = self._model_cache.get_or_load(
449                self._cache_entry_name(target_path),
450                self._model_cache_entry_id(target_path),
451                loader=loader,
452            )
453            model._path = target_path
454            return model
455
456        def _cache_entry_name(self, target_path: Path) -> str:
457            return "__".join(target_path.relative_to(self._context_path).parts).replace(
458                target_path.suffix, ""
459            )
460
461        def _model_cache_entry_id(self, model_path: Path) -> str:
462            mtimes = [
463                self._loader._path_mtimes[model_path],
464                self._loader._macros_max_mtime,
465                self._loader._config_mtimes.get(self._context_path),
466                self._loader._config_mtimes.get(c.SQLMESH_PATH),
467            ]
468            return "__".join(
469                [
470                    str(max(m for m in mtimes if m is not None)),
471                    self._loader._context.config.fingerprint,
472                    # We need to check default catalog since the provided config could not change but the
473                    # gateway we are using could change, therefore potentially changing the default catalog
474                    # which would then invalidate the cached model definition.
475                    self._loader._context.default_catalog or "",
476                ]
477            )
46def update_model_schemas(
47    dag: DAG[str],
48    models: UniqueKeyDict[str, Model],
49    context_path: Path,
50) -> None:
51    schema = MappingSchema(normalize=False)
52    optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE)
53
54    for name in dag.sorted:
55        model = models.get(name)
56
57        # External models don't exist in the context, so we need to skip them
58        if not model:
59            continue
60
61        try:
62            model.update_schema(schema)
63            optimized_query_cache.with_optimized_query(model)
64
65            columns_to_types = model.columns_to_types
66            if columns_to_types is not None:
67                schema.add_table(
68                    model.fqn, columns_to_types, dialect=model.dialect, normalize=False
69                )
70        except SchemaError as e:
71            if "nesting level:" in str(e):
72                logger.error(
73                    "SQLMesh requires all model names and references to have the same level of nesting."
74                )
75            raise
class LoadedProject:
79class LoadedProject:
80    macros: MacroRegistry
81    jinja_macros: JinjaMacroRegistry
82    models: UniqueKeyDict[str, Model]
83    audits: UniqueKeyDict[str, Audit]
84    metrics: UniqueKeyDict[str, Metric]
85    dag: DAG[str]
class Loader(abc.ABC):
 88class Loader(abc.ABC):
 89    """Abstract base class to load macros and models for a context"""
 90
 91    def __init__(self) -> None:
 92        self._path_mtimes: t.Dict[Path, float] = {}
 93        self._dag: DAG[str] = DAG()
 94
 95    def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedProject:
 96        """
 97        Loads all macros and models in the context's path.
 98
 99        Args:
100            context: The context to load macros and models for.
101            update_schemas: Convert star projections to explicit columns.
102        """
103        # python files are cached by the system
104        # need to manually clear here so we can reload macros
105        linecache.clearcache()
106
107        self._context = context
108        self._path_mtimes.clear()
109        self._dag = DAG()
110
111        config_mtimes: t.Dict[Path, t.List[float]] = defaultdict(list)
112        for context_path, config in self._context.configs.items():
113            for config_file in context_path.glob("config.*"):
114                self._track_file(config_file)
115                config_mtimes[context_path].append(self._path_mtimes[config_file])
116
117        for config_file in c.SQLMESH_PATH.glob("config.*"):
118            self._track_file(config_file)
119            config_mtimes[c.SQLMESH_PATH].append(self._path_mtimes[config_file])
120
121        self._config_mtimes = {path: max(mtimes) for path, mtimes in config_mtimes.items()}
122
123        macros, jinja_macros = self._load_scripts()
124        models = self._load_models(macros, jinja_macros)
125
126        for model in models.values():
127            self._add_model_to_dag(model)
128
129        if update_schemas:
130            update_model_schemas(
131                self._dag,
132                models,
133                self._context.path,
134            )
135            for model in models.values():
136                # The model definition can be validated correctly only after the schema is set.
137                model.validate_definition()
138
139        metrics = self._load_metrics()
140
141        project = LoadedProject(
142            macros=macros,
143            jinja_macros=jinja_macros,
144            models=models,
145            audits=self._load_audits(macros=macros, jinja_macros=jinja_macros),
146            metrics=expand_metrics(metrics),
147            dag=self._dag,
148        )
149        return project
150
151    def reload_needed(self) -> bool:
152        """
153        Checks for any modifications to the files the macros and models depend on
154        since the last load.
155
156        Returns:
157            True if a modification is found; False otherwise
158        """
159        return any(
160            not path.exists() or path.stat().st_mtime > initial_mtime
161            for path, initial_mtime in self._path_mtimes.items()
162        )
163
164    @abc.abstractmethod
165    def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
166        """Loads all user defined macros."""
167
168    @abc.abstractmethod
169    def _load_models(
170        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
171    ) -> UniqueKeyDict[str, Model]:
172        """Loads all models."""
173
174    @abc.abstractmethod
175    def _load_audits(
176        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
177    ) -> UniqueKeyDict[str, Audit]:
178        """Loads all audits."""
179
180    def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
181        return UniqueKeyDict("metrics")
182
183    def _load_external_models(self) -> UniqueKeyDict[str, Model]:
184        models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
185        for context_path, config in self._context.configs.items():
186            schema_path = Path(context_path / c.SCHEMA_YAML)
187            external_models_path = context_path / c.EXTERNAL_MODELS
188
189            paths_to_load = []
190            if schema_path.exists():
191                paths_to_load.append(schema_path)
192
193            if external_models_path.exists() and external_models_path.is_dir():
194                paths_to_load.extend(external_models_path.glob("*.yaml"))
195
196            for path in paths_to_load:
197                self._track_file(path)
198
199                with open(path, "r", encoding="utf-8") as file:
200                    for row in YAML().load(file.read()):
201                        model = create_external_model(
202                            **row,
203                            dialect=config.model_defaults.dialect,
204                            path=path,
205                            project=config.project,
206                            default_catalog=self._context.default_catalog,
207                        )
208                        models[model.fqn] = model
209        return models
210
211    def _add_model_to_dag(self, model: Model) -> None:
212        self._dag.add(model.fqn, model.depends_on)
213
214    def _track_file(self, path: Path) -> None:
215        """Project file to track for modifications"""
216        self._path_mtimes[path] = path.stat().st_mtime

Abstract base class to load macros and models for a context

def load( self, context: sqlmesh.core.context.GenericContext, update_schemas: bool = True) -> sqlmesh.core.loader.LoadedProject:
 95    def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedProject:
 96        """
 97        Loads all macros and models in the context's path.
 98
 99        Args:
100            context: The context to load macros and models for.
101            update_schemas: Convert star projections to explicit columns.
102        """
103        # python files are cached by the system
104        # need to manually clear here so we can reload macros
105        linecache.clearcache()
106
107        self._context = context
108        self._path_mtimes.clear()
109        self._dag = DAG()
110
111        config_mtimes: t.Dict[Path, t.List[float]] = defaultdict(list)
112        for context_path, config in self._context.configs.items():
113            for config_file in context_path.glob("config.*"):
114                self._track_file(config_file)
115                config_mtimes[context_path].append(self._path_mtimes[config_file])
116
117        for config_file in c.SQLMESH_PATH.glob("config.*"):
118            self._track_file(config_file)
119            config_mtimes[c.SQLMESH_PATH].append(self._path_mtimes[config_file])
120
121        self._config_mtimes = {path: max(mtimes) for path, mtimes in config_mtimes.items()}
122
123        macros, jinja_macros = self._load_scripts()
124        models = self._load_models(macros, jinja_macros)
125
126        for model in models.values():
127            self._add_model_to_dag(model)
128
129        if update_schemas:
130            update_model_schemas(
131                self._dag,
132                models,
133                self._context.path,
134            )
135            for model in models.values():
136                # The model definition can be validated correctly only after the schema is set.
137                model.validate_definition()
138
139        metrics = self._load_metrics()
140
141        project = LoadedProject(
142            macros=macros,
143            jinja_macros=jinja_macros,
144            models=models,
145            audits=self._load_audits(macros=macros, jinja_macros=jinja_macros),
146            metrics=expand_metrics(metrics),
147            dag=self._dag,
148        )
149        return project

Loads all macros and models in the context's path.

Arguments:
  • context: The context to load macros and models for.
  • update_schemas: Convert star projections to explicit columns.
def reload_needed(self) -> bool:
151    def reload_needed(self) -> bool:
152        """
153        Checks for any modifications to the files the macros and models depend on
154        since the last load.
155
156        Returns:
157            True if a modification is found; False otherwise
158        """
159        return any(
160            not path.exists() or path.stat().st_mtime > initial_mtime
161            for path, initial_mtime in self._path_mtimes.items()
162        )

Checks for any modifications to the files the macros and models depend on since the last load.

Returns:

True if a modification is found; False otherwise

class SqlMeshLoader(Loader):
219class SqlMeshLoader(Loader):
220    """Loads macros and models for a context using the SQLMesh file formats"""
221
222    def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
223        """Loads all user defined macros."""
224        # Store a copy of the macro registry
225        standard_macros = macro.get_registry()
226        jinja_macros = JinjaMacroRegistry()
227        extractor = MacroExtractor()
228
229        macros_max_mtime: t.Optional[float] = None
230
231        for context_path, config in self._context.configs.items():
232            for path in self._glob_paths(context_path / c.MACROS, config=config, extension=".py"):
233                if import_python_file(path, context_path):
234                    self._track_file(path)
235                    macro_file_mtime = self._path_mtimes[path]
236                    macros_max_mtime = (
237                        max(macros_max_mtime, macro_file_mtime)
238                        if macros_max_mtime
239                        else macro_file_mtime
240                    )
241
242            for path in self._glob_paths(context_path / c.MACROS, config=config, extension=".sql"):
243                self._track_file(path)
244                macro_file_mtime = self._path_mtimes[path]
245                macros_max_mtime = (
246                    max(macros_max_mtime, macro_file_mtime)
247                    if macros_max_mtime
248                    else macro_file_mtime
249                )
250                with open(path, "r", encoding="utf-8") as file:
251                    jinja_macros.add_macros(extractor.extract(file.read()))
252
253        self._macros_max_mtime = macros_max_mtime
254
255        macros = macro.get_registry()
256        macro.set_registry(standard_macros)
257
258        return macros, jinja_macros
259
260    def _load_models(
261        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
262    ) -> UniqueKeyDict[str, Model]:
263        """
264        Loads all of the models within the model directory with their associated
265        audits into a Dict and creates the dag
266        """
267        models = self._load_sql_models(macros, jinja_macros)
268        models.update(self._load_external_models())
269        models.update(self._load_python_models())
270
271        return models
272
273    def _load_sql_models(
274        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
275    ) -> UniqueKeyDict[str, Model]:
276        """Loads the sql models into a Dict"""
277        models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
278        for context_path, config in self._context.configs.items():
279            cache = SqlMeshLoader._Cache(self, context_path)
280            variables = self._variables(config)
281
282            for path in self._glob_paths(context_path / c.MODELS, config=config, extension=".sql"):
283                if not os.path.getsize(path):
284                    continue
285
286                self._track_file(path)
287
288                def _load() -> Model:
289                    with open(path, "r", encoding="utf-8") as file:
290                        try:
291                            expressions = parse(
292                                file.read(), default_dialect=config.model_defaults.dialect
293                            )
294                        except SqlglotError as ex:
295                            raise ConfigError(
296                                f"Failed to parse a model definition at '{path}': {ex}."
297                            )
298
299                    return load_sql_based_model(
300                        expressions,
301                        defaults=config.model_defaults.dict(),
302                        macros=macros,
303                        jinja_macros=jinja_macros,
304                        path=Path(path).absolute(),
305                        module_path=context_path,
306                        dialect=config.model_defaults.dialect,
307                        time_column_format=config.time_column_format,
308                        physical_schema_override=config.physical_schema_override,
309                        project=config.project,
310                        default_catalog=self._context.default_catalog,
311                        variables=variables,
312                    )
313
314                model = cache.get_or_load_model(path, _load)
315                models[model.fqn] = model
316
317                if isinstance(model, SeedModel):
318                    seed_path = model.seed_path
319                    self._track_file(seed_path)
320
321        return models
322
323    def _load_python_models(self) -> UniqueKeyDict[str, Model]:
324        """Loads the python models into a Dict"""
325        models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
326        registry = model_registry.registry()
327        registry.clear()
328        registered: t.Set[str] = set()
329
330        for context_path, config in self._context.configs.items():
331            variables = self._variables(config)
332            model_registry._dialect = config.model_defaults.dialect
333            try:
334                for path in self._glob_paths(
335                    context_path / c.MODELS, config=config, extension=".py"
336                ):
337                    if not os.path.getsize(path):
338                        continue
339
340                    self._track_file(path)
341                    import_python_file(path, context_path)
342                    new = registry.keys() - registered
343                    registered |= new
344                    for name in new:
345                        model = registry[name].model(
346                            path=path,
347                            module_path=context_path,
348                            defaults=config.model_defaults.dict(),
349                            dialect=config.model_defaults.dialect,
350                            time_column_format=config.time_column_format,
351                            physical_schema_override=config.physical_schema_override,
352                            project=config.project,
353                            default_catalog=self._context.default_catalog,
354                            variables=variables,
355                        )
356                        models[model.fqn] = model
357            finally:
358                model_registry._dialect = None
359
360        return models
361
362    def _load_audits(
363        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
364    ) -> UniqueKeyDict[str, Audit]:
365        """Loads all the model audits."""
366        audits_by_name: UniqueKeyDict[str, Audit] = UniqueKeyDict("audits")
367        for context_path, config in self._context.configs.items():
368            variables = self._variables(config)
369            for path in self._glob_paths(context_path / c.AUDITS, config=config, extension=".sql"):
370                self._track_file(path)
371                with open(path, "r", encoding="utf-8") as file:
372                    expressions = parse(file.read(), default_dialect=config.model_defaults.dialect)
373                    audits = load_multiple_audits(
374                        expressions=expressions,
375                        path=path,
376                        module_path=context_path,
377                        macros=macros,
378                        jinja_macros=jinja_macros,
379                        dialect=config.model_defaults.dialect,
380                        default_catalog=self._context.default_catalog,
381                        variables=variables,
382                    )
383                    for audit in audits:
384                        audits_by_name[audit.name] = audit
385        return audits_by_name
386
387    def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
388        """Loads all metrics."""
389        metrics: UniqueKeyDict[str, MetricMeta] = UniqueKeyDict("metrics")
390
391        for context_path, config in self._context.configs.items():
392            for path in self._glob_paths(context_path / c.METRICS, config=config, extension=".sql"):
393                if not os.path.getsize(path):
394                    continue
395                self._track_file(path)
396
397                with open(path, "r", encoding="utf-8") as file:
398                    dialect = config.model_defaults.dialect
399                    try:
400                        for expression in parse(file.read(), default_dialect=dialect):
401                            metric = load_metric_ddl(expression, path=path, dialect=dialect)
402                            metrics[metric.name] = metric
403                    except SqlglotError as ex:
404                        raise ConfigError(f"Failed to parse metric definitions at '{path}': {ex}.")
405
406        return metrics
407
408    def _glob_paths(
409        self, path: Path, config: Config, extension: str
410    ) -> t.Generator[Path, None, None]:
411        """
412        Globs the provided path for the file extension but also removes any filepaths that match an ignore
413        pattern either set in constants or provided in config
414
415        Args:
416            path: The filepath to glob
417            extension: The extension to check for in that path (checks recursively in zero or more subdirectories)
418
419        Returns:
420            Matched paths that are not ignored
421        """
422        for filepath in path.glob(f"**/*{extension}"):
423            for ignore_pattern in config.ignore_patterns:
424                if filepath.match(ignore_pattern):
425                    break
426            else:
427                yield filepath
428
429    def _variables(self, config: Config) -> t.Dict[str, t.Any]:
430        gateway_name = self._context.gateway or self._context.config.default_gateway_name
431        try:
432            gateway = config.get_gateway(gateway_name)
433        except ConfigError:
434            logger.warning("Gateway '%s' not found in project '%s'", gateway_name, config.project)
435            gateway = None
436        return {
437            **config.variables,
438            **(gateway.variables if gateway else {}),
439            c.GATEWAY: gateway_name,
440        }
441
442    class _Cache:
443        def __init__(self, loader: SqlMeshLoader, context_path: Path):
444            self._loader = loader
445            self._context_path = context_path
446            self._model_cache = ModelCache(self._context_path / c.CACHE)
447
448        def get_or_load_model(self, target_path: Path, loader: t.Callable[[], Model]) -> Model:
449            model = self._model_cache.get_or_load(
450                self._cache_entry_name(target_path),
451                self._model_cache_entry_id(target_path),
452                loader=loader,
453            )
454            model._path = target_path
455            return model
456
457        def _cache_entry_name(self, target_path: Path) -> str:
458            return "__".join(target_path.relative_to(self._context_path).parts).replace(
459                target_path.suffix, ""
460            )
461
462        def _model_cache_entry_id(self, model_path: Path) -> str:
463            mtimes = [
464                self._loader._path_mtimes[model_path],
465                self._loader._macros_max_mtime,
466                self._loader._config_mtimes.get(self._context_path),
467                self._loader._config_mtimes.get(c.SQLMESH_PATH),
468            ]
469            return "__".join(
470                [
471                    str(max(m for m in mtimes if m is not None)),
472                    self._loader._context.config.fingerprint,
473                    # We need to check default catalog since the provided config could not change but the
474                    # gateway we are using could change, therefore potentially changing the default catalog
475                    # which would then invalidate the cached model definition.
476                    self._loader._context.default_catalog or "",
477                ]
478            )

Loads macros and models for a context using the SQLMesh file formats

Inherited Members
Loader
load
reload_needed