Edit on GitHub

sqlmesh.dbt.loader

  1from __future__ import annotations
  2
  3import logging
  4import sys
  5import typing as t
  6import sqlmesh.core.dialect as d
  7from pathlib import Path
  8from collections import defaultdict
  9from sqlmesh.core.config import (
 10    Config,
 11    ConnectionConfig,
 12    GatewayConfig,
 13    ModelDefaultsConfig,
 14    DbtConfig as RootDbtConfig,
 15)
 16from sqlmesh.core.environment import EnvironmentStatements
 17from sqlmesh.core.loader import CacheBase, LoadedProject, Loader
 18from sqlmesh.core.macros import MacroRegistry, macro
 19from sqlmesh.core.model import Model, ModelCache
 20from sqlmesh.core.signal import signal
 21from sqlmesh.dbt.basemodel import BMC, BaseModelConfig
 22from sqlmesh.dbt.common import Dependencies
 23from sqlmesh.dbt.context import DbtContext
 24from sqlmesh.dbt.model import ModelConfig
 25from sqlmesh.dbt.profile import Profile
 26from sqlmesh.dbt.project import Project
 27from sqlmesh.dbt.target import TargetConfig
 28from sqlmesh.utils import UniqueKeyDict
 29from sqlmesh.utils.errors import ConfigError, MissingModelError, BaseMissingReferenceError
 30from sqlmesh.utils.jinja import (
 31    JinjaMacroRegistry,
 32    make_jinja_registry,
 33)
 34
 35if sys.version_info >= (3, 12):
 36    from importlib import metadata
 37else:
 38    import importlib_metadata as metadata  # type: ignore
 39
 40if t.TYPE_CHECKING:
 41    from sqlmesh.core.audit import Audit, ModelAudit
 42    from sqlmesh.core.context import GenericContext
 43
 44logger = logging.getLogger(__name__)
 45
 46
 47def sqlmesh_config(
 48    project_root: t.Optional[Path] = None,
 49    state_connection: t.Optional[ConnectionConfig] = None,
 50    dbt_profile_name: t.Optional[str] = None,
 51    dbt_target_name: t.Optional[str] = None,
 52    variables: t.Optional[t.Dict[str, t.Any]] = None,
 53    threads: t.Optional[int] = None,
 54    register_comments: t.Optional[bool] = None,
 55    infer_state_schema_name: bool = False,
 56    profiles_dir: t.Optional[Path] = None,
 57    **kwargs: t.Any,
 58) -> Config:
 59    project_root = project_root or Path()
 60    context = DbtContext(
 61        project_root=project_root, profiles_dir=profiles_dir, profile_name=dbt_profile_name
 62    )
 63
 64    # note: Profile.load() is called twice with different DbtContext's:
 65    # - once here with the above DbtContext (to determine connnection / gateway config which has to be set up before everything else)
 66    # - again on the SQLMesh side via GenericContext.load() -> DbtLoader._load_projects() -> Project.load() which constructs a fresh DbtContext and ignores the above one
 67    # it's important to ensure that the DbtContext created within the DbtLoader uses the same project root / profiles dir that we use here
 68    profile = Profile.load(context, target_name=dbt_target_name)
 69    model_defaults = kwargs.pop("model_defaults", ModelDefaultsConfig())
 70    if model_defaults.dialect is None:
 71        model_defaults.dialect = profile.target.dialect
 72
 73    target_to_sqlmesh_args = {
 74        "register_comments": register_comments or False,
 75    }
 76
 77    loader = kwargs.pop("loader", DbtLoader)
 78    if not issubclass(loader, DbtLoader):
 79        raise ConfigError("The loader must be a DbtLoader.")
 80
 81    if threads is not None:
 82        # the to_sqlmesh() function on TargetConfig maps self.threads -> concurrent_tasks
 83        profile.target.threads = threads
 84
 85    gateway_kwargs = {}
 86    if infer_state_schema_name:
 87        profile_name = context.profile_name
 88
 89        # Note: we deliberately isolate state based on the target *schema* and not the target name.
 90        # It is assumed that the project will define a target, eg 'dev', and then in each users own ~/.dbt/profiles.yml the schema
 91        # for the 'dev' target is overriden to something user-specific, rather than making the target name itself user-specific.
 92        # This means that the schema name is the indicator of isolated state, not the target name which may be re-used across multiple schemas.
 93        target_schema = profile.target.schema_
 94
 95        # dbt-core doesnt allow schema to be undefined, but it does allow an empty string, and then just
 96        # fails at runtime when `CREATE SCHEMA ""` doesnt work
 97        if not target_schema:
 98            raise ConfigError(
 99                f"Target '{profile.target_name}' does not specify a schema.\n"
100                "A schema is required in order to infer where to store SQLMesh state"
101            )
102
103        inferred_state_schema_name = f"sqlmesh_state_{profile_name}_{target_schema}"
104        logger.info("Inferring state schema: %s", inferred_state_schema_name)
105        gateway_kwargs["state_schema"] = inferred_state_schema_name
106
107    return Config(
108        loader=loader,
109        loader_kwargs=dict(profiles_dir=profiles_dir),
110        model_defaults=model_defaults,
111        variables=variables or {},
112        dbt=RootDbtConfig(infer_state_schema_name=infer_state_schema_name),
113        **{
114            "default_gateway": profile.target_name if "gateways" not in kwargs else "",
115            "gateways": {
116                profile.target_name: GatewayConfig(
117                    connection=profile.target.to_sqlmesh(**target_to_sqlmesh_args),
118                    state_connection=state_connection,
119                    **gateway_kwargs,
120                )
121            },  # type: ignore
122            **kwargs,
123        },
124    )
125
126
127class DbtLoader(Loader):
128    def __init__(
129        self, context: GenericContext, path: Path, profiles_dir: t.Optional[Path] = None
130    ) -> None:
131        self._projects: t.List[Project] = []
132        self._macros_max_mtime: t.Optional[float] = None
133        self._profiles_dir = profiles_dir
134        super().__init__(context, path)
135
136    def load(self) -> LoadedProject:
137        self._projects = []
138        return super().load()
139
140    def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
141        macro_files = list(Path(self.config_path, "macros").glob("**/*.sql"))
142
143        for file in macro_files:
144            self._track_file(file)
145
146        # This doesn't do anything, the actual content will be loaded from the manifest
147        return (
148            macro.get_registry(),
149            JinjaMacroRegistry(),
150        )
151
152    def _load_models(
153        self,
154        macros: MacroRegistry,
155        jinja_macros: JinjaMacroRegistry,
156        gateway: t.Optional[str],
157        audits: UniqueKeyDict[str, ModelAudit],
158        signals: UniqueKeyDict[str, signal],
159    ) -> UniqueKeyDict[str, Model]:
160        models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
161
162        def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
163            logger.debug("Converting '%s' to sqlmesh format", config.canonical_name(context))
164            return config.to_sqlmesh(
165                context,
166                audit_definitions=audits,
167                virtual_environment_mode=self.config.virtual_environment_mode,
168            )
169
170        for project in self._load_projects():
171            macros_max_mtime = self._macros_max_mtime
172            yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder(
173                project.context.project_root
174            )
175            cache = DbtLoader._Cache(self, project, macros_max_mtime, yaml_max_mtimes)
176
177            logger.debug("Converting models to sqlmesh")
178            # Now that config is rendered, create the sqlmesh models
179            for package in project.packages.values():
180                package_context = project.context.copy()
181                package_context.set_and_render_variables(package.variables, package.name)
182                package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds}
183
184                package_models_by_path: t.Dict[Path, t.List[BaseModelConfig]] = defaultdict(list)
185                for model in package_models.values():
186                    if isinstance(model, ModelConfig) and not model.sql.strip():
187                        logger.info(f"Skipping empty model '{model.name}' at path '{model.path}'.")
188                        continue
189                    package_models_by_path[model.path].append(model)
190
191                for path, path_models in package_models_by_path.items():
192                    sqlmesh_models = cache.get_or_load_models(
193                        path,
194                        loader=lambda: [
195                            _to_sqlmesh(model, package_context) for model in path_models
196                        ],
197                    )
198                    for sqlmesh_model in sqlmesh_models:
199                        models[sqlmesh_model.fqn] = sqlmesh_model
200
201            models.update(self._load_external_models(audits, cache))
202
203        return models
204
205    def _load_audits(
206        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
207    ) -> UniqueKeyDict[str, Audit]:
208        audits: UniqueKeyDict = UniqueKeyDict("audits")
209
210        for project in self._load_projects():
211            logger.debug("Converting audits to sqlmesh")
212            for package in project.packages.values():
213                package_context = project.context.copy()
214                package_context.set_and_render_variables(package.variables, package.name)
215                for test in package.tests.values():
216                    logger.debug("Converting '%s' to sqlmesh format", test.name)
217                    try:
218                        audits[test.canonical_name] = test.to_sqlmesh(package_context)
219
220                    except BaseMissingReferenceError as e:
221                        ref_type = "model" if isinstance(e, MissingModelError) else "source"
222                        logger.warning(
223                            "Skipping audit '%s' because %s '%s' is not a valid ref",
224                            test.name,
225                            ref_type,
226                            e.ref,
227                        )
228
229        return audits
230
231    def _load_projects(self) -> t.List[Project]:
232        if not self._projects:
233            target_name = self.context.selected_gateway
234
235            self._projects = []
236
237            project = Project.load(
238                DbtContext(
239                    project_root=self.config_path,
240                    profiles_dir=self._profiles_dir,
241                    target_name=target_name,
242                    sqlmesh_config=self.config,
243                ),
244                variables=self.config.variables,
245            )
246
247            self._projects.append(project)
248
249            context_default_catalog = self.context.default_catalog or ""
250            if project.context.target.database != context_default_catalog:
251                raise ConfigError(
252                    f"Project default catalog ('{project.context.target.database}') does not match context default catalog ('{context_default_catalog}')."
253                )
254            for path in project.project_files:
255                self._track_file(path)
256
257            context = project.context
258
259            macros_mtimes: t.List[float] = []
260
261            for package_name, package in project.packages.items():
262                context.add_sources(package.sources)
263                context.add_seeds(package.seeds)
264                context.add_models(package.models)
265                macros_mtimes.extend(
266                    [
267                        self._path_mtimes[m.path]
268                        for m in package.macros.values()
269                        if m.path in self._path_mtimes
270                    ]
271                )
272
273            for package_name, macro_infos in context.manifest.all_macros.items():
274                context.add_macros(macro_infos, package=package_name)
275
276            self._macros_max_mtime = max(macros_mtimes) if macros_mtimes else None
277
278        return self._projects
279
280    def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]:
281        requirements, excluded_requirements = super()._load_requirements()
282
283        target_packages = ["dbt-core"]
284        for project in self._load_projects():
285            target_packages.append(f"dbt-{project.context.target.type}")
286
287        for target_package in target_packages:
288            if target_package in requirements or target_package in excluded_requirements:
289                continue
290            try:
291                requirements[target_package] = metadata.version(target_package)
292            except metadata.PackageNotFoundError:
293                from sqlmesh.core.console import get_console
294
295                get_console().log_warning(f"dbt package {target_package} is not installed.")
296
297        return requirements, excluded_requirements
298
299    def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]:
300        """Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively."""
301
302        hooks_by_package_name: t.Dict[str, EnvironmentStatements] = {}
303        project_names: t.Set[str] = set()
304        dialect = self.config.dialect
305        for project in self._load_projects():
306            for package_name, package in project.packages.items():
307                package_context = project.context.copy()
308                package_context.set_and_render_variables(package.variables, package_name)
309                on_run_start: t.List[str] = [
310                    on_run_hook.sql
311                    for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index)
312                ]
313                on_run_end: t.List[str] = [
314                    on_run_hook.sql
315                    for on_run_hook in sorted(package.on_run_end.values(), key=lambda h: h.index)
316                ]
317
318                if on_run_start or on_run_end:
319                    dependencies = Dependencies()
320                    for hook in [*package.on_run_start.values(), *package.on_run_end.values()]:
321                        dependencies = dependencies.union(hook.dependencies)
322
323                    statements_context = package_context.context_for_dependencies(dependencies)
324                    jinja_registry = make_jinja_registry(
325                        statements_context.jinja_macros, package_name, set(dependencies.macros)
326                    )
327                    jinja_registry.add_globals(statements_context.jinja_globals)
328
329                    hooks_by_package_name[package_name] = EnvironmentStatements(
330                        before_all=[
331                            d.jinja_statement(stmt).sql(dialect=dialect)
332                            for stmt in on_run_start or []
333                        ],
334                        after_all=[
335                            d.jinja_statement(stmt).sql(dialect=dialect)
336                            for stmt in on_run_end or []
337                        ],
338                        python_env={},
339                        jinja_macros=jinja_registry,
340                        project=package_name,
341                    )
342                    project_names.add(package_name)
343
344        return [
345            statements
346            for _, statements in sorted(
347                hooks_by_package_name.items(),
348                key=lambda item: 0 if item[0] in project_names else 1,
349            )
350        ]
351
352    def _compute_yaml_max_mtime_per_subfolder(
353        self, root: Path, visited: t.Optional[t.Set[Path]] = None
354    ) -> t.Dict[Path, float]:
355        root = root.resolve()
356        visited = visited or set()
357        if not root.is_dir() or root in visited:
358            return {}
359
360        visited.add(root)
361
362        result = {}
363        max_mtime: t.Optional[float] = None
364
365        for nested in root.iterdir():
366            try:
367                if nested.is_dir():
368                    result.update(
369                        self._compute_yaml_max_mtime_per_subfolder(nested, visited=visited)
370                    )
371                elif nested.suffix.lower() in (".yaml", ".yml"):
372                    yaml_mtime = self._path_mtimes.get(nested)
373                    if yaml_mtime:
374                        max_mtime = (
375                            max(max_mtime, yaml_mtime) if max_mtime is not None else yaml_mtime
376                        )
377            except PermissionError:
378                pass
379
380        if max_mtime is not None:
381            result[root] = max_mtime
382
383        return result
384
385    class _Cache(CacheBase):
386        MAX_ENTRY_NAME_LENGTH = 200
387
388        def __init__(
389            self,
390            loader: DbtLoader,
391            project: Project,
392            macros_max_mtime: t.Optional[float],
393            yaml_max_mtimes: t.Dict[Path, float],
394        ):
395            self._loader = loader
396            self._project = project
397            self._macros_max_mtime = macros_max_mtime
398            self._yaml_max_mtimes = yaml_max_mtimes
399
400            target = t.cast(TargetConfig, project.context.target)
401            cache_dir = loader.context.cache_dir / target.name
402            self._model_cache = ModelCache(cache_dir)
403
404        def get_or_load_models(
405            self, target_path: Path, loader: t.Callable[[], t.List[Model]]
406        ) -> t.List[Model]:
407            models = self._model_cache.get_or_load(
408                self._cache_entry_name(target_path),
409                self._cache_entry_id(target_path),
410                loader=loader,
411            )
412            for model in models:
413                model._path = target_path
414
415            return models
416
417        def put(self, models: t.List[Model], path: Path) -> bool:
418            return self._model_cache.put(
419                models,
420                self._cache_entry_name(path),
421                self._cache_entry_id(path),
422            )
423
424        def get(self, path: Path) -> t.List[Model]:
425            return self._model_cache.get(
426                self._cache_entry_name(path),
427                self._cache_entry_id(path),
428            )
429
430        def _cache_entry_name(self, target_path: Path) -> str:
431            try:
432                path_for_name = target_path.absolute().relative_to(
433                    self._project.context.project_root.absolute()
434                )
435            except ValueError:
436                path_for_name = target_path
437            name = "__".join(path_for_name.parts).replace(path_for_name.suffix, "")
438            if len(name) > self.MAX_ENTRY_NAME_LENGTH:
439                return name[len(name) - self.MAX_ENTRY_NAME_LENGTH :]
440            return name
441
442        def _cache_entry_id(self, target_path: Path) -> str:
443            max_mtime = self._max_mtime_for_path(target_path)
444            return "__".join(
445                [
446                    str(int(max_mtime)) if max_mtime is not None else "na",
447                    self._loader.config.fingerprint,
448                ]
449            )
450
451        def _max_mtime_for_path(self, target_path: Path) -> t.Optional[float]:
452            project_root = self._project.context.project_root
453
454            try:
455                target_path.absolute().relative_to(project_root.absolute())
456            except ValueError:
457                return None
458
459            mtimes = [
460                self._loader._path_mtimes.get(target_path),
461                self._loader._path_mtimes.get(self._project.profile.path),
462                # FIXME: take into account which macros are actually referenced in the target model.
463                self._macros_max_mtime,
464            ]
465
466            cursor = target_path
467            while cursor != project_root:
468                cursor = cursor.parent
469                mtimes.append(self._yaml_max_mtimes.get(cursor))
470
471            non_null_mtimes = [t for t in mtimes if t is not None]
472            return max(non_null_mtimes) if non_null_mtimes else None
logger = <Logger sqlmesh.dbt.loader (WARNING)>
def sqlmesh_config( project_root: Optional[pathlib.Path] = None, state_connection: Optional[sqlmesh.core.config.connection.ConnectionConfig] = None, dbt_profile_name: Optional[str] = None, dbt_target_name: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, threads: Optional[int] = None, register_comments: Optional[bool] = None, infer_state_schema_name: bool = False, profiles_dir: Optional[pathlib.Path] = None, **kwargs: Any) -> sqlmesh.core.config.root.Config:
 48def sqlmesh_config(
 49    project_root: t.Optional[Path] = None,
 50    state_connection: t.Optional[ConnectionConfig] = None,
 51    dbt_profile_name: t.Optional[str] = None,
 52    dbt_target_name: t.Optional[str] = None,
 53    variables: t.Optional[t.Dict[str, t.Any]] = None,
 54    threads: t.Optional[int] = None,
 55    register_comments: t.Optional[bool] = None,
 56    infer_state_schema_name: bool = False,
 57    profiles_dir: t.Optional[Path] = None,
 58    **kwargs: t.Any,
 59) -> Config:
 60    project_root = project_root or Path()
 61    context = DbtContext(
 62        project_root=project_root, profiles_dir=profiles_dir, profile_name=dbt_profile_name
 63    )
 64
 65    # note: Profile.load() is called twice with different DbtContext's:
 66    # - once here with the above DbtContext (to determine connnection / gateway config which has to be set up before everything else)
 67    # - again on the SQLMesh side via GenericContext.load() -> DbtLoader._load_projects() -> Project.load() which constructs a fresh DbtContext and ignores the above one
 68    # it's important to ensure that the DbtContext created within the DbtLoader uses the same project root / profiles dir that we use here
 69    profile = Profile.load(context, target_name=dbt_target_name)
 70    model_defaults = kwargs.pop("model_defaults", ModelDefaultsConfig())
 71    if model_defaults.dialect is None:
 72        model_defaults.dialect = profile.target.dialect
 73
 74    target_to_sqlmesh_args = {
 75        "register_comments": register_comments or False,
 76    }
 77
 78    loader = kwargs.pop("loader", DbtLoader)
 79    if not issubclass(loader, DbtLoader):
 80        raise ConfigError("The loader must be a DbtLoader.")
 81
 82    if threads is not None:
 83        # the to_sqlmesh() function on TargetConfig maps self.threads -> concurrent_tasks
 84        profile.target.threads = threads
 85
 86    gateway_kwargs = {}
 87    if infer_state_schema_name:
 88        profile_name = context.profile_name
 89
 90        # Note: we deliberately isolate state based on the target *schema* and not the target name.
 91        # It is assumed that the project will define a target, eg 'dev', and then in each users own ~/.dbt/profiles.yml the schema
 92        # for the 'dev' target is overriden to something user-specific, rather than making the target name itself user-specific.
 93        # This means that the schema name is the indicator of isolated state, not the target name which may be re-used across multiple schemas.
 94        target_schema = profile.target.schema_
 95
 96        # dbt-core doesnt allow schema to be undefined, but it does allow an empty string, and then just
 97        # fails at runtime when `CREATE SCHEMA ""` doesnt work
 98        if not target_schema:
 99            raise ConfigError(
100                f"Target '{profile.target_name}' does not specify a schema.\n"
101                "A schema is required in order to infer where to store SQLMesh state"
102            )
103
104        inferred_state_schema_name = f"sqlmesh_state_{profile_name}_{target_schema}"
105        logger.info("Inferring state schema: %s", inferred_state_schema_name)
106        gateway_kwargs["state_schema"] = inferred_state_schema_name
107
108    return Config(
109        loader=loader,
110        loader_kwargs=dict(profiles_dir=profiles_dir),
111        model_defaults=model_defaults,
112        variables=variables or {},
113        dbt=RootDbtConfig(infer_state_schema_name=infer_state_schema_name),
114        **{
115            "default_gateway": profile.target_name if "gateways" not in kwargs else "",
116            "gateways": {
117                profile.target_name: GatewayConfig(
118                    connection=profile.target.to_sqlmesh(**target_to_sqlmesh_args),
119                    state_connection=state_connection,
120                    **gateway_kwargs,
121                )
122            },  # type: ignore
123            **kwargs,
124        },
125    )
class DbtLoader(sqlmesh.core.loader.Loader):
128class DbtLoader(Loader):
129    def __init__(
130        self, context: GenericContext, path: Path, profiles_dir: t.Optional[Path] = None
131    ) -> None:
132        self._projects: t.List[Project] = []
133        self._macros_max_mtime: t.Optional[float] = None
134        self._profiles_dir = profiles_dir
135        super().__init__(context, path)
136
137    def load(self) -> LoadedProject:
138        self._projects = []
139        return super().load()
140
141    def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
142        macro_files = list(Path(self.config_path, "macros").glob("**/*.sql"))
143
144        for file in macro_files:
145            self._track_file(file)
146
147        # This doesn't do anything, the actual content will be loaded from the manifest
148        return (
149            macro.get_registry(),
150            JinjaMacroRegistry(),
151        )
152
153    def _load_models(
154        self,
155        macros: MacroRegistry,
156        jinja_macros: JinjaMacroRegistry,
157        gateway: t.Optional[str],
158        audits: UniqueKeyDict[str, ModelAudit],
159        signals: UniqueKeyDict[str, signal],
160    ) -> UniqueKeyDict[str, Model]:
161        models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
162
163        def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
164            logger.debug("Converting '%s' to sqlmesh format", config.canonical_name(context))
165            return config.to_sqlmesh(
166                context,
167                audit_definitions=audits,
168                virtual_environment_mode=self.config.virtual_environment_mode,
169            )
170
171        for project in self._load_projects():
172            macros_max_mtime = self._macros_max_mtime
173            yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder(
174                project.context.project_root
175            )
176            cache = DbtLoader._Cache(self, project, macros_max_mtime, yaml_max_mtimes)
177
178            logger.debug("Converting models to sqlmesh")
179            # Now that config is rendered, create the sqlmesh models
180            for package in project.packages.values():
181                package_context = project.context.copy()
182                package_context.set_and_render_variables(package.variables, package.name)
183                package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds}
184
185                package_models_by_path: t.Dict[Path, t.List[BaseModelConfig]] = defaultdict(list)
186                for model in package_models.values():
187                    if isinstance(model, ModelConfig) and not model.sql.strip():
188                        logger.info(f"Skipping empty model '{model.name}' at path '{model.path}'.")
189                        continue
190                    package_models_by_path[model.path].append(model)
191
192                for path, path_models in package_models_by_path.items():
193                    sqlmesh_models = cache.get_or_load_models(
194                        path,
195                        loader=lambda: [
196                            _to_sqlmesh(model, package_context) for model in path_models
197                        ],
198                    )
199                    for sqlmesh_model in sqlmesh_models:
200                        models[sqlmesh_model.fqn] = sqlmesh_model
201
202            models.update(self._load_external_models(audits, cache))
203
204        return models
205
206    def _load_audits(
207        self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
208    ) -> UniqueKeyDict[str, Audit]:
209        audits: UniqueKeyDict = UniqueKeyDict("audits")
210
211        for project in self._load_projects():
212            logger.debug("Converting audits to sqlmesh")
213            for package in project.packages.values():
214                package_context = project.context.copy()
215                package_context.set_and_render_variables(package.variables, package.name)
216                for test in package.tests.values():
217                    logger.debug("Converting '%s' to sqlmesh format", test.name)
218                    try:
219                        audits[test.canonical_name] = test.to_sqlmesh(package_context)
220
221                    except BaseMissingReferenceError as e:
222                        ref_type = "model" if isinstance(e, MissingModelError) else "source"
223                        logger.warning(
224                            "Skipping audit '%s' because %s '%s' is not a valid ref",
225                            test.name,
226                            ref_type,
227                            e.ref,
228                        )
229
230        return audits
231
232    def _load_projects(self) -> t.List[Project]:
233        if not self._projects:
234            target_name = self.context.selected_gateway
235
236            self._projects = []
237
238            project = Project.load(
239                DbtContext(
240                    project_root=self.config_path,
241                    profiles_dir=self._profiles_dir,
242                    target_name=target_name,
243                    sqlmesh_config=self.config,
244                ),
245                variables=self.config.variables,
246            )
247
248            self._projects.append(project)
249
250            context_default_catalog = self.context.default_catalog or ""
251            if project.context.target.database != context_default_catalog:
252                raise ConfigError(
253                    f"Project default catalog ('{project.context.target.database}') does not match context default catalog ('{context_default_catalog}')."
254                )
255            for path in project.project_files:
256                self._track_file(path)
257
258            context = project.context
259
260            macros_mtimes: t.List[float] = []
261
262            for package_name, package in project.packages.items():
263                context.add_sources(package.sources)
264                context.add_seeds(package.seeds)
265                context.add_models(package.models)
266                macros_mtimes.extend(
267                    [
268                        self._path_mtimes[m.path]
269                        for m in package.macros.values()
270                        if m.path in self._path_mtimes
271                    ]
272                )
273
274            for package_name, macro_infos in context.manifest.all_macros.items():
275                context.add_macros(macro_infos, package=package_name)
276
277            self._macros_max_mtime = max(macros_mtimes) if macros_mtimes else None
278
279        return self._projects
280
281    def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]:
282        requirements, excluded_requirements = super()._load_requirements()
283
284        target_packages = ["dbt-core"]
285        for project in self._load_projects():
286            target_packages.append(f"dbt-{project.context.target.type}")
287
288        for target_package in target_packages:
289            if target_package in requirements or target_package in excluded_requirements:
290                continue
291            try:
292                requirements[target_package] = metadata.version(target_package)
293            except metadata.PackageNotFoundError:
294                from sqlmesh.core.console import get_console
295
296                get_console().log_warning(f"dbt package {target_package} is not installed.")
297
298        return requirements, excluded_requirements
299
300    def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]:
301        """Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively."""
302
303        hooks_by_package_name: t.Dict[str, EnvironmentStatements] = {}
304        project_names: t.Set[str] = set()
305        dialect = self.config.dialect
306        for project in self._load_projects():
307            for package_name, package in project.packages.items():
308                package_context = project.context.copy()
309                package_context.set_and_render_variables(package.variables, package_name)
310                on_run_start: t.List[str] = [
311                    on_run_hook.sql
312                    for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index)
313                ]
314                on_run_end: t.List[str] = [
315                    on_run_hook.sql
316                    for on_run_hook in sorted(package.on_run_end.values(), key=lambda h: h.index)
317                ]
318
319                if on_run_start or on_run_end:
320                    dependencies = Dependencies()
321                    for hook in [*package.on_run_start.values(), *package.on_run_end.values()]:
322                        dependencies = dependencies.union(hook.dependencies)
323
324                    statements_context = package_context.context_for_dependencies(dependencies)
325                    jinja_registry = make_jinja_registry(
326                        statements_context.jinja_macros, package_name, set(dependencies.macros)
327                    )
328                    jinja_registry.add_globals(statements_context.jinja_globals)
329
330                    hooks_by_package_name[package_name] = EnvironmentStatements(
331                        before_all=[
332                            d.jinja_statement(stmt).sql(dialect=dialect)
333                            for stmt in on_run_start or []
334                        ],
335                        after_all=[
336                            d.jinja_statement(stmt).sql(dialect=dialect)
337                            for stmt in on_run_end or []
338                        ],
339                        python_env={},
340                        jinja_macros=jinja_registry,
341                        project=package_name,
342                    )
343                    project_names.add(package_name)
344
345        return [
346            statements
347            for _, statements in sorted(
348                hooks_by_package_name.items(),
349                key=lambda item: 0 if item[0] in project_names else 1,
350            )
351        ]
352
353    def _compute_yaml_max_mtime_per_subfolder(
354        self, root: Path, visited: t.Optional[t.Set[Path]] = None
355    ) -> t.Dict[Path, float]:
356        root = root.resolve()
357        visited = visited or set()
358        if not root.is_dir() or root in visited:
359            return {}
360
361        visited.add(root)
362
363        result = {}
364        max_mtime: t.Optional[float] = None
365
366        for nested in root.iterdir():
367            try:
368                if nested.is_dir():
369                    result.update(
370                        self._compute_yaml_max_mtime_per_subfolder(nested, visited=visited)
371                    )
372                elif nested.suffix.lower() in (".yaml", ".yml"):
373                    yaml_mtime = self._path_mtimes.get(nested)
374                    if yaml_mtime:
375                        max_mtime = (
376                            max(max_mtime, yaml_mtime) if max_mtime is not None else yaml_mtime
377                        )
378            except PermissionError:
379                pass
380
381        if max_mtime is not None:
382            result[root] = max_mtime
383
384        return result
385
386    class _Cache(CacheBase):
387        MAX_ENTRY_NAME_LENGTH = 200
388
389        def __init__(
390            self,
391            loader: DbtLoader,
392            project: Project,
393            macros_max_mtime: t.Optional[float],
394            yaml_max_mtimes: t.Dict[Path, float],
395        ):
396            self._loader = loader
397            self._project = project
398            self._macros_max_mtime = macros_max_mtime
399            self._yaml_max_mtimes = yaml_max_mtimes
400
401            target = t.cast(TargetConfig, project.context.target)
402            cache_dir = loader.context.cache_dir / target.name
403            self._model_cache = ModelCache(cache_dir)
404
405        def get_or_load_models(
406            self, target_path: Path, loader: t.Callable[[], t.List[Model]]
407        ) -> t.List[Model]:
408            models = self._model_cache.get_or_load(
409                self._cache_entry_name(target_path),
410                self._cache_entry_id(target_path),
411                loader=loader,
412            )
413            for model in models:
414                model._path = target_path
415
416            return models
417
418        def put(self, models: t.List[Model], path: Path) -> bool:
419            return self._model_cache.put(
420                models,
421                self._cache_entry_name(path),
422                self._cache_entry_id(path),
423            )
424
425        def get(self, path: Path) -> t.List[Model]:
426            return self._model_cache.get(
427                self._cache_entry_name(path),
428                self._cache_entry_id(path),
429            )
430
431        def _cache_entry_name(self, target_path: Path) -> str:
432            try:
433                path_for_name = target_path.absolute().relative_to(
434                    self._project.context.project_root.absolute()
435                )
436            except ValueError:
437                path_for_name = target_path
438            name = "__".join(path_for_name.parts).replace(path_for_name.suffix, "")
439            if len(name) > self.MAX_ENTRY_NAME_LENGTH:
440                return name[len(name) - self.MAX_ENTRY_NAME_LENGTH :]
441            return name
442
443        def _cache_entry_id(self, target_path: Path) -> str:
444            max_mtime = self._max_mtime_for_path(target_path)
445            return "__".join(
446                [
447                    str(int(max_mtime)) if max_mtime is not None else "na",
448                    self._loader.config.fingerprint,
449                ]
450            )
451
452        def _max_mtime_for_path(self, target_path: Path) -> t.Optional[float]:
453            project_root = self._project.context.project_root
454
455            try:
456                target_path.absolute().relative_to(project_root.absolute())
457            except ValueError:
458                return None
459
460            mtimes = [
461                self._loader._path_mtimes.get(target_path),
462                self._loader._path_mtimes.get(self._project.profile.path),
463                # FIXME: take into account which macros are actually referenced in the target model.
464                self._macros_max_mtime,
465            ]
466
467            cursor = target_path
468            while cursor != project_root:
469                cursor = cursor.parent
470                mtimes.append(self._yaml_max_mtimes.get(cursor))
471
472            non_null_mtimes = [t for t in mtimes if t is not None]
473            return max(non_null_mtimes) if non_null_mtimes else None

Abstract base class to load macros and models for a context

DbtLoader( context: sqlmesh.core.context.GenericContext, path: pathlib.Path, profiles_dir: Optional[pathlib.Path] = None)
129    def __init__(
130        self, context: GenericContext, path: Path, profiles_dir: t.Optional[Path] = None
131    ) -> None:
132        self._projects: t.List[Project] = []
133        self._macros_max_mtime: t.Optional[float] = None
134        self._profiles_dir = profiles_dir
135        super().__init__(context, path)
def load(self) -> sqlmesh.core.loader.LoadedProject:
137    def load(self) -> LoadedProject:
138        self._projects = []
139        return super().load()

Loads all macros and models in the context's path.

Returns:

A loaded project object.