  1from __future__ import annotations
  3import abc
  4import linecache
  5import logging
  6import os
  7import typing as t
  8from collections import defaultdict
  9from dataclasses import dataclass
 10from pathlib import Path
 12from sqlglot.errors import SchemaError, SqlglotError
 13from sqlglot.schema import MappingSchema
 15from sqlmesh.core import constants as c
 16from sqlmesh.core.audit import Audit, load_multiple_audits
 17from sqlmesh.core.dialect import parse
 18from sqlmesh.core.macros import MacroRegistry, macro
 19from sqlmesh.core.metric import Metric, MetricMeta, expand_metrics, load_metric_ddl
 20from sqlmesh.core.model import (
 21    Model,
 22    ModelCache,
 23    OptimizedQueryCache,
 24    SeedModel,
 25    create_external_model,
 26    load_sql_based_model,
 28from sqlmesh.core.model import model as model_registry
 29from sqlmesh.utils import UniqueKeyDict
 30from sqlmesh.utils.dag import DAG
 31from sqlmesh.utils.errors import ConfigError
 32from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
 33from sqlmesh.utils.metaprogramming import import_python_file
 34from sqlmesh.utils.yaml import YAML
 37    from sqlmesh.core.config import Config
 38    from sqlmesh.core.context import GenericContext
 41logger = logging.getLogger(__name__)
 44# TODO: consider moving this to context
 45def update_model_schemas(
 46    dag: DAG[str],
 47    models: UniqueKeyDict[str, Model],
 48    context_path: Path,
 49) -> None:
 50    schema = MappingSchema(normalize=False)
 51    optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE)
 53    for name in dag.sorted:
 54        model = models.get(name)
 56        # External models don't exist in the context, so we need to skip them
 57        if not model:
 58            continue
 60        try:
 61            model.update_schema(schema)
 62            optimized_query_cache.with_optimized_query(model)
 64            columns_to_types = model.columns_to_types
 65            if columns_to_types is not None:
 66                schema.add_table(
 67                    model.fqn, columns_to_types, dialect=model.dialect, normalize=False
 68                )
 69        except SchemaError as e:
 70            if "nesting level:" in str(e):
 71                logger.error(
 72                    "SQLMesh requires all model names and references to have the same level of nesting."
 73                )
 74            raise
 78class LoadedProject:
 79    macros: MacroRegistry
 80    jinja_macros: JinjaMacroRegistry
 81    models: UniqueKeyDict[str, Model]
 82    audits: UniqueKeyDict[str, Audit]
 83    metrics: UniqueKeyDict[str, Metric]
 84    dag: DAG[str]
 87class Loader(abc.ABC):
 88    """Abstract base class to load macros and models for a context"""
 90    def __init__(self) -> None:
 91        self._path_mtimes: t.Dict[Path, float] = {}
 92        self._dag: DAG[str] = DAG()
 94    def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedProject:
 95        """
 96        Loads all macros and models in the context's path.
 98        Args:
 99            context: The context to load macros and models for.
100            update_schemas: Convert star projections to explicit columns.
101        """
102        # python files are cached by the system
103        # need to manually clear here so we can reload macros
104        linecache.clearcache()
106        self._context = context
107        self._path_mtimes.clear()
108        self._dag = DAG()
110        config_mtimes: t.Dict[Path, t.List[float]] = defaultdict(list)
111        for context_path, config in self._context.configs.items():
112            for config_file in context_path.glob("config.*"):
113                self._track_file(config_file)
114                config_mtimes[context_path].append(self._path_mtimes[config_file])
116        for config_file in c.SQLMESH_PATH.glob("config.*"):
117            self._track_file(config_file)
118            config_mtimes[c.SQLMESH_PATH].append(self._path_mtimes[config_file])
120        self._config_mtimes = {path: max(mtimes) for path, mtimes in config_mtimes.items()}
122        macros, jinja_macros = self._load_scripts()
123        models = self._load_models(macros, jinja_macros)
125        for model in models.values():
126            self._add_model_to_dag(model)
128        if update_schemas:
129            update_model_schemas(
130                self._dag,
131                models,
132                self._context.path,
133            )
134            for model in models.values():
135                # The model definition can be validated correctly only after the schema is set.
136                model.validate_definition()
138        metrics = self._load_metrics()
140        project = LoadedProject(
141            macros=macros,
142            jinja_macros=jinja_macros,
143            models=models,
144            audits=self._load_audits(macros=macros, jinja_macros=jinja_macros),
145            metrics=expand_metrics(metrics),
146            dag=self._dag,
147        )
148        return project
150    def reload_needed(self) -> bool:
151        """
152        Checks for any modifications to the files the macros and models depend on
153        since the last load.
155        Returns:
156            True if a modification is found; False otherwise
157        """
158        return any(
159            not path.exists() or path.stat().st_mtime > initial_mtime
160            for path, initial_mtime in self._path_mtimes.items()
161        )
163    @abc.abstractmethod
164    def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
165        """Loads all user defined macros."""
167    @abc.abstractmethod
168    def _load_models(
169        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
170    ) -> UniqueKeyDict[str, Model]:
171        """Loads all models."""
173    @abc.abstractmethod
174    def _load_audits(
175        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
176    ) -> UniqueKeyDict[str, Audit]:
177        """Loads all audits."""
179    def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
180        return UniqueKeyDict("metrics")
182    def _load_external_models(self) -> UniqueKeyDict[str, Model]:
183        models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
184        for context_path, config in self._context.configs.items():
185            schema_path = Path(context_path / c.SCHEMA_YAML)
186            external_models_path = context_path / c.EXTERNAL_MODELS
188            paths_to_load = []
189            if schema_path.exists():
190                paths_to_load.append(schema_path)
192            if external_models_path.exists() and external_models_path.is_dir():
193                paths_to_load.extend(external_models_path.glob("*.yaml"))
195            for path in paths_to_load:
196                self._track_file(path)
198                with open(path, "r", encoding="utf-8") as file:
199                    for row in YAML().load(file.read()):
200                        model = create_external_model(
201                            **row,
202                            dialect=config.model_defaults.dialect,
203                            path=path,
204                            project=config.project,
205                            default_catalog=self._context.default_catalog,
206                        )
207                        models[model.fqn] = model
208        return models
210    def _add_model_to_dag(self, model: Model) -> None:
211        self._dag.add(model.fqn, model.depends_on)
213    def _track_file(self, path: Path) -> None:
214        """Project file to track for modifications"""
215        self._path_mtimes[path] = path.stat().st_mtime
218class SqlMeshLoader(Loader):
219    """Loads macros and models for a context using the SQLMesh file formats"""
221    def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
222        """Loads all user defined macros."""
223        # Store a copy of the macro registry
224        standard_macros = macro.get_registry()
225        jinja_macros = JinjaMacroRegistry()
226        extractor = MacroExtractor()
228        macros_max_mtime: t.Optional[float] = None
230        for context_path, config in self._context.configs.items():
231            for path in self._glob_paths(context_path / c.MACROS, config=config, extension=".py"):
232                if import_python_file(path, context_path):
233                    self._track_file(path)
234                    macro_file_mtime = self._path_mtimes[path]
235                    macros_max_mtime = (
236                        max(macros_max_mtime, macro_file_mtime)
237                        if macros_max_mtime
238                        else macro_file_mtime
239                    )
241            for path in self._glob_paths(context_path / c.MACROS, config=config, extension=".sql"):
242                self._track_file(path)
243                macro_file_mtime = self._path_mtimes[path]
244                macros_max_mtime = (
245                    max(macros_max_mtime, macro_file_mtime)
246                    if macros_max_mtime
247                    else macro_file_mtime
248                )
249                with open(path, "r", encoding="utf-8") as file:
250                    jinja_macros.add_macros(extractor.extract(file.read()))
252        self._macros_max_mtime = macros_max_mtime
254        macros = macro.get_registry()
255        macro.set_registry(standard_macros)
257        return macros, jinja_macros
259    def _load_models(
260        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
261    ) -> UniqueKeyDict[str, Model]:
262        """
263        Loads all of the models within the model directory with their associated
264        audits into a Dict and creates the dag
265        """
266        models = self._load_sql_models(macros, jinja_macros)
267        models.update(self._load_external_models())
268        models.update(self._load_python_models())
270        return models
272    def _load_sql_models(
273        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
274    ) -> UniqueKeyDict[str, Model]:
275        """Loads the sql models into a Dict"""
276        models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
277        for context_path, config in self._context.configs.items():
278            cache = SqlMeshLoader._Cache(self, context_path)
279            variables = self._variables(config)
281            for path in self._glob_paths(context_path / c.MODELS, config=config, extension=".sql"):
282                if not os.path.getsize(path):
283                    continue
285                self._track_file(path)
287                def _load() -> Model:
288                    with open(path, "r", encoding="utf-8") as file:
289                        try:
290                            expressions = parse(
291                                file.read(), default_dialect=config.model_defaults.dialect
292                            )
293                        except SqlglotError as ex:
294                            raise ConfigError(
295                                f"Failed to parse a model definition at '{path}': {ex}."
296                            )
298                    return load_sql_based_model(
299                        expressions,
300                        defaults=config.model_defaults.dict(),
301                        macros=macros,
302                        jinja_macros=jinja_macros,
303                        path=Path(path).absolute(),
304                        module_path=context_path,
305                        dialect=config.model_defaults.dialect,
306                        time_column_format=config.time_column_format,
307                        physical_schema_override=config.physical_schema_override,
308                        project=config.project,
309                        default_catalog=self._context.default_catalog,
310                        variables=variables,
311                    )
313                model = cache.get_or_load_model(path, _load)
314                models[model.fqn] = model
316                if isinstance(model, SeedModel):
317                    seed_path = model.seed_path
318                    self._track_file(seed_path)
320        return models
322    def _load_python_models(self) -> UniqueKeyDict[str, Model]:
323        """Loads the python models into a Dict"""
324        models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
325        registry = model_registry.registry()
326        registry.clear()
327        registered: t.Set[str] = set()
329        for context_path, config in self._context.configs.items():
330            variables = self._variables(config)
331            model_registry._dialect = config.model_defaults.dialect
332            try:
333                for path in self._glob_paths(
334                    context_path / c.MODELS, config=config, extension=".py"
335                ):
336                    if not os.path.getsize(path):
337                        continue
339                    self._track_file(path)
340                    import_python_file(path, context_path)
341                    new = registry.keys() - registered
342                    registered |= new
343                    for name in new:
344                        model = registry[name].model(
345                            path=path,
346                            module_path=context_path,
347                            defaults=config.model_defaults.dict(),
348                            dialect=config.model_defaults.dialect,
349                            time_column_format=config.time_column_format,
350                            physical_schema_override=config.physical_schema_override,
351                            project=config.project,
352                            default_catalog=self._context.default_catalog,
353                            variables=variables,
354                        )
355                        models[model.fqn] = model
356            finally:
357                model_registry._dialect = None
359        return models
361    def _load_audits(
362        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
363    ) -> UniqueKeyDict[str, Audit]:
364        """Loads all the model audits."""
365        audits_by_name: UniqueKeyDict[str, Audit] = UniqueKeyDict("audits")
366        for context_path, config in self._context.configs.items():
367            variables = self._variables(config)
368            for path in self._glob_paths(context_path / c.AUDITS, config=config, extension=".sql"):
369                self._track_file(path)
370                with open(path, "r", encoding="utf-8") as file:
371                    expressions = parse(file.read(), default_dialect=config.model_defaults.dialect)
372                    audits = load_multiple_audits(
373                        expressions=expressions,
374                        path=path,
375                        module_path=context_path,
376                        macros=macros,
377                        jinja_macros=jinja_macros,
378                        dialect=config.model_defaults.dialect,
379                        default_catalog=self._context.default_catalog,
380                        variables=variables,
381                    )
382                    for audit in audits:
383                        audits_by_name[audit.name] = audit
384        return audits_by_name
386    def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
387        """Loads all metrics."""
388        metrics: UniqueKeyDict[str, MetricMeta] = UniqueKeyDict("metrics")
390        for context_path, config in self._context.configs.items():
391            for path in self._glob_paths(context_path / c.METRICS, config=config, extension=".sql"):
392                if not os.path.getsize(path):
393                    continue
394                self._track_file(path)
396                with open(path, "r", encoding="utf-8") as file:
397                    dialect = config.model_defaults.dialect
398                    try:
399                        for expression in parse(file.read(), default_dialect=dialect):
400                            metric = load_metric_ddl(expression, path=path, dialect=dialect)
401                            metrics[metric.name] = metric
402                    except SqlglotError as ex:
403                        raise ConfigError(f"Failed to parse metric definitions at '{path}': {ex}.")
405        return metrics
407    def _glob_paths(
408        self, path: Path, config: Config, extension: str
409    ) -> t.Generator[Path, None, None]:
410        """
411        Globs the provided path for the file extension but also removes any filepaths that match an ignore
412        pattern either set in constants or provided in config
414        Args:
415            path: The filepath to glob
416            extension: The extension to check for in that path (checks recursively in zero or more subdirectories)
418        Returns:
419            Matched paths that are not ignored
420        """
421        for filepath in path.glob(f"**/*{extension}"):
422            for ignore_pattern in config.ignore_patterns:
423                if filepath.match(ignore_pattern):
424                    break
425            else:
426                yield filepath
428    def _variables(self, config: Config) -> t.Dict[str, t.Any]:
429        gateway_name = self._context.gateway or self._context.config.default_gateway_name
430        try:
431            gateway = config.get_gateway(gateway_name)
432        except ConfigError:
433            logger.warning("Gateway '%s' not found in project '%s'", gateway_name, config.project)
434            gateway = None
435        return {
436            **config.variables,
437            **(gateway.variables if gateway else {}),
438            c.GATEWAY: gateway_name,
439        }
441    class _Cache:
442        def __init__(self, loader: SqlMeshLoader, context_path: Path):
443            self._loader = loader
444            self._context_path = context_path
445            self._model_cache = ModelCache(self._context_path / c.CACHE)
447        def get_or_load_model(self, target_path: Path, loader: t.Callable[[], Model]) -> Model:
448            model = self._model_cache.get_or_load(
449                self._cache_entry_name(target_path),
450                self._model_cache_entry_id(target_path),
451                loader=loader,
452            )
453            model._path = target_path
454            return model
456        def _cache_entry_name(self, target_path: Path) -> str:
457            return "__".join(target_path.relative_to(self._context_path).parts).replace(
458                target_path.suffix, ""
459            )
461        def _model_cache_entry_id(self, model_path: Path) -> str:
462            mtimes = [
463                self._loader._path_mtimes[model_path],
464                self._loader._macros_max_mtime,
465                self._loader._config_mtimes.get(self._context_path),
466                self._loader._config_mtimes.get(c.SQLMESH_PATH),
467            ]
468            return "__".join(
469                [
470                    str(max(m for m in mtimes if m is not None)),
471                    self._loader._context.config.fingerprint,
472                    # We need to check default catalog since the provided config could not change but the
473                    # gateway we are using could change, therefore potentially changing the default catalog
474                    # which would then invalidate the cached model definition.
475                    self._loader._context.default_catalog or "",
476                ]
477            )
