Edit on GitHub

sqlmesh.dbt.builtin

  1from __future__ import annotations
  2
  3import json
  4import logging
  5import os
  6import typing as t
  7from ast import literal_eval
  8from dataclasses import asdict
  9
 10import agate
 11import jinja2
 12from dbt import version
 13from dbt.adapters.base import BaseRelation, Column
 14from ruamel.yaml import YAMLError
 15from sqlglot import Dialect
 16
 17from sqlmesh.core.console import get_console
 18from sqlmesh.core.engine_adapter import EngineAdapter
 19from sqlmesh.core.model.definition import SqlModel
 20from sqlmesh.core.snapshot.definition import DeployabilityIndex
 21from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter
 22from sqlmesh.dbt.common import RAW_CODE_KEY
 23from sqlmesh.dbt.relation import Policy
 24from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS
 25from sqlmesh.dbt.util import DBT_VERSION
 26from sqlmesh.utils import AttributeDict, debug_mode_enabled, yaml
 27from sqlmesh.utils.date import now
 28from sqlmesh.utils.errors import ConfigError
 29from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReference, MacroReturnVal
 30
 31logger = logging.getLogger(__name__)
 32
 33
 34class Exceptions:
 35    def raise_compiler_error(self, msg: str) -> None:
 36        if DBT_VERSION >= (1, 4, 0):
 37            from dbt.exceptions import CompilationError
 38
 39            raise CompilationError(msg)
 40        else:
 41            from dbt.exceptions import CompilationException  # type: ignore
 42
 43            raise CompilationException(msg)
 44
 45    def raise_not_implemented(self, msg: str) -> None:
 46        raise NotImplementedError(msg)
 47
 48    def warn(self, msg: str) -> str:
 49        logger.warning(msg)
 50        return ""
 51
 52
 53def try_or_compiler_error(
 54    message_if_exception: str, func: t.Callable, *args: t.Any, **kwargs: t.Any
 55) -> t.Any:
 56    try:
 57        return func(*args, **kwargs)
 58    except Exception:
 59        if DBT_VERSION >= (1, 4, 0):
 60            from dbt.exceptions import CompilationError
 61
 62            raise CompilationError(message_if_exception)
 63        else:
 64            from dbt.exceptions import CompilationException  # type: ignore
 65
 66            raise CompilationException(message_if_exception)
 67
 68
 69class Api:
 70    def __init__(self, dialect: t.Optional[str]) -> None:
 71        if dialect:
 72            config_class = TARGET_TYPE_TO_CONFIG_CLASS[
 73                Dialect.get_or_raise(dialect).__class__.__name__.lower()
 74            ]
 75            self.Relation = config_class.relation_class
 76            self.Column = config_class.column_class
 77            self.quote_policy = config_class.quote_policy
 78        else:
 79            self.Relation = BaseRelation
 80            self.Column = Column
 81            self.quote_policy = Policy()
 82
 83
 84class Flags:
 85    def __init__(
 86        self,
 87        full_refresh: t.Optional[str] = None,
 88        store_failures: t.Optional[str] = None,
 89        which: t.Optional[str] = None,
 90    ) -> None:
 91        # Temporary placeholder values for now (these are generally passed from the CLI)
 92        self.FULL_REFRESH = full_refresh
 93        self.STORE_FAILURES = store_failures
 94        self.WHICH = which
 95
 96
 97class Modules:
 98    def __init__(self) -> None:
 99        import datetime
100        import itertools
101        import re
102
103        try:
104            import pytz
105
106            self.pytz = pytz
107        except ImportError:
108            pass
109
110        self.datetime = datetime
111        self.re = re
112        self.itertools = itertools
113
114
115class SQLExecution:
116    def __init__(self, adapter: BaseAdapter):
117        self.adapter = adapter
118        self._results: t.Dict[str, AttributeDict] = {}
119
120    def store_result(self, name: str, response: t.Any, agate_table: t.Optional[agate.Table]) -> str:
121        from sqlmesh.dbt.util import empty_table, as_matrix
122
123        if agate_table is None:
124            agate_table = empty_table()
125
126        self._results[name] = AttributeDict(
127            {
128                "response": response,
129                "data": as_matrix(agate_table),
130                "table": agate_table,
131            }
132        )
133        return ""
134
135    def load_result(self, name: str) -> t.Optional[AttributeDict]:
136        return self._results.get(name)
137
138    def run_query(self, sql: str) -> agate.Table:
139        self.statement("run_query_statement", fetch_result=True, auto_begin=False, caller=sql)
140        resp = self.load_result("run_query_statement")
141        assert resp is not None
142        return resp["table"]
143
144    def statement(
145        self,
146        name: t.Optional[str],
147        fetch_result: bool = False,
148        auto_begin: bool = True,
149        language: str = "sql",
150        caller: t.Optional[jinja2.runtime.Macro | str] = None,
151    ) -> str:
152        """
153        Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional
154        but we make it optional and at the end because we need to match the signature of the jinja2 macro.
155
156        Name is the name that we store the results to which can be retrieved with `load_result`. If name is not
157        provided then the SQL is executed but the results are not stored.
158        """
159        if not caller:
160            raise RuntimeError(
161                "Statement relies on a caller to be set that is the target SQL to be run"
162            )
163        sql = caller if isinstance(caller, str) else caller()
164        if language != "sql":
165            raise NotImplementedError(
166                "SQLMesh's dbt integration only supports SQL statements at this time."
167            )
168        res, table = self.adapter.execute(sql, fetch=fetch_result, auto_begin=auto_begin)
169        if name:
170            self.store_result(name, res, table)
171        return ""
172
173
174class Var:
175    def __init__(self, variables: t.Dict[str, t.Any]) -> None:
176        self.variables = variables
177
178    def __call__(self, name: str, default: t.Optional[t.Any] = None) -> t.Any:
179        return self.variables.get(name, default)
180
181    def has_var(self, name: str) -> bool:
182        return name in self.variables
183
184
185class Config:
186    def __init__(self, config_dict: t.Dict[str, t.Any]) -> None:
187        self._config = config_dict
188
189    def __call__(self, *args: t.Any, **kwargs: t.Any) -> str:
190        if args and kwargs:
191            raise ConfigError(
192                "Invalid inline model config: cannot mix positional and keyword arguments"
193            )
194
195        if args:
196            if len(args) == 1 and isinstance(args[0], dict):
197                # Single dict argument: config({"materialized": "table"})
198                self._config.update(args[0])
199            else:
200                raise ConfigError(
201                    f"Invalid inline model config: expected a single dictionary, got {len(args)} arguments"
202                )
203        elif kwargs:
204            # Keyword arguments: config(materialized="table")
205            self._config.update(kwargs)
206
207        return ""
208
209    def set(self, name: str, value: t.Any) -> str:
210        self._config.update({name: value})
211        return ""
212
213    def _validate(self, name: str, validator: t.Callable, value: t.Optional[t.Any] = None) -> None:
214        try:
215            validator(value)
216        except Exception as e:
217            raise ConfigError(f"Config validation failed for '{name}': {e}")
218
219    def require(self, name: str, validator: t.Optional[t.Callable] = None) -> t.Any:
220        if name not in self._config:
221            raise ConfigError(f"Missing required config: {name}")
222
223        value = self._config[name]
224
225        if validator is not None:
226            self._validate(name, validator, value)
227
228        return value
229
230    def get(
231        self, name: str, default: t.Any = None, validator: t.Optional[t.Callable] = None
232    ) -> t.Any:
233        value = self._config.get(name, default)
234
235        if validator is not None and value is not None:
236            self._validate(name, validator, value)
237
238        return value
239
240    def persist_relation_docs(self) -> bool:
241        persist_docs = self.get("persist_docs", default={})
242        if not isinstance(persist_docs, dict):
243            return False
244        return persist_docs.get("relation", False)
245
246    def persist_column_docs(self) -> bool:
247        persist_docs = self.get("persist_docs", default={})
248        if not isinstance(persist_docs, dict):
249            return False
250        return persist_docs.get("columns", False)
251
252
253def env_var(name: str, default: t.Optional[str] = None) -> t.Optional[str]:
254    if name not in os.environ and default is None:
255        raise ConfigError(f"Missing environment variable '{name}'")
256    return os.environ.get(name, default)
257
258
259def log(msg: str, info: bool = False) -> str:
260    if info:
261        # Write to both log file and stdout
262        logger.info(msg)
263        get_console().log_status_update(msg)
264    else:
265        logger.debug(msg)
266
267    return ""
268
269
270def generate_ref(refs: t.Dict[str, t.Any], api: Api) -> t.Callable:
271    def ref(
272        package: str, name: t.Optional[str] = None, **kwargs: t.Any
273    ) -> t.Optional[BaseRelation]:
274        version = kwargs.get("version", kwargs.get("v"))
275        ref_name = f"{package}.{name}" if name else package
276
277        if version is not None:
278            relation_info = refs.get(f"{ref_name}_v{version}")
279            if relation_info is None:
280                logger.warning(
281                    "Could not resolve ref '%s' with version '%s'. Falling back to unversioned reference",
282                    ref_name,
283                    version,
284                )
285                relation_info = refs.get(ref_name)
286        else:
287            relation_info = refs.get(ref_name)
288            if not relation_info:
289                versioned_infos = sorted(
290                    [(r, info) for r, info in refs.items() if r.startswith(f"{ref_name}_v")],
291                    key=lambda i: i[0],
292                )
293                if versioned_infos:
294                    relation_info = versioned_infos[-1][1]
295
296        if relation_info is None:
297            logger.debug("Could not resolve ref '%s', version '%s'", ref_name, version)
298            return None
299
300        return _relation_info_to_relation(relation_info, api.Relation, api.quote_policy)
301
302    return ref
303
304
305def generate_source(sources: t.Dict[str, t.Any], api: Api) -> t.Callable:
306    def source(package: str, name: str) -> t.Optional[BaseRelation]:
307        relation_info = sources.get(f"{package}.{name}")
308        if relation_info is None:
309            logger.debug("Could not resolve source package='%s' name='%s'", package, name)
310            return None
311
312        # Clickhouse uses a 2-level schema.table naming scheme, where the second level is called
313        # a "database" (instead of "schema" as one would reasonably assume). This can lead to confusion
314        # because it is not clear how Clickhouse identifiers map onto dbt's "database" and "schema" fields.
315        #
316        # This confusion can occur in source resolution. If a source's `schema` is not explicitly specified,
317        # the source name is used as the schema by default.
318        #
319        # If a source specified the `database` field and the schema has defaulted to the source name,
320        # we follow dbt-clickhouse in assuming that the user intended for the `database` field to be the
321        # second level identifier.
322        # https://github.com/ClickHouse/dbt-clickhouse/blob/065f3a724fa09205446ecadac7a00d92b2d8c646/dbt/adapters/clickhouse/relation.py#L112
323        #
324        # NOTE: determining relation class based on name so we don't introduce a dependency on dbt-clickhouse
325        if (
326            api.Relation.__name__ == "ClickHouseRelation"
327            and relation_info.schema == package
328            and relation_info.database
329        ):
330            relation_info["schema"] = relation_info["database"]
331            relation_info["database"] = ""
332
333        return _relation_info_to_relation(relation_info, api.Relation, api.quote_policy)
334
335    return source
336
337
338def return_val(val: t.Any) -> None:
339    raise MacroReturnVal(val)
340
341
342def to_set(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
343    try:
344        return set(value)
345    except TypeError:
346        return default
347
348
349def to_json(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
350    try:
351        return json.dumps(value)
352    except TypeError:
353        return default
354
355
356def from_json(value: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
357    try:
358        return json.loads(value)
359    except (TypeError, json.JSONDecodeError):
360        return default
361
362
363def to_yaml(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
364    try:
365        return yaml.dump(value)
366    except (TypeError, YAMLError):
367        return default
368
369
370def from_yaml(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
371    try:
372        return dict(yaml.load(value, raise_if_empty=False, render_jinja=False))
373    except (TypeError, YAMLError):
374        return default
375
376
377def do_zip(*args: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
378    try:
379        return list(zip(*args))
380    except TypeError:
381        return default
382
383
384def as_bool(value: t.Any) -> t.Any:
385    # dbt's jinja TEXT_FILTERS just return the input value as is
386    # https://github.com/dbt-labs/dbt-common/blob/main/dbt_common/clients/jinja.py#L559
387    return value
388
389
390def as_number(value: str) -> t.Any:
391    # dbt's jinja TEXT_FILTERS just return the input value as is
392    # https://github.com/dbt-labs/dbt-common/blob/main/dbt_common/clients/jinja.py#L559
393    return value
394
395
396def _try_literal_eval(value: str) -> t.Any:
397    try:
398        return literal_eval(value)
399    except (ValueError, SyntaxError, MemoryError):
400        return value
401
402
403def debug() -> str:
404    import sys
405    import ipdb  # type: ignore
406
407    frame = sys._getframe(3)
408    ipdb.set_trace(frame)
409    return ""
410
411
412BUILTIN_GLOBALS = {
413    "dbt_version": version.__version__,
414    "env_var": env_var,
415    "exceptions": Exceptions(),
416    "fromjson": from_json,
417    "fromyaml": from_yaml,
418    "log": log,
419    "modules": Modules(),
420    "print": log,
421    "return": return_val,
422    "set": to_set,
423    "set_strict": set,
424    "sqlmesh": True,
425    "sqlmesh_incremental": True,
426    "tojson": to_json,
427    "toyaml": to_yaml,
428    "try_or_compiler_error": try_or_compiler_error,
429    "zip": do_zip,
430    "zip_strict": lambda *args: list(zip(*args)),
431}
432
433# Add debug function conditionally both with dbt or sqlmesh equivalent flag
434if os.environ.get("DBT_MACRO_DEBUGGING") or debug_mode_enabled():
435    BUILTIN_GLOBALS["debug"] = debug
436
437BUILTIN_FILTERS = {
438    "as_bool": as_bool,
439    "as_native": _try_literal_eval,
440    "as_number": as_number,
441    "as_text": lambda v: "" if v is None else str(v),
442}
443
444OVERRIDDEN_MACROS = {
445    MacroReference(package="dbt", name="is_incremental"),
446    MacroReference(name="is_incremental"),
447}
448
449
450def create_builtin_globals(
451    jinja_macros: JinjaMacroRegistry,
452    jinja_globals: t.Dict[str, t.Any],
453    engine_adapter: t.Optional[EngineAdapter],
454) -> t.Dict[str, t.Any]:
455    builtin_globals = BUILTIN_GLOBALS.copy()
456    jinja_globals = jinja_globals.copy()
457
458    target: t.Optional[AttributeDict] = jinja_globals.get("target", None)
459    project_dialect = jinja_globals.pop("dialect", None) or (target.get("type") if target else None)
460    api = Api(project_dialect)
461
462    builtin_globals["api"] = api
463
464    this = jinja_globals.pop("this", None)
465    if this is not None:
466        if not isinstance(this, api.Relation):
467            builtin_globals["this"] = api.Relation.create(**this)
468        else:
469            builtin_globals["this"] = this
470
471    sources = jinja_globals.pop("sources", None)
472    if sources is not None:
473        builtin_globals["source"] = generate_source(sources, api)
474
475    refs = jinja_globals.pop("refs", None)
476    if refs is not None:
477        builtin_globals["ref"] = generate_ref(refs, api)
478
479    variables = jinja_globals.pop("vars", None)
480    if variables is not None:
481        builtin_globals["var"] = Var(variables)
482
483    builtin_globals["config"] = Config(jinja_globals.pop("config", {"tags": []}))
484
485    deployability_index = (
486        jinja_globals.get("deployability_index") or DeployabilityIndex.all_deployable()
487    )
488    snapshot = jinja_globals.pop("snapshot", None)
489
490    if snapshot and snapshot.is_incremental:
491        intervals = (
492            snapshot.intervals
493            if deployability_index.is_deployable(snapshot)
494            else snapshot.dev_intervals
495        )
496        is_incremental = bool(intervals)
497
498        snapshot_table_exists = jinja_globals.get("snapshot_table_exists")
499        if is_incremental and snapshot_table_exists is not None:
500            # If we know the information about table existence, we can use it to correctly
501            # set the flag
502            is_incremental &= snapshot_table_exists
503    else:
504        is_incremental = False
505
506    builtin_globals["is_incremental"] = lambda: is_incremental
507
508    builtin_globals["builtins"] = AttributeDict(
509        {k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")}
510    )
511
512    if (model := jinja_globals.pop("model", None)) is not None:
513        if isinstance(model_instance := jinja_globals.pop("model_instance", None), SqlModel):
514            builtin_globals["model"] = AttributeDict(
515                {**model, RAW_CODE_KEY: model_instance.query.name}
516            )
517        else:
518            builtin_globals["model"] = AttributeDict(model.copy())
519
520    builtin_globals["flags"] = (
521        Flags(which="run") if engine_adapter is not None else Flags(which="parse")
522    )
523    builtin_globals["invocation_args_dict"] = {
524        k.lower(): v for k, v in builtin_globals["flags"].__dict__.items()
525    }
526
527    if engine_adapter is not None:
528        adapter: BaseAdapter = RuntimeAdapter(
529            engine_adapter,
530            jinja_macros,
531            jinja_globals={
532                **builtin_globals,
533                **jinja_globals,
534                "engine_adapter": engine_adapter,
535            },
536            relation_type=api.Relation,
537            column_type=api.Column,
538            quote_policy=api.quote_policy,
539            snapshots=jinja_globals.get("snapshots", {}),
540            table_mapping=jinja_globals.get("table_mapping", {}),
541            deployability_index=deployability_index,
542            project_dialect=project_dialect,
543        )
544    else:
545        adapter = ParsetimeAdapter(
546            jinja_macros,
547            jinja_globals={**builtin_globals, **jinja_globals},
548            project_dialect=project_dialect,
549            quote_policy=api.quote_policy,
550        )
551
552    sql_execution = SQLExecution(adapter)
553    builtin_globals.update(
554        {
555            "adapter": adapter,
556            "execute": True,
557            "load_relation": adapter.load_relation,
558            "store_result": sql_execution.store_result,
559            "load_result": sql_execution.load_result,
560            "run_query": sql_execution.run_query,
561            "statement": sql_execution.statement,
562            "graph": adapter.graph,
563            "selected_resources": list(jinja_globals.get("selected_models") or []),
564            "write": lambda input: None,  # We don't support writing yet
565        }
566    )
567
568    builtin_globals["run_started_at"] = jinja_globals.get("execution_dt") or now()
569    builtin_globals["dbt"] = AttributeDict(builtin_globals)
570    builtin_globals["context"] = builtin_globals["dbt"]
571
572    return {**builtin_globals, **jinja_globals}
573
574
575def create_builtin_filters() -> t.Dict[str, t.Callable]:
576    return BUILTIN_FILTERS
577
578
579def _relation_info_to_relation(
580    relation_info: t.Dict[str, t.Any],
581    relation_type: t.Type[BaseRelation],
582    target_quote_policy: Policy,
583) -> BaseRelation:
584    relation_info = relation_info.copy()
585    quote_policy = Policy(
586        **{
587            **asdict(target_quote_policy),
588            **{k: v for k, v in relation_info.pop("quote_policy", {}).items() if v is not None},
589        }
590    )
591    return relation_type.create(**relation_info, quote_policy=quote_policy)
logger = <Logger sqlmesh.dbt.builtin (WARNING)>
class Exceptions:
35class Exceptions:
36    def raise_compiler_error(self, msg: str) -> None:
37        if DBT_VERSION >= (1, 4, 0):
38            from dbt.exceptions import CompilationError
39
40            raise CompilationError(msg)
41        else:
42            from dbt.exceptions import CompilationException  # type: ignore
43
44            raise CompilationException(msg)
45
46    def raise_not_implemented(self, msg: str) -> None:
47        raise NotImplementedError(msg)
48
49    def warn(self, msg: str) -> str:
50        logger.warning(msg)
51        return ""
def raise_compiler_error(self, msg: str) -> None:
36    def raise_compiler_error(self, msg: str) -> None:
37        if DBT_VERSION >= (1, 4, 0):
38            from dbt.exceptions import CompilationError
39
40            raise CompilationError(msg)
41        else:
42            from dbt.exceptions import CompilationException  # type: ignore
43
44            raise CompilationException(msg)
def raise_not_implemented(self, msg: str) -> None:
46    def raise_not_implemented(self, msg: str) -> None:
47        raise NotImplementedError(msg)
def warn(self, msg: str) -> str:
49    def warn(self, msg: str) -> str:
50        logger.warning(msg)
51        return ""
def try_or_compiler_error( message_if_exception: str, func: Callable, *args: Any, **kwargs: Any) -> Any:
54def try_or_compiler_error(
55    message_if_exception: str, func: t.Callable, *args: t.Any, **kwargs: t.Any
56) -> t.Any:
57    try:
58        return func(*args, **kwargs)
59    except Exception:
60        if DBT_VERSION >= (1, 4, 0):
61            from dbt.exceptions import CompilationError
62
63            raise CompilationError(message_if_exception)
64        else:
65            from dbt.exceptions import CompilationException  # type: ignore
66
67            raise CompilationException(message_if_exception)
class Api:
70class Api:
71    def __init__(self, dialect: t.Optional[str]) -> None:
72        if dialect:
73            config_class = TARGET_TYPE_TO_CONFIG_CLASS[
74                Dialect.get_or_raise(dialect).__class__.__name__.lower()
75            ]
76            self.Relation = config_class.relation_class
77            self.Column = config_class.column_class
78            self.quote_policy = config_class.quote_policy
79        else:
80            self.Relation = BaseRelation
81            self.Column = Column
82            self.quote_policy = Policy()
Api(dialect: Optional[str])
71    def __init__(self, dialect: t.Optional[str]) -> None:
72        if dialect:
73            config_class = TARGET_TYPE_TO_CONFIG_CLASS[
74                Dialect.get_or_raise(dialect).__class__.__name__.lower()
75            ]
76            self.Relation = config_class.relation_class
77            self.Column = config_class.column_class
78            self.quote_policy = config_class.quote_policy
79        else:
80            self.Relation = BaseRelation
81            self.Column = Column
82            self.quote_policy = Policy()
class Flags:
85class Flags:
86    def __init__(
87        self,
88        full_refresh: t.Optional[str] = None,
89        store_failures: t.Optional[str] = None,
90        which: t.Optional[str] = None,
91    ) -> None:
92        # Temporary placeholder values for now (these are generally passed from the CLI)
93        self.FULL_REFRESH = full_refresh
94        self.STORE_FAILURES = store_failures
95        self.WHICH = which
Flags( full_refresh: Optional[str] = None, store_failures: Optional[str] = None, which: Optional[str] = None)
86    def __init__(
87        self,
88        full_refresh: t.Optional[str] = None,
89        store_failures: t.Optional[str] = None,
90        which: t.Optional[str] = None,
91    ) -> None:
92        # Temporary placeholder values for now (these are generally passed from the CLI)
93        self.FULL_REFRESH = full_refresh
94        self.STORE_FAILURES = store_failures
95        self.WHICH = which
FULL_REFRESH
STORE_FAILURES
WHICH
class Modules:
 98class Modules:
 99    def __init__(self) -> None:
100        import datetime
101        import itertools
102        import re
103
104        try:
105            import pytz
106
107            self.pytz = pytz
108        except ImportError:
109            pass
110
111        self.datetime = datetime
112        self.re = re
113        self.itertools = itertools
datetime
re
itertools
class SQLExecution:
116class SQLExecution:
117    def __init__(self, adapter: BaseAdapter):
118        self.adapter = adapter
119        self._results: t.Dict[str, AttributeDict] = {}
120
121    def store_result(self, name: str, response: t.Any, agate_table: t.Optional[agate.Table]) -> str:
122        from sqlmesh.dbt.util import empty_table, as_matrix
123
124        if agate_table is None:
125            agate_table = empty_table()
126
127        self._results[name] = AttributeDict(
128            {
129                "response": response,
130                "data": as_matrix(agate_table),
131                "table": agate_table,
132            }
133        )
134        return ""
135
136    def load_result(self, name: str) -> t.Optional[AttributeDict]:
137        return self._results.get(name)
138
139    def run_query(self, sql: str) -> agate.Table:
140        self.statement("run_query_statement", fetch_result=True, auto_begin=False, caller=sql)
141        resp = self.load_result("run_query_statement")
142        assert resp is not None
143        return resp["table"]
144
145    def statement(
146        self,
147        name: t.Optional[str],
148        fetch_result: bool = False,
149        auto_begin: bool = True,
150        language: str = "sql",
151        caller: t.Optional[jinja2.runtime.Macro | str] = None,
152    ) -> str:
153        """
154        Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional
155        but we make it optional and at the end because we need to match the signature of the jinja2 macro.
156
157        Name is the name that we store the results to which can be retrieved with `load_result`. If name is not
158        provided then the SQL is executed but the results are not stored.
159        """
160        if not caller:
161            raise RuntimeError(
162                "Statement relies on a caller to be set that is the target SQL to be run"
163            )
164        sql = caller if isinstance(caller, str) else caller()
165        if language != "sql":
166            raise NotImplementedError(
167                "SQLMesh's dbt integration only supports SQL statements at this time."
168            )
169        res, table = self.adapter.execute(sql, fetch=fetch_result, auto_begin=auto_begin)
170        if name:
171            self.store_result(name, res, table)
172        return ""
SQLExecution(adapter: sqlmesh.dbt.adapter.BaseAdapter)
117    def __init__(self, adapter: BaseAdapter):
118        self.adapter = adapter
119        self._results: t.Dict[str, AttributeDict] = {}
adapter
def store_result( self, name: str, response: Any, agate_table: Optional[agate.table.Table]) -> str:
121    def store_result(self, name: str, response: t.Any, agate_table: t.Optional[agate.Table]) -> str:
122        from sqlmesh.dbt.util import empty_table, as_matrix
123
124        if agate_table is None:
125            agate_table = empty_table()
126
127        self._results[name] = AttributeDict(
128            {
129                "response": response,
130                "data": as_matrix(agate_table),
131                "table": agate_table,
132            }
133        )
134        return ""
def load_result(self, name: str) -> Optional[sqlmesh.utils.AttributeDict]:
136    def load_result(self, name: str) -> t.Optional[AttributeDict]:
137        return self._results.get(name)
def run_query(self, sql: str) -> agate.table.Table:
139    def run_query(self, sql: str) -> agate.Table:
140        self.statement("run_query_statement", fetch_result=True, auto_begin=False, caller=sql)
141        resp = self.load_result("run_query_statement")
142        assert resp is not None
143        return resp["table"]
def statement( self, name: Optional[str], fetch_result: bool = False, auto_begin: bool = True, language: str = 'sql', caller: Union[jinja2.runtime.Macro, str, NoneType] = None) -> str:
145    def statement(
146        self,
147        name: t.Optional[str],
148        fetch_result: bool = False,
149        auto_begin: bool = True,
150        language: str = "sql",
151        caller: t.Optional[jinja2.runtime.Macro | str] = None,
152    ) -> str:
153        """
154        Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional
155        but we make it optional and at the end because we need to match the signature of the jinja2 macro.
156
157        Name is the name that we store the results to which can be retrieved with `load_result`. If name is not
158        provided then the SQL is executed but the results are not stored.
159        """
160        if not caller:
161            raise RuntimeError(
162                "Statement relies on a caller to be set that is the target SQL to be run"
163            )
164        sql = caller if isinstance(caller, str) else caller()
165        if language != "sql":
166            raise NotImplementedError(
167                "SQLMesh's dbt integration only supports SQL statements at this time."
168            )
169        res, table = self.adapter.execute(sql, fetch=fetch_result, auto_begin=auto_begin)
170        if name:
171            self.store_result(name, res, table)
172        return ""

Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional but we make it optional and at the end because we need to match the signature of the jinja2 macro.

Name is the name that we store the results to which can be retrieved with load_result. If name is not provided then the SQL is executed but the results are not stored.

class Var:
175class Var:
176    def __init__(self, variables: t.Dict[str, t.Any]) -> None:
177        self.variables = variables
178
179    def __call__(self, name: str, default: t.Optional[t.Any] = None) -> t.Any:
180        return self.variables.get(name, default)
181
182    def has_var(self, name: str) -> bool:
183        return name in self.variables
Var(variables: Dict[str, Any])
176    def __init__(self, variables: t.Dict[str, t.Any]) -> None:
177        self.variables = variables
variables
def has_var(self, name: str) -> bool:
182    def has_var(self, name: str) -> bool:
183        return name in self.variables
class Config:
186class Config:
187    def __init__(self, config_dict: t.Dict[str, t.Any]) -> None:
188        self._config = config_dict
189
190    def __call__(self, *args: t.Any, **kwargs: t.Any) -> str:
191        if args and kwargs:
192            raise ConfigError(
193                "Invalid inline model config: cannot mix positional and keyword arguments"
194            )
195
196        if args:
197            if len(args) == 1 and isinstance(args[0], dict):
198                # Single dict argument: config({"materialized": "table"})
199                self._config.update(args[0])
200            else:
201                raise ConfigError(
202                    f"Invalid inline model config: expected a single dictionary, got {len(args)} arguments"
203                )
204        elif kwargs:
205            # Keyword arguments: config(materialized="table")
206            self._config.update(kwargs)
207
208        return ""
209
210    def set(self, name: str, value: t.Any) -> str:
211        self._config.update({name: value})
212        return ""
213
214    def _validate(self, name: str, validator: t.Callable, value: t.Optional[t.Any] = None) -> None:
215        try:
216            validator(value)
217        except Exception as e:
218            raise ConfigError(f"Config validation failed for '{name}': {e}")
219
220    def require(self, name: str, validator: t.Optional[t.Callable] = None) -> t.Any:
221        if name not in self._config:
222            raise ConfigError(f"Missing required config: {name}")
223
224        value = self._config[name]
225
226        if validator is not None:
227            self._validate(name, validator, value)
228
229        return value
230
231    def get(
232        self, name: str, default: t.Any = None, validator: t.Optional[t.Callable] = None
233    ) -> t.Any:
234        value = self._config.get(name, default)
235
236        if validator is not None and value is not None:
237            self._validate(name, validator, value)
238
239        return value
240
241    def persist_relation_docs(self) -> bool:
242        persist_docs = self.get("persist_docs", default={})
243        if not isinstance(persist_docs, dict):
244            return False
245        return persist_docs.get("relation", False)
246
247    def persist_column_docs(self) -> bool:
248        persist_docs = self.get("persist_docs", default={})
249        if not isinstance(persist_docs, dict):
250            return False
251        return persist_docs.get("columns", False)
Config(config_dict: Dict[str, Any])
187    def __init__(self, config_dict: t.Dict[str, t.Any]) -> None:
188        self._config = config_dict
def set(self, name: str, value: Any) -> str:
210    def set(self, name: str, value: t.Any) -> str:
211        self._config.update({name: value})
212        return ""
def require(self, name: str, validator: Optional[Callable] = None) -> Any:
220    def require(self, name: str, validator: t.Optional[t.Callable] = None) -> t.Any:
221        if name not in self._config:
222            raise ConfigError(f"Missing required config: {name}")
223
224        value = self._config[name]
225
226        if validator is not None:
227            self._validate(name, validator, value)
228
229        return value
def get( self, name: str, default: Any = None, validator: Optional[Callable] = None) -> Any:
231    def get(
232        self, name: str, default: t.Any = None, validator: t.Optional[t.Callable] = None
233    ) -> t.Any:
234        value = self._config.get(name, default)
235
236        if validator is not None and value is not None:
237            self._validate(name, validator, value)
238
239        return value
def persist_relation_docs(self) -> bool:
241    def persist_relation_docs(self) -> bool:
242        persist_docs = self.get("persist_docs", default={})
243        if not isinstance(persist_docs, dict):
244            return False
245        return persist_docs.get("relation", False)
def persist_column_docs(self) -> bool:
247    def persist_column_docs(self) -> bool:
248        persist_docs = self.get("persist_docs", default={})
249        if not isinstance(persist_docs, dict):
250            return False
251        return persist_docs.get("columns", False)
def env_var(name: str, default: Optional[str] = None) -> Optional[str]:
254def env_var(name: str, default: t.Optional[str] = None) -> t.Optional[str]:
255    if name not in os.environ and default is None:
256        raise ConfigError(f"Missing environment variable '{name}'")
257    return os.environ.get(name, default)
def log(msg: str, info: bool = False) -> str:
260def log(msg: str, info: bool = False) -> str:
261    if info:
262        # Write to both log file and stdout
263        logger.info(msg)
264        get_console().log_status_update(msg)
265    else:
266        logger.debug(msg)
267
268    return ""
def generate_ref(refs: Dict[str, Any], api: Api) -> Callable:
271def generate_ref(refs: t.Dict[str, t.Any], api: Api) -> t.Callable:
272    def ref(
273        package: str, name: t.Optional[str] = None, **kwargs: t.Any
274    ) -> t.Optional[BaseRelation]:
275        version = kwargs.get("version", kwargs.get("v"))
276        ref_name = f"{package}.{name}" if name else package
277
278        if version is not None:
279            relation_info = refs.get(f"{ref_name}_v{version}")
280            if relation_info is None:
281                logger.warning(
282                    "Could not resolve ref '%s' with version '%s'. Falling back to unversioned reference",
283                    ref_name,
284                    version,
285                )
286                relation_info = refs.get(ref_name)
287        else:
288            relation_info = refs.get(ref_name)
289            if not relation_info:
290                versioned_infos = sorted(
291                    [(r, info) for r, info in refs.items() if r.startswith(f"{ref_name}_v")],
292                    key=lambda i: i[0],
293                )
294                if versioned_infos:
295                    relation_info = versioned_infos[-1][1]
296
297        if relation_info is None:
298            logger.debug("Could not resolve ref '%s', version '%s'", ref_name, version)
299            return None
300
301        return _relation_info_to_relation(relation_info, api.Relation, api.quote_policy)
302
303    return ref
def generate_source(sources: Dict[str, Any], api: Api) -> Callable:
306def generate_source(sources: t.Dict[str, t.Any], api: Api) -> t.Callable:
307    def source(package: str, name: str) -> t.Optional[BaseRelation]:
308        relation_info = sources.get(f"{package}.{name}")
309        if relation_info is None:
310            logger.debug("Could not resolve source package='%s' name='%s'", package, name)
311            return None
312
313        # Clickhouse uses a 2-level schema.table naming scheme, where the second level is called
314        # a "database" (instead of "schema" as one would reasonably assume). This can lead to confusion
315        # because it is not clear how Clickhouse identifiers map onto dbt's "database" and "schema" fields.
316        #
317        # This confusion can occur in source resolution. If a source's `schema` is not explicitly specified,
318        # the source name is used as the schema by default.
319        #
320        # If a source specified the `database` field and the schema has defaulted to the source name,
321        # we follow dbt-clickhouse in assuming that the user intended for the `database` field to be the
322        # second level identifier.
323        # https://github.com/ClickHouse/dbt-clickhouse/blob/065f3a724fa09205446ecadac7a00d92b2d8c646/dbt/adapters/clickhouse/relation.py#L112
324        #
325        # NOTE: determining relation class based on name so we don't introduce a dependency on dbt-clickhouse
326        if (
327            api.Relation.__name__ == "ClickHouseRelation"
328            and relation_info.schema == package
329            and relation_info.database
330        ):
331            relation_info["schema"] = relation_info["database"]
332            relation_info["database"] = ""
333
334        return _relation_info_to_relation(relation_info, api.Relation, api.quote_policy)
335
336    return source
def return_val(val: Any) -> None:
339def return_val(val: t.Any) -> None:
340    raise MacroReturnVal(val)
def to_set(value: Any, default: Optional[Any] = None) -> Optional[Any]:
343def to_set(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
344    try:
345        return set(value)
346    except TypeError:
347        return default
def to_json(value: Any, default: Optional[Any] = None) -> Optional[Any]:
350def to_json(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
351    try:
352        return json.dumps(value)
353    except TypeError:
354        return default
def from_json(value: str, default: Optional[Any] = None) -> Optional[Any]:
357def from_json(value: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
358    try:
359        return json.loads(value)
360    except (TypeError, json.JSONDecodeError):
361        return default
def to_yaml(value: Any, default: Optional[Any] = None) -> Optional[Any]:
364def to_yaml(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
365    try:
366        return yaml.dump(value)
367    except (TypeError, YAMLError):
368        return default
def from_yaml(value: Any, default: Optional[Any] = None) -> Optional[Any]:
371def from_yaml(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
372    try:
373        return dict(yaml.load(value, raise_if_empty=False, render_jinja=False))
374    except (TypeError, YAMLError):
375        return default
def do_zip(*args: Any, default: Optional[Any] = None) -> Optional[Any]:
378def do_zip(*args: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
379    try:
380        return list(zip(*args))
381    except TypeError:
382        return default
def as_bool(value: Any) -> Any:
385def as_bool(value: t.Any) -> t.Any:
386    # dbt's jinja TEXT_FILTERS just return the input value as is
387    # https://github.com/dbt-labs/dbt-common/blob/main/dbt_common/clients/jinja.py#L559
388    return value
def as_number(value: str) -> Any:
391def as_number(value: str) -> t.Any:
392    # dbt's jinja TEXT_FILTERS just return the input value as is
393    # https://github.com/dbt-labs/dbt-common/blob/main/dbt_common/clients/jinja.py#L559
394    return value
def debug() -> str:
404def debug() -> str:
405    import sys
406    import ipdb  # type: ignore
407
408    frame = sys._getframe(3)
409    ipdb.set_trace(frame)
410    return ""
BUILTIN_GLOBALS = {'dbt_version': '1.11.7', 'env_var': <function env_var>, 'exceptions': <Exceptions object>, 'fromjson': <function from_json>, 'fromyaml': <function from_yaml>, 'log': <function log>, 'modules': <Modules object>, 'print': <function log>, 'return': <function return_val>, 'set': <function to_set>, 'set_strict': <class 'set'>, 'sqlmesh': True, 'sqlmesh_incremental': True, 'tojson': <function to_json>, 'toyaml': <function to_yaml>, 'try_or_compiler_error': <function try_or_compiler_error>, 'zip': <function do_zip>, 'zip_strict': <function <lambda>>}
BUILTIN_FILTERS = {'as_bool': <function as_bool>, 'as_native': <function _try_literal_eval>, 'as_number': <function as_number>, 'as_text': <function <lambda>>}
OVERRIDDEN_MACROS = {dbt.is_incremental, is_incremental}
def create_builtin_globals( jinja_macros: sqlmesh.utils.jinja.JinjaMacroRegistry, jinja_globals: Dict[str, Any], engine_adapter: Optional[sqlmesh.core.engine_adapter.base.EngineAdapter]) -> Dict[str, Any]:
451def create_builtin_globals(
452    jinja_macros: JinjaMacroRegistry,
453    jinja_globals: t.Dict[str, t.Any],
454    engine_adapter: t.Optional[EngineAdapter],
455) -> t.Dict[str, t.Any]:
456    builtin_globals = BUILTIN_GLOBALS.copy()
457    jinja_globals = jinja_globals.copy()
458
459    target: t.Optional[AttributeDict] = jinja_globals.get("target", None)
460    project_dialect = jinja_globals.pop("dialect", None) or (target.get("type") if target else None)
461    api = Api(project_dialect)
462
463    builtin_globals["api"] = api
464
465    this = jinja_globals.pop("this", None)
466    if this is not None:
467        if not isinstance(this, api.Relation):
468            builtin_globals["this"] = api.Relation.create(**this)
469        else:
470            builtin_globals["this"] = this
471
472    sources = jinja_globals.pop("sources", None)
473    if sources is not None:
474        builtin_globals["source"] = generate_source(sources, api)
475
476    refs = jinja_globals.pop("refs", None)
477    if refs is not None:
478        builtin_globals["ref"] = generate_ref(refs, api)
479
480    variables = jinja_globals.pop("vars", None)
481    if variables is not None:
482        builtin_globals["var"] = Var(variables)
483
484    builtin_globals["config"] = Config(jinja_globals.pop("config", {"tags": []}))
485
486    deployability_index = (
487        jinja_globals.get("deployability_index") or DeployabilityIndex.all_deployable()
488    )
489    snapshot = jinja_globals.pop("snapshot", None)
490
491    if snapshot and snapshot.is_incremental:
492        intervals = (
493            snapshot.intervals
494            if deployability_index.is_deployable(snapshot)
495            else snapshot.dev_intervals
496        )
497        is_incremental = bool(intervals)
498
499        snapshot_table_exists = jinja_globals.get("snapshot_table_exists")
500        if is_incremental and snapshot_table_exists is not None:
501            # If we know the information about table existence, we can use it to correctly
502            # set the flag
503            is_incremental &= snapshot_table_exists
504    else:
505        is_incremental = False
506
507    builtin_globals["is_incremental"] = lambda: is_incremental
508
509    builtin_globals["builtins"] = AttributeDict(
510        {k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")}
511    )
512
513    if (model := jinja_globals.pop("model", None)) is not None:
514        if isinstance(model_instance := jinja_globals.pop("model_instance", None), SqlModel):
515            builtin_globals["model"] = AttributeDict(
516                {**model, RAW_CODE_KEY: model_instance.query.name}
517            )
518        else:
519            builtin_globals["model"] = AttributeDict(model.copy())
520
521    builtin_globals["flags"] = (
522        Flags(which="run") if engine_adapter is not None else Flags(which="parse")
523    )
524    builtin_globals["invocation_args_dict"] = {
525        k.lower(): v for k, v in builtin_globals["flags"].__dict__.items()
526    }
527
528    if engine_adapter is not None:
529        adapter: BaseAdapter = RuntimeAdapter(
530            engine_adapter,
531            jinja_macros,
532            jinja_globals={
533                **builtin_globals,
534                **jinja_globals,
535                "engine_adapter": engine_adapter,
536            },
537            relation_type=api.Relation,
538            column_type=api.Column,
539            quote_policy=api.quote_policy,
540            snapshots=jinja_globals.get("snapshots", {}),
541            table_mapping=jinja_globals.get("table_mapping", {}),
542            deployability_index=deployability_index,
543            project_dialect=project_dialect,
544        )
545    else:
546        adapter = ParsetimeAdapter(
547            jinja_macros,
548            jinja_globals={**builtin_globals, **jinja_globals},
549            project_dialect=project_dialect,
550            quote_policy=api.quote_policy,
551        )
552
553    sql_execution = SQLExecution(adapter)
554    builtin_globals.update(
555        {
556            "adapter": adapter,
557            "execute": True,
558            "load_relation": adapter.load_relation,
559            "store_result": sql_execution.store_result,
560            "load_result": sql_execution.load_result,
561            "run_query": sql_execution.run_query,
562            "statement": sql_execution.statement,
563            "graph": adapter.graph,
564            "selected_resources": list(jinja_globals.get("selected_models") or []),
565            "write": lambda input: None,  # We don't support writing yet
566        }
567    )
568
569    builtin_globals["run_started_at"] = jinja_globals.get("execution_dt") or now()
570    builtin_globals["dbt"] = AttributeDict(builtin_globals)
571    builtin_globals["context"] = builtin_globals["dbt"]
572
573    return {**builtin_globals, **jinja_globals}
def create_builtin_filters() -> Dict[str, Callable]:
576def create_builtin_filters() -> t.Dict[str, t.Callable]:
577    return BUILTIN_FILTERS