sqlmesh.dbt.loader
1from __future__ import annotations 2 3import logging 4import sys 5import typing as t 6import sqlmesh.core.dialect as d 7from pathlib import Path 8from collections import defaultdict 9from sqlmesh.core.config import ( 10 Config, 11 ConnectionConfig, 12 GatewayConfig, 13 ModelDefaultsConfig, 14 DbtConfig as RootDbtConfig, 15) 16from sqlmesh.core.environment import EnvironmentStatements 17from sqlmesh.core.loader import CacheBase, LoadedProject, Loader 18from sqlmesh.core.macros import MacroRegistry, macro 19from sqlmesh.core.model import Model, ModelCache 20from sqlmesh.core.signal import signal 21from sqlmesh.dbt.basemodel import BMC, BaseModelConfig 22from sqlmesh.dbt.common import Dependencies 23from sqlmesh.dbt.context import DbtContext 24from sqlmesh.dbt.model import ModelConfig 25from sqlmesh.dbt.profile import Profile 26from sqlmesh.dbt.project import Project 27from sqlmesh.dbt.target import TargetConfig 28from sqlmesh.utils import UniqueKeyDict 29from sqlmesh.utils.errors import ConfigError, MissingModelError, BaseMissingReferenceError 30from sqlmesh.utils.jinja import ( 31 JinjaMacroRegistry, 32 make_jinja_registry, 33) 34 35if sys.version_info >= (3, 12): 36 from importlib import metadata 37else: 38 import importlib_metadata as metadata # type: ignore 39 40if t.TYPE_CHECKING: 41 from sqlmesh.core.audit import Audit, ModelAudit 42 from sqlmesh.core.context import GenericContext 43 44logger = logging.getLogger(__name__) 45 46 47def sqlmesh_config( 48 project_root: t.Optional[Path] = None, 49 state_connection: t.Optional[ConnectionConfig] = None, 50 dbt_profile_name: t.Optional[str] = None, 51 dbt_target_name: t.Optional[str] = None, 52 variables: t.Optional[t.Dict[str, t.Any]] = None, 53 threads: t.Optional[int] = None, 54 register_comments: t.Optional[bool] = None, 55 infer_state_schema_name: bool = False, 56 profiles_dir: t.Optional[Path] = None, 57 **kwargs: t.Any, 58) -> Config: 59 project_root = project_root or Path() 60 context = DbtContext( 61 project_root=project_root, profiles_dir=profiles_dir, profile_name=dbt_profile_name 62 ) 63 64 # note: Profile.load() is called twice with different DbtContext's: 65 # - once here with the above DbtContext (to determine connnection / gateway config which has to be set up before everything else) 66 # - again on the SQLMesh side via GenericContext.load() -> DbtLoader._load_projects() -> Project.load() which constructs a fresh DbtContext and ignores the above one 67 # it's important to ensure that the DbtContext created within the DbtLoader uses the same project root / profiles dir that we use here 68 profile = Profile.load(context, target_name=dbt_target_name) 69 model_defaults = kwargs.pop("model_defaults", ModelDefaultsConfig()) 70 if model_defaults.dialect is None: 71 model_defaults.dialect = profile.target.dialect 72 73 target_to_sqlmesh_args = { 74 "register_comments": register_comments or False, 75 } 76 77 loader = kwargs.pop("loader", DbtLoader) 78 if not issubclass(loader, DbtLoader): 79 raise ConfigError("The loader must be a DbtLoader.") 80 81 if threads is not None: 82 # the to_sqlmesh() function on TargetConfig maps self.threads -> concurrent_tasks 83 profile.target.threads = threads 84 85 gateway_kwargs = {} 86 if infer_state_schema_name: 87 profile_name = context.profile_name 88 89 # Note: we deliberately isolate state based on the target *schema* and not the target name. 90 # It is assumed that the project will define a target, eg 'dev', and then in each users own ~/.dbt/profiles.yml the schema 91 # for the 'dev' target is overriden to something user-specific, rather than making the target name itself user-specific. 92 # This means that the schema name is the indicator of isolated state, not the target name which may be re-used across multiple schemas. 93 target_schema = profile.target.schema_ 94 95 # dbt-core doesnt allow schema to be undefined, but it does allow an empty string, and then just 96 # fails at runtime when `CREATE SCHEMA ""` doesnt work 97 if not target_schema: 98 raise ConfigError( 99 f"Target '{profile.target_name}' does not specify a schema.\n" 100 "A schema is required in order to infer where to store SQLMesh state" 101 ) 102 103 inferred_state_schema_name = f"sqlmesh_state_{profile_name}_{target_schema}" 104 logger.info("Inferring state schema: %s", inferred_state_schema_name) 105 gateway_kwargs["state_schema"] = inferred_state_schema_name 106 107 return Config( 108 loader=loader, 109 loader_kwargs=dict(profiles_dir=profiles_dir), 110 model_defaults=model_defaults, 111 variables=variables or {}, 112 dbt=RootDbtConfig(infer_state_schema_name=infer_state_schema_name), 113 **{ 114 "default_gateway": profile.target_name if "gateways" not in kwargs else "", 115 "gateways": { 116 profile.target_name: GatewayConfig( 117 connection=profile.target.to_sqlmesh(**target_to_sqlmesh_args), 118 state_connection=state_connection, 119 **gateway_kwargs, 120 ) 121 }, # type: ignore 122 **kwargs, 123 }, 124 ) 125 126 127class DbtLoader(Loader): 128 def __init__( 129 self, context: GenericContext, path: Path, profiles_dir: t.Optional[Path] = None 130 ) -> None: 131 self._projects: t.List[Project] = [] 132 self._macros_max_mtime: t.Optional[float] = None 133 self._profiles_dir = profiles_dir 134 super().__init__(context, path) 135 136 def load(self) -> LoadedProject: 137 self._projects = [] 138 return super().load() 139 140 def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: 141 macro_files = list(Path(self.config_path, "macros").glob("**/*.sql")) 142 143 for file in macro_files: 144 self._track_file(file) 145 146 # This doesn't do anything, the actual content will be loaded from the manifest 147 return ( 148 macro.get_registry(), 149 JinjaMacroRegistry(), 150 ) 151 152 def _load_models( 153 self, 154 macros: MacroRegistry, 155 jinja_macros: JinjaMacroRegistry, 156 gateway: t.Optional[str], 157 audits: UniqueKeyDict[str, ModelAudit], 158 signals: UniqueKeyDict[str, signal], 159 ) -> UniqueKeyDict[str, Model]: 160 models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") 161 162 def _to_sqlmesh(config: BMC, context: DbtContext) -> Model: 163 logger.debug("Converting '%s' to sqlmesh format", config.canonical_name(context)) 164 return config.to_sqlmesh( 165 context, 166 audit_definitions=audits, 167 virtual_environment_mode=self.config.virtual_environment_mode, 168 ) 169 170 for project in self._load_projects(): 171 macros_max_mtime = self._macros_max_mtime 172 yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder( 173 project.context.project_root 174 ) 175 cache = DbtLoader._Cache(self, project, macros_max_mtime, yaml_max_mtimes) 176 177 logger.debug("Converting models to sqlmesh") 178 # Now that config is rendered, create the sqlmesh models 179 for package in project.packages.values(): 180 package_context = project.context.copy() 181 package_context.set_and_render_variables(package.variables, package.name) 182 package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds} 183 184 package_models_by_path: t.Dict[Path, t.List[BaseModelConfig]] = defaultdict(list) 185 for model in package_models.values(): 186 if isinstance(model, ModelConfig) and not model.sql.strip(): 187 logger.info(f"Skipping empty model '{model.name}' at path '{model.path}'.") 188 continue 189 package_models_by_path[model.path].append(model) 190 191 for path, path_models in package_models_by_path.items(): 192 sqlmesh_models = cache.get_or_load_models( 193 path, 194 loader=lambda: [ 195 _to_sqlmesh(model, package_context) for model in path_models 196 ], 197 ) 198 for sqlmesh_model in sqlmesh_models: 199 models[sqlmesh_model.fqn] = sqlmesh_model 200 201 models.update(self._load_external_models(audits, cache)) 202 203 return models 204 205 def _load_audits( 206 self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry 207 ) -> UniqueKeyDict[str, Audit]: 208 audits: UniqueKeyDict = UniqueKeyDict("audits") 209 210 for project in self._load_projects(): 211 logger.debug("Converting audits to sqlmesh") 212 for package in project.packages.values(): 213 package_context = project.context.copy() 214 package_context.set_and_render_variables(package.variables, package.name) 215 for test in package.tests.values(): 216 logger.debug("Converting '%s' to sqlmesh format", test.name) 217 try: 218 audits[test.canonical_name] = test.to_sqlmesh(package_context) 219 220 except BaseMissingReferenceError as e: 221 ref_type = "model" if isinstance(e, MissingModelError) else "source" 222 logger.warning( 223 "Skipping audit '%s' because %s '%s' is not a valid ref", 224 test.name, 225 ref_type, 226 e.ref, 227 ) 228 229 return audits 230 231 def _load_projects(self) -> t.List[Project]: 232 if not self._projects: 233 target_name = self.context.selected_gateway 234 235 self._projects = [] 236 237 project = Project.load( 238 DbtContext( 239 project_root=self.config_path, 240 profiles_dir=self._profiles_dir, 241 target_name=target_name, 242 sqlmesh_config=self.config, 243 ), 244 variables=self.config.variables, 245 ) 246 247 self._projects.append(project) 248 249 context_default_catalog = self.context.default_catalog or "" 250 if project.context.target.database != context_default_catalog: 251 raise ConfigError( 252 f"Project default catalog ('{project.context.target.database}') does not match context default catalog ('{context_default_catalog}')." 253 ) 254 for path in project.project_files: 255 self._track_file(path) 256 257 context = project.context 258 259 macros_mtimes: t.List[float] = [] 260 261 for package_name, package in project.packages.items(): 262 context.add_sources(package.sources) 263 context.add_seeds(package.seeds) 264 context.add_models(package.models) 265 macros_mtimes.extend( 266 [ 267 self._path_mtimes[m.path] 268 for m in package.macros.values() 269 if m.path in self._path_mtimes 270 ] 271 ) 272 273 for package_name, macro_infos in context.manifest.all_macros.items(): 274 context.add_macros(macro_infos, package=package_name) 275 276 self._macros_max_mtime = max(macros_mtimes) if macros_mtimes else None 277 278 return self._projects 279 280 def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]: 281 requirements, excluded_requirements = super()._load_requirements() 282 283 target_packages = ["dbt-core"] 284 for project in self._load_projects(): 285 target_packages.append(f"dbt-{project.context.target.type}") 286 287 for target_package in target_packages: 288 if target_package in requirements or target_package in excluded_requirements: 289 continue 290 try: 291 requirements[target_package] = metadata.version(target_package) 292 except metadata.PackageNotFoundError: 293 from sqlmesh.core.console import get_console 294 295 get_console().log_warning(f"dbt package {target_package} is not installed.") 296 297 return requirements, excluded_requirements 298 299 def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]: 300 """Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively.""" 301 302 hooks_by_package_name: t.Dict[str, EnvironmentStatements] = {} 303 project_names: t.Set[str] = set() 304 dialect = self.config.dialect 305 for project in self._load_projects(): 306 for package_name, package in project.packages.items(): 307 package_context = project.context.copy() 308 package_context.set_and_render_variables(package.variables, package_name) 309 on_run_start: t.List[str] = [ 310 on_run_hook.sql 311 for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index) 312 ] 313 on_run_end: t.List[str] = [ 314 on_run_hook.sql 315 for on_run_hook in sorted(package.on_run_end.values(), key=lambda h: h.index) 316 ] 317 318 if on_run_start or on_run_end: 319 dependencies = Dependencies() 320 for hook in [*package.on_run_start.values(), *package.on_run_end.values()]: 321 dependencies = dependencies.union(hook.dependencies) 322 323 statements_context = package_context.context_for_dependencies(dependencies) 324 jinja_registry = make_jinja_registry( 325 statements_context.jinja_macros, package_name, set(dependencies.macros) 326 ) 327 jinja_registry.add_globals(statements_context.jinja_globals) 328 329 hooks_by_package_name[package_name] = EnvironmentStatements( 330 before_all=[ 331 d.jinja_statement(stmt).sql(dialect=dialect) 332 for stmt in on_run_start or [] 333 ], 334 after_all=[ 335 d.jinja_statement(stmt).sql(dialect=dialect) 336 for stmt in on_run_end or [] 337 ], 338 python_env={}, 339 jinja_macros=jinja_registry, 340 project=package_name, 341 ) 342 project_names.add(package_name) 343 344 return [ 345 statements 346 for _, statements in sorted( 347 hooks_by_package_name.items(), 348 key=lambda item: 0 if item[0] in project_names else 1, 349 ) 350 ] 351 352 def _compute_yaml_max_mtime_per_subfolder( 353 self, root: Path, visited: t.Optional[t.Set[Path]] = None 354 ) -> t.Dict[Path, float]: 355 root = root.resolve() 356 visited = visited or set() 357 if not root.is_dir() or root in visited: 358 return {} 359 360 visited.add(root) 361 362 result = {} 363 max_mtime: t.Optional[float] = None 364 365 for nested in root.iterdir(): 366 try: 367 if nested.is_dir(): 368 result.update( 369 self._compute_yaml_max_mtime_per_subfolder(nested, visited=visited) 370 ) 371 elif nested.suffix.lower() in (".yaml", ".yml"): 372 yaml_mtime = self._path_mtimes.get(nested) 373 if yaml_mtime: 374 max_mtime = ( 375 max(max_mtime, yaml_mtime) if max_mtime is not None else yaml_mtime 376 ) 377 except PermissionError: 378 pass 379 380 if max_mtime is not None: 381 result[root] = max_mtime 382 383 return result 384 385 class _Cache(CacheBase): 386 MAX_ENTRY_NAME_LENGTH = 200 387 388 def __init__( 389 self, 390 loader: DbtLoader, 391 project: Project, 392 macros_max_mtime: t.Optional[float], 393 yaml_max_mtimes: t.Dict[Path, float], 394 ): 395 self._loader = loader 396 self._project = project 397 self._macros_max_mtime = macros_max_mtime 398 self._yaml_max_mtimes = yaml_max_mtimes 399 400 target = t.cast(TargetConfig, project.context.target) 401 cache_dir = loader.context.cache_dir / target.name 402 self._model_cache = ModelCache(cache_dir) 403 404 def get_or_load_models( 405 self, target_path: Path, loader: t.Callable[[], t.List[Model]] 406 ) -> t.List[Model]: 407 models = self._model_cache.get_or_load( 408 self._cache_entry_name(target_path), 409 self._cache_entry_id(target_path), 410 loader=loader, 411 ) 412 for model in models: 413 model._path = target_path 414 415 return models 416 417 def put(self, models: t.List[Model], path: Path) -> bool: 418 return self._model_cache.put( 419 models, 420 self._cache_entry_name(path), 421 self._cache_entry_id(path), 422 ) 423 424 def get(self, path: Path) -> t.List[Model]: 425 return self._model_cache.get( 426 self._cache_entry_name(path), 427 self._cache_entry_id(path), 428 ) 429 430 def _cache_entry_name(self, target_path: Path) -> str: 431 try: 432 path_for_name = target_path.absolute().relative_to( 433 self._project.context.project_root.absolute() 434 ) 435 except ValueError: 436 path_for_name = target_path 437 name = "__".join(path_for_name.parts).replace(path_for_name.suffix, "") 438 if len(name) > self.MAX_ENTRY_NAME_LENGTH: 439 return name[len(name) - self.MAX_ENTRY_NAME_LENGTH :] 440 return name 441 442 def _cache_entry_id(self, target_path: Path) -> str: 443 max_mtime = self._max_mtime_for_path(target_path) 444 return "__".join( 445 [ 446 str(int(max_mtime)) if max_mtime is not None else "na", 447 self._loader.config.fingerprint, 448 ] 449 ) 450 451 def _max_mtime_for_path(self, target_path: Path) -> t.Optional[float]: 452 project_root = self._project.context.project_root 453 454 try: 455 target_path.absolute().relative_to(project_root.absolute()) 456 except ValueError: 457 return None 458 459 mtimes = [ 460 self._loader._path_mtimes.get(target_path), 461 self._loader._path_mtimes.get(self._project.profile.path), 462 # FIXME: take into account which macros are actually referenced in the target model. 463 self._macros_max_mtime, 464 ] 465 466 cursor = target_path 467 while cursor != project_root: 468 cursor = cursor.parent 469 mtimes.append(self._yaml_max_mtimes.get(cursor)) 470 471 non_null_mtimes = [t for t in mtimes if t is not None] 472 return max(non_null_mtimes) if non_null_mtimes else None
logger =
<Logger sqlmesh.dbt.loader (WARNING)>
def
sqlmesh_config( project_root: Optional[pathlib.Path] = None, state_connection: Optional[sqlmesh.core.config.connection.ConnectionConfig] = None, dbt_profile_name: Optional[str] = None, dbt_target_name: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, threads: Optional[int] = None, register_comments: Optional[bool] = None, infer_state_schema_name: bool = False, profiles_dir: Optional[pathlib.Path] = None, **kwargs: Any) -> sqlmesh.core.config.root.Config:
48def sqlmesh_config( 49 project_root: t.Optional[Path] = None, 50 state_connection: t.Optional[ConnectionConfig] = None, 51 dbt_profile_name: t.Optional[str] = None, 52 dbt_target_name: t.Optional[str] = None, 53 variables: t.Optional[t.Dict[str, t.Any]] = None, 54 threads: t.Optional[int] = None, 55 register_comments: t.Optional[bool] = None, 56 infer_state_schema_name: bool = False, 57 profiles_dir: t.Optional[Path] = None, 58 **kwargs: t.Any, 59) -> Config: 60 project_root = project_root or Path() 61 context = DbtContext( 62 project_root=project_root, profiles_dir=profiles_dir, profile_name=dbt_profile_name 63 ) 64 65 # note: Profile.load() is called twice with different DbtContext's: 66 # - once here with the above DbtContext (to determine connnection / gateway config which has to be set up before everything else) 67 # - again on the SQLMesh side via GenericContext.load() -> DbtLoader._load_projects() -> Project.load() which constructs a fresh DbtContext and ignores the above one 68 # it's important to ensure that the DbtContext created within the DbtLoader uses the same project root / profiles dir that we use here 69 profile = Profile.load(context, target_name=dbt_target_name) 70 model_defaults = kwargs.pop("model_defaults", ModelDefaultsConfig()) 71 if model_defaults.dialect is None: 72 model_defaults.dialect = profile.target.dialect 73 74 target_to_sqlmesh_args = { 75 "register_comments": register_comments or False, 76 } 77 78 loader = kwargs.pop("loader", DbtLoader) 79 if not issubclass(loader, DbtLoader): 80 raise ConfigError("The loader must be a DbtLoader.") 81 82 if threads is not None: 83 # the to_sqlmesh() function on TargetConfig maps self.threads -> concurrent_tasks 84 profile.target.threads = threads 85 86 gateway_kwargs = {} 87 if infer_state_schema_name: 88 profile_name = context.profile_name 89 90 # Note: we deliberately isolate state based on the target *schema* and not the target name. 91 # It is assumed that the project will define a target, eg 'dev', and then in each users own ~/.dbt/profiles.yml the schema 92 # for the 'dev' target is overriden to something user-specific, rather than making the target name itself user-specific. 93 # This means that the schema name is the indicator of isolated state, not the target name which may be re-used across multiple schemas. 94 target_schema = profile.target.schema_ 95 96 # dbt-core doesnt allow schema to be undefined, but it does allow an empty string, and then just 97 # fails at runtime when `CREATE SCHEMA ""` doesnt work 98 if not target_schema: 99 raise ConfigError( 100 f"Target '{profile.target_name}' does not specify a schema.\n" 101 "A schema is required in order to infer where to store SQLMesh state" 102 ) 103 104 inferred_state_schema_name = f"sqlmesh_state_{profile_name}_{target_schema}" 105 logger.info("Inferring state schema: %s", inferred_state_schema_name) 106 gateway_kwargs["state_schema"] = inferred_state_schema_name 107 108 return Config( 109 loader=loader, 110 loader_kwargs=dict(profiles_dir=profiles_dir), 111 model_defaults=model_defaults, 112 variables=variables or {}, 113 dbt=RootDbtConfig(infer_state_schema_name=infer_state_schema_name), 114 **{ 115 "default_gateway": profile.target_name if "gateways" not in kwargs else "", 116 "gateways": { 117 profile.target_name: GatewayConfig( 118 connection=profile.target.to_sqlmesh(**target_to_sqlmesh_args), 119 state_connection=state_connection, 120 **gateway_kwargs, 121 ) 122 }, # type: ignore 123 **kwargs, 124 }, 125 )
128class DbtLoader(Loader): 129 def __init__( 130 self, context: GenericContext, path: Path, profiles_dir: t.Optional[Path] = None 131 ) -> None: 132 self._projects: t.List[Project] = [] 133 self._macros_max_mtime: t.Optional[float] = None 134 self._profiles_dir = profiles_dir 135 super().__init__(context, path) 136 137 def load(self) -> LoadedProject: 138 self._projects = [] 139 return super().load() 140 141 def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: 142 macro_files = list(Path(self.config_path, "macros").glob("**/*.sql")) 143 144 for file in macro_files: 145 self._track_file(file) 146 147 # This doesn't do anything, the actual content will be loaded from the manifest 148 return ( 149 macro.get_registry(), 150 JinjaMacroRegistry(), 151 ) 152 153 def _load_models( 154 self, 155 macros: MacroRegistry, 156 jinja_macros: JinjaMacroRegistry, 157 gateway: t.Optional[str], 158 audits: UniqueKeyDict[str, ModelAudit], 159 signals: UniqueKeyDict[str, signal], 160 ) -> UniqueKeyDict[str, Model]: 161 models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") 162 163 def _to_sqlmesh(config: BMC, context: DbtContext) -> Model: 164 logger.debug("Converting '%s' to sqlmesh format", config.canonical_name(context)) 165 return config.to_sqlmesh( 166 context, 167 audit_definitions=audits, 168 virtual_environment_mode=self.config.virtual_environment_mode, 169 ) 170 171 for project in self._load_projects(): 172 macros_max_mtime = self._macros_max_mtime 173 yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder( 174 project.context.project_root 175 ) 176 cache = DbtLoader._Cache(self, project, macros_max_mtime, yaml_max_mtimes) 177 178 logger.debug("Converting models to sqlmesh") 179 # Now that config is rendered, create the sqlmesh models 180 for package in project.packages.values(): 181 package_context = project.context.copy() 182 package_context.set_and_render_variables(package.variables, package.name) 183 package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds} 184 185 package_models_by_path: t.Dict[Path, t.List[BaseModelConfig]] = defaultdict(list) 186 for model in package_models.values(): 187 if isinstance(model, ModelConfig) and not model.sql.strip(): 188 logger.info(f"Skipping empty model '{model.name}' at path '{model.path}'.") 189 continue 190 package_models_by_path[model.path].append(model) 191 192 for path, path_models in package_models_by_path.items(): 193 sqlmesh_models = cache.get_or_load_models( 194 path, 195 loader=lambda: [ 196 _to_sqlmesh(model, package_context) for model in path_models 197 ], 198 ) 199 for sqlmesh_model in sqlmesh_models: 200 models[sqlmesh_model.fqn] = sqlmesh_model 201 202 models.update(self._load_external_models(audits, cache)) 203 204 return models 205 206 def _load_audits( 207 self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry 208 ) -> UniqueKeyDict[str, Audit]: 209 audits: UniqueKeyDict = UniqueKeyDict("audits") 210 211 for project in self._load_projects(): 212 logger.debug("Converting audits to sqlmesh") 213 for package in project.packages.values(): 214 package_context = project.context.copy() 215 package_context.set_and_render_variables(package.variables, package.name) 216 for test in package.tests.values(): 217 logger.debug("Converting '%s' to sqlmesh format", test.name) 218 try: 219 audits[test.canonical_name] = test.to_sqlmesh(package_context) 220 221 except BaseMissingReferenceError as e: 222 ref_type = "model" if isinstance(e, MissingModelError) else "source" 223 logger.warning( 224 "Skipping audit '%s' because %s '%s' is not a valid ref", 225 test.name, 226 ref_type, 227 e.ref, 228 ) 229 230 return audits 231 232 def _load_projects(self) -> t.List[Project]: 233 if not self._projects: 234 target_name = self.context.selected_gateway 235 236 self._projects = [] 237 238 project = Project.load( 239 DbtContext( 240 project_root=self.config_path, 241 profiles_dir=self._profiles_dir, 242 target_name=target_name, 243 sqlmesh_config=self.config, 244 ), 245 variables=self.config.variables, 246 ) 247 248 self._projects.append(project) 249 250 context_default_catalog = self.context.default_catalog or "" 251 if project.context.target.database != context_default_catalog: 252 raise ConfigError( 253 f"Project default catalog ('{project.context.target.database}') does not match context default catalog ('{context_default_catalog}')." 254 ) 255 for path in project.project_files: 256 self._track_file(path) 257 258 context = project.context 259 260 macros_mtimes: t.List[float] = [] 261 262 for package_name, package in project.packages.items(): 263 context.add_sources(package.sources) 264 context.add_seeds(package.seeds) 265 context.add_models(package.models) 266 macros_mtimes.extend( 267 [ 268 self._path_mtimes[m.path] 269 for m in package.macros.values() 270 if m.path in self._path_mtimes 271 ] 272 ) 273 274 for package_name, macro_infos in context.manifest.all_macros.items(): 275 context.add_macros(macro_infos, package=package_name) 276 277 self._macros_max_mtime = max(macros_mtimes) if macros_mtimes else None 278 279 return self._projects 280 281 def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]: 282 requirements, excluded_requirements = super()._load_requirements() 283 284 target_packages = ["dbt-core"] 285 for project in self._load_projects(): 286 target_packages.append(f"dbt-{project.context.target.type}") 287 288 for target_package in target_packages: 289 if target_package in requirements or target_package in excluded_requirements: 290 continue 291 try: 292 requirements[target_package] = metadata.version(target_package) 293 except metadata.PackageNotFoundError: 294 from sqlmesh.core.console import get_console 295 296 get_console().log_warning(f"dbt package {target_package} is not installed.") 297 298 return requirements, excluded_requirements 299 300 def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]: 301 """Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively.""" 302 303 hooks_by_package_name: t.Dict[str, EnvironmentStatements] = {} 304 project_names: t.Set[str] = set() 305 dialect = self.config.dialect 306 for project in self._load_projects(): 307 for package_name, package in project.packages.items(): 308 package_context = project.context.copy() 309 package_context.set_and_render_variables(package.variables, package_name) 310 on_run_start: t.List[str] = [ 311 on_run_hook.sql 312 for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index) 313 ] 314 on_run_end: t.List[str] = [ 315 on_run_hook.sql 316 for on_run_hook in sorted(package.on_run_end.values(), key=lambda h: h.index) 317 ] 318 319 if on_run_start or on_run_end: 320 dependencies = Dependencies() 321 for hook in [*package.on_run_start.values(), *package.on_run_end.values()]: 322 dependencies = dependencies.union(hook.dependencies) 323 324 statements_context = package_context.context_for_dependencies(dependencies) 325 jinja_registry = make_jinja_registry( 326 statements_context.jinja_macros, package_name, set(dependencies.macros) 327 ) 328 jinja_registry.add_globals(statements_context.jinja_globals) 329 330 hooks_by_package_name[package_name] = EnvironmentStatements( 331 before_all=[ 332 d.jinja_statement(stmt).sql(dialect=dialect) 333 for stmt in on_run_start or [] 334 ], 335 after_all=[ 336 d.jinja_statement(stmt).sql(dialect=dialect) 337 for stmt in on_run_end or [] 338 ], 339 python_env={}, 340 jinja_macros=jinja_registry, 341 project=package_name, 342 ) 343 project_names.add(package_name) 344 345 return [ 346 statements 347 for _, statements in sorted( 348 hooks_by_package_name.items(), 349 key=lambda item: 0 if item[0] in project_names else 1, 350 ) 351 ] 352 353 def _compute_yaml_max_mtime_per_subfolder( 354 self, root: Path, visited: t.Optional[t.Set[Path]] = None 355 ) -> t.Dict[Path, float]: 356 root = root.resolve() 357 visited = visited or set() 358 if not root.is_dir() or root in visited: 359 return {} 360 361 visited.add(root) 362 363 result = {} 364 max_mtime: t.Optional[float] = None 365 366 for nested in root.iterdir(): 367 try: 368 if nested.is_dir(): 369 result.update( 370 self._compute_yaml_max_mtime_per_subfolder(nested, visited=visited) 371 ) 372 elif nested.suffix.lower() in (".yaml", ".yml"): 373 yaml_mtime = self._path_mtimes.get(nested) 374 if yaml_mtime: 375 max_mtime = ( 376 max(max_mtime, yaml_mtime) if max_mtime is not None else yaml_mtime 377 ) 378 except PermissionError: 379 pass 380 381 if max_mtime is not None: 382 result[root] = max_mtime 383 384 return result 385 386 class _Cache(CacheBase): 387 MAX_ENTRY_NAME_LENGTH = 200 388 389 def __init__( 390 self, 391 loader: DbtLoader, 392 project: Project, 393 macros_max_mtime: t.Optional[float], 394 yaml_max_mtimes: t.Dict[Path, float], 395 ): 396 self._loader = loader 397 self._project = project 398 self._macros_max_mtime = macros_max_mtime 399 self._yaml_max_mtimes = yaml_max_mtimes 400 401 target = t.cast(TargetConfig, project.context.target) 402 cache_dir = loader.context.cache_dir / target.name 403 self._model_cache = ModelCache(cache_dir) 404 405 def get_or_load_models( 406 self, target_path: Path, loader: t.Callable[[], t.List[Model]] 407 ) -> t.List[Model]: 408 models = self._model_cache.get_or_load( 409 self._cache_entry_name(target_path), 410 self._cache_entry_id(target_path), 411 loader=loader, 412 ) 413 for model in models: 414 model._path = target_path 415 416 return models 417 418 def put(self, models: t.List[Model], path: Path) -> bool: 419 return self._model_cache.put( 420 models, 421 self._cache_entry_name(path), 422 self._cache_entry_id(path), 423 ) 424 425 def get(self, path: Path) -> t.List[Model]: 426 return self._model_cache.get( 427 self._cache_entry_name(path), 428 self._cache_entry_id(path), 429 ) 430 431 def _cache_entry_name(self, target_path: Path) -> str: 432 try: 433 path_for_name = target_path.absolute().relative_to( 434 self._project.context.project_root.absolute() 435 ) 436 except ValueError: 437 path_for_name = target_path 438 name = "__".join(path_for_name.parts).replace(path_for_name.suffix, "") 439 if len(name) > self.MAX_ENTRY_NAME_LENGTH: 440 return name[len(name) - self.MAX_ENTRY_NAME_LENGTH :] 441 return name 442 443 def _cache_entry_id(self, target_path: Path) -> str: 444 max_mtime = self._max_mtime_for_path(target_path) 445 return "__".join( 446 [ 447 str(int(max_mtime)) if max_mtime is not None else "na", 448 self._loader.config.fingerprint, 449 ] 450 ) 451 452 def _max_mtime_for_path(self, target_path: Path) -> t.Optional[float]: 453 project_root = self._project.context.project_root 454 455 try: 456 target_path.absolute().relative_to(project_root.absolute()) 457 except ValueError: 458 return None 459 460 mtimes = [ 461 self._loader._path_mtimes.get(target_path), 462 self._loader._path_mtimes.get(self._project.profile.path), 463 # FIXME: take into account which macros are actually referenced in the target model. 464 self._macros_max_mtime, 465 ] 466 467 cursor = target_path 468 while cursor != project_root: 469 cursor = cursor.parent 470 mtimes.append(self._yaml_max_mtimes.get(cursor)) 471 472 non_null_mtimes = [t for t in mtimes if t is not None] 473 return max(non_null_mtimes) if non_null_mtimes else None
Abstract base class to load macros and models for a context
DbtLoader( context: sqlmesh.core.context.GenericContext, path: pathlib.Path, profiles_dir: Optional[pathlib.Path] = None)