sqlmesh.dbt.builtin
1from __future__ import annotations 2 3import json 4import logging 5import os 6import typing as t 7from ast import literal_eval 8from dataclasses import asdict 9 10import agate 11import jinja2 12from dbt import version 13from dbt.adapters.base import BaseRelation, Column 14from ruamel.yaml import YAMLError 15from sqlglot import Dialect 16 17from sqlmesh.core.console import get_console 18from sqlmesh.core.engine_adapter import EngineAdapter 19from sqlmesh.core.model.definition import SqlModel 20from sqlmesh.core.snapshot.definition import DeployabilityIndex 21from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter 22from sqlmesh.dbt.common import RAW_CODE_KEY 23from sqlmesh.dbt.relation import Policy 24from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS 25from sqlmesh.dbt.util import DBT_VERSION 26from sqlmesh.utils import AttributeDict, debug_mode_enabled, yaml 27from sqlmesh.utils.date import now 28from sqlmesh.utils.errors import ConfigError 29from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReference, MacroReturnVal 30 31logger = logging.getLogger(__name__) 32 33 34class Exceptions: 35 def raise_compiler_error(self, msg: str) -> None: 36 if DBT_VERSION >= (1, 4, 0): 37 from dbt.exceptions import CompilationError 38 39 raise CompilationError(msg) 40 else: 41 from dbt.exceptions import CompilationException # type: ignore 42 43 raise CompilationException(msg) 44 45 def raise_not_implemented(self, msg: str) -> None: 46 raise NotImplementedError(msg) 47 48 def warn(self, msg: str) -> str: 49 logger.warning(msg) 50 return "" 51 52 53def try_or_compiler_error( 54 message_if_exception: str, func: t.Callable, *args: t.Any, **kwargs: t.Any 55) -> t.Any: 56 try: 57 return func(*args, **kwargs) 58 except Exception: 59 if DBT_VERSION >= (1, 4, 0): 60 from dbt.exceptions import CompilationError 61 62 raise CompilationError(message_if_exception) 63 else: 64 from dbt.exceptions import CompilationException # type: ignore 65 66 raise CompilationException(message_if_exception) 67 68 69class Api: 70 def __init__(self, dialect: t.Optional[str]) -> None: 71 if dialect: 72 config_class = TARGET_TYPE_TO_CONFIG_CLASS[ 73 Dialect.get_or_raise(dialect).__class__.__name__.lower() 74 ] 75 self.Relation = config_class.relation_class 76 self.Column = config_class.column_class 77 self.quote_policy = config_class.quote_policy 78 else: 79 self.Relation = BaseRelation 80 self.Column = Column 81 self.quote_policy = Policy() 82 83 84class Flags: 85 def __init__( 86 self, 87 full_refresh: t.Optional[str] = None, 88 store_failures: t.Optional[str] = None, 89 which: t.Optional[str] = None, 90 ) -> None: 91 # Temporary placeholder values for now (these are generally passed from the CLI) 92 self.FULL_REFRESH = full_refresh 93 self.STORE_FAILURES = store_failures 94 self.WHICH = which 95 96 97class Modules: 98 def __init__(self) -> None: 99 import datetime 100 import itertools 101 import re 102 103 try: 104 import pytz 105 106 self.pytz = pytz 107 except ImportError: 108 pass 109 110 self.datetime = datetime 111 self.re = re 112 self.itertools = itertools 113 114 115class SQLExecution: 116 def __init__(self, adapter: BaseAdapter): 117 self.adapter = adapter 118 self._results: t.Dict[str, AttributeDict] = {} 119 120 def store_result(self, name: str, response: t.Any, agate_table: t.Optional[agate.Table]) -> str: 121 from sqlmesh.dbt.util import empty_table, as_matrix 122 123 if agate_table is None: 124 agate_table = empty_table() 125 126 self._results[name] = AttributeDict( 127 { 128 "response": response, 129 "data": as_matrix(agate_table), 130 "table": agate_table, 131 } 132 ) 133 return "" 134 135 def load_result(self, name: str) -> t.Optional[AttributeDict]: 136 return self._results.get(name) 137 138 def run_query(self, sql: str) -> agate.Table: 139 self.statement("run_query_statement", fetch_result=True, auto_begin=False, caller=sql) 140 resp = self.load_result("run_query_statement") 141 assert resp is not None 142 return resp["table"] 143 144 def statement( 145 self, 146 name: t.Optional[str], 147 fetch_result: bool = False, 148 auto_begin: bool = True, 149 language: str = "sql", 150 caller: t.Optional[jinja2.runtime.Macro | str] = None, 151 ) -> str: 152 """ 153 Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional 154 but we make it optional and at the end because we need to match the signature of the jinja2 macro. 155 156 Name is the name that we store the results to which can be retrieved with `load_result`. If name is not 157 provided then the SQL is executed but the results are not stored. 158 """ 159 if not caller: 160 raise RuntimeError( 161 "Statement relies on a caller to be set that is the target SQL to be run" 162 ) 163 sql = caller if isinstance(caller, str) else caller() 164 if language != "sql": 165 raise NotImplementedError( 166 "SQLMesh's dbt integration only supports SQL statements at this time." 167 ) 168 res, table = self.adapter.execute(sql, fetch=fetch_result, auto_begin=auto_begin) 169 if name: 170 self.store_result(name, res, table) 171 return "" 172 173 174class Var: 175 def __init__(self, variables: t.Dict[str, t.Any]) -> None: 176 self.variables = variables 177 178 def __call__(self, name: str, default: t.Optional[t.Any] = None) -> t.Any: 179 return self.variables.get(name, default) 180 181 def has_var(self, name: str) -> bool: 182 return name in self.variables 183 184 185class Config: 186 def __init__(self, config_dict: t.Dict[str, t.Any]) -> None: 187 self._config = config_dict 188 189 def __call__(self, *args: t.Any, **kwargs: t.Any) -> str: 190 if args and kwargs: 191 raise ConfigError( 192 "Invalid inline model config: cannot mix positional and keyword arguments" 193 ) 194 195 if args: 196 if len(args) == 1 and isinstance(args[0], dict): 197 # Single dict argument: config({"materialized": "table"}) 198 self._config.update(args[0]) 199 else: 200 raise ConfigError( 201 f"Invalid inline model config: expected a single dictionary, got {len(args)} arguments" 202 ) 203 elif kwargs: 204 # Keyword arguments: config(materialized="table") 205 self._config.update(kwargs) 206 207 return "" 208 209 def set(self, name: str, value: t.Any) -> str: 210 self._config.update({name: value}) 211 return "" 212 213 def _validate(self, name: str, validator: t.Callable, value: t.Optional[t.Any] = None) -> None: 214 try: 215 validator(value) 216 except Exception as e: 217 raise ConfigError(f"Config validation failed for '{name}': {e}") 218 219 def require(self, name: str, validator: t.Optional[t.Callable] = None) -> t.Any: 220 if name not in self._config: 221 raise ConfigError(f"Missing required config: {name}") 222 223 value = self._config[name] 224 225 if validator is not None: 226 self._validate(name, validator, value) 227 228 return value 229 230 def get( 231 self, name: str, default: t.Any = None, validator: t.Optional[t.Callable] = None 232 ) -> t.Any: 233 value = self._config.get(name, default) 234 235 if validator is not None and value is not None: 236 self._validate(name, validator, value) 237 238 return value 239 240 def persist_relation_docs(self) -> bool: 241 persist_docs = self.get("persist_docs", default={}) 242 if not isinstance(persist_docs, dict): 243 return False 244 return persist_docs.get("relation", False) 245 246 def persist_column_docs(self) -> bool: 247 persist_docs = self.get("persist_docs", default={}) 248 if not isinstance(persist_docs, dict): 249 return False 250 return persist_docs.get("columns", False) 251 252 253def env_var(name: str, default: t.Optional[str] = None) -> t.Optional[str]: 254 if name not in os.environ and default is None: 255 raise ConfigError(f"Missing environment variable '{name}'") 256 return os.environ.get(name, default) 257 258 259def log(msg: str, info: bool = False) -> str: 260 if info: 261 # Write to both log file and stdout 262 logger.info(msg) 263 get_console().log_status_update(msg) 264 else: 265 logger.debug(msg) 266 267 return "" 268 269 270def generate_ref(refs: t.Dict[str, t.Any], api: Api) -> t.Callable: 271 def ref( 272 package: str, name: t.Optional[str] = None, **kwargs: t.Any 273 ) -> t.Optional[BaseRelation]: 274 version = kwargs.get("version", kwargs.get("v")) 275 ref_name = f"{package}.{name}" if name else package 276 277 if version is not None: 278 relation_info = refs.get(f"{ref_name}_v{version}") 279 if relation_info is None: 280 logger.warning( 281 "Could not resolve ref '%s' with version '%s'. Falling back to unversioned reference", 282 ref_name, 283 version, 284 ) 285 relation_info = refs.get(ref_name) 286 else: 287 relation_info = refs.get(ref_name) 288 if not relation_info: 289 versioned_infos = sorted( 290 [(r, info) for r, info in refs.items() if r.startswith(f"{ref_name}_v")], 291 key=lambda i: i[0], 292 ) 293 if versioned_infos: 294 relation_info = versioned_infos[-1][1] 295 296 if relation_info is None: 297 logger.debug("Could not resolve ref '%s', version '%s'", ref_name, version) 298 return None 299 300 return _relation_info_to_relation(relation_info, api.Relation, api.quote_policy) 301 302 return ref 303 304 305def generate_source(sources: t.Dict[str, t.Any], api: Api) -> t.Callable: 306 def source(package: str, name: str) -> t.Optional[BaseRelation]: 307 relation_info = sources.get(f"{package}.{name}") 308 if relation_info is None: 309 logger.debug("Could not resolve source package='%s' name='%s'", package, name) 310 return None 311 312 # Clickhouse uses a 2-level schema.table naming scheme, where the second level is called 313 # a "database" (instead of "schema" as one would reasonably assume). This can lead to confusion 314 # because it is not clear how Clickhouse identifiers map onto dbt's "database" and "schema" fields. 315 # 316 # This confusion can occur in source resolution. If a source's `schema` is not explicitly specified, 317 # the source name is used as the schema by default. 318 # 319 # If a source specified the `database` field and the schema has defaulted to the source name, 320 # we follow dbt-clickhouse in assuming that the user intended for the `database` field to be the 321 # second level identifier. 322 # https://github.com/ClickHouse/dbt-clickhouse/blob/065f3a724fa09205446ecadac7a00d92b2d8c646/dbt/adapters/clickhouse/relation.py#L112 323 # 324 # NOTE: determining relation class based on name so we don't introduce a dependency on dbt-clickhouse 325 if ( 326 api.Relation.__name__ == "ClickHouseRelation" 327 and relation_info.schema == package 328 and relation_info.database 329 ): 330 relation_info["schema"] = relation_info["database"] 331 relation_info["database"] = "" 332 333 return _relation_info_to_relation(relation_info, api.Relation, api.quote_policy) 334 335 return source 336 337 338def return_val(val: t.Any) -> None: 339 raise MacroReturnVal(val) 340 341 342def to_set(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 343 try: 344 return set(value) 345 except TypeError: 346 return default 347 348 349def to_json(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 350 try: 351 return json.dumps(value) 352 except TypeError: 353 return default 354 355 356def from_json(value: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 357 try: 358 return json.loads(value) 359 except (TypeError, json.JSONDecodeError): 360 return default 361 362 363def to_yaml(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 364 try: 365 return yaml.dump(value) 366 except (TypeError, YAMLError): 367 return default 368 369 370def from_yaml(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 371 try: 372 return dict(yaml.load(value, raise_if_empty=False, render_jinja=False)) 373 except (TypeError, YAMLError): 374 return default 375 376 377def do_zip(*args: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 378 try: 379 return list(zip(*args)) 380 except TypeError: 381 return default 382 383 384def as_bool(value: t.Any) -> t.Any: 385 # dbt's jinja TEXT_FILTERS just return the input value as is 386 # https://github.com/dbt-labs/dbt-common/blob/main/dbt_common/clients/jinja.py#L559 387 return value 388 389 390def as_number(value: str) -> t.Any: 391 # dbt's jinja TEXT_FILTERS just return the input value as is 392 # https://github.com/dbt-labs/dbt-common/blob/main/dbt_common/clients/jinja.py#L559 393 return value 394 395 396def _try_literal_eval(value: str) -> t.Any: 397 try: 398 return literal_eval(value) 399 except (ValueError, SyntaxError, MemoryError): 400 return value 401 402 403def debug() -> str: 404 import sys 405 import ipdb # type: ignore 406 407 frame = sys._getframe(3) 408 ipdb.set_trace(frame) 409 return "" 410 411 412BUILTIN_GLOBALS = { 413 "dbt_version": version.__version__, 414 "env_var": env_var, 415 "exceptions": Exceptions(), 416 "fromjson": from_json, 417 "fromyaml": from_yaml, 418 "log": log, 419 "modules": Modules(), 420 "print": log, 421 "return": return_val, 422 "set": to_set, 423 "set_strict": set, 424 "sqlmesh": True, 425 "sqlmesh_incremental": True, 426 "tojson": to_json, 427 "toyaml": to_yaml, 428 "try_or_compiler_error": try_or_compiler_error, 429 "zip": do_zip, 430 "zip_strict": lambda *args: list(zip(*args)), 431} 432 433# Add debug function conditionally both with dbt or sqlmesh equivalent flag 434if os.environ.get("DBT_MACRO_DEBUGGING") or debug_mode_enabled(): 435 BUILTIN_GLOBALS["debug"] = debug 436 437BUILTIN_FILTERS = { 438 "as_bool": as_bool, 439 "as_native": _try_literal_eval, 440 "as_number": as_number, 441 "as_text": lambda v: "" if v is None else str(v), 442} 443 444OVERRIDDEN_MACROS = { 445 MacroReference(package="dbt", name="is_incremental"), 446 MacroReference(name="is_incremental"), 447} 448 449 450def create_builtin_globals( 451 jinja_macros: JinjaMacroRegistry, 452 jinja_globals: t.Dict[str, t.Any], 453 engine_adapter: t.Optional[EngineAdapter], 454) -> t.Dict[str, t.Any]: 455 builtin_globals = BUILTIN_GLOBALS.copy() 456 jinja_globals = jinja_globals.copy() 457 458 target: t.Optional[AttributeDict] = jinja_globals.get("target", None) 459 project_dialect = jinja_globals.pop("dialect", None) or (target.get("type") if target else None) 460 api = Api(project_dialect) 461 462 builtin_globals["api"] = api 463 464 this = jinja_globals.pop("this", None) 465 if this is not None: 466 if not isinstance(this, api.Relation): 467 builtin_globals["this"] = api.Relation.create(**this) 468 else: 469 builtin_globals["this"] = this 470 471 sources = jinja_globals.pop("sources", None) 472 if sources is not None: 473 builtin_globals["source"] = generate_source(sources, api) 474 475 refs = jinja_globals.pop("refs", None) 476 if refs is not None: 477 builtin_globals["ref"] = generate_ref(refs, api) 478 479 variables = jinja_globals.pop("vars", None) 480 if variables is not None: 481 builtin_globals["var"] = Var(variables) 482 483 builtin_globals["config"] = Config(jinja_globals.pop("config", {"tags": []})) 484 485 deployability_index = ( 486 jinja_globals.get("deployability_index") or DeployabilityIndex.all_deployable() 487 ) 488 snapshot = jinja_globals.pop("snapshot", None) 489 490 if snapshot and snapshot.is_incremental: 491 intervals = ( 492 snapshot.intervals 493 if deployability_index.is_deployable(snapshot) 494 else snapshot.dev_intervals 495 ) 496 is_incremental = bool(intervals) 497 498 snapshot_table_exists = jinja_globals.get("snapshot_table_exists") 499 if is_incremental and snapshot_table_exists is not None: 500 # If we know the information about table existence, we can use it to correctly 501 # set the flag 502 is_incremental &= snapshot_table_exists 503 else: 504 is_incremental = False 505 506 builtin_globals["is_incremental"] = lambda: is_incremental 507 508 builtin_globals["builtins"] = AttributeDict( 509 {k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")} 510 ) 511 512 if (model := jinja_globals.pop("model", None)) is not None: 513 if isinstance(model_instance := jinja_globals.pop("model_instance", None), SqlModel): 514 builtin_globals["model"] = AttributeDict( 515 {**model, RAW_CODE_KEY: model_instance.query.name} 516 ) 517 else: 518 builtin_globals["model"] = AttributeDict(model.copy()) 519 520 builtin_globals["flags"] = ( 521 Flags(which="run") if engine_adapter is not None else Flags(which="parse") 522 ) 523 builtin_globals["invocation_args_dict"] = { 524 k.lower(): v for k, v in builtin_globals["flags"].__dict__.items() 525 } 526 527 if engine_adapter is not None: 528 adapter: BaseAdapter = RuntimeAdapter( 529 engine_adapter, 530 jinja_macros, 531 jinja_globals={ 532 **builtin_globals, 533 **jinja_globals, 534 "engine_adapter": engine_adapter, 535 }, 536 relation_type=api.Relation, 537 column_type=api.Column, 538 quote_policy=api.quote_policy, 539 snapshots=jinja_globals.get("snapshots", {}), 540 table_mapping=jinja_globals.get("table_mapping", {}), 541 deployability_index=deployability_index, 542 project_dialect=project_dialect, 543 ) 544 else: 545 adapter = ParsetimeAdapter( 546 jinja_macros, 547 jinja_globals={**builtin_globals, **jinja_globals}, 548 project_dialect=project_dialect, 549 quote_policy=api.quote_policy, 550 ) 551 552 sql_execution = SQLExecution(adapter) 553 builtin_globals.update( 554 { 555 "adapter": adapter, 556 "execute": True, 557 "load_relation": adapter.load_relation, 558 "store_result": sql_execution.store_result, 559 "load_result": sql_execution.load_result, 560 "run_query": sql_execution.run_query, 561 "statement": sql_execution.statement, 562 "graph": adapter.graph, 563 "selected_resources": list(jinja_globals.get("selected_models") or []), 564 "write": lambda input: None, # We don't support writing yet 565 } 566 ) 567 568 builtin_globals["run_started_at"] = jinja_globals.get("execution_dt") or now() 569 builtin_globals["dbt"] = AttributeDict(builtin_globals) 570 builtin_globals["context"] = builtin_globals["dbt"] 571 572 return {**builtin_globals, **jinja_globals} 573 574 575def create_builtin_filters() -> t.Dict[str, t.Callable]: 576 return BUILTIN_FILTERS 577 578 579def _relation_info_to_relation( 580 relation_info: t.Dict[str, t.Any], 581 relation_type: t.Type[BaseRelation], 582 target_quote_policy: Policy, 583) -> BaseRelation: 584 relation_info = relation_info.copy() 585 quote_policy = Policy( 586 **{ 587 **asdict(target_quote_policy), 588 **{k: v for k, v in relation_info.pop("quote_policy", {}).items() if v is not None}, 589 } 590 ) 591 return relation_type.create(**relation_info, quote_policy=quote_policy)
logger =
<Logger sqlmesh.dbt.builtin (WARNING)>
class
Exceptions:
35class Exceptions: 36 def raise_compiler_error(self, msg: str) -> None: 37 if DBT_VERSION >= (1, 4, 0): 38 from dbt.exceptions import CompilationError 39 40 raise CompilationError(msg) 41 else: 42 from dbt.exceptions import CompilationException # type: ignore 43 44 raise CompilationException(msg) 45 46 def raise_not_implemented(self, msg: str) -> None: 47 raise NotImplementedError(msg) 48 49 def warn(self, msg: str) -> str: 50 logger.warning(msg) 51 return ""
def
try_or_compiler_error( message_if_exception: str, func: Callable, *args: Any, **kwargs: Any) -> Any:
54def try_or_compiler_error( 55 message_if_exception: str, func: t.Callable, *args: t.Any, **kwargs: t.Any 56) -> t.Any: 57 try: 58 return func(*args, **kwargs) 59 except Exception: 60 if DBT_VERSION >= (1, 4, 0): 61 from dbt.exceptions import CompilationError 62 63 raise CompilationError(message_if_exception) 64 else: 65 from dbt.exceptions import CompilationException # type: ignore 66 67 raise CompilationException(message_if_exception)
class
Api:
70class Api: 71 def __init__(self, dialect: t.Optional[str]) -> None: 72 if dialect: 73 config_class = TARGET_TYPE_TO_CONFIG_CLASS[ 74 Dialect.get_or_raise(dialect).__class__.__name__.lower() 75 ] 76 self.Relation = config_class.relation_class 77 self.Column = config_class.column_class 78 self.quote_policy = config_class.quote_policy 79 else: 80 self.Relation = BaseRelation 81 self.Column = Column 82 self.quote_policy = Policy()
Api(dialect: Optional[str])
71 def __init__(self, dialect: t.Optional[str]) -> None: 72 if dialect: 73 config_class = TARGET_TYPE_TO_CONFIG_CLASS[ 74 Dialect.get_or_raise(dialect).__class__.__name__.lower() 75 ] 76 self.Relation = config_class.relation_class 77 self.Column = config_class.column_class 78 self.quote_policy = config_class.quote_policy 79 else: 80 self.Relation = BaseRelation 81 self.Column = Column 82 self.quote_policy = Policy()
class
Flags:
85class Flags: 86 def __init__( 87 self, 88 full_refresh: t.Optional[str] = None, 89 store_failures: t.Optional[str] = None, 90 which: t.Optional[str] = None, 91 ) -> None: 92 # Temporary placeholder values for now (these are generally passed from the CLI) 93 self.FULL_REFRESH = full_refresh 94 self.STORE_FAILURES = store_failures 95 self.WHICH = which
Flags( full_refresh: Optional[str] = None, store_failures: Optional[str] = None, which: Optional[str] = None)
86 def __init__( 87 self, 88 full_refresh: t.Optional[str] = None, 89 store_failures: t.Optional[str] = None, 90 which: t.Optional[str] = None, 91 ) -> None: 92 # Temporary placeholder values for now (these are generally passed from the CLI) 93 self.FULL_REFRESH = full_refresh 94 self.STORE_FAILURES = store_failures 95 self.WHICH = which
class
Modules:
class
SQLExecution:
116class SQLExecution: 117 def __init__(self, adapter: BaseAdapter): 118 self.adapter = adapter 119 self._results: t.Dict[str, AttributeDict] = {} 120 121 def store_result(self, name: str, response: t.Any, agate_table: t.Optional[agate.Table]) -> str: 122 from sqlmesh.dbt.util import empty_table, as_matrix 123 124 if agate_table is None: 125 agate_table = empty_table() 126 127 self._results[name] = AttributeDict( 128 { 129 "response": response, 130 "data": as_matrix(agate_table), 131 "table": agate_table, 132 } 133 ) 134 return "" 135 136 def load_result(self, name: str) -> t.Optional[AttributeDict]: 137 return self._results.get(name) 138 139 def run_query(self, sql: str) -> agate.Table: 140 self.statement("run_query_statement", fetch_result=True, auto_begin=False, caller=sql) 141 resp = self.load_result("run_query_statement") 142 assert resp is not None 143 return resp["table"] 144 145 def statement( 146 self, 147 name: t.Optional[str], 148 fetch_result: bool = False, 149 auto_begin: bool = True, 150 language: str = "sql", 151 caller: t.Optional[jinja2.runtime.Macro | str] = None, 152 ) -> str: 153 """ 154 Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional 155 but we make it optional and at the end because we need to match the signature of the jinja2 macro. 156 157 Name is the name that we store the results to which can be retrieved with `load_result`. If name is not 158 provided then the SQL is executed but the results are not stored. 159 """ 160 if not caller: 161 raise RuntimeError( 162 "Statement relies on a caller to be set that is the target SQL to be run" 163 ) 164 sql = caller if isinstance(caller, str) else caller() 165 if language != "sql": 166 raise NotImplementedError( 167 "SQLMesh's dbt integration only supports SQL statements at this time." 168 ) 169 res, table = self.adapter.execute(sql, fetch=fetch_result, auto_begin=auto_begin) 170 if name: 171 self.store_result(name, res, table) 172 return ""
SQLExecution(adapter: sqlmesh.dbt.adapter.BaseAdapter)
def
store_result( self, name: str, response: Any, agate_table: Optional[agate.table.Table]) -> str:
121 def store_result(self, name: str, response: t.Any, agate_table: t.Optional[agate.Table]) -> str: 122 from sqlmesh.dbt.util import empty_table, as_matrix 123 124 if agate_table is None: 125 agate_table = empty_table() 126 127 self._results[name] = AttributeDict( 128 { 129 "response": response, 130 "data": as_matrix(agate_table), 131 "table": agate_table, 132 } 133 ) 134 return ""
def
statement( self, name: Optional[str], fetch_result: bool = False, auto_begin: bool = True, language: str = 'sql', caller: Union[jinja2.runtime.Macro, str, NoneType] = None) -> str:
145 def statement( 146 self, 147 name: t.Optional[str], 148 fetch_result: bool = False, 149 auto_begin: bool = True, 150 language: str = "sql", 151 caller: t.Optional[jinja2.runtime.Macro | str] = None, 152 ) -> str: 153 """ 154 Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional 155 but we make it optional and at the end because we need to match the signature of the jinja2 macro. 156 157 Name is the name that we store the results to which can be retrieved with `load_result`. If name is not 158 provided then the SQL is executed but the results are not stored. 159 """ 160 if not caller: 161 raise RuntimeError( 162 "Statement relies on a caller to be set that is the target SQL to be run" 163 ) 164 sql = caller if isinstance(caller, str) else caller() 165 if language != "sql": 166 raise NotImplementedError( 167 "SQLMesh's dbt integration only supports SQL statements at this time." 168 ) 169 res, table = self.adapter.execute(sql, fetch=fetch_result, auto_begin=auto_begin) 170 if name: 171 self.store_result(name, res, table) 172 return ""
Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional but we make it optional and at the end because we need to match the signature of the jinja2 macro.
Name is the name that we store the results to which can be retrieved with load_result. If name is not
provided then the SQL is executed but the results are not stored.
class
Var:
175class Var: 176 def __init__(self, variables: t.Dict[str, t.Any]) -> None: 177 self.variables = variables 178 179 def __call__(self, name: str, default: t.Optional[t.Any] = None) -> t.Any: 180 return self.variables.get(name, default) 181 182 def has_var(self, name: str) -> bool: 183 return name in self.variables
class
Config:
186class Config: 187 def __init__(self, config_dict: t.Dict[str, t.Any]) -> None: 188 self._config = config_dict 189 190 def __call__(self, *args: t.Any, **kwargs: t.Any) -> str: 191 if args and kwargs: 192 raise ConfigError( 193 "Invalid inline model config: cannot mix positional and keyword arguments" 194 ) 195 196 if args: 197 if len(args) == 1 and isinstance(args[0], dict): 198 # Single dict argument: config({"materialized": "table"}) 199 self._config.update(args[0]) 200 else: 201 raise ConfigError( 202 f"Invalid inline model config: expected a single dictionary, got {len(args)} arguments" 203 ) 204 elif kwargs: 205 # Keyword arguments: config(materialized="table") 206 self._config.update(kwargs) 207 208 return "" 209 210 def set(self, name: str, value: t.Any) -> str: 211 self._config.update({name: value}) 212 return "" 213 214 def _validate(self, name: str, validator: t.Callable, value: t.Optional[t.Any] = None) -> None: 215 try: 216 validator(value) 217 except Exception as e: 218 raise ConfigError(f"Config validation failed for '{name}': {e}") 219 220 def require(self, name: str, validator: t.Optional[t.Callable] = None) -> t.Any: 221 if name not in self._config: 222 raise ConfigError(f"Missing required config: {name}") 223 224 value = self._config[name] 225 226 if validator is not None: 227 self._validate(name, validator, value) 228 229 return value 230 231 def get( 232 self, name: str, default: t.Any = None, validator: t.Optional[t.Callable] = None 233 ) -> t.Any: 234 value = self._config.get(name, default) 235 236 if validator is not None and value is not None: 237 self._validate(name, validator, value) 238 239 return value 240 241 def persist_relation_docs(self) -> bool: 242 persist_docs = self.get("persist_docs", default={}) 243 if not isinstance(persist_docs, dict): 244 return False 245 return persist_docs.get("relation", False) 246 247 def persist_column_docs(self) -> bool: 248 persist_docs = self.get("persist_docs", default={}) 249 if not isinstance(persist_docs, dict): 250 return False 251 return persist_docs.get("columns", False)
def
require(self, name: str, validator: Optional[Callable] = None) -> Any:
220 def require(self, name: str, validator: t.Optional[t.Callable] = None) -> t.Any: 221 if name not in self._config: 222 raise ConfigError(f"Missing required config: {name}") 223 224 value = self._config[name] 225 226 if validator is not None: 227 self._validate(name, validator, value) 228 229 return value
def
env_var(name: str, default: Optional[str] = None) -> Optional[str]:
def
log(msg: str, info: bool = False) -> str:
271def generate_ref(refs: t.Dict[str, t.Any], api: Api) -> t.Callable: 272 def ref( 273 package: str, name: t.Optional[str] = None, **kwargs: t.Any 274 ) -> t.Optional[BaseRelation]: 275 version = kwargs.get("version", kwargs.get("v")) 276 ref_name = f"{package}.{name}" if name else package 277 278 if version is not None: 279 relation_info = refs.get(f"{ref_name}_v{version}") 280 if relation_info is None: 281 logger.warning( 282 "Could not resolve ref '%s' with version '%s'. Falling back to unversioned reference", 283 ref_name, 284 version, 285 ) 286 relation_info = refs.get(ref_name) 287 else: 288 relation_info = refs.get(ref_name) 289 if not relation_info: 290 versioned_infos = sorted( 291 [(r, info) for r, info in refs.items() if r.startswith(f"{ref_name}_v")], 292 key=lambda i: i[0], 293 ) 294 if versioned_infos: 295 relation_info = versioned_infos[-1][1] 296 297 if relation_info is None: 298 logger.debug("Could not resolve ref '%s', version '%s'", ref_name, version) 299 return None 300 301 return _relation_info_to_relation(relation_info, api.Relation, api.quote_policy) 302 303 return ref
306def generate_source(sources: t.Dict[str, t.Any], api: Api) -> t.Callable: 307 def source(package: str, name: str) -> t.Optional[BaseRelation]: 308 relation_info = sources.get(f"{package}.{name}") 309 if relation_info is None: 310 logger.debug("Could not resolve source package='%s' name='%s'", package, name) 311 return None 312 313 # Clickhouse uses a 2-level schema.table naming scheme, where the second level is called 314 # a "database" (instead of "schema" as one would reasonably assume). This can lead to confusion 315 # because it is not clear how Clickhouse identifiers map onto dbt's "database" and "schema" fields. 316 # 317 # This confusion can occur in source resolution. If a source's `schema` is not explicitly specified, 318 # the source name is used as the schema by default. 319 # 320 # If a source specified the `database` field and the schema has defaulted to the source name, 321 # we follow dbt-clickhouse in assuming that the user intended for the `database` field to be the 322 # second level identifier. 323 # https://github.com/ClickHouse/dbt-clickhouse/blob/065f3a724fa09205446ecadac7a00d92b2d8c646/dbt/adapters/clickhouse/relation.py#L112 324 # 325 # NOTE: determining relation class based on name so we don't introduce a dependency on dbt-clickhouse 326 if ( 327 api.Relation.__name__ == "ClickHouseRelation" 328 and relation_info.schema == package 329 and relation_info.database 330 ): 331 relation_info["schema"] = relation_info["database"] 332 relation_info["database"] = "" 333 334 return _relation_info_to_relation(relation_info, api.Relation, api.quote_policy) 335 336 return source
def
return_val(val: Any) -> None:
def
to_set(value: Any, default: Optional[Any] = None) -> Optional[Any]:
def
to_json(value: Any, default: Optional[Any] = None) -> Optional[Any]:
def
from_json(value: str, default: Optional[Any] = None) -> Optional[Any]:
def
to_yaml(value: Any, default: Optional[Any] = None) -> Optional[Any]:
def
from_yaml(value: Any, default: Optional[Any] = None) -> Optional[Any]:
def
do_zip(*args: Any, default: Optional[Any] = None) -> Optional[Any]:
def
as_bool(value: Any) -> Any:
def
as_number(value: str) -> Any:
def
debug() -> str:
BUILTIN_GLOBALS =
{'dbt_version': '1.11.7', 'env_var': <function env_var>, 'exceptions': <Exceptions object>, 'fromjson': <function from_json>, 'fromyaml': <function from_yaml>, 'log': <function log>, 'modules': <Modules object>, 'print': <function log>, 'return': <function return_val>, 'set': <function to_set>, 'set_strict': <class 'set'>, 'sqlmesh': True, 'sqlmesh_incremental': True, 'tojson': <function to_json>, 'toyaml': <function to_yaml>, 'try_or_compiler_error': <function try_or_compiler_error>, 'zip': <function do_zip>, 'zip_strict': <function <lambda>>}
BUILTIN_FILTERS =
{'as_bool': <function as_bool>, 'as_native': <function _try_literal_eval>, 'as_number': <function as_number>, 'as_text': <function <lambda>>}
OVERRIDDEN_MACROS =
{dbt.is_incremental, is_incremental}
def
create_builtin_globals( jinja_macros: sqlmesh.utils.jinja.JinjaMacroRegistry, jinja_globals: Dict[str, Any], engine_adapter: Optional[sqlmesh.core.engine_adapter.base.EngineAdapter]) -> Dict[str, Any]:
451def create_builtin_globals( 452 jinja_macros: JinjaMacroRegistry, 453 jinja_globals: t.Dict[str, t.Any], 454 engine_adapter: t.Optional[EngineAdapter], 455) -> t.Dict[str, t.Any]: 456 builtin_globals = BUILTIN_GLOBALS.copy() 457 jinja_globals = jinja_globals.copy() 458 459 target: t.Optional[AttributeDict] = jinja_globals.get("target", None) 460 project_dialect = jinja_globals.pop("dialect", None) or (target.get("type") if target else None) 461 api = Api(project_dialect) 462 463 builtin_globals["api"] = api 464 465 this = jinja_globals.pop("this", None) 466 if this is not None: 467 if not isinstance(this, api.Relation): 468 builtin_globals["this"] = api.Relation.create(**this) 469 else: 470 builtin_globals["this"] = this 471 472 sources = jinja_globals.pop("sources", None) 473 if sources is not None: 474 builtin_globals["source"] = generate_source(sources, api) 475 476 refs = jinja_globals.pop("refs", None) 477 if refs is not None: 478 builtin_globals["ref"] = generate_ref(refs, api) 479 480 variables = jinja_globals.pop("vars", None) 481 if variables is not None: 482 builtin_globals["var"] = Var(variables) 483 484 builtin_globals["config"] = Config(jinja_globals.pop("config", {"tags": []})) 485 486 deployability_index = ( 487 jinja_globals.get("deployability_index") or DeployabilityIndex.all_deployable() 488 ) 489 snapshot = jinja_globals.pop("snapshot", None) 490 491 if snapshot and snapshot.is_incremental: 492 intervals = ( 493 snapshot.intervals 494 if deployability_index.is_deployable(snapshot) 495 else snapshot.dev_intervals 496 ) 497 is_incremental = bool(intervals) 498 499 snapshot_table_exists = jinja_globals.get("snapshot_table_exists") 500 if is_incremental and snapshot_table_exists is not None: 501 # If we know the information about table existence, we can use it to correctly 502 # set the flag 503 is_incremental &= snapshot_table_exists 504 else: 505 is_incremental = False 506 507 builtin_globals["is_incremental"] = lambda: is_incremental 508 509 builtin_globals["builtins"] = AttributeDict( 510 {k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")} 511 ) 512 513 if (model := jinja_globals.pop("model", None)) is not None: 514 if isinstance(model_instance := jinja_globals.pop("model_instance", None), SqlModel): 515 builtin_globals["model"] = AttributeDict( 516 {**model, RAW_CODE_KEY: model_instance.query.name} 517 ) 518 else: 519 builtin_globals["model"] = AttributeDict(model.copy()) 520 521 builtin_globals["flags"] = ( 522 Flags(which="run") if engine_adapter is not None else Flags(which="parse") 523 ) 524 builtin_globals["invocation_args_dict"] = { 525 k.lower(): v for k, v in builtin_globals["flags"].__dict__.items() 526 } 527 528 if engine_adapter is not None: 529 adapter: BaseAdapter = RuntimeAdapter( 530 engine_adapter, 531 jinja_macros, 532 jinja_globals={ 533 **builtin_globals, 534 **jinja_globals, 535 "engine_adapter": engine_adapter, 536 }, 537 relation_type=api.Relation, 538 column_type=api.Column, 539 quote_policy=api.quote_policy, 540 snapshots=jinja_globals.get("snapshots", {}), 541 table_mapping=jinja_globals.get("table_mapping", {}), 542 deployability_index=deployability_index, 543 project_dialect=project_dialect, 544 ) 545 else: 546 adapter = ParsetimeAdapter( 547 jinja_macros, 548 jinja_globals={**builtin_globals, **jinja_globals}, 549 project_dialect=project_dialect, 550 quote_policy=api.quote_policy, 551 ) 552 553 sql_execution = SQLExecution(adapter) 554 builtin_globals.update( 555 { 556 "adapter": adapter, 557 "execute": True, 558 "load_relation": adapter.load_relation, 559 "store_result": sql_execution.store_result, 560 "load_result": sql_execution.load_result, 561 "run_query": sql_execution.run_query, 562 "statement": sql_execution.statement, 563 "graph": adapter.graph, 564 "selected_resources": list(jinja_globals.get("selected_models") or []), 565 "write": lambda input: None, # We don't support writing yet 566 } 567 ) 568 569 builtin_globals["run_started_at"] = jinja_globals.get("execution_dt") or now() 570 builtin_globals["dbt"] = AttributeDict(builtin_globals) 571 builtin_globals["context"] = builtin_globals["dbt"] 572 573 return {**builtin_globals, **jinja_globals}
def
create_builtin_filters() -> Dict[str, Callable]: