Edit on GitHub

sqlmesh.utils.jinja

  1from __future__ import annotations
  2
  3import importlib
  4import json
  5import re
  6import typing as t
  7import zlib
  8from collections import defaultdict
  9from enum import Enum
 10from sys import exc_info
 11from traceback import walk_tb
 12
 13from jinja2 import Environment, Template, nodes, UndefinedError
 14from jinja2.runtime import Macro
 15from sqlglot import Dialect, Parser, TokenType
 16from sqlglot.expressions import Expression
 17
 18from sqlmesh.core import constants as c
 19from sqlmesh.core import dialect as d
 20from sqlmesh.utils import AttributeDict
 21from sqlmesh.utils.pydantic import PRIVATE_FIELDS, PydanticModel, field_serializer, field_validator
 22from sqlmesh.utils.metaprogramming import SqlValue
 23
 24
 25if t.TYPE_CHECKING:
 26    CallNames = t.Tuple[t.Tuple[str, ...], t.Union[nodes.Call, nodes.Getattr]]
 27
 28SQLMESH_JINJA_PACKAGE = "sqlmesh.utils.jinja"
 29
 30
 31def environment(**kwargs: t.Any) -> Environment:
 32    extensions = kwargs.pop("extensions", [])
 33    extensions.append("jinja2.ext.do")
 34    extensions.append("jinja2.ext.loopcontrols")
 35    return Environment(extensions=extensions, **kwargs)
 36
 37
 38ENVIRONMENT = environment()
 39
 40
 41class MacroReference(PydanticModel, frozen=True):
 42    package: t.Optional[str] = None
 43    name: str
 44
 45    @property
 46    def reference(self) -> str:
 47        if self.package is None:
 48            return self.name
 49        return ".".join((self.package, self.name))
 50
 51    def __str__(self) -> str:
 52        return self.reference
 53
 54
 55class MacroInfo(PydanticModel):
 56    """Class to hold macro and its calls"""
 57
 58    definition: str
 59    depends_on: t.List[MacroReference]
 60    is_top_level: bool = False
 61
 62
 63class MacroReturnVal(Exception):
 64    def __init__(self, val: t.Any):
 65        self.value = val
 66
 67
 68class MacroExtractor(Parser):
 69    def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]:
 70        """Extract a dictionary of macro definitions from a jinja string.
 71
 72        Args:
 73            jinja: The jinja string to extract from.
 74            dialect: The dialect of SQL.
 75
 76        Returns:
 77            A dictionary of macro name to macro definition.
 78        """
 79        self.reset()
 80        self.sql = jinja
 81        self._tokens = Dialect.get_or_raise(dialect).tokenize(jinja)
 82
 83        # guard for older sqlglot versions (before 30.0.3)
 84        if hasattr(self, "_tokens_size"):
 85            # keep the cached length in sync
 86            self._tokens_size = len(self._tokens)
 87        self._index = -1
 88        self._advance()
 89
 90        macros: t.Dict[str, MacroInfo] = {}
 91
 92        while self._curr:
 93            if self._curr.token_type == TokenType.BLOCK_START:
 94                macro_start = self._curr
 95            elif self._tag == "MACRO" and self._next:
 96                name = self._next.text
 97                while self._curr and self._curr.token_type != TokenType.BLOCK_END:
 98                    self._advance()
 99
100                while self._curr and self._tag != "ENDMACRO":
101                    self._advance()
102
103                macro_str = self._find_sql(macro_start, self._next)
104                macros[name] = MacroInfo(
105                    definition=macro_str,
106                    depends_on=list(extract_macro_references_and_variables(macro_str)[0]),
107                )
108
109            self._advance()
110
111        return macros
112
113    def _advance(self, times: int = 1) -> None:
114        super()._advance(times)
115        self._tag = (
116            self._curr.text.upper()
117            if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START
118            else ""
119        )
120
121
122def call_name(node: nodes.Expr) -> t.Tuple[str, ...]:
123    if isinstance(node, nodes.Name):
124        return (node.name,)
125    if isinstance(node, nodes.Const):
126        return (f"'{node.value}'",)
127    if isinstance(node, nodes.Getattr):
128        return call_name(node.node) + (node.attr,)
129    if isinstance(node, (nodes.Getitem, nodes.Call)):
130        return call_name(node.node)
131    return ()
132
133
134def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) -> str:
135    return ENVIRONMENT.from_string(query).render(methods or {})
136
137
138def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[CallNames]:
139    vars_in_scope = vars_in_scope.copy()
140    for child_node in node.iter_child_nodes():
141        if "target" in child_node.fields:
142            # For nodes with assignment targets (Assign, AssignBlock, For, Import),
143            # the target name could shadow a reference in the right hand side.
144            # So we need to process the RHS before adding the target to scope.
145            # For example: {% set model = model.path %} should track model.path.
146            yield from find_call_names(child_node, vars_in_scope)
147
148            target = getattr(child_node, "target")
149            if isinstance(target, nodes.Name):
150                vars_in_scope.add(target.name)
151            elif isinstance(target, nodes.Tuple):
152                for item in target.items:
153                    if isinstance(item, nodes.Name):
154                        vars_in_scope.add(item.name)
155        elif isinstance(child_node, nodes.Macro):
156            for arg in child_node.args:
157                vars_in_scope.add(arg.name)
158        elif isinstance(child_node, nodes.Call) or (
159            isinstance(child_node, nodes.Getattr) and not isinstance(child_node.node, nodes.Getattr)
160        ):
161            name = call_name(child_node)
162            if name[0][0] != "'" and name[0] not in vars_in_scope:
163                yield (name, child_node)
164
165        if "target" not in child_node.fields:
166            yield from find_call_names(child_node, vars_in_scope)
167
168
169def extract_call_names(
170    jinja_str: str, cache: t.Optional[t.Dict[str, t.Tuple[t.List[CallNames], bool]]] = None
171) -> t.List[CallNames]:
172    def parse() -> t.List[CallNames]:
173        return list(find_call_names(ENVIRONMENT.parse(jinja_str), set()))
174
175    if cache is not None:
176        key = str(zlib.crc32(jinja_str.encode("utf-8")))
177        if key in cache:
178            names = cache[key][0]
179        else:
180            names = parse()
181        cache[key] = (names, True)
182        return names
183    return parse()
184
185
186def is_variable_node(n: nodes.Node) -> bool:
187    return (
188        isinstance(n, nodes.Call)
189        and isinstance(n.node, nodes.Name)
190        and n.node.name in (c.VAR, c.BLUEPRINT_VAR)
191    )
192
193
194def extract_macro_references_and_variables(
195    *jinja_strs: str,
196) -> t.Tuple[t.Set[MacroReference], t.Set[str]]:
197    macro_references = set()
198    variables = set()
199    for jinja_str in jinja_strs:
200        for call_name, node in extract_call_names(jinja_str):
201            if call_name[0] in (c.VAR, c.BLUEPRINT_VAR):
202                if not is_variable_node(node):
203                    # Find the variable node which could be nested
204                    for n in node.find_all(nodes.Call):
205                        if is_variable_node(n):
206                            node = n
207                            break
208                    else:
209                        raise ValueError(f"Could not find variable name in {jinja_str}")
210                node = t.cast(nodes.Call, node)
211                args = [jinja_call_arg_name(arg) for arg in node.args]
212                if args and args[0]:
213                    variables.add(args[0].lower())
214            elif call_name[0] == c.GATEWAY:
215                variables.add(c.GATEWAY)
216            elif len(call_name) == 1:
217                macro_references.add(MacroReference(name=call_name[0]))
218            elif len(call_name) == 2:
219                macro_references.add(MacroReference(package=call_name[0], name=call_name[1]))
220    return macro_references, variables
221
222
223def sort_dict_recursive(
224    item: t.Dict[str, t.Any],
225) -> t.Dict[str, t.Any]:
226    sorted_dict: t.Dict[str, t.Any] = {}
227    for k, v in sorted(item.items()):
228        if isinstance(v, list):
229            sorted_dict[k] = sorted(v)
230        elif isinstance(v, dict):
231            sorted_dict[k] = sort_dict_recursive(v)
232        else:
233            sorted_dict[k] = v
234    return sorted_dict
235
236
237JinjaGlobalAttribute = t.Union[str, int, float, bool, AttributeDict]
238
239
240class JinjaMacroRegistry(PydanticModel):
241    """Registry for Jinja macros.
242
243    Args:
244        packages: The mapping from package name to a collection of macro definitions.
245        root_macros: The collection of top-level macro definitions.
246        global_objs: The global objects.
247        create_builtins_module: The name of a module which defines the `create_builtins` factory
248            function that will be used to construct builtin variables and functions.
249        root_package_name: The name of the root package. If specified root macros will be available
250            as both `root_package_name.macro_name` and `macro_name`.
251        top_level_packages: The list of top-level packages. Macros in this packages will be available
252            as both `package_name.macro_name` and `macro_name`.
253    """
254
255    packages: t.Dict[str, t.Dict[str, MacroInfo]] = {}
256    root_macros: t.Dict[str, MacroInfo] = {}
257    global_objs: t.Dict[str, JinjaGlobalAttribute] = {}
258    create_builtins_module: t.Optional[str] = SQLMESH_JINJA_PACKAGE
259    root_package_name: t.Optional[str] = None
260    top_level_packages: t.List[str] = []
261
262    _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {}
263    _trimmed: bool = False
264    __environment: t.Optional[Environment] = None
265
266    def __getstate__(self) -> t.Dict[t.Any, t.Any]:
267        state = super().__getstate__()
268        private = state[PRIVATE_FIELDS]
269        private["_parser_cache"] = {}
270        private["_JinjaMacroRegistry__environment"] = None
271        return state
272
273    @field_validator("global_objs", mode="before")
274    @classmethod
275    def _validate_global_objs(cls, value: t.Any) -> t.Any:
276        def _normalize(val: t.Any) -> t.Any:
277            if isinstance(val, dict):
278                return AttributeDict({k: _normalize(v) for k, v in val.items()})
279            if isinstance(val, list):
280                return [_normalize(v) for v in val]
281            if isinstance(val, set):
282                return [_normalize(v) for v in sorted(val)]
283            if isinstance(val, Enum):
284                return val.value
285            return val
286
287        return _normalize(value)
288
289    @field_serializer("global_objs")
290    def _serialize_attribute_dict(
291        self, value: t.Dict[str, JinjaGlobalAttribute]
292    ) -> t.Dict[str, t.Any]:
293        # NOTE: This is called only when used with Pydantic V2.
294        def _convert(
295            val: t.Union[t.Dict[str, JinjaGlobalAttribute], t.Dict[str, t.Any]],
296        ) -> t.Dict[str, t.Any]:
297            return {k: _convert(v) if isinstance(v, AttributeDict) else v for k, v in val.items()}
298
299        return _convert(value)
300
301    @property
302    def trimmed(self) -> bool:
303        return self._trimmed
304
305    def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None:
306        """Adds macros to the target package.
307
308        Args:
309            macros: Macros that should be added.
310            package: The name of the package the given macros belong to. If not specified, the provided
311            macros will be added to the root namespace.
312        """
313
314        if package is not None:
315            package_macros = self.packages.get(package, {})
316            package_macros.update(macros)
317            self.packages[package] = package_macros
318        else:
319            self.root_macros.update(macros)
320
321    def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None:
322        """Adds global objects to the registry.
323
324        Args:
325            globals: The global objects that should be added.
326        """
327        # Keep the registry lightweight when the graph is not needed
328        if not "graph" in self.packages:
329            globals.pop("flat_graph", None)
330        self.global_objs.update(**self._validate_global_objs(globals))
331
332    def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]:
333        """Builds a Python callable for a macro with the given reference.
334
335        Args:
336            reference: The macro reference.
337        Returns:
338            The macro as a Python callable or None if not found.
339        """
340        env: Environment = self.build_environment(**kwargs)
341        if reference.package is not None:
342            package = env.globals.get(reference.package, {})
343            return package.get(reference.name)  # type: ignore
344        return env.globals.get(reference.name)  # type: ignore
345
346    def build_environment(self, **kwargs: t.Any) -> Environment:
347        """Builds a new Jinja environment based on this registry."""
348
349        context: t.Dict[str, t.Any] = {}
350
351        root_macros = {
352            name: self._MacroWrapper(name, None, self, context)
353            for name, macro in self.root_macros.items()
354        }
355
356        package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict)
357        for package_name, macros in self.packages.items():
358            for macro_name, macro in macros.items():
359                macro_wrapper = self._MacroWrapper(macro_name, package_name, self, context)
360                package_macros[package_name][macro_name] = macro_wrapper
361                if macro.is_top_level and macro_name not in root_macros:
362                    root_macros[macro_name] = macro_wrapper
363
364        top_level_packages = self.top_level_packages.copy()
365
366        if self.root_package_name is not None:
367            package_macros[self.root_package_name].update(root_macros)
368            top_level_packages.append(self.root_package_name)
369
370        env = environment()
371
372        builtin_globals = self._create_builtin_globals(kwargs)
373        for top_level_package_name in top_level_packages:
374            # Make sure that the top-level package doesn't fully override the same builtin package.
375            package_macros[top_level_package_name] = AttributeDict(
376                {
377                    **(builtin_globals.pop(top_level_package_name, None) or {}),
378                    **(package_macros.get(top_level_package_name) or {}),
379                }
380            )
381            root_macros.update(package_macros[top_level_package_name])
382
383        context.update(builtin_globals)
384        context.update(root_macros)
385        context.update(package_macros)
386        context["render"] = lambda input: env.from_string(input).render()
387
388        env.globals.update(context)
389        env.filters.update(self._environment.filters)
390        return env
391
392    def trim(
393        self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None
394    ) -> JinjaMacroRegistry:
395        """Trims the registry by keeping only macros with given references and their transitive dependencies.
396
397        Args:
398            dependencies: References to macros that should be kept.
399            package: The name of the package in the context of which the trimming should be performed.
400
401        Returns:
402            A new trimmed registry.
403        """
404        dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set)
405        for dep in dependencies:
406            dependencies_by_package[dep.package or package].add(dep.name)
407
408        top_level_packages = self.top_level_packages.copy()
409        if package is not None:
410            top_level_packages.append(package)
411
412        result = JinjaMacroRegistry(
413            global_objs=self.global_objs.copy(),
414            create_builtins_module=self.create_builtins_module,
415            root_package_name=self.root_package_name,
416            top_level_packages=top_level_packages,
417        )
418        for package, names in dependencies_by_package.items():
419            result = result.merge(self._trim_macros(names, package))
420
421        result._trimmed = True
422
423        return result
424
425    def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry:
426        """Returns a copy of the registry which contains macros from both this and `other` instances.
427
428        Args:
429            other: The other registry instance.
430
431        Returns:
432            A new merged registry.
433        """
434
435        root_macros = {
436            **self.root_macros,
437            **other.root_macros,
438        }
439
440        packages = {}
441        for package in {*self.packages, *other.packages}:
442            packages[package] = {
443                **self.packages.get(package, {}),
444                **other.packages.get(package, {}),
445            }
446
447        global_objs = {
448            **self.global_objs,
449            **other.global_objs,
450        }
451
452        return JinjaMacroRegistry(
453            packages=packages,
454            root_macros=root_macros,
455            global_objs=global_objs,
456            create_builtins_module=self.create_builtins_module or other.create_builtins_module,
457            root_package_name=self.root_package_name or other.root_package_name,
458            top_level_packages=[*self.top_level_packages, *other.top_level_packages],
459        )
460
461    def to_expressions(self) -> t.List[Expression]:
462        output: t.List[Expression] = []
463
464        filtered_objs = {
465            k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars")
466        }
467        if filtered_objs:
468            output.append(
469                d.PythonCode(
470                    expressions=[
471                        f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}"
472                        for k, v in sort_dict_recursive(filtered_objs).items()
473                    ]
474                )
475            )
476
477        for macro_name, macro_info in sorted(self.root_macros.items()):
478            output.append(d.jinja_statement(macro_info.definition))
479
480        for _, package in sorted(self.packages.items()):
481            for macro_name, macro_info in sorted(package.items()):
482                output.append(d.jinja_statement(macro_info.definition))
483
484        return output
485
486    @property
487    def data_hash_values(self) -> t.List[str]:
488        data = []
489
490        for macro_name, macro in sorted(self.root_macros.items()):
491            data.append(macro_name)
492            data.append(macro.definition)
493
494        for _, package in sorted(self.packages.items()):
495            for macro_name, macro in sorted(package.items()):
496                data.append(macro_name)
497                data.append(macro.definition)
498
499        trimmed_global_objs = {
500            k: self.global_objs[k] for k in ("refs", "sources", "vars") if k in self.global_objs
501        }
502        data.append(json.dumps(trimmed_global_objs, sort_keys=True))
503
504        return data
505
506    def __deepcopy__(self, memo: t.Optional[t.Dict[int, t.Any]] = None) -> JinjaMacroRegistry:
507        return JinjaMacroRegistry.parse_obj(self.dict())
508
509    def _parse_macro(self, name: str, package: t.Optional[str]) -> Template:
510        cache_key = (package, name)
511        if cache_key not in self._parser_cache:
512            macro = self._get_macro(name, package)
513
514            definition: nodes.Template = self._environment.parse(macro.definition)
515            if _is_private_macro(name):
516                # A workaround to expose private jinja macros.
517                definition = self._to_non_private_macro_def(name, definition)
518
519            self._parser_cache[cache_key] = self._environment.from_string(definition)
520        return self._parser_cache[cache_key]
521
522    @property
523    def _environment(self) -> Environment:
524        if self.__environment is None:
525            self.__environment = environment()
526            self.__environment.filters.update(self._create_builtin_filters())
527        return self.__environment
528
529    def _trim_macros(
530        self,
531        names: t.Set[str],
532        package: t.Optional[str] = None,
533        visited: t.Optional[t.Dict[t.Optional[str], t.Set[str]]] = None,
534    ) -> JinjaMacroRegistry:
535        if visited is None:
536            visited = defaultdict(set)
537
538        macros = self.packages.get(package, {}) if package is not None else self.root_macros
539        trimmed_macros = {}
540
541        dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set)
542
543        for name in names:
544            if name in macros and name not in visited[package]:
545                macro = macros[name]
546                trimmed_macros[name] = macro
547                for dependency in macro.depends_on:
548                    dependencies[dependency.package or package].add(dependency.name)
549                visited[package].add(name)
550
551        if package is not None:
552            result = JinjaMacroRegistry(packages={package: trimmed_macros})
553        else:
554            result = JinjaMacroRegistry(root_macros=trimmed_macros)
555
556        for upstream_package, upstream_names in dependencies.items():
557            result = result.merge(
558                self._trim_macros(upstream_names, upstream_package, visited=visited)
559            )
560
561        return result
562
563    def _macro_exists(self, name: str, package: t.Optional[str]) -> bool:
564        return (
565            name in self.packages.get(package, {})
566            if package is not None
567            else name in self.root_macros
568        )
569
570    def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo:
571        return self.packages[package][name] if package is not None else self.root_macros[name]
572
573    def _to_non_private_macro_def(self, name: str, template: nodes.Template) -> nodes.Template:
574        for node in template.find_all((nodes.Macro, nodes.Call)):
575            if isinstance(node, nodes.Macro):
576                node.name = _non_private_name(name)
577            elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name):
578                node.node.name = _non_private_name(name)
579
580        return template
581
582    def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
583        """Creates Jinja builtin globals using a factory function defined in the provided module."""
584        engine_adapter = global_vars.pop("engine_adapter", None)
585        global_vars = {**self.global_objs, **global_vars}
586        if self.create_builtins_module is not None:
587            module = importlib.import_module(self.create_builtins_module)
588            if hasattr(module, "create_builtin_globals"):
589                return module.create_builtin_globals(self, global_vars, engine_adapter)
590        return global_vars
591
592    def _create_builtin_filters(self) -> t.Dict[str, t.Any]:
593        """Creates Jinja builtin filters using a factory function defined in the provided module."""
594        if self.create_builtins_module is not None:
595            module = importlib.import_module(self.create_builtins_module)
596            if hasattr(module, "create_builtin_filters"):
597                return module.create_builtin_filters()
598        return {}
599
600    class _MacroWrapper:
601        def __init__(
602            self,
603            name: str,
604            package: t.Optional[str],
605            registry: JinjaMacroRegistry,
606            context: t.Dict[str, t.Any],
607        ):
608            self.name = name
609            self.package = package
610            self.context = context
611            self.registry = registry
612
613        def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
614            context = self.context.copy()
615            if self.package is not None and self.package in context:
616                context.update(context[self.package])
617
618            template = self.registry._parse_macro(self.name, self.package)
619            macro_callable = getattr(
620                template.make_module(vars=context), _non_private_name(self.name)
621            )
622            try:
623                return macro_callable(*args, **kwargs)
624            except MacroReturnVal as ret:
625                return ret.value
626
627
628def _is_private_macro(name: str) -> bool:
629    return name.startswith("_")
630
631
632def _non_private_name(name: str) -> str:
633    return name.lstrip("_")
634
635
636JINJA_REGEX = re.compile(r"({{|{%)")
637
638
639def has_jinja(value: str) -> bool:
640    return JINJA_REGEX.search(value) is not None
641
642
643def jinja_call_arg_name(node: nodes.Node) -> str:
644    if isinstance(node, nodes.Const):
645        return node.value
646    return ""
647
648
649def create_var(variables: t.Dict[str, t.Any]) -> t.Callable:
650    def _var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
651        value = variables.get(var_name.lower(), default)
652        if isinstance(value, SqlValue):
653            return value.sql
654        return value
655
656    return _var
657
658
659def create_builtin_globals(
660    jinja_macros: JinjaMacroRegistry, global_vars: t.Dict[str, t.Any], *args: t.Any, **kwargs: t.Any
661) -> t.Dict[str, t.Any]:
662    global_vars.pop(c.GATEWAY, None)
663    variables = global_vars.pop(c.SQLMESH_VARS, None) or {}
664    blueprint_variables = global_vars.pop(c.SQLMESH_BLUEPRINT_VARS, None) or {}
665    return {
666        **global_vars,
667        c.VAR: create_var(variables),
668        c.GATEWAY: lambda: variables.get(c.GATEWAY, None),
669        c.BLUEPRINT_VAR: create_var(blueprint_variables),
670    }
671
672
673def make_jinja_registry(
674    jinja_macros: JinjaMacroRegistry, package_name: str, jinja_references: t.Set[MacroReference]
675) -> JinjaMacroRegistry:
676    """
677    Creates a Jinja macro registry for a specific package.
678
679    This function takes an existing Jinja macro registry and returns a new
680    registry that includes only the macros associated with the specified
681    package and trims the registry to include only the macros referenced
682    in the provided set of macro references.
683
684    Args:
685        jinja_macros: The original Jinja macro registry containing all macros.
686        package_name: The name of the package for which to create the registry.
687        jinja_references: A set of macro references to retain in the new registry.
688
689    Returns:
690        A new JinjaMacroRegistry containing only the macros for the specified
691        package and the referenced macros.
692    """
693
694    jinja_registry = jinja_macros.copy()
695    jinja_registry.root_macros = jinja_registry.packages.get(package_name) or {}
696    jinja_registry = jinja_registry.trim(jinja_references)
697
698    return jinja_registry
699
700
701def extract_error_details(ex: Exception) -> str:
702    """Extracts a readable message from a Jinja2 error, to include missing name and macro."""
703
704    error_details = ""
705    if isinstance(ex, UndefinedError):
706        if match := re.search(r"'(\w+)'", str(ex)):
707            error_details += f"\nUndefined macro/variable: '{match.group(1)}'"
708        try:
709            _, _, exc_traceback = exc_info()
710            for frame, _ in walk_tb(exc_traceback):
711                if frame.f_code.co_name == "_invoke":
712                    macro = frame.f_locals.get("self")
713                    if isinstance(macro, Macro):
714                        error_details += f" in macro: '{macro.name}'\n"
715                        break
716        except:
717            # to fall back to the generic error message if frame analysis fails
718            pass
719    return error_details or str(ex)
SQLMESH_JINJA_PACKAGE = 'sqlmesh.utils.jinja'
def environment(**kwargs: Any) -> jinja2.environment.Environment:
32def environment(**kwargs: t.Any) -> Environment:
33    extensions = kwargs.pop("extensions", [])
34    extensions.append("jinja2.ext.do")
35    extensions.append("jinja2.ext.loopcontrols")
36    return Environment(extensions=extensions, **kwargs)
ENVIRONMENT = <jinja2.environment.Environment object>
class MacroReference(sqlmesh.utils.pydantic.PydanticModel):
42class MacroReference(PydanticModel, frozen=True):
43    package: t.Optional[str] = None
44    name: str
45
46    @property
47    def reference(self) -> str:
48        if self.package is None:
49            return self.name
50        return ".".join((self.package, self.name))
51
52    def __str__(self) -> str:
53        return self.reference

!!! abstract "Usage Documentation" Models

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
package: Optional[str]
name: str
reference: str
46    @property
47    def reference(self) -> str:
48        if self.package is None:
49            return self.name
50        return ".".join((self.package, self.name))
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': (), 'frozen': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class MacroInfo(sqlmesh.utils.pydantic.PydanticModel):
56class MacroInfo(PydanticModel):
57    """Class to hold macro and its calls"""
58
59    definition: str
60    depends_on: t.List[MacroReference]
61    is_top_level: bool = False

Class to hold macro and its calls

definition: str
depends_on: List[MacroReference]
is_top_level: bool
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class MacroReturnVal(builtins.Exception):
64class MacroReturnVal(Exception):
65    def __init__(self, val: t.Any):
66        self.value = val

Common base class for all non-exit exceptions.

MacroReturnVal(val: Any)
65    def __init__(self, val: t.Any):
66        self.value = val
value
Inherited Members
builtins.BaseException
with_traceback
args
class MacroExtractor(sqlglot.parser.Parser):
 69class MacroExtractor(Parser):
 70    def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]:
 71        """Extract a dictionary of macro definitions from a jinja string.
 72
 73        Args:
 74            jinja: The jinja string to extract from.
 75            dialect: The dialect of SQL.
 76
 77        Returns:
 78            A dictionary of macro name to macro definition.
 79        """
 80        self.reset()
 81        self.sql = jinja
 82        self._tokens = Dialect.get_or_raise(dialect).tokenize(jinja)
 83
 84        # guard for older sqlglot versions (before 30.0.3)
 85        if hasattr(self, "_tokens_size"):
 86            # keep the cached length in sync
 87            self._tokens_size = len(self._tokens)
 88        self._index = -1
 89        self._advance()
 90
 91        macros: t.Dict[str, MacroInfo] = {}
 92
 93        while self._curr:
 94            if self._curr.token_type == TokenType.BLOCK_START:
 95                macro_start = self._curr
 96            elif self._tag == "MACRO" and self._next:
 97                name = self._next.text
 98                while self._curr and self._curr.token_type != TokenType.BLOCK_END:
 99                    self._advance()
100
101                while self._curr and self._tag != "ENDMACRO":
102                    self._advance()
103
104                macro_str = self._find_sql(macro_start, self._next)
105                macros[name] = MacroInfo(
106                    definition=macro_str,
107                    depends_on=list(extract_macro_references_and_variables(macro_str)[0]),
108                )
109
110            self._advance()
111
112        return macros
113
114    def _advance(self, times: int = 1) -> None:
115        super()._advance(times)
116        self._tag = (
117            self._curr.text.upper()
118            if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START
119            else ""
120        )

Parser consumes a list of tokens produced by the Tokenizer and produces a parsed syntax tree.

Arguments:
  • error_level: The desired error level. Default: ErrorLevel.IMMEDIATE
  • error_message_context: The amount of context to capture from a query string when displaying the error message (in number of characters). Default: 100
  • max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. Default: 3
def extract( self, jinja: str, dialect: str = '') -> Dict[str, MacroInfo]:
 70    def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]:
 71        """Extract a dictionary of macro definitions from a jinja string.
 72
 73        Args:
 74            jinja: The jinja string to extract from.
 75            dialect: The dialect of SQL.
 76
 77        Returns:
 78            A dictionary of macro name to macro definition.
 79        """
 80        self.reset()
 81        self.sql = jinja
 82        self._tokens = Dialect.get_or_raise(dialect).tokenize(jinja)
 83
 84        # guard for older sqlglot versions (before 30.0.3)
 85        if hasattr(self, "_tokens_size"):
 86            # keep the cached length in sync
 87            self._tokens_size = len(self._tokens)
 88        self._index = -1
 89        self._advance()
 90
 91        macros: t.Dict[str, MacroInfo] = {}
 92
 93        while self._curr:
 94            if self._curr.token_type == TokenType.BLOCK_START:
 95                macro_start = self._curr
 96            elif self._tag == "MACRO" and self._next:
 97                name = self._next.text
 98                while self._curr and self._curr.token_type != TokenType.BLOCK_END:
 99                    self._advance()
100
101                while self._curr and self._tag != "ENDMACRO":
102                    self._advance()
103
104                macro_str = self._find_sql(macro_start, self._next)
105                macros[name] = MacroInfo(
106                    definition=macro_str,
107                    depends_on=list(extract_macro_references_and_variables(macro_str)[0]),
108                )
109
110            self._advance()
111
112        return macros

Extract a dictionary of macro definitions from a jinja string.

Arguments:
  • jinja: The jinja string to extract from.
  • dialect: The dialect of SQL.
Returns:

A dictionary of macro name to macro definition.

Inherited Members
sqlglot.parser.Parser
Parser
FUNCTIONS
NO_PAREN_FUNCTIONS
STRUCT_TYPE_TOKENS
NESTED_TYPE_TOKENS
ENUM_TYPE_TOKENS
AGGREGATE_TYPE_TOKENS
TYPE_TOKENS
SIGNED_TO_UNSIGNED_TYPE_TOKEN
SUBQUERY_PREDICATES
SUBQUERY_TOKENS
RESERVED_TOKENS
DB_CREATABLES
CREATABLES
TRIGGER_EVENTS
ALTERABLES
ID_VAR_TOKENS
TABLE_ALIAS_TOKENS
ALIAS_TOKENS
COLON_PLACEHOLDER_TOKENS
ARRAY_CONSTRUCTORS
COMMENT_TABLE_ALIAS_TOKENS
UPDATE_ALIAS_TOKENS
TRIM_TYPES
IDENTIFIER_TOKENS
BRACKETS
COLUMN_POSTFIX_TOKENS
TABLE_POSTFIX_TOKENS
FUNC_TOKENS
CONJUNCTION
ASSIGNMENT
DISJUNCTION
EQUALITY
COMPARISON
BITWISE
TERM
FACTOR
EXPONENT
TIMES
TIMESTAMPS
SET_OPERATIONS
JOIN_METHODS
JOIN_SIDES
JOIN_KINDS
JOIN_HINTS
TABLE_TERMINATORS
LAMBDAS
TYPED_LAMBDA_ARGS
LAMBDA_ARG_TERMINATORS
COLUMN_OPERATORS
CAST_COLUMN_OPERATORS
EXPRESSION_PARSERS
STATEMENT_PARSERS
UNARY_PARSERS
STRING_PARSERS
NUMERIC_PARSERS
PRIMARY_PARSERS
PLACEHOLDER_PARSERS
RANGE_PARSERS
PIPE_SYNTAX_TRANSFORM_PARSERS
PROPERTY_PARSERS
CONSTRAINT_PARSERS
ALTER_PARSERS
ALTER_ALTER_PARSERS
SCHEMA_UNNAMED_CONSTRAINTS
NO_PAREN_FUNCTION_PARSERS
INVALID_FUNC_NAME_TOKENS
FUNCTIONS_WITH_ALIASED_ARGS
KEY_VALUE_DEFINITIONS
FUNCTION_PARSERS
QUERY_MODIFIER_PARSERS
QUERY_MODIFIER_TOKENS
SET_PARSERS
SHOW_PARSERS
TYPE_LITERAL_PARSERS
TYPE_CONVERTERS
DDL_SELECT_TOKENS
PRE_VOLATILE_TOKENS
TRANSACTION_KIND
TRANSACTION_CHARACTERISTICS
CONFLICT_ACTIONS
TRIGGER_TIMING
TRIGGER_DEFERRABLE
CREATE_SEQUENCE
ISOLATED_LOADING_OPTIONS
USABLES
CAST_ACTIONS
SCHEMA_BINDING_OPTIONS
PROCEDURE_OPTIONS
EXECUTE_AS_OPTIONS
KEY_CONSTRAINT_OPTIONS
WINDOW_EXCLUDE_OPTIONS
INSERT_ALTERNATIVES
CLONE_KEYWORDS
HISTORICAL_DATA_PREFIX
HISTORICAL_DATA_KIND
OPCLASS_FOLLOW_KEYWORDS
OPTYPE_FOLLOW_TOKENS
TABLE_INDEX_HINT_TOKENS
VIEW_ATTRIBUTES
WINDOW_ALIAS_TOKENS
WINDOW_BEFORE_PAREN_TOKENS
WINDOW_SIDES
JSON_KEY_VALUE_SEPARATOR_TOKENS
FETCH_TOKENS
ADD_CONSTRAINT_TOKENS
DISTINCT_TOKENS
UNNEST_OFFSET_ALIAS_TOKENS
SELECT_START_TOKENS
COPY_INTO_VARLEN_OPTIONS
IS_JSON_PREDICATE_KIND
ODBC_DATETIME_LITERALS
ON_CONDITION_TOKENS
PRIVILEGE_FOLLOW_TOKENS
DESCRIBE_STYLES
SET_ASSIGNMENT_DELIMITERS
ANALYZE_STYLES
ANALYZE_EXPRESSION_PARSERS
PARTITION_KEYWORDS
AMBIGUOUS_ALIAS_TOKENS
OPERATION_MODIFIERS
RECURSIVE_CTE_SEARCH_KIND
SECURITY_PROPERTY_KEYWORDS
MODIFIABLES
STRICT_CAST
PREFIXED_PIVOT_COLUMNS
IDENTIFY_PIVOT_STRINGS
LOG_DEFAULTS_TO_LN
TABLESAMPLE_CSV
DEFAULT_SAMPLING_METHOD
SET_REQUIRES_ASSIGNMENT_DELIMITER
TRIM_PATTERN_FIRST
STRING_ALIASES
MODIFIERS_ATTACHED_TO_SET_OP
SET_OP_MODIFIERS
NO_PAREN_IF_COMMANDS
JSON_ARROWS_REQUIRE_JSON_TYPE
COLON_IS_VARIANT_EXTRACT
VALUES_FOLLOWED_BY_PAREN
SUPPORTS_IMPLICIT_UNNEST
INTERVAL_SPANS
SUPPORTS_PARTITION_SELECTION
WRAPPED_TRANSFORM_COLUMN_CONSTRAINT
OPTIONAL_ALIAS_TOKEN_CTE
ALTER_RENAME_REQUIRES_COLUMN
ALTER_TABLE_PARTITIONS
JOINS_HAVE_EQUAL_PRECEDENCE
ZONE_AWARE_TIMESTAMP_CONSTRUCTOR
MAP_KEYS_ARE_ARBITRARY_EXPRESSIONS
JSON_EXTRACT_REQUIRES_JSON_EXPRESSION
ADD_JOIN_ON_TRUE
SUPPORTS_OMITTED_INTERVAL_SPAN_UNIT
SHOW_TRIE
SET_TRIE
error_level
error_message_context
max_errors
dialect
sql
errors
reset
raise_error
validate_expression
parse
parse_into
check_errors
expression
parse_set_operation
build_cast
def call_name(node: jinja2.nodes.Expr) -> Tuple[str, ...]:
123def call_name(node: nodes.Expr) -> t.Tuple[str, ...]:
124    if isinstance(node, nodes.Name):
125        return (node.name,)
126    if isinstance(node, nodes.Const):
127        return (f"'{node.value}'",)
128    if isinstance(node, nodes.Getattr):
129        return call_name(node.node) + (node.attr,)
130    if isinstance(node, (nodes.Getitem, nodes.Call)):
131        return call_name(node.node)
132    return ()
def render_jinja(query: str, methods: Optional[Dict[str, Any]] = None) -> str:
135def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) -> str:
136    return ENVIRONMENT.from_string(query).render(methods or {})
def find_call_names( node: jinja2.nodes.Node, vars_in_scope: Set[str]) -> Iterator[Tuple[Tuple[str, ...], Union[jinja2.nodes.Call, jinja2.nodes.Getattr]]]:
139def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[CallNames]:
140    vars_in_scope = vars_in_scope.copy()
141    for child_node in node.iter_child_nodes():
142        if "target" in child_node.fields:
143            # For nodes with assignment targets (Assign, AssignBlock, For, Import),
144            # the target name could shadow a reference in the right hand side.
145            # So we need to process the RHS before adding the target to scope.
146            # For example: {% set model = model.path %} should track model.path.
147            yield from find_call_names(child_node, vars_in_scope)
148
149            target = getattr(child_node, "target")
150            if isinstance(target, nodes.Name):
151                vars_in_scope.add(target.name)
152            elif isinstance(target, nodes.Tuple):
153                for item in target.items:
154                    if isinstance(item, nodes.Name):
155                        vars_in_scope.add(item.name)
156        elif isinstance(child_node, nodes.Macro):
157            for arg in child_node.args:
158                vars_in_scope.add(arg.name)
159        elif isinstance(child_node, nodes.Call) or (
160            isinstance(child_node, nodes.Getattr) and not isinstance(child_node.node, nodes.Getattr)
161        ):
162            name = call_name(child_node)
163            if name[0][0] != "'" and name[0] not in vars_in_scope:
164                yield (name, child_node)
165
166        if "target" not in child_node.fields:
167            yield from find_call_names(child_node, vars_in_scope)
def extract_call_names( jinja_str: str, cache: Optional[Dict[str, Tuple[List[Tuple[Tuple[str, ...], Union[jinja2.nodes.Call, jinja2.nodes.Getattr]]], bool]]] = None) -> List[Tuple[Tuple[str, ...], Union[jinja2.nodes.Call, jinja2.nodes.Getattr]]]:
170def extract_call_names(
171    jinja_str: str, cache: t.Optional[t.Dict[str, t.Tuple[t.List[CallNames], bool]]] = None
172) -> t.List[CallNames]:
173    def parse() -> t.List[CallNames]:
174        return list(find_call_names(ENVIRONMENT.parse(jinja_str), set()))
175
176    if cache is not None:
177        key = str(zlib.crc32(jinja_str.encode("utf-8")))
178        if key in cache:
179            names = cache[key][0]
180        else:
181            names = parse()
182        cache[key] = (names, True)
183        return names
184    return parse()
def is_variable_node(n: jinja2.nodes.Node) -> bool:
187def is_variable_node(n: nodes.Node) -> bool:
188    return (
189        isinstance(n, nodes.Call)
190        and isinstance(n.node, nodes.Name)
191        and n.node.name in (c.VAR, c.BLUEPRINT_VAR)
192    )
def extract_macro_references_and_variables( *jinja_strs: str) -> Tuple[Set[MacroReference], Set[str]]:
195def extract_macro_references_and_variables(
196    *jinja_strs: str,
197) -> t.Tuple[t.Set[MacroReference], t.Set[str]]:
198    macro_references = set()
199    variables = set()
200    for jinja_str in jinja_strs:
201        for call_name, node in extract_call_names(jinja_str):
202            if call_name[0] in (c.VAR, c.BLUEPRINT_VAR):
203                if not is_variable_node(node):
204                    # Find the variable node which could be nested
205                    for n in node.find_all(nodes.Call):
206                        if is_variable_node(n):
207                            node = n
208                            break
209                    else:
210                        raise ValueError(f"Could not find variable name in {jinja_str}")
211                node = t.cast(nodes.Call, node)
212                args = [jinja_call_arg_name(arg) for arg in node.args]
213                if args and args[0]:
214                    variables.add(args[0].lower())
215            elif call_name[0] == c.GATEWAY:
216                variables.add(c.GATEWAY)
217            elif len(call_name) == 1:
218                macro_references.add(MacroReference(name=call_name[0]))
219            elif len(call_name) == 2:
220                macro_references.add(MacroReference(package=call_name[0], name=call_name[1]))
221    return macro_references, variables
def sort_dict_recursive(item: Dict[str, Any]) -> Dict[str, Any]:
224def sort_dict_recursive(
225    item: t.Dict[str, t.Any],
226) -> t.Dict[str, t.Any]:
227    sorted_dict: t.Dict[str, t.Any] = {}
228    for k, v in sorted(item.items()):
229        if isinstance(v, list):
230            sorted_dict[k] = sorted(v)
231        elif isinstance(v, dict):
232            sorted_dict[k] = sort_dict_recursive(v)
233        else:
234            sorted_dict[k] = v
235    return sorted_dict
JinjaGlobalAttribute = typing.Union[str, int, float, bool, sqlmesh.utils.AttributeDict]
class JinjaMacroRegistry(sqlmesh.utils.pydantic.PydanticModel):
241class JinjaMacroRegistry(PydanticModel):
242    """Registry for Jinja macros.
243
244    Args:
245        packages: The mapping from package name to a collection of macro definitions.
246        root_macros: The collection of top-level macro definitions.
247        global_objs: The global objects.
248        create_builtins_module: The name of a module which defines the `create_builtins` factory
249            function that will be used to construct builtin variables and functions.
250        root_package_name: The name of the root package. If specified root macros will be available
251            as both `root_package_name.macro_name` and `macro_name`.
252        top_level_packages: The list of top-level packages. Macros in this packages will be available
253            as both `package_name.macro_name` and `macro_name`.
254    """
255
256    packages: t.Dict[str, t.Dict[str, MacroInfo]] = {}
257    root_macros: t.Dict[str, MacroInfo] = {}
258    global_objs: t.Dict[str, JinjaGlobalAttribute] = {}
259    create_builtins_module: t.Optional[str] = SQLMESH_JINJA_PACKAGE
260    root_package_name: t.Optional[str] = None
261    top_level_packages: t.List[str] = []
262
263    _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {}
264    _trimmed: bool = False
265    __environment: t.Optional[Environment] = None
266
267    def __getstate__(self) -> t.Dict[t.Any, t.Any]:
268        state = super().__getstate__()
269        private = state[PRIVATE_FIELDS]
270        private["_parser_cache"] = {}
271        private["_JinjaMacroRegistry__environment"] = None
272        return state
273
274    @field_validator("global_objs", mode="before")
275    @classmethod
276    def _validate_global_objs(cls, value: t.Any) -> t.Any:
277        def _normalize(val: t.Any) -> t.Any:
278            if isinstance(val, dict):
279                return AttributeDict({k: _normalize(v) for k, v in val.items()})
280            if isinstance(val, list):
281                return [_normalize(v) for v in val]
282            if isinstance(val, set):
283                return [_normalize(v) for v in sorted(val)]
284            if isinstance(val, Enum):
285                return val.value
286            return val
287
288        return _normalize(value)
289
290    @field_serializer("global_objs")
291    def _serialize_attribute_dict(
292        self, value: t.Dict[str, JinjaGlobalAttribute]
293    ) -> t.Dict[str, t.Any]:
294        # NOTE: This is called only when used with Pydantic V2.
295        def _convert(
296            val: t.Union[t.Dict[str, JinjaGlobalAttribute], t.Dict[str, t.Any]],
297        ) -> t.Dict[str, t.Any]:
298            return {k: _convert(v) if isinstance(v, AttributeDict) else v for k, v in val.items()}
299
300        return _convert(value)
301
302    @property
303    def trimmed(self) -> bool:
304        return self._trimmed
305
306    def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None:
307        """Adds macros to the target package.
308
309        Args:
310            macros: Macros that should be added.
311            package: The name of the package the given macros belong to. If not specified, the provided
312            macros will be added to the root namespace.
313        """
314
315        if package is not None:
316            package_macros = self.packages.get(package, {})
317            package_macros.update(macros)
318            self.packages[package] = package_macros
319        else:
320            self.root_macros.update(macros)
321
322    def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None:
323        """Adds global objects to the registry.
324
325        Args:
326            globals: The global objects that should be added.
327        """
328        # Keep the registry lightweight when the graph is not needed
329        if not "graph" in self.packages:
330            globals.pop("flat_graph", None)
331        self.global_objs.update(**self._validate_global_objs(globals))
332
333    def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]:
334        """Builds a Python callable for a macro with the given reference.
335
336        Args:
337            reference: The macro reference.
338        Returns:
339            The macro as a Python callable or None if not found.
340        """
341        env: Environment = self.build_environment(**kwargs)
342        if reference.package is not None:
343            package = env.globals.get(reference.package, {})
344            return package.get(reference.name)  # type: ignore
345        return env.globals.get(reference.name)  # type: ignore
346
347    def build_environment(self, **kwargs: t.Any) -> Environment:
348        """Builds a new Jinja environment based on this registry."""
349
350        context: t.Dict[str, t.Any] = {}
351
352        root_macros = {
353            name: self._MacroWrapper(name, None, self, context)
354            for name, macro in self.root_macros.items()
355        }
356
357        package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict)
358        for package_name, macros in self.packages.items():
359            for macro_name, macro in macros.items():
360                macro_wrapper = self._MacroWrapper(macro_name, package_name, self, context)
361                package_macros[package_name][macro_name] = macro_wrapper
362                if macro.is_top_level and macro_name not in root_macros:
363                    root_macros[macro_name] = macro_wrapper
364
365        top_level_packages = self.top_level_packages.copy()
366
367        if self.root_package_name is not None:
368            package_macros[self.root_package_name].update(root_macros)
369            top_level_packages.append(self.root_package_name)
370
371        env = environment()
372
373        builtin_globals = self._create_builtin_globals(kwargs)
374        for top_level_package_name in top_level_packages:
375            # Make sure that the top-level package doesn't fully override the same builtin package.
376            package_macros[top_level_package_name] = AttributeDict(
377                {
378                    **(builtin_globals.pop(top_level_package_name, None) or {}),
379                    **(package_macros.get(top_level_package_name) or {}),
380                }
381            )
382            root_macros.update(package_macros[top_level_package_name])
383
384        context.update(builtin_globals)
385        context.update(root_macros)
386        context.update(package_macros)
387        context["render"] = lambda input: env.from_string(input).render()
388
389        env.globals.update(context)
390        env.filters.update(self._environment.filters)
391        return env
392
393    def trim(
394        self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None
395    ) -> JinjaMacroRegistry:
396        """Trims the registry by keeping only macros with given references and their transitive dependencies.
397
398        Args:
399            dependencies: References to macros that should be kept.
400            package: The name of the package in the context of which the trimming should be performed.
401
402        Returns:
403            A new trimmed registry.
404        """
405        dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set)
406        for dep in dependencies:
407            dependencies_by_package[dep.package or package].add(dep.name)
408
409        top_level_packages = self.top_level_packages.copy()
410        if package is not None:
411            top_level_packages.append(package)
412
413        result = JinjaMacroRegistry(
414            global_objs=self.global_objs.copy(),
415            create_builtins_module=self.create_builtins_module,
416            root_package_name=self.root_package_name,
417            top_level_packages=top_level_packages,
418        )
419        for package, names in dependencies_by_package.items():
420            result = result.merge(self._trim_macros(names, package))
421
422        result._trimmed = True
423
424        return result
425
426    def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry:
427        """Returns a copy of the registry which contains macros from both this and `other` instances.
428
429        Args:
430            other: The other registry instance.
431
432        Returns:
433            A new merged registry.
434        """
435
436        root_macros = {
437            **self.root_macros,
438            **other.root_macros,
439        }
440
441        packages = {}
442        for package in {*self.packages, *other.packages}:
443            packages[package] = {
444                **self.packages.get(package, {}),
445                **other.packages.get(package, {}),
446            }
447
448        global_objs = {
449            **self.global_objs,
450            **other.global_objs,
451        }
452
453        return JinjaMacroRegistry(
454            packages=packages,
455            root_macros=root_macros,
456            global_objs=global_objs,
457            create_builtins_module=self.create_builtins_module or other.create_builtins_module,
458            root_package_name=self.root_package_name or other.root_package_name,
459            top_level_packages=[*self.top_level_packages, *other.top_level_packages],
460        )
461
462    def to_expressions(self) -> t.List[Expression]:
463        output: t.List[Expression] = []
464
465        filtered_objs = {
466            k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars")
467        }
468        if filtered_objs:
469            output.append(
470                d.PythonCode(
471                    expressions=[
472                        f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}"
473                        for k, v in sort_dict_recursive(filtered_objs).items()
474                    ]
475                )
476            )
477
478        for macro_name, macro_info in sorted(self.root_macros.items()):
479            output.append(d.jinja_statement(macro_info.definition))
480
481        for _, package in sorted(self.packages.items()):
482            for macro_name, macro_info in sorted(package.items()):
483                output.append(d.jinja_statement(macro_info.definition))
484
485        return output
486
487    @property
488    def data_hash_values(self) -> t.List[str]:
489        data = []
490
491        for macro_name, macro in sorted(self.root_macros.items()):
492            data.append(macro_name)
493            data.append(macro.definition)
494
495        for _, package in sorted(self.packages.items()):
496            for macro_name, macro in sorted(package.items()):
497                data.append(macro_name)
498                data.append(macro.definition)
499
500        trimmed_global_objs = {
501            k: self.global_objs[k] for k in ("refs", "sources", "vars") if k in self.global_objs
502        }
503        data.append(json.dumps(trimmed_global_objs, sort_keys=True))
504
505        return data
506
507    def __deepcopy__(self, memo: t.Optional[t.Dict[int, t.Any]] = None) -> JinjaMacroRegistry:
508        return JinjaMacroRegistry.parse_obj(self.dict())
509
510    def _parse_macro(self, name: str, package: t.Optional[str]) -> Template:
511        cache_key = (package, name)
512        if cache_key not in self._parser_cache:
513            macro = self._get_macro(name, package)
514
515            definition: nodes.Template = self._environment.parse(macro.definition)
516            if _is_private_macro(name):
517                # A workaround to expose private jinja macros.
518                definition = self._to_non_private_macro_def(name, definition)
519
520            self._parser_cache[cache_key] = self._environment.from_string(definition)
521        return self._parser_cache[cache_key]
522
523    @property
524    def _environment(self) -> Environment:
525        if self.__environment is None:
526            self.__environment = environment()
527            self.__environment.filters.update(self._create_builtin_filters())
528        return self.__environment
529
530    def _trim_macros(
531        self,
532        names: t.Set[str],
533        package: t.Optional[str] = None,
534        visited: t.Optional[t.Dict[t.Optional[str], t.Set[str]]] = None,
535    ) -> JinjaMacroRegistry:
536        if visited is None:
537            visited = defaultdict(set)
538
539        macros = self.packages.get(package, {}) if package is not None else self.root_macros
540        trimmed_macros = {}
541
542        dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set)
543
544        for name in names:
545            if name in macros and name not in visited[package]:
546                macro = macros[name]
547                trimmed_macros[name] = macro
548                for dependency in macro.depends_on:
549                    dependencies[dependency.package or package].add(dependency.name)
550                visited[package].add(name)
551
552        if package is not None:
553            result = JinjaMacroRegistry(packages={package: trimmed_macros})
554        else:
555            result = JinjaMacroRegistry(root_macros=trimmed_macros)
556
557        for upstream_package, upstream_names in dependencies.items():
558            result = result.merge(
559                self._trim_macros(upstream_names, upstream_package, visited=visited)
560            )
561
562        return result
563
564    def _macro_exists(self, name: str, package: t.Optional[str]) -> bool:
565        return (
566            name in self.packages.get(package, {})
567            if package is not None
568            else name in self.root_macros
569        )
570
571    def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo:
572        return self.packages[package][name] if package is not None else self.root_macros[name]
573
574    def _to_non_private_macro_def(self, name: str, template: nodes.Template) -> nodes.Template:
575        for node in template.find_all((nodes.Macro, nodes.Call)):
576            if isinstance(node, nodes.Macro):
577                node.name = _non_private_name(name)
578            elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name):
579                node.node.name = _non_private_name(name)
580
581        return template
582
583    def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
584        """Creates Jinja builtin globals using a factory function defined in the provided module."""
585        engine_adapter = global_vars.pop("engine_adapter", None)
586        global_vars = {**self.global_objs, **global_vars}
587        if self.create_builtins_module is not None:
588            module = importlib.import_module(self.create_builtins_module)
589            if hasattr(module, "create_builtin_globals"):
590                return module.create_builtin_globals(self, global_vars, engine_adapter)
591        return global_vars
592
593    def _create_builtin_filters(self) -> t.Dict[str, t.Any]:
594        """Creates Jinja builtin filters using a factory function defined in the provided module."""
595        if self.create_builtins_module is not None:
596            module = importlib.import_module(self.create_builtins_module)
597            if hasattr(module, "create_builtin_filters"):
598                return module.create_builtin_filters()
599        return {}
600
601    class _MacroWrapper:
602        def __init__(
603            self,
604            name: str,
605            package: t.Optional[str],
606            registry: JinjaMacroRegistry,
607            context: t.Dict[str, t.Any],
608        ):
609            self.name = name
610            self.package = package
611            self.context = context
612            self.registry = registry
613
614        def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
615            context = self.context.copy()
616            if self.package is not None and self.package in context:
617                context.update(context[self.package])
618
619            template = self.registry._parse_macro(self.name, self.package)
620            macro_callable = getattr(
621                template.make_module(vars=context), _non_private_name(self.name)
622            )
623            try:
624                return macro_callable(*args, **kwargs)
625            except MacroReturnVal as ret:
626                return ret.value

Registry for Jinja macros.

Arguments:
  • packages: The mapping from package name to a collection of macro definitions.
  • root_macros: The collection of top-level macro definitions.
  • global_objs: The global objects.
  • create_builtins_module: The name of a module which defines the create_builtins factory function that will be used to construct builtin variables and functions.
  • root_package_name: The name of the root package. If specified root macros will be available as both root_package_name.macro_name and macro_name.
  • top_level_packages: The list of top-level packages. Macros in this packages will be available as both package_name.macro_name and macro_name.
packages: Dict[str, Dict[str, MacroInfo]]
root_macros: Dict[str, MacroInfo]
global_objs: Dict[str, Union[str, int, float, bool, sqlmesh.utils.AttributeDict]]
create_builtins_module: Optional[str]
root_package_name: Optional[str]
top_level_packages: List[str]
trimmed: bool
302    @property
303    def trimmed(self) -> bool:
304        return self._trimmed
def add_macros( self, macros: Dict[str, MacroInfo], package: Optional[str] = None) -> None:
306    def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None:
307        """Adds macros to the target package.
308
309        Args:
310            macros: Macros that should be added.
311            package: The name of the package the given macros belong to. If not specified, the provided
312            macros will be added to the root namespace.
313        """
314
315        if package is not None:
316            package_macros = self.packages.get(package, {})
317            package_macros.update(macros)
318            self.packages[package] = package_macros
319        else:
320            self.root_macros.update(macros)

Adds macros to the target package.

Arguments:
  • macros: Macros that should be added.
  • package: The name of the package the given macros belong to. If not specified, the provided
  • macros will be added to the root namespace.
def add_globals( self, globals: Dict[str, Union[str, int, float, bool, sqlmesh.utils.AttributeDict]]) -> None:
322    def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None:
323        """Adds global objects to the registry.
324
325        Args:
326            globals: The global objects that should be added.
327        """
328        # Keep the registry lightweight when the graph is not needed
329        if not "graph" in self.packages:
330            globals.pop("flat_graph", None)
331        self.global_objs.update(**self._validate_global_objs(globals))

Adds global objects to the registry.

Arguments:
  • globals: The global objects that should be added.
def build_macro( self, reference: MacroReference, **kwargs: Any) -> Optional[Callable]:
333    def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]:
334        """Builds a Python callable for a macro with the given reference.
335
336        Args:
337            reference: The macro reference.
338        Returns:
339            The macro as a Python callable or None if not found.
340        """
341        env: Environment = self.build_environment(**kwargs)
342        if reference.package is not None:
343            package = env.globals.get(reference.package, {})
344            return package.get(reference.name)  # type: ignore
345        return env.globals.get(reference.name)  # type: ignore

Builds a Python callable for a macro with the given reference.

Arguments:
  • reference: The macro reference.
Returns:

The macro as a Python callable or None if not found.

def build_environment(self, **kwargs: Any) -> jinja2.environment.Environment:
347    def build_environment(self, **kwargs: t.Any) -> Environment:
348        """Builds a new Jinja environment based on this registry."""
349
350        context: t.Dict[str, t.Any] = {}
351
352        root_macros = {
353            name: self._MacroWrapper(name, None, self, context)
354            for name, macro in self.root_macros.items()
355        }
356
357        package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict)
358        for package_name, macros in self.packages.items():
359            for macro_name, macro in macros.items():
360                macro_wrapper = self._MacroWrapper(macro_name, package_name, self, context)
361                package_macros[package_name][macro_name] = macro_wrapper
362                if macro.is_top_level and macro_name not in root_macros:
363                    root_macros[macro_name] = macro_wrapper
364
365        top_level_packages = self.top_level_packages.copy()
366
367        if self.root_package_name is not None:
368            package_macros[self.root_package_name].update(root_macros)
369            top_level_packages.append(self.root_package_name)
370
371        env = environment()
372
373        builtin_globals = self._create_builtin_globals(kwargs)
374        for top_level_package_name in top_level_packages:
375            # Make sure that the top-level package doesn't fully override the same builtin package.
376            package_macros[top_level_package_name] = AttributeDict(
377                {
378                    **(builtin_globals.pop(top_level_package_name, None) or {}),
379                    **(package_macros.get(top_level_package_name) or {}),
380                }
381            )
382            root_macros.update(package_macros[top_level_package_name])
383
384        context.update(builtin_globals)
385        context.update(root_macros)
386        context.update(package_macros)
387        context["render"] = lambda input: env.from_string(input).render()
388
389        env.globals.update(context)
390        env.filters.update(self._environment.filters)
391        return env

Builds a new Jinja environment based on this registry.

def trim( self, dependencies: Iterable[MacroReference], package: Optional[str] = None) -> JinjaMacroRegistry:
393    def trim(
394        self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None
395    ) -> JinjaMacroRegistry:
396        """Trims the registry by keeping only macros with given references and their transitive dependencies.
397
398        Args:
399            dependencies: References to macros that should be kept.
400            package: The name of the package in the context of which the trimming should be performed.
401
402        Returns:
403            A new trimmed registry.
404        """
405        dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set)
406        for dep in dependencies:
407            dependencies_by_package[dep.package or package].add(dep.name)
408
409        top_level_packages = self.top_level_packages.copy()
410        if package is not None:
411            top_level_packages.append(package)
412
413        result = JinjaMacroRegistry(
414            global_objs=self.global_objs.copy(),
415            create_builtins_module=self.create_builtins_module,
416            root_package_name=self.root_package_name,
417            top_level_packages=top_level_packages,
418        )
419        for package, names in dependencies_by_package.items():
420            result = result.merge(self._trim_macros(names, package))
421
422        result._trimmed = True
423
424        return result

Trims the registry by keeping only macros with given references and their transitive dependencies.

Arguments:
  • dependencies: References to macros that should be kept.
  • package: The name of the package in the context of which the trimming should be performed.
Returns:

A new trimmed registry.

def merge( self, other: JinjaMacroRegistry) -> JinjaMacroRegistry:
426    def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry:
427        """Returns a copy of the registry which contains macros from both this and `other` instances.
428
429        Args:
430            other: The other registry instance.
431
432        Returns:
433            A new merged registry.
434        """
435
436        root_macros = {
437            **self.root_macros,
438            **other.root_macros,
439        }
440
441        packages = {}
442        for package in {*self.packages, *other.packages}:
443            packages[package] = {
444                **self.packages.get(package, {}),
445                **other.packages.get(package, {}),
446            }
447
448        global_objs = {
449            **self.global_objs,
450            **other.global_objs,
451        }
452
453        return JinjaMacroRegistry(
454            packages=packages,
455            root_macros=root_macros,
456            global_objs=global_objs,
457            create_builtins_module=self.create_builtins_module or other.create_builtins_module,
458            root_package_name=self.root_package_name or other.root_package_name,
459            top_level_packages=[*self.top_level_packages, *other.top_level_packages],
460        )

Returns a copy of the registry which contains macros from both this and other instances.

Arguments:
  • other: The other registry instance.
Returns:

A new merged registry.

def to_expressions(self) -> List[sqlglot.expressions.core.Expression]:
462    def to_expressions(self) -> t.List[Expression]:
463        output: t.List[Expression] = []
464
465        filtered_objs = {
466            k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars")
467        }
468        if filtered_objs:
469            output.append(
470                d.PythonCode(
471                    expressions=[
472                        f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}"
473                        for k, v in sort_dict_recursive(filtered_objs).items()
474                    ]
475                )
476            )
477
478        for macro_name, macro_info in sorted(self.root_macros.items()):
479            output.append(d.jinja_statement(macro_info.definition))
480
481        for _, package in sorted(self.packages.items()):
482            for macro_name, macro_info in sorted(package.items()):
483                output.append(d.jinja_statement(macro_info.definition))
484
485        return output
data_hash_values: List[str]
487    @property
488    def data_hash_values(self) -> t.List[str]:
489        data = []
490
491        for macro_name, macro in sorted(self.root_macros.items()):
492            data.append(macro_name)
493            data.append(macro.definition)
494
495        for _, package in sorted(self.packages.items()):
496            for macro_name, macro in sorted(package.items()):
497                data.append(macro_name)
498                data.append(macro.definition)
499
500        trimmed_global_objs = {
501            k: self.global_objs[k] for k in ("refs", "sources", "vars") if k in self.global_objs
502        }
503        data.append(json.dumps(trimmed_global_objs, sort_keys=True))
504
505        return data
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
358def init_private_attributes(self: BaseModel, context: Any, /) -> None:
359    """This function is meant to behave like a BaseModel method to initialise private attributes.
360
361    It takes context as an argument since that's what pydantic-core passes when calling it.
362
363    Args:
364        self: The BaseModel instance.
365        context: The context.
366    """
367    if getattr(self, '__pydantic_private__', None) is None:
368        pydantic_private = {}
369        for name, private_attr in self.__private_attributes__.items():
370            default = private_attr.get_default()
371            if default is not PydanticUndefined:
372                pydantic_private[name] = default
373        object_setattr(self, '__pydantic_private__', pydantic_private)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Arguments:
  • self: The BaseModel instance.
  • context: The context.
Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
JINJA_REGEX = re.compile('({{|{%)')
def has_jinja(value: str) -> bool:
640def has_jinja(value: str) -> bool:
641    return JINJA_REGEX.search(value) is not None
def jinja_call_arg_name(node: jinja2.nodes.Node) -> str:
644def jinja_call_arg_name(node: nodes.Node) -> str:
645    if isinstance(node, nodes.Const):
646        return node.value
647    return ""
def create_var(variables: Dict[str, Any]) -> Callable:
650def create_var(variables: t.Dict[str, t.Any]) -> t.Callable:
651    def _var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
652        value = variables.get(var_name.lower(), default)
653        if isinstance(value, SqlValue):
654            return value.sql
655        return value
656
657    return _var
def create_builtin_globals( jinja_macros: JinjaMacroRegistry, global_vars: Dict[str, Any], *args: Any, **kwargs: Any) -> Dict[str, Any]:
660def create_builtin_globals(
661    jinja_macros: JinjaMacroRegistry, global_vars: t.Dict[str, t.Any], *args: t.Any, **kwargs: t.Any
662) -> t.Dict[str, t.Any]:
663    global_vars.pop(c.GATEWAY, None)
664    variables = global_vars.pop(c.SQLMESH_VARS, None) or {}
665    blueprint_variables = global_vars.pop(c.SQLMESH_BLUEPRINT_VARS, None) or {}
666    return {
667        **global_vars,
668        c.VAR: create_var(variables),
669        c.GATEWAY: lambda: variables.get(c.GATEWAY, None),
670        c.BLUEPRINT_VAR: create_var(blueprint_variables),
671    }
def make_jinja_registry( jinja_macros: JinjaMacroRegistry, package_name: str, jinja_references: Set[MacroReference]) -> JinjaMacroRegistry:
674def make_jinja_registry(
675    jinja_macros: JinjaMacroRegistry, package_name: str, jinja_references: t.Set[MacroReference]
676) -> JinjaMacroRegistry:
677    """
678    Creates a Jinja macro registry for a specific package.
679
680    This function takes an existing Jinja macro registry and returns a new
681    registry that includes only the macros associated with the specified
682    package and trims the registry to include only the macros referenced
683    in the provided set of macro references.
684
685    Args:
686        jinja_macros: The original Jinja macro registry containing all macros.
687        package_name: The name of the package for which to create the registry.
688        jinja_references: A set of macro references to retain in the new registry.
689
690    Returns:
691        A new JinjaMacroRegistry containing only the macros for the specified
692        package and the referenced macros.
693    """
694
695    jinja_registry = jinja_macros.copy()
696    jinja_registry.root_macros = jinja_registry.packages.get(package_name) or {}
697    jinja_registry = jinja_registry.trim(jinja_references)
698
699    return jinja_registry

Creates a Jinja macro registry for a specific package.

This function takes an existing Jinja macro registry and returns a new registry that includes only the macros associated with the specified package and trims the registry to include only the macros referenced in the provided set of macro references.

Arguments:
  • jinja_macros: The original Jinja macro registry containing all macros.
  • package_name: The name of the package for which to create the registry.
  • jinja_references: A set of macro references to retain in the new registry.
Returns:

A new JinjaMacroRegistry containing only the macros for the specified package and the referenced macros.

def extract_error_details(ex: Exception) -> str:
702def extract_error_details(ex: Exception) -> str:
703    """Extracts a readable message from a Jinja2 error, to include missing name and macro."""
704
705    error_details = ""
706    if isinstance(ex, UndefinedError):
707        if match := re.search(r"'(\w+)'", str(ex)):
708            error_details += f"\nUndefined macro/variable: '{match.group(1)}'"
709        try:
710            _, _, exc_traceback = exc_info()
711            for frame, _ in walk_tb(exc_traceback):
712                if frame.f_code.co_name == "_invoke":
713                    macro = frame.f_locals.get("self")
714                    if isinstance(macro, Macro):
715                        error_details += f" in macro: '{macro.name}'\n"
716                        break
717        except:
718            # to fall back to the generic error message if frame analysis fails
719            pass
720    return error_details or str(ex)

Extracts a readable message from a Jinja2 error, to include missing name and macro.