sqlmesh.utils.jinja
1from __future__ import annotations 2 3import importlib 4import json 5import re 6import typing as t 7import zlib 8from collections import defaultdict 9from enum import Enum 10from sys import exc_info 11from traceback import walk_tb 12 13from jinja2 import Environment, Template, nodes, UndefinedError 14from jinja2.runtime import Macro 15from sqlglot import Dialect, Parser, TokenType 16from sqlglot.expressions import Expression 17 18from sqlmesh.core import constants as c 19from sqlmesh.core import dialect as d 20from sqlmesh.utils import AttributeDict 21from sqlmesh.utils.pydantic import PRIVATE_FIELDS, PydanticModel, field_serializer, field_validator 22from sqlmesh.utils.metaprogramming import SqlValue 23 24 25if t.TYPE_CHECKING: 26 CallNames = t.Tuple[t.Tuple[str, ...], t.Union[nodes.Call, nodes.Getattr]] 27 28SQLMESH_JINJA_PACKAGE = "sqlmesh.utils.jinja" 29 30 31def environment(**kwargs: t.Any) -> Environment: 32 extensions = kwargs.pop("extensions", []) 33 extensions.append("jinja2.ext.do") 34 extensions.append("jinja2.ext.loopcontrols") 35 return Environment(extensions=extensions, **kwargs) 36 37 38ENVIRONMENT = environment() 39 40 41class MacroReference(PydanticModel, frozen=True): 42 package: t.Optional[str] = None 43 name: str 44 45 @property 46 def reference(self) -> str: 47 if self.package is None: 48 return self.name 49 return ".".join((self.package, self.name)) 50 51 def __str__(self) -> str: 52 return self.reference 53 54 55class MacroInfo(PydanticModel): 56 """Class to hold macro and its calls""" 57 58 definition: str 59 depends_on: t.List[MacroReference] 60 is_top_level: bool = False 61 62 63class MacroReturnVal(Exception): 64 def __init__(self, val: t.Any): 65 self.value = val 66 67 68class MacroExtractor(Parser): 69 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 70 """Extract a dictionary of macro definitions from a jinja string. 71 72 Args: 73 jinja: The jinja string to extract from. 74 dialect: The dialect of SQL. 75 76 Returns: 77 A dictionary of macro name to macro definition. 78 """ 79 self.reset() 80 self.sql = jinja 81 self._tokens = Dialect.get_or_raise(dialect).tokenize(jinja) 82 83 # guard for older sqlglot versions (before 30.0.3) 84 if hasattr(self, "_tokens_size"): 85 # keep the cached length in sync 86 self._tokens_size = len(self._tokens) 87 self._index = -1 88 self._advance() 89 90 macros: t.Dict[str, MacroInfo] = {} 91 92 while self._curr: 93 if self._curr.token_type == TokenType.BLOCK_START: 94 macro_start = self._curr 95 elif self._tag == "MACRO" and self._next: 96 name = self._next.text 97 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 98 self._advance() 99 100 while self._curr and self._tag != "ENDMACRO": 101 self._advance() 102 103 macro_str = self._find_sql(macro_start, self._next) 104 macros[name] = MacroInfo( 105 definition=macro_str, 106 depends_on=list(extract_macro_references_and_variables(macro_str)[0]), 107 ) 108 109 self._advance() 110 111 return macros 112 113 def _advance(self, times: int = 1) -> None: 114 super()._advance(times) 115 self._tag = ( 116 self._curr.text.upper() 117 if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START 118 else "" 119 ) 120 121 122def call_name(node: nodes.Expr) -> t.Tuple[str, ...]: 123 if isinstance(node, nodes.Name): 124 return (node.name,) 125 if isinstance(node, nodes.Const): 126 return (f"'{node.value}'",) 127 if isinstance(node, nodes.Getattr): 128 return call_name(node.node) + (node.attr,) 129 if isinstance(node, (nodes.Getitem, nodes.Call)): 130 return call_name(node.node) 131 return () 132 133 134def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) -> str: 135 return ENVIRONMENT.from_string(query).render(methods or {}) 136 137 138def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[CallNames]: 139 vars_in_scope = vars_in_scope.copy() 140 for child_node in node.iter_child_nodes(): 141 if "target" in child_node.fields: 142 # For nodes with assignment targets (Assign, AssignBlock, For, Import), 143 # the target name could shadow a reference in the right hand side. 144 # So we need to process the RHS before adding the target to scope. 145 # For example: {% set model = model.path %} should track model.path. 146 yield from find_call_names(child_node, vars_in_scope) 147 148 target = getattr(child_node, "target") 149 if isinstance(target, nodes.Name): 150 vars_in_scope.add(target.name) 151 elif isinstance(target, nodes.Tuple): 152 for item in target.items: 153 if isinstance(item, nodes.Name): 154 vars_in_scope.add(item.name) 155 elif isinstance(child_node, nodes.Macro): 156 for arg in child_node.args: 157 vars_in_scope.add(arg.name) 158 elif isinstance(child_node, nodes.Call) or ( 159 isinstance(child_node, nodes.Getattr) and not isinstance(child_node.node, nodes.Getattr) 160 ): 161 name = call_name(child_node) 162 if name[0][0] != "'" and name[0] not in vars_in_scope: 163 yield (name, child_node) 164 165 if "target" not in child_node.fields: 166 yield from find_call_names(child_node, vars_in_scope) 167 168 169def extract_call_names( 170 jinja_str: str, cache: t.Optional[t.Dict[str, t.Tuple[t.List[CallNames], bool]]] = None 171) -> t.List[CallNames]: 172 def parse() -> t.List[CallNames]: 173 return list(find_call_names(ENVIRONMENT.parse(jinja_str), set())) 174 175 if cache is not None: 176 key = str(zlib.crc32(jinja_str.encode("utf-8"))) 177 if key in cache: 178 names = cache[key][0] 179 else: 180 names = parse() 181 cache[key] = (names, True) 182 return names 183 return parse() 184 185 186def is_variable_node(n: nodes.Node) -> bool: 187 return ( 188 isinstance(n, nodes.Call) 189 and isinstance(n.node, nodes.Name) 190 and n.node.name in (c.VAR, c.BLUEPRINT_VAR) 191 ) 192 193 194def extract_macro_references_and_variables( 195 *jinja_strs: str, 196) -> t.Tuple[t.Set[MacroReference], t.Set[str]]: 197 macro_references = set() 198 variables = set() 199 for jinja_str in jinja_strs: 200 for call_name, node in extract_call_names(jinja_str): 201 if call_name[0] in (c.VAR, c.BLUEPRINT_VAR): 202 if not is_variable_node(node): 203 # Find the variable node which could be nested 204 for n in node.find_all(nodes.Call): 205 if is_variable_node(n): 206 node = n 207 break 208 else: 209 raise ValueError(f"Could not find variable name in {jinja_str}") 210 node = t.cast(nodes.Call, node) 211 args = [jinja_call_arg_name(arg) for arg in node.args] 212 if args and args[0]: 213 variables.add(args[0].lower()) 214 elif call_name[0] == c.GATEWAY: 215 variables.add(c.GATEWAY) 216 elif len(call_name) == 1: 217 macro_references.add(MacroReference(name=call_name[0])) 218 elif len(call_name) == 2: 219 macro_references.add(MacroReference(package=call_name[0], name=call_name[1])) 220 return macro_references, variables 221 222 223def sort_dict_recursive( 224 item: t.Dict[str, t.Any], 225) -> t.Dict[str, t.Any]: 226 sorted_dict: t.Dict[str, t.Any] = {} 227 for k, v in sorted(item.items()): 228 if isinstance(v, list): 229 sorted_dict[k] = sorted(v) 230 elif isinstance(v, dict): 231 sorted_dict[k] = sort_dict_recursive(v) 232 else: 233 sorted_dict[k] = v 234 return sorted_dict 235 236 237JinjaGlobalAttribute = t.Union[str, int, float, bool, AttributeDict] 238 239 240class JinjaMacroRegistry(PydanticModel): 241 """Registry for Jinja macros. 242 243 Args: 244 packages: The mapping from package name to a collection of macro definitions. 245 root_macros: The collection of top-level macro definitions. 246 global_objs: The global objects. 247 create_builtins_module: The name of a module which defines the `create_builtins` factory 248 function that will be used to construct builtin variables and functions. 249 root_package_name: The name of the root package. If specified root macros will be available 250 as both `root_package_name.macro_name` and `macro_name`. 251 top_level_packages: The list of top-level packages. Macros in this packages will be available 252 as both `package_name.macro_name` and `macro_name`. 253 """ 254 255 packages: t.Dict[str, t.Dict[str, MacroInfo]] = {} 256 root_macros: t.Dict[str, MacroInfo] = {} 257 global_objs: t.Dict[str, JinjaGlobalAttribute] = {} 258 create_builtins_module: t.Optional[str] = SQLMESH_JINJA_PACKAGE 259 root_package_name: t.Optional[str] = None 260 top_level_packages: t.List[str] = [] 261 262 _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {} 263 _trimmed: bool = False 264 __environment: t.Optional[Environment] = None 265 266 def __getstate__(self) -> t.Dict[t.Any, t.Any]: 267 state = super().__getstate__() 268 private = state[PRIVATE_FIELDS] 269 private["_parser_cache"] = {} 270 private["_JinjaMacroRegistry__environment"] = None 271 return state 272 273 @field_validator("global_objs", mode="before") 274 @classmethod 275 def _validate_global_objs(cls, value: t.Any) -> t.Any: 276 def _normalize(val: t.Any) -> t.Any: 277 if isinstance(val, dict): 278 return AttributeDict({k: _normalize(v) for k, v in val.items()}) 279 if isinstance(val, list): 280 return [_normalize(v) for v in val] 281 if isinstance(val, set): 282 return [_normalize(v) for v in sorted(val)] 283 if isinstance(val, Enum): 284 return val.value 285 return val 286 287 return _normalize(value) 288 289 @field_serializer("global_objs") 290 def _serialize_attribute_dict( 291 self, value: t.Dict[str, JinjaGlobalAttribute] 292 ) -> t.Dict[str, t.Any]: 293 # NOTE: This is called only when used with Pydantic V2. 294 def _convert( 295 val: t.Union[t.Dict[str, JinjaGlobalAttribute], t.Dict[str, t.Any]], 296 ) -> t.Dict[str, t.Any]: 297 return {k: _convert(v) if isinstance(v, AttributeDict) else v for k, v in val.items()} 298 299 return _convert(value) 300 301 @property 302 def trimmed(self) -> bool: 303 return self._trimmed 304 305 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 306 """Adds macros to the target package. 307 308 Args: 309 macros: Macros that should be added. 310 package: The name of the package the given macros belong to. If not specified, the provided 311 macros will be added to the root namespace. 312 """ 313 314 if package is not None: 315 package_macros = self.packages.get(package, {}) 316 package_macros.update(macros) 317 self.packages[package] = package_macros 318 else: 319 self.root_macros.update(macros) 320 321 def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None: 322 """Adds global objects to the registry. 323 324 Args: 325 globals: The global objects that should be added. 326 """ 327 # Keep the registry lightweight when the graph is not needed 328 if not "graph" in self.packages: 329 globals.pop("flat_graph", None) 330 self.global_objs.update(**self._validate_global_objs(globals)) 331 332 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 333 """Builds a Python callable for a macro with the given reference. 334 335 Args: 336 reference: The macro reference. 337 Returns: 338 The macro as a Python callable or None if not found. 339 """ 340 env: Environment = self.build_environment(**kwargs) 341 if reference.package is not None: 342 package = env.globals.get(reference.package, {}) 343 return package.get(reference.name) # type: ignore 344 return env.globals.get(reference.name) # type: ignore 345 346 def build_environment(self, **kwargs: t.Any) -> Environment: 347 """Builds a new Jinja environment based on this registry.""" 348 349 context: t.Dict[str, t.Any] = {} 350 351 root_macros = { 352 name: self._MacroWrapper(name, None, self, context) 353 for name, macro in self.root_macros.items() 354 } 355 356 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 357 for package_name, macros in self.packages.items(): 358 for macro_name, macro in macros.items(): 359 macro_wrapper = self._MacroWrapper(macro_name, package_name, self, context) 360 package_macros[package_name][macro_name] = macro_wrapper 361 if macro.is_top_level and macro_name not in root_macros: 362 root_macros[macro_name] = macro_wrapper 363 364 top_level_packages = self.top_level_packages.copy() 365 366 if self.root_package_name is not None: 367 package_macros[self.root_package_name].update(root_macros) 368 top_level_packages.append(self.root_package_name) 369 370 env = environment() 371 372 builtin_globals = self._create_builtin_globals(kwargs) 373 for top_level_package_name in top_level_packages: 374 # Make sure that the top-level package doesn't fully override the same builtin package. 375 package_macros[top_level_package_name] = AttributeDict( 376 { 377 **(builtin_globals.pop(top_level_package_name, None) or {}), 378 **(package_macros.get(top_level_package_name) or {}), 379 } 380 ) 381 root_macros.update(package_macros[top_level_package_name]) 382 383 context.update(builtin_globals) 384 context.update(root_macros) 385 context.update(package_macros) 386 context["render"] = lambda input: env.from_string(input).render() 387 388 env.globals.update(context) 389 env.filters.update(self._environment.filters) 390 return env 391 392 def trim( 393 self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None 394 ) -> JinjaMacroRegistry: 395 """Trims the registry by keeping only macros with given references and their transitive dependencies. 396 397 Args: 398 dependencies: References to macros that should be kept. 399 package: The name of the package in the context of which the trimming should be performed. 400 401 Returns: 402 A new trimmed registry. 403 """ 404 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 405 for dep in dependencies: 406 dependencies_by_package[dep.package or package].add(dep.name) 407 408 top_level_packages = self.top_level_packages.copy() 409 if package is not None: 410 top_level_packages.append(package) 411 412 result = JinjaMacroRegistry( 413 global_objs=self.global_objs.copy(), 414 create_builtins_module=self.create_builtins_module, 415 root_package_name=self.root_package_name, 416 top_level_packages=top_level_packages, 417 ) 418 for package, names in dependencies_by_package.items(): 419 result = result.merge(self._trim_macros(names, package)) 420 421 result._trimmed = True 422 423 return result 424 425 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 426 """Returns a copy of the registry which contains macros from both this and `other` instances. 427 428 Args: 429 other: The other registry instance. 430 431 Returns: 432 A new merged registry. 433 """ 434 435 root_macros = { 436 **self.root_macros, 437 **other.root_macros, 438 } 439 440 packages = {} 441 for package in {*self.packages, *other.packages}: 442 packages[package] = { 443 **self.packages.get(package, {}), 444 **other.packages.get(package, {}), 445 } 446 447 global_objs = { 448 **self.global_objs, 449 **other.global_objs, 450 } 451 452 return JinjaMacroRegistry( 453 packages=packages, 454 root_macros=root_macros, 455 global_objs=global_objs, 456 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 457 root_package_name=self.root_package_name or other.root_package_name, 458 top_level_packages=[*self.top_level_packages, *other.top_level_packages], 459 ) 460 461 def to_expressions(self) -> t.List[Expression]: 462 output: t.List[Expression] = [] 463 464 filtered_objs = { 465 k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars") 466 } 467 if filtered_objs: 468 output.append( 469 d.PythonCode( 470 expressions=[ 471 f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" 472 for k, v in sort_dict_recursive(filtered_objs).items() 473 ] 474 ) 475 ) 476 477 for macro_name, macro_info in sorted(self.root_macros.items()): 478 output.append(d.jinja_statement(macro_info.definition)) 479 480 for _, package in sorted(self.packages.items()): 481 for macro_name, macro_info in sorted(package.items()): 482 output.append(d.jinja_statement(macro_info.definition)) 483 484 return output 485 486 @property 487 def data_hash_values(self) -> t.List[str]: 488 data = [] 489 490 for macro_name, macro in sorted(self.root_macros.items()): 491 data.append(macro_name) 492 data.append(macro.definition) 493 494 for _, package in sorted(self.packages.items()): 495 for macro_name, macro in sorted(package.items()): 496 data.append(macro_name) 497 data.append(macro.definition) 498 499 trimmed_global_objs = { 500 k: self.global_objs[k] for k in ("refs", "sources", "vars") if k in self.global_objs 501 } 502 data.append(json.dumps(trimmed_global_objs, sort_keys=True)) 503 504 return data 505 506 def __deepcopy__(self, memo: t.Optional[t.Dict[int, t.Any]] = None) -> JinjaMacroRegistry: 507 return JinjaMacroRegistry.parse_obj(self.dict()) 508 509 def _parse_macro(self, name: str, package: t.Optional[str]) -> Template: 510 cache_key = (package, name) 511 if cache_key not in self._parser_cache: 512 macro = self._get_macro(name, package) 513 514 definition: nodes.Template = self._environment.parse(macro.definition) 515 if _is_private_macro(name): 516 # A workaround to expose private jinja macros. 517 definition = self._to_non_private_macro_def(name, definition) 518 519 self._parser_cache[cache_key] = self._environment.from_string(definition) 520 return self._parser_cache[cache_key] 521 522 @property 523 def _environment(self) -> Environment: 524 if self.__environment is None: 525 self.__environment = environment() 526 self.__environment.filters.update(self._create_builtin_filters()) 527 return self.__environment 528 529 def _trim_macros( 530 self, 531 names: t.Set[str], 532 package: t.Optional[str] = None, 533 visited: t.Optional[t.Dict[t.Optional[str], t.Set[str]]] = None, 534 ) -> JinjaMacroRegistry: 535 if visited is None: 536 visited = defaultdict(set) 537 538 macros = self.packages.get(package, {}) if package is not None else self.root_macros 539 trimmed_macros = {} 540 541 dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 542 543 for name in names: 544 if name in macros and name not in visited[package]: 545 macro = macros[name] 546 trimmed_macros[name] = macro 547 for dependency in macro.depends_on: 548 dependencies[dependency.package or package].add(dependency.name) 549 visited[package].add(name) 550 551 if package is not None: 552 result = JinjaMacroRegistry(packages={package: trimmed_macros}) 553 else: 554 result = JinjaMacroRegistry(root_macros=trimmed_macros) 555 556 for upstream_package, upstream_names in dependencies.items(): 557 result = result.merge( 558 self._trim_macros(upstream_names, upstream_package, visited=visited) 559 ) 560 561 return result 562 563 def _macro_exists(self, name: str, package: t.Optional[str]) -> bool: 564 return ( 565 name in self.packages.get(package, {}) 566 if package is not None 567 else name in self.root_macros 568 ) 569 570 def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo: 571 return self.packages[package][name] if package is not None else self.root_macros[name] 572 573 def _to_non_private_macro_def(self, name: str, template: nodes.Template) -> nodes.Template: 574 for node in template.find_all((nodes.Macro, nodes.Call)): 575 if isinstance(node, nodes.Macro): 576 node.name = _non_private_name(name) 577 elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name): 578 node.node.name = _non_private_name(name) 579 580 return template 581 582 def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 583 """Creates Jinja builtin globals using a factory function defined in the provided module.""" 584 engine_adapter = global_vars.pop("engine_adapter", None) 585 global_vars = {**self.global_objs, **global_vars} 586 if self.create_builtins_module is not None: 587 module = importlib.import_module(self.create_builtins_module) 588 if hasattr(module, "create_builtin_globals"): 589 return module.create_builtin_globals(self, global_vars, engine_adapter) 590 return global_vars 591 592 def _create_builtin_filters(self) -> t.Dict[str, t.Any]: 593 """Creates Jinja builtin filters using a factory function defined in the provided module.""" 594 if self.create_builtins_module is not None: 595 module = importlib.import_module(self.create_builtins_module) 596 if hasattr(module, "create_builtin_filters"): 597 return module.create_builtin_filters() 598 return {} 599 600 class _MacroWrapper: 601 def __init__( 602 self, 603 name: str, 604 package: t.Optional[str], 605 registry: JinjaMacroRegistry, 606 context: t.Dict[str, t.Any], 607 ): 608 self.name = name 609 self.package = package 610 self.context = context 611 self.registry = registry 612 613 def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: 614 context = self.context.copy() 615 if self.package is not None and self.package in context: 616 context.update(context[self.package]) 617 618 template = self.registry._parse_macro(self.name, self.package) 619 macro_callable = getattr( 620 template.make_module(vars=context), _non_private_name(self.name) 621 ) 622 try: 623 return macro_callable(*args, **kwargs) 624 except MacroReturnVal as ret: 625 return ret.value 626 627 628def _is_private_macro(name: str) -> bool: 629 return name.startswith("_") 630 631 632def _non_private_name(name: str) -> str: 633 return name.lstrip("_") 634 635 636JINJA_REGEX = re.compile(r"({{|{%)") 637 638 639def has_jinja(value: str) -> bool: 640 return JINJA_REGEX.search(value) is not None 641 642 643def jinja_call_arg_name(node: nodes.Node) -> str: 644 if isinstance(node, nodes.Const): 645 return node.value 646 return "" 647 648 649def create_var(variables: t.Dict[str, t.Any]) -> t.Callable: 650 def _var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 651 value = variables.get(var_name.lower(), default) 652 if isinstance(value, SqlValue): 653 return value.sql 654 return value 655 656 return _var 657 658 659def create_builtin_globals( 660 jinja_macros: JinjaMacroRegistry, global_vars: t.Dict[str, t.Any], *args: t.Any, **kwargs: t.Any 661) -> t.Dict[str, t.Any]: 662 global_vars.pop(c.GATEWAY, None) 663 variables = global_vars.pop(c.SQLMESH_VARS, None) or {} 664 blueprint_variables = global_vars.pop(c.SQLMESH_BLUEPRINT_VARS, None) or {} 665 return { 666 **global_vars, 667 c.VAR: create_var(variables), 668 c.GATEWAY: lambda: variables.get(c.GATEWAY, None), 669 c.BLUEPRINT_VAR: create_var(blueprint_variables), 670 } 671 672 673def make_jinja_registry( 674 jinja_macros: JinjaMacroRegistry, package_name: str, jinja_references: t.Set[MacroReference] 675) -> JinjaMacroRegistry: 676 """ 677 Creates a Jinja macro registry for a specific package. 678 679 This function takes an existing Jinja macro registry and returns a new 680 registry that includes only the macros associated with the specified 681 package and trims the registry to include only the macros referenced 682 in the provided set of macro references. 683 684 Args: 685 jinja_macros: The original Jinja macro registry containing all macros. 686 package_name: The name of the package for which to create the registry. 687 jinja_references: A set of macro references to retain in the new registry. 688 689 Returns: 690 A new JinjaMacroRegistry containing only the macros for the specified 691 package and the referenced macros. 692 """ 693 694 jinja_registry = jinja_macros.copy() 695 jinja_registry.root_macros = jinja_registry.packages.get(package_name) or {} 696 jinja_registry = jinja_registry.trim(jinja_references) 697 698 return jinja_registry 699 700 701def extract_error_details(ex: Exception) -> str: 702 """Extracts a readable message from a Jinja2 error, to include missing name and macro.""" 703 704 error_details = "" 705 if isinstance(ex, UndefinedError): 706 if match := re.search(r"'(\w+)'", str(ex)): 707 error_details += f"\nUndefined macro/variable: '{match.group(1)}'" 708 try: 709 _, _, exc_traceback = exc_info() 710 for frame, _ in walk_tb(exc_traceback): 711 if frame.f_code.co_name == "_invoke": 712 macro = frame.f_locals.get("self") 713 if isinstance(macro, Macro): 714 error_details += f" in macro: '{macro.name}'\n" 715 break 716 except: 717 # to fall back to the generic error message if frame analysis fails 718 pass 719 return error_details or str(ex)
42class MacroReference(PydanticModel, frozen=True): 43 package: t.Optional[str] = None 44 name: str 45 46 @property 47 def reference(self) -> str: 48 if self.package is None: 49 return self.name 50 return ".".join((self.package, self.name)) 51 52 def __str__(self) -> str: 53 return self.reference
!!! abstract "Usage Documentation" Models
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__[Signature][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-coreSchemaSerializerused to dump instances of the model. - __pydantic_validator__: The
pydantic-coreSchemaValidatorused to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
56class MacroInfo(PydanticModel): 57 """Class to hold macro and its calls""" 58 59 definition: str 60 depends_on: t.List[MacroReference] 61 is_top_level: bool = False
Class to hold macro and its calls
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
Common base class for all non-exit exceptions.
Inherited Members
- builtins.BaseException
- with_traceback
- args
69class MacroExtractor(Parser): 70 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 71 """Extract a dictionary of macro definitions from a jinja string. 72 73 Args: 74 jinja: The jinja string to extract from. 75 dialect: The dialect of SQL. 76 77 Returns: 78 A dictionary of macro name to macro definition. 79 """ 80 self.reset() 81 self.sql = jinja 82 self._tokens = Dialect.get_or_raise(dialect).tokenize(jinja) 83 84 # guard for older sqlglot versions (before 30.0.3) 85 if hasattr(self, "_tokens_size"): 86 # keep the cached length in sync 87 self._tokens_size = len(self._tokens) 88 self._index = -1 89 self._advance() 90 91 macros: t.Dict[str, MacroInfo] = {} 92 93 while self._curr: 94 if self._curr.token_type == TokenType.BLOCK_START: 95 macro_start = self._curr 96 elif self._tag == "MACRO" and self._next: 97 name = self._next.text 98 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 99 self._advance() 100 101 while self._curr and self._tag != "ENDMACRO": 102 self._advance() 103 104 macro_str = self._find_sql(macro_start, self._next) 105 macros[name] = MacroInfo( 106 definition=macro_str, 107 depends_on=list(extract_macro_references_and_variables(macro_str)[0]), 108 ) 109 110 self._advance() 111 112 return macros 113 114 def _advance(self, times: int = 1) -> None: 115 super()._advance(times) 116 self._tag = ( 117 self._curr.text.upper() 118 if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START 119 else "" 120 )
Parser consumes a list of tokens produced by the Tokenizer and produces a parsed syntax tree.
Arguments:
- error_level: The desired error level. Default: ErrorLevel.IMMEDIATE
- error_message_context: The amount of context to capture from a query string when displaying the error message (in number of characters). Default: 100
- max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. Default: 3
70 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 71 """Extract a dictionary of macro definitions from a jinja string. 72 73 Args: 74 jinja: The jinja string to extract from. 75 dialect: The dialect of SQL. 76 77 Returns: 78 A dictionary of macro name to macro definition. 79 """ 80 self.reset() 81 self.sql = jinja 82 self._tokens = Dialect.get_or_raise(dialect).tokenize(jinja) 83 84 # guard for older sqlglot versions (before 30.0.3) 85 if hasattr(self, "_tokens_size"): 86 # keep the cached length in sync 87 self._tokens_size = len(self._tokens) 88 self._index = -1 89 self._advance() 90 91 macros: t.Dict[str, MacroInfo] = {} 92 93 while self._curr: 94 if self._curr.token_type == TokenType.BLOCK_START: 95 macro_start = self._curr 96 elif self._tag == "MACRO" and self._next: 97 name = self._next.text 98 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 99 self._advance() 100 101 while self._curr and self._tag != "ENDMACRO": 102 self._advance() 103 104 macro_str = self._find_sql(macro_start, self._next) 105 macros[name] = MacroInfo( 106 definition=macro_str, 107 depends_on=list(extract_macro_references_and_variables(macro_str)[0]), 108 ) 109 110 self._advance() 111 112 return macros
Extract a dictionary of macro definitions from a jinja string.
Arguments:
- jinja: The jinja string to extract from.
- dialect: The dialect of SQL.
Returns:
A dictionary of macro name to macro definition.
Inherited Members
- sqlglot.parser.Parser
- Parser
- FUNCTIONS
- NO_PAREN_FUNCTIONS
- STRUCT_TYPE_TOKENS
- NESTED_TYPE_TOKENS
- ENUM_TYPE_TOKENS
- AGGREGATE_TYPE_TOKENS
- TYPE_TOKENS
- SIGNED_TO_UNSIGNED_TYPE_TOKEN
- SUBQUERY_PREDICATES
- SUBQUERY_TOKENS
- RESERVED_TOKENS
- DB_CREATABLES
- CREATABLES
- TRIGGER_EVENTS
- ALTERABLES
- ID_VAR_TOKENS
- TABLE_ALIAS_TOKENS
- ALIAS_TOKENS
- COLON_PLACEHOLDER_TOKENS
- ARRAY_CONSTRUCTORS
- COMMENT_TABLE_ALIAS_TOKENS
- UPDATE_ALIAS_TOKENS
- TRIM_TYPES
- IDENTIFIER_TOKENS
- BRACKETS
- COLUMN_POSTFIX_TOKENS
- TABLE_POSTFIX_TOKENS
- FUNC_TOKENS
- CONJUNCTION
- ASSIGNMENT
- DISJUNCTION
- EQUALITY
- COMPARISON
- BITWISE
- TERM
- FACTOR
- EXPONENT
- TIMES
- TIMESTAMPS
- SET_OPERATIONS
- JOIN_METHODS
- JOIN_SIDES
- JOIN_KINDS
- JOIN_HINTS
- TABLE_TERMINATORS
- LAMBDAS
- TYPED_LAMBDA_ARGS
- LAMBDA_ARG_TERMINATORS
- COLUMN_OPERATORS
- CAST_COLUMN_OPERATORS
- EXPRESSION_PARSERS
- STATEMENT_PARSERS
- UNARY_PARSERS
- STRING_PARSERS
- NUMERIC_PARSERS
- PRIMARY_PARSERS
- PLACEHOLDER_PARSERS
- RANGE_PARSERS
- PIPE_SYNTAX_TRANSFORM_PARSERS
- PROPERTY_PARSERS
- CONSTRAINT_PARSERS
- ALTER_PARSERS
- ALTER_ALTER_PARSERS
- SCHEMA_UNNAMED_CONSTRAINTS
- NO_PAREN_FUNCTION_PARSERS
- INVALID_FUNC_NAME_TOKENS
- FUNCTIONS_WITH_ALIASED_ARGS
- KEY_VALUE_DEFINITIONS
- FUNCTION_PARSERS
- QUERY_MODIFIER_PARSERS
- QUERY_MODIFIER_TOKENS
- SET_PARSERS
- SHOW_PARSERS
- TYPE_LITERAL_PARSERS
- TYPE_CONVERTERS
- DDL_SELECT_TOKENS
- PRE_VOLATILE_TOKENS
- TRANSACTION_KIND
- TRANSACTION_CHARACTERISTICS
- CONFLICT_ACTIONS
- TRIGGER_TIMING
- TRIGGER_DEFERRABLE
- CREATE_SEQUENCE
- ISOLATED_LOADING_OPTIONS
- USABLES
- CAST_ACTIONS
- SCHEMA_BINDING_OPTIONS
- PROCEDURE_OPTIONS
- EXECUTE_AS_OPTIONS
- KEY_CONSTRAINT_OPTIONS
- WINDOW_EXCLUDE_OPTIONS
- INSERT_ALTERNATIVES
- CLONE_KEYWORDS
- HISTORICAL_DATA_PREFIX
- HISTORICAL_DATA_KIND
- OPCLASS_FOLLOW_KEYWORDS
- OPTYPE_FOLLOW_TOKENS
- TABLE_INDEX_HINT_TOKENS
- VIEW_ATTRIBUTES
- WINDOW_ALIAS_TOKENS
- WINDOW_BEFORE_PAREN_TOKENS
- WINDOW_SIDES
- JSON_KEY_VALUE_SEPARATOR_TOKENS
- FETCH_TOKENS
- ADD_CONSTRAINT_TOKENS
- DISTINCT_TOKENS
- UNNEST_OFFSET_ALIAS_TOKENS
- SELECT_START_TOKENS
- COPY_INTO_VARLEN_OPTIONS
- IS_JSON_PREDICATE_KIND
- ODBC_DATETIME_LITERALS
- ON_CONDITION_TOKENS
- PRIVILEGE_FOLLOW_TOKENS
- DESCRIBE_STYLES
- SET_ASSIGNMENT_DELIMITERS
- ANALYZE_STYLES
- ANALYZE_EXPRESSION_PARSERS
- PARTITION_KEYWORDS
- AMBIGUOUS_ALIAS_TOKENS
- OPERATION_MODIFIERS
- RECURSIVE_CTE_SEARCH_KIND
- SECURITY_PROPERTY_KEYWORDS
- MODIFIABLES
- STRICT_CAST
- PREFIXED_PIVOT_COLUMNS
- IDENTIFY_PIVOT_STRINGS
- LOG_DEFAULTS_TO_LN
- TABLESAMPLE_CSV
- DEFAULT_SAMPLING_METHOD
- SET_REQUIRES_ASSIGNMENT_DELIMITER
- TRIM_PATTERN_FIRST
- STRING_ALIASES
- MODIFIERS_ATTACHED_TO_SET_OP
- SET_OP_MODIFIERS
- NO_PAREN_IF_COMMANDS
- JSON_ARROWS_REQUIRE_JSON_TYPE
- COLON_IS_VARIANT_EXTRACT
- VALUES_FOLLOWED_BY_PAREN
- SUPPORTS_IMPLICIT_UNNEST
- INTERVAL_SPANS
- SUPPORTS_PARTITION_SELECTION
- WRAPPED_TRANSFORM_COLUMN_CONSTRAINT
- OPTIONAL_ALIAS_TOKEN_CTE
- ALTER_RENAME_REQUIRES_COLUMN
- ALTER_TABLE_PARTITIONS
- JOINS_HAVE_EQUAL_PRECEDENCE
- ZONE_AWARE_TIMESTAMP_CONSTRUCTOR
- MAP_KEYS_ARE_ARBITRARY_EXPRESSIONS
- JSON_EXTRACT_REQUIRES_JSON_EXPRESSION
- ADD_JOIN_ON_TRUE
- SUPPORTS_OMITTED_INTERVAL_SPAN_UNIT
- SHOW_TRIE
- SET_TRIE
- error_level
- error_message_context
- max_errors
- dialect
- sql
- errors
- reset
- raise_error
- validate_expression
- parse
- parse_into
- check_errors
- expression
- parse_set_operation
- build_cast
123def call_name(node: nodes.Expr) -> t.Tuple[str, ...]: 124 if isinstance(node, nodes.Name): 125 return (node.name,) 126 if isinstance(node, nodes.Const): 127 return (f"'{node.value}'",) 128 if isinstance(node, nodes.Getattr): 129 return call_name(node.node) + (node.attr,) 130 if isinstance(node, (nodes.Getitem, nodes.Call)): 131 return call_name(node.node) 132 return ()
139def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[CallNames]: 140 vars_in_scope = vars_in_scope.copy() 141 for child_node in node.iter_child_nodes(): 142 if "target" in child_node.fields: 143 # For nodes with assignment targets (Assign, AssignBlock, For, Import), 144 # the target name could shadow a reference in the right hand side. 145 # So we need to process the RHS before adding the target to scope. 146 # For example: {% set model = model.path %} should track model.path. 147 yield from find_call_names(child_node, vars_in_scope) 148 149 target = getattr(child_node, "target") 150 if isinstance(target, nodes.Name): 151 vars_in_scope.add(target.name) 152 elif isinstance(target, nodes.Tuple): 153 for item in target.items: 154 if isinstance(item, nodes.Name): 155 vars_in_scope.add(item.name) 156 elif isinstance(child_node, nodes.Macro): 157 for arg in child_node.args: 158 vars_in_scope.add(arg.name) 159 elif isinstance(child_node, nodes.Call) or ( 160 isinstance(child_node, nodes.Getattr) and not isinstance(child_node.node, nodes.Getattr) 161 ): 162 name = call_name(child_node) 163 if name[0][0] != "'" and name[0] not in vars_in_scope: 164 yield (name, child_node) 165 166 if "target" not in child_node.fields: 167 yield from find_call_names(child_node, vars_in_scope)
170def extract_call_names( 171 jinja_str: str, cache: t.Optional[t.Dict[str, t.Tuple[t.List[CallNames], bool]]] = None 172) -> t.List[CallNames]: 173 def parse() -> t.List[CallNames]: 174 return list(find_call_names(ENVIRONMENT.parse(jinja_str), set())) 175 176 if cache is not None: 177 key = str(zlib.crc32(jinja_str.encode("utf-8"))) 178 if key in cache: 179 names = cache[key][0] 180 else: 181 names = parse() 182 cache[key] = (names, True) 183 return names 184 return parse()
195def extract_macro_references_and_variables( 196 *jinja_strs: str, 197) -> t.Tuple[t.Set[MacroReference], t.Set[str]]: 198 macro_references = set() 199 variables = set() 200 for jinja_str in jinja_strs: 201 for call_name, node in extract_call_names(jinja_str): 202 if call_name[0] in (c.VAR, c.BLUEPRINT_VAR): 203 if not is_variable_node(node): 204 # Find the variable node which could be nested 205 for n in node.find_all(nodes.Call): 206 if is_variable_node(n): 207 node = n 208 break 209 else: 210 raise ValueError(f"Could not find variable name in {jinja_str}") 211 node = t.cast(nodes.Call, node) 212 args = [jinja_call_arg_name(arg) for arg in node.args] 213 if args and args[0]: 214 variables.add(args[0].lower()) 215 elif call_name[0] == c.GATEWAY: 216 variables.add(c.GATEWAY) 217 elif len(call_name) == 1: 218 macro_references.add(MacroReference(name=call_name[0])) 219 elif len(call_name) == 2: 220 macro_references.add(MacroReference(package=call_name[0], name=call_name[1])) 221 return macro_references, variables
224def sort_dict_recursive( 225 item: t.Dict[str, t.Any], 226) -> t.Dict[str, t.Any]: 227 sorted_dict: t.Dict[str, t.Any] = {} 228 for k, v in sorted(item.items()): 229 if isinstance(v, list): 230 sorted_dict[k] = sorted(v) 231 elif isinstance(v, dict): 232 sorted_dict[k] = sort_dict_recursive(v) 233 else: 234 sorted_dict[k] = v 235 return sorted_dict
241class JinjaMacroRegistry(PydanticModel): 242 """Registry for Jinja macros. 243 244 Args: 245 packages: The mapping from package name to a collection of macro definitions. 246 root_macros: The collection of top-level macro definitions. 247 global_objs: The global objects. 248 create_builtins_module: The name of a module which defines the `create_builtins` factory 249 function that will be used to construct builtin variables and functions. 250 root_package_name: The name of the root package. If specified root macros will be available 251 as both `root_package_name.macro_name` and `macro_name`. 252 top_level_packages: The list of top-level packages. Macros in this packages will be available 253 as both `package_name.macro_name` and `macro_name`. 254 """ 255 256 packages: t.Dict[str, t.Dict[str, MacroInfo]] = {} 257 root_macros: t.Dict[str, MacroInfo] = {} 258 global_objs: t.Dict[str, JinjaGlobalAttribute] = {} 259 create_builtins_module: t.Optional[str] = SQLMESH_JINJA_PACKAGE 260 root_package_name: t.Optional[str] = None 261 top_level_packages: t.List[str] = [] 262 263 _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {} 264 _trimmed: bool = False 265 __environment: t.Optional[Environment] = None 266 267 def __getstate__(self) -> t.Dict[t.Any, t.Any]: 268 state = super().__getstate__() 269 private = state[PRIVATE_FIELDS] 270 private["_parser_cache"] = {} 271 private["_JinjaMacroRegistry__environment"] = None 272 return state 273 274 @field_validator("global_objs", mode="before") 275 @classmethod 276 def _validate_global_objs(cls, value: t.Any) -> t.Any: 277 def _normalize(val: t.Any) -> t.Any: 278 if isinstance(val, dict): 279 return AttributeDict({k: _normalize(v) for k, v in val.items()}) 280 if isinstance(val, list): 281 return [_normalize(v) for v in val] 282 if isinstance(val, set): 283 return [_normalize(v) for v in sorted(val)] 284 if isinstance(val, Enum): 285 return val.value 286 return val 287 288 return _normalize(value) 289 290 @field_serializer("global_objs") 291 def _serialize_attribute_dict( 292 self, value: t.Dict[str, JinjaGlobalAttribute] 293 ) -> t.Dict[str, t.Any]: 294 # NOTE: This is called only when used with Pydantic V2. 295 def _convert( 296 val: t.Union[t.Dict[str, JinjaGlobalAttribute], t.Dict[str, t.Any]], 297 ) -> t.Dict[str, t.Any]: 298 return {k: _convert(v) if isinstance(v, AttributeDict) else v for k, v in val.items()} 299 300 return _convert(value) 301 302 @property 303 def trimmed(self) -> bool: 304 return self._trimmed 305 306 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 307 """Adds macros to the target package. 308 309 Args: 310 macros: Macros that should be added. 311 package: The name of the package the given macros belong to. If not specified, the provided 312 macros will be added to the root namespace. 313 """ 314 315 if package is not None: 316 package_macros = self.packages.get(package, {}) 317 package_macros.update(macros) 318 self.packages[package] = package_macros 319 else: 320 self.root_macros.update(macros) 321 322 def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None: 323 """Adds global objects to the registry. 324 325 Args: 326 globals: The global objects that should be added. 327 """ 328 # Keep the registry lightweight when the graph is not needed 329 if not "graph" in self.packages: 330 globals.pop("flat_graph", None) 331 self.global_objs.update(**self._validate_global_objs(globals)) 332 333 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 334 """Builds a Python callable for a macro with the given reference. 335 336 Args: 337 reference: The macro reference. 338 Returns: 339 The macro as a Python callable or None if not found. 340 """ 341 env: Environment = self.build_environment(**kwargs) 342 if reference.package is not None: 343 package = env.globals.get(reference.package, {}) 344 return package.get(reference.name) # type: ignore 345 return env.globals.get(reference.name) # type: ignore 346 347 def build_environment(self, **kwargs: t.Any) -> Environment: 348 """Builds a new Jinja environment based on this registry.""" 349 350 context: t.Dict[str, t.Any] = {} 351 352 root_macros = { 353 name: self._MacroWrapper(name, None, self, context) 354 for name, macro in self.root_macros.items() 355 } 356 357 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 358 for package_name, macros in self.packages.items(): 359 for macro_name, macro in macros.items(): 360 macro_wrapper = self._MacroWrapper(macro_name, package_name, self, context) 361 package_macros[package_name][macro_name] = macro_wrapper 362 if macro.is_top_level and macro_name not in root_macros: 363 root_macros[macro_name] = macro_wrapper 364 365 top_level_packages = self.top_level_packages.copy() 366 367 if self.root_package_name is not None: 368 package_macros[self.root_package_name].update(root_macros) 369 top_level_packages.append(self.root_package_name) 370 371 env = environment() 372 373 builtin_globals = self._create_builtin_globals(kwargs) 374 for top_level_package_name in top_level_packages: 375 # Make sure that the top-level package doesn't fully override the same builtin package. 376 package_macros[top_level_package_name] = AttributeDict( 377 { 378 **(builtin_globals.pop(top_level_package_name, None) or {}), 379 **(package_macros.get(top_level_package_name) or {}), 380 } 381 ) 382 root_macros.update(package_macros[top_level_package_name]) 383 384 context.update(builtin_globals) 385 context.update(root_macros) 386 context.update(package_macros) 387 context["render"] = lambda input: env.from_string(input).render() 388 389 env.globals.update(context) 390 env.filters.update(self._environment.filters) 391 return env 392 393 def trim( 394 self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None 395 ) -> JinjaMacroRegistry: 396 """Trims the registry by keeping only macros with given references and their transitive dependencies. 397 398 Args: 399 dependencies: References to macros that should be kept. 400 package: The name of the package in the context of which the trimming should be performed. 401 402 Returns: 403 A new trimmed registry. 404 """ 405 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 406 for dep in dependencies: 407 dependencies_by_package[dep.package or package].add(dep.name) 408 409 top_level_packages = self.top_level_packages.copy() 410 if package is not None: 411 top_level_packages.append(package) 412 413 result = JinjaMacroRegistry( 414 global_objs=self.global_objs.copy(), 415 create_builtins_module=self.create_builtins_module, 416 root_package_name=self.root_package_name, 417 top_level_packages=top_level_packages, 418 ) 419 for package, names in dependencies_by_package.items(): 420 result = result.merge(self._trim_macros(names, package)) 421 422 result._trimmed = True 423 424 return result 425 426 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 427 """Returns a copy of the registry which contains macros from both this and `other` instances. 428 429 Args: 430 other: The other registry instance. 431 432 Returns: 433 A new merged registry. 434 """ 435 436 root_macros = { 437 **self.root_macros, 438 **other.root_macros, 439 } 440 441 packages = {} 442 for package in {*self.packages, *other.packages}: 443 packages[package] = { 444 **self.packages.get(package, {}), 445 **other.packages.get(package, {}), 446 } 447 448 global_objs = { 449 **self.global_objs, 450 **other.global_objs, 451 } 452 453 return JinjaMacroRegistry( 454 packages=packages, 455 root_macros=root_macros, 456 global_objs=global_objs, 457 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 458 root_package_name=self.root_package_name or other.root_package_name, 459 top_level_packages=[*self.top_level_packages, *other.top_level_packages], 460 ) 461 462 def to_expressions(self) -> t.List[Expression]: 463 output: t.List[Expression] = [] 464 465 filtered_objs = { 466 k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars") 467 } 468 if filtered_objs: 469 output.append( 470 d.PythonCode( 471 expressions=[ 472 f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" 473 for k, v in sort_dict_recursive(filtered_objs).items() 474 ] 475 ) 476 ) 477 478 for macro_name, macro_info in sorted(self.root_macros.items()): 479 output.append(d.jinja_statement(macro_info.definition)) 480 481 for _, package in sorted(self.packages.items()): 482 for macro_name, macro_info in sorted(package.items()): 483 output.append(d.jinja_statement(macro_info.definition)) 484 485 return output 486 487 @property 488 def data_hash_values(self) -> t.List[str]: 489 data = [] 490 491 for macro_name, macro in sorted(self.root_macros.items()): 492 data.append(macro_name) 493 data.append(macro.definition) 494 495 for _, package in sorted(self.packages.items()): 496 for macro_name, macro in sorted(package.items()): 497 data.append(macro_name) 498 data.append(macro.definition) 499 500 trimmed_global_objs = { 501 k: self.global_objs[k] for k in ("refs", "sources", "vars") if k in self.global_objs 502 } 503 data.append(json.dumps(trimmed_global_objs, sort_keys=True)) 504 505 return data 506 507 def __deepcopy__(self, memo: t.Optional[t.Dict[int, t.Any]] = None) -> JinjaMacroRegistry: 508 return JinjaMacroRegistry.parse_obj(self.dict()) 509 510 def _parse_macro(self, name: str, package: t.Optional[str]) -> Template: 511 cache_key = (package, name) 512 if cache_key not in self._parser_cache: 513 macro = self._get_macro(name, package) 514 515 definition: nodes.Template = self._environment.parse(macro.definition) 516 if _is_private_macro(name): 517 # A workaround to expose private jinja macros. 518 definition = self._to_non_private_macro_def(name, definition) 519 520 self._parser_cache[cache_key] = self._environment.from_string(definition) 521 return self._parser_cache[cache_key] 522 523 @property 524 def _environment(self) -> Environment: 525 if self.__environment is None: 526 self.__environment = environment() 527 self.__environment.filters.update(self._create_builtin_filters()) 528 return self.__environment 529 530 def _trim_macros( 531 self, 532 names: t.Set[str], 533 package: t.Optional[str] = None, 534 visited: t.Optional[t.Dict[t.Optional[str], t.Set[str]]] = None, 535 ) -> JinjaMacroRegistry: 536 if visited is None: 537 visited = defaultdict(set) 538 539 macros = self.packages.get(package, {}) if package is not None else self.root_macros 540 trimmed_macros = {} 541 542 dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 543 544 for name in names: 545 if name in macros and name not in visited[package]: 546 macro = macros[name] 547 trimmed_macros[name] = macro 548 for dependency in macro.depends_on: 549 dependencies[dependency.package or package].add(dependency.name) 550 visited[package].add(name) 551 552 if package is not None: 553 result = JinjaMacroRegistry(packages={package: trimmed_macros}) 554 else: 555 result = JinjaMacroRegistry(root_macros=trimmed_macros) 556 557 for upstream_package, upstream_names in dependencies.items(): 558 result = result.merge( 559 self._trim_macros(upstream_names, upstream_package, visited=visited) 560 ) 561 562 return result 563 564 def _macro_exists(self, name: str, package: t.Optional[str]) -> bool: 565 return ( 566 name in self.packages.get(package, {}) 567 if package is not None 568 else name in self.root_macros 569 ) 570 571 def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo: 572 return self.packages[package][name] if package is not None else self.root_macros[name] 573 574 def _to_non_private_macro_def(self, name: str, template: nodes.Template) -> nodes.Template: 575 for node in template.find_all((nodes.Macro, nodes.Call)): 576 if isinstance(node, nodes.Macro): 577 node.name = _non_private_name(name) 578 elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name): 579 node.node.name = _non_private_name(name) 580 581 return template 582 583 def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 584 """Creates Jinja builtin globals using a factory function defined in the provided module.""" 585 engine_adapter = global_vars.pop("engine_adapter", None) 586 global_vars = {**self.global_objs, **global_vars} 587 if self.create_builtins_module is not None: 588 module = importlib.import_module(self.create_builtins_module) 589 if hasattr(module, "create_builtin_globals"): 590 return module.create_builtin_globals(self, global_vars, engine_adapter) 591 return global_vars 592 593 def _create_builtin_filters(self) -> t.Dict[str, t.Any]: 594 """Creates Jinja builtin filters using a factory function defined in the provided module.""" 595 if self.create_builtins_module is not None: 596 module = importlib.import_module(self.create_builtins_module) 597 if hasattr(module, "create_builtin_filters"): 598 return module.create_builtin_filters() 599 return {} 600 601 class _MacroWrapper: 602 def __init__( 603 self, 604 name: str, 605 package: t.Optional[str], 606 registry: JinjaMacroRegistry, 607 context: t.Dict[str, t.Any], 608 ): 609 self.name = name 610 self.package = package 611 self.context = context 612 self.registry = registry 613 614 def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: 615 context = self.context.copy() 616 if self.package is not None and self.package in context: 617 context.update(context[self.package]) 618 619 template = self.registry._parse_macro(self.name, self.package) 620 macro_callable = getattr( 621 template.make_module(vars=context), _non_private_name(self.name) 622 ) 623 try: 624 return macro_callable(*args, **kwargs) 625 except MacroReturnVal as ret: 626 return ret.value
Registry for Jinja macros.
Arguments:
- packages: The mapping from package name to a collection of macro definitions.
- root_macros: The collection of top-level macro definitions.
- global_objs: The global objects.
- create_builtins_module: The name of a module which defines the
create_builtinsfactory function that will be used to construct builtin variables and functions. - root_package_name: The name of the root package. If specified root macros will be available
as both
root_package_name.macro_nameandmacro_name. - top_level_packages: The list of top-level packages. Macros in this packages will be available
as both
package_name.macro_nameandmacro_name.
306 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 307 """Adds macros to the target package. 308 309 Args: 310 macros: Macros that should be added. 311 package: The name of the package the given macros belong to. If not specified, the provided 312 macros will be added to the root namespace. 313 """ 314 315 if package is not None: 316 package_macros = self.packages.get(package, {}) 317 package_macros.update(macros) 318 self.packages[package] = package_macros 319 else: 320 self.root_macros.update(macros)
Adds macros to the target package.
Arguments:
- macros: Macros that should be added.
- package: The name of the package the given macros belong to. If not specified, the provided
- macros will be added to the root namespace.
322 def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None: 323 """Adds global objects to the registry. 324 325 Args: 326 globals: The global objects that should be added. 327 """ 328 # Keep the registry lightweight when the graph is not needed 329 if not "graph" in self.packages: 330 globals.pop("flat_graph", None) 331 self.global_objs.update(**self._validate_global_objs(globals))
Adds global objects to the registry.
Arguments:
- globals: The global objects that should be added.
333 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 334 """Builds a Python callable for a macro with the given reference. 335 336 Args: 337 reference: The macro reference. 338 Returns: 339 The macro as a Python callable or None if not found. 340 """ 341 env: Environment = self.build_environment(**kwargs) 342 if reference.package is not None: 343 package = env.globals.get(reference.package, {}) 344 return package.get(reference.name) # type: ignore 345 return env.globals.get(reference.name) # type: ignore
Builds a Python callable for a macro with the given reference.
Arguments:
- reference: The macro reference.
Returns:
The macro as a Python callable or None if not found.
347 def build_environment(self, **kwargs: t.Any) -> Environment: 348 """Builds a new Jinja environment based on this registry.""" 349 350 context: t.Dict[str, t.Any] = {} 351 352 root_macros = { 353 name: self._MacroWrapper(name, None, self, context) 354 for name, macro in self.root_macros.items() 355 } 356 357 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 358 for package_name, macros in self.packages.items(): 359 for macro_name, macro in macros.items(): 360 macro_wrapper = self._MacroWrapper(macro_name, package_name, self, context) 361 package_macros[package_name][macro_name] = macro_wrapper 362 if macro.is_top_level and macro_name not in root_macros: 363 root_macros[macro_name] = macro_wrapper 364 365 top_level_packages = self.top_level_packages.copy() 366 367 if self.root_package_name is not None: 368 package_macros[self.root_package_name].update(root_macros) 369 top_level_packages.append(self.root_package_name) 370 371 env = environment() 372 373 builtin_globals = self._create_builtin_globals(kwargs) 374 for top_level_package_name in top_level_packages: 375 # Make sure that the top-level package doesn't fully override the same builtin package. 376 package_macros[top_level_package_name] = AttributeDict( 377 { 378 **(builtin_globals.pop(top_level_package_name, None) or {}), 379 **(package_macros.get(top_level_package_name) or {}), 380 } 381 ) 382 root_macros.update(package_macros[top_level_package_name]) 383 384 context.update(builtin_globals) 385 context.update(root_macros) 386 context.update(package_macros) 387 context["render"] = lambda input: env.from_string(input).render() 388 389 env.globals.update(context) 390 env.filters.update(self._environment.filters) 391 return env
Builds a new Jinja environment based on this registry.
393 def trim( 394 self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None 395 ) -> JinjaMacroRegistry: 396 """Trims the registry by keeping only macros with given references and their transitive dependencies. 397 398 Args: 399 dependencies: References to macros that should be kept. 400 package: The name of the package in the context of which the trimming should be performed. 401 402 Returns: 403 A new trimmed registry. 404 """ 405 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 406 for dep in dependencies: 407 dependencies_by_package[dep.package or package].add(dep.name) 408 409 top_level_packages = self.top_level_packages.copy() 410 if package is not None: 411 top_level_packages.append(package) 412 413 result = JinjaMacroRegistry( 414 global_objs=self.global_objs.copy(), 415 create_builtins_module=self.create_builtins_module, 416 root_package_name=self.root_package_name, 417 top_level_packages=top_level_packages, 418 ) 419 for package, names in dependencies_by_package.items(): 420 result = result.merge(self._trim_macros(names, package)) 421 422 result._trimmed = True 423 424 return result
Trims the registry by keeping only macros with given references and their transitive dependencies.
Arguments:
- dependencies: References to macros that should be kept.
- package: The name of the package in the context of which the trimming should be performed.
Returns:
A new trimmed registry.
426 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 427 """Returns a copy of the registry which contains macros from both this and `other` instances. 428 429 Args: 430 other: The other registry instance. 431 432 Returns: 433 A new merged registry. 434 """ 435 436 root_macros = { 437 **self.root_macros, 438 **other.root_macros, 439 } 440 441 packages = {} 442 for package in {*self.packages, *other.packages}: 443 packages[package] = { 444 **self.packages.get(package, {}), 445 **other.packages.get(package, {}), 446 } 447 448 global_objs = { 449 **self.global_objs, 450 **other.global_objs, 451 } 452 453 return JinjaMacroRegistry( 454 packages=packages, 455 root_macros=root_macros, 456 global_objs=global_objs, 457 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 458 root_package_name=self.root_package_name or other.root_package_name, 459 top_level_packages=[*self.top_level_packages, *other.top_level_packages], 460 )
Returns a copy of the registry which contains macros from both this and other instances.
Arguments:
- other: The other registry instance.
Returns:
A new merged registry.
462 def to_expressions(self) -> t.List[Expression]: 463 output: t.List[Expression] = [] 464 465 filtered_objs = { 466 k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars") 467 } 468 if filtered_objs: 469 output.append( 470 d.PythonCode( 471 expressions=[ 472 f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" 473 for k, v in sort_dict_recursive(filtered_objs).items() 474 ] 475 ) 476 ) 477 478 for macro_name, macro_info in sorted(self.root_macros.items()): 479 output.append(d.jinja_statement(macro_info.definition)) 480 481 for _, package in sorted(self.packages.items()): 482 for macro_name, macro_info in sorted(package.items()): 483 output.append(d.jinja_statement(macro_info.definition)) 484 485 return output
487 @property 488 def data_hash_values(self) -> t.List[str]: 489 data = [] 490 491 for macro_name, macro in sorted(self.root_macros.items()): 492 data.append(macro_name) 493 data.append(macro.definition) 494 495 for _, package in sorted(self.packages.items()): 496 for macro_name, macro in sorted(package.items()): 497 data.append(macro_name) 498 data.append(macro.definition) 499 500 trimmed_global_objs = { 501 k: self.global_objs[k] for k in ("refs", "sources", "vars") if k in self.global_objs 502 } 503 data.append(json.dumps(trimmed_global_objs, sort_keys=True)) 504 505 return data
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
358def init_private_attributes(self: BaseModel, context: Any, /) -> None: 359 """This function is meant to behave like a BaseModel method to initialise private attributes. 360 361 It takes context as an argument since that's what pydantic-core passes when calling it. 362 363 Args: 364 self: The BaseModel instance. 365 context: The context. 366 """ 367 if getattr(self, '__pydantic_private__', None) is None: 368 pydantic_private = {} 369 for name, private_attr in self.__private_attributes__.items(): 370 default = private_attr.get_default() 371 if default is not PydanticUndefined: 372 pydantic_private[name] = default 373 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Arguments:
- self: The BaseModel instance.
- context: The context.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
660def create_builtin_globals( 661 jinja_macros: JinjaMacroRegistry, global_vars: t.Dict[str, t.Any], *args: t.Any, **kwargs: t.Any 662) -> t.Dict[str, t.Any]: 663 global_vars.pop(c.GATEWAY, None) 664 variables = global_vars.pop(c.SQLMESH_VARS, None) or {} 665 blueprint_variables = global_vars.pop(c.SQLMESH_BLUEPRINT_VARS, None) or {} 666 return { 667 **global_vars, 668 c.VAR: create_var(variables), 669 c.GATEWAY: lambda: variables.get(c.GATEWAY, None), 670 c.BLUEPRINT_VAR: create_var(blueprint_variables), 671 }
674def make_jinja_registry( 675 jinja_macros: JinjaMacroRegistry, package_name: str, jinja_references: t.Set[MacroReference] 676) -> JinjaMacroRegistry: 677 """ 678 Creates a Jinja macro registry for a specific package. 679 680 This function takes an existing Jinja macro registry and returns a new 681 registry that includes only the macros associated with the specified 682 package and trims the registry to include only the macros referenced 683 in the provided set of macro references. 684 685 Args: 686 jinja_macros: The original Jinja macro registry containing all macros. 687 package_name: The name of the package for which to create the registry. 688 jinja_references: A set of macro references to retain in the new registry. 689 690 Returns: 691 A new JinjaMacroRegistry containing only the macros for the specified 692 package and the referenced macros. 693 """ 694 695 jinja_registry = jinja_macros.copy() 696 jinja_registry.root_macros = jinja_registry.packages.get(package_name) or {} 697 jinja_registry = jinja_registry.trim(jinja_references) 698 699 return jinja_registry
Creates a Jinja macro registry for a specific package.
This function takes an existing Jinja macro registry and returns a new registry that includes only the macros associated with the specified package and trims the registry to include only the macros referenced in the provided set of macro references.
Arguments:
- jinja_macros: The original Jinja macro registry containing all macros.
- package_name: The name of the package for which to create the registry.
- jinja_references: A set of macro references to retain in the new registry.
Returns:
A new JinjaMacroRegistry containing only the macros for the specified package and the referenced macros.
702def extract_error_details(ex: Exception) -> str: 703 """Extracts a readable message from a Jinja2 error, to include missing name and macro.""" 704 705 error_details = "" 706 if isinstance(ex, UndefinedError): 707 if match := re.search(r"'(\w+)'", str(ex)): 708 error_details += f"\nUndefined macro/variable: '{match.group(1)}'" 709 try: 710 _, _, exc_traceback = exc_info() 711 for frame, _ in walk_tb(exc_traceback): 712 if frame.f_code.co_name == "_invoke": 713 macro = frame.f_locals.get("self") 714 if isinstance(macro, Macro): 715 error_details += f" in macro: '{macro.name}'\n" 716 break 717 except: 718 # to fall back to the generic error message if frame analysis fails 719 pass 720 return error_details or str(ex)
Extracts a readable message from a Jinja2 error, to include missing name and macro.