sqlmesh.utils.jinja
1from __future__ import annotations 2 3import importlib 4import json 5import re 6import typing as t 7import zlib 8from collections import defaultdict 9from enum import Enum 10from sys import exc_info 11from traceback import walk_tb 12 13from jinja2 import Environment, Template, nodes, UndefinedError 14from jinja2.runtime import Macro 15from sqlglot import Dialect, Expression, Parser, TokenType 16 17from sqlmesh.core import constants as c 18from sqlmesh.core import dialect as d 19from sqlmesh.utils import AttributeDict 20from sqlmesh.utils.pydantic import PRIVATE_FIELDS, PydanticModel, field_serializer, field_validator 21from sqlmesh.utils.metaprogramming import SqlValue 22 23 24if t.TYPE_CHECKING: 25 CallNames = t.Tuple[t.Tuple[str, ...], t.Union[nodes.Call, nodes.Getattr]] 26 27SQLMESH_JINJA_PACKAGE = "sqlmesh.utils.jinja" 28 29 30def environment(**kwargs: t.Any) -> Environment: 31 extensions = kwargs.pop("extensions", []) 32 extensions.append("jinja2.ext.do") 33 extensions.append("jinja2.ext.loopcontrols") 34 return Environment(extensions=extensions, **kwargs) 35 36 37ENVIRONMENT = environment() 38 39 40class MacroReference(PydanticModel, frozen=True): 41 package: t.Optional[str] = None 42 name: str 43 44 @property 45 def reference(self) -> str: 46 if self.package is None: 47 return self.name 48 return ".".join((self.package, self.name)) 49 50 def __str__(self) -> str: 51 return self.reference 52 53 54class MacroInfo(PydanticModel): 55 """Class to hold macro and its calls""" 56 57 definition: str 58 depends_on: t.List[MacroReference] 59 is_top_level: bool = False 60 61 62class MacroReturnVal(Exception): 63 def __init__(self, val: t.Any): 64 self.value = val 65 66 67class MacroExtractor(Parser): 68 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 69 """Extract a dictionary of macro definitions from a jinja string. 70 71 Args: 72 jinja: The jinja string to extract from. 73 dialect: The dialect of SQL. 74 75 Returns: 76 A dictionary of macro name to macro definition. 77 """ 78 self.reset() 79 self.sql = jinja 80 self._tokens = Dialect.get_or_raise(dialect).tokenize(jinja) 81 self._index = -1 82 self._advance() 83 84 macros: t.Dict[str, MacroInfo] = {} 85 86 while self._curr: 87 if self._curr.token_type == TokenType.BLOCK_START: 88 macro_start = self._curr 89 elif self._tag == "MACRO" and self._next: 90 name = self._next.text 91 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 92 self._advance() 93 94 while self._curr and self._tag != "ENDMACRO": 95 self._advance() 96 97 macro_str = self._find_sql(macro_start, self._next) 98 macros[name] = MacroInfo( 99 definition=macro_str, 100 depends_on=list(extract_macro_references_and_variables(macro_str)[0]), 101 ) 102 103 self._advance() 104 105 return macros 106 107 def _advance(self, times: int = 1) -> None: 108 super()._advance(times) 109 self._tag = ( 110 self._curr.text.upper() 111 if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START 112 else "" 113 ) 114 115 116def call_name(node: nodes.Expr) -> t.Tuple[str, ...]: 117 if isinstance(node, nodes.Name): 118 return (node.name,) 119 if isinstance(node, nodes.Const): 120 return (f"'{node.value}'",) 121 if isinstance(node, nodes.Getattr): 122 return call_name(node.node) + (node.attr,) 123 if isinstance(node, (nodes.Getitem, nodes.Call)): 124 return call_name(node.node) 125 return () 126 127 128def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) -> str: 129 return ENVIRONMENT.from_string(query).render(methods or {}) 130 131 132def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[CallNames]: 133 vars_in_scope = vars_in_scope.copy() 134 for child_node in node.iter_child_nodes(): 135 if "target" in child_node.fields: 136 # For nodes with assignment targets (Assign, AssignBlock, For, Import), 137 # the target name could shadow a reference in the right hand side. 138 # So we need to process the RHS before adding the target to scope. 139 # For example: {% set model = model.path %} should track model.path. 140 yield from find_call_names(child_node, vars_in_scope) 141 142 target = getattr(child_node, "target") 143 if isinstance(target, nodes.Name): 144 vars_in_scope.add(target.name) 145 elif isinstance(target, nodes.Tuple): 146 for item in target.items: 147 if isinstance(item, nodes.Name): 148 vars_in_scope.add(item.name) 149 elif isinstance(child_node, nodes.Macro): 150 for arg in child_node.args: 151 vars_in_scope.add(arg.name) 152 elif isinstance(child_node, nodes.Call) or ( 153 isinstance(child_node, nodes.Getattr) and not isinstance(child_node.node, nodes.Getattr) 154 ): 155 name = call_name(child_node) 156 if name[0][0] != "'" and name[0] not in vars_in_scope: 157 yield (name, child_node) 158 159 if "target" not in child_node.fields: 160 yield from find_call_names(child_node, vars_in_scope) 161 162 163def extract_call_names( 164 jinja_str: str, cache: t.Optional[t.Dict[str, t.Tuple[t.List[CallNames], bool]]] = None 165) -> t.List[CallNames]: 166 def parse() -> t.List[CallNames]: 167 return list(find_call_names(ENVIRONMENT.parse(jinja_str), set())) 168 169 if cache is not None: 170 key = str(zlib.crc32(jinja_str.encode("utf-8"))) 171 if key in cache: 172 names = cache[key][0] 173 else: 174 names = parse() 175 cache[key] = (names, True) 176 return names 177 return parse() 178 179 180def is_variable_node(n: nodes.Node) -> bool: 181 return ( 182 isinstance(n, nodes.Call) 183 and isinstance(n.node, nodes.Name) 184 and n.node.name in (c.VAR, c.BLUEPRINT_VAR) 185 ) 186 187 188def extract_macro_references_and_variables( 189 *jinja_strs: str, 190) -> t.Tuple[t.Set[MacroReference], t.Set[str]]: 191 macro_references = set() 192 variables = set() 193 for jinja_str in jinja_strs: 194 for call_name, node in extract_call_names(jinja_str): 195 if call_name[0] in (c.VAR, c.BLUEPRINT_VAR): 196 if not is_variable_node(node): 197 # Find the variable node which could be nested 198 for n in node.find_all(nodes.Call): 199 if is_variable_node(n): 200 node = n 201 break 202 else: 203 raise ValueError(f"Could not find variable name in {jinja_str}") 204 node = t.cast(nodes.Call, node) 205 args = [jinja_call_arg_name(arg) for arg in node.args] 206 if args and args[0]: 207 variables.add(args[0].lower()) 208 elif call_name[0] == c.GATEWAY: 209 variables.add(c.GATEWAY) 210 elif len(call_name) == 1: 211 macro_references.add(MacroReference(name=call_name[0])) 212 elif len(call_name) == 2: 213 macro_references.add(MacroReference(package=call_name[0], name=call_name[1])) 214 return macro_references, variables 215 216 217def sort_dict_recursive( 218 item: t.Dict[str, t.Any], 219) -> t.Dict[str, t.Any]: 220 sorted_dict: t.Dict[str, t.Any] = {} 221 for k, v in sorted(item.items()): 222 if isinstance(v, list): 223 sorted_dict[k] = sorted(v) 224 elif isinstance(v, dict): 225 sorted_dict[k] = sort_dict_recursive(v) 226 else: 227 sorted_dict[k] = v 228 return sorted_dict 229 230 231JinjaGlobalAttribute = t.Union[str, int, float, bool, AttributeDict] 232 233 234class JinjaMacroRegistry(PydanticModel): 235 """Registry for Jinja macros. 236 237 Args: 238 packages: The mapping from package name to a collection of macro definitions. 239 root_macros: The collection of top-level macro definitions. 240 global_objs: The global objects. 241 create_builtins_module: The name of a module which defines the `create_builtins` factory 242 function that will be used to construct builtin variables and functions. 243 root_package_name: The name of the root package. If specified root macros will be available 244 as both `root_package_name.macro_name` and `macro_name`. 245 top_level_packages: The list of top-level packages. Macros in this packages will be available 246 as both `package_name.macro_name` and `macro_name`. 247 """ 248 249 packages: t.Dict[str, t.Dict[str, MacroInfo]] = {} 250 root_macros: t.Dict[str, MacroInfo] = {} 251 global_objs: t.Dict[str, JinjaGlobalAttribute] = {} 252 create_builtins_module: t.Optional[str] = SQLMESH_JINJA_PACKAGE 253 root_package_name: t.Optional[str] = None 254 top_level_packages: t.List[str] = [] 255 256 _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {} 257 _trimmed: bool = False 258 __environment: t.Optional[Environment] = None 259 260 def __getstate__(self) -> t.Dict[t.Any, t.Any]: 261 state = super().__getstate__() 262 private = state[PRIVATE_FIELDS] 263 private["_parser_cache"] = {} 264 private["_JinjaMacroRegistry__environment"] = None 265 return state 266 267 @field_validator("global_objs", mode="before") 268 @classmethod 269 def _validate_global_objs(cls, value: t.Any) -> t.Any: 270 def _normalize(val: t.Any) -> t.Any: 271 if isinstance(val, dict): 272 return AttributeDict({k: _normalize(v) for k, v in val.items()}) 273 if isinstance(val, list): 274 return [_normalize(v) for v in val] 275 if isinstance(val, set): 276 return [_normalize(v) for v in sorted(val)] 277 if isinstance(val, Enum): 278 return val.value 279 return val 280 281 return _normalize(value) 282 283 @field_serializer("global_objs") 284 def _serialize_attribute_dict( 285 self, value: t.Dict[str, JinjaGlobalAttribute] 286 ) -> t.Dict[str, t.Any]: 287 # NOTE: This is called only when used with Pydantic V2. 288 def _convert( 289 val: t.Union[t.Dict[str, JinjaGlobalAttribute], t.Dict[str, t.Any]], 290 ) -> t.Dict[str, t.Any]: 291 return {k: _convert(v) if isinstance(v, AttributeDict) else v for k, v in val.items()} 292 293 return _convert(value) 294 295 @property 296 def trimmed(self) -> bool: 297 return self._trimmed 298 299 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 300 """Adds macros to the target package. 301 302 Args: 303 macros: Macros that should be added. 304 package: The name of the package the given macros belong to. If not specified, the provided 305 macros will be added to the root namespace. 306 """ 307 308 if package is not None: 309 package_macros = self.packages.get(package, {}) 310 package_macros.update(macros) 311 self.packages[package] = package_macros 312 else: 313 self.root_macros.update(macros) 314 315 def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None: 316 """Adds global objects to the registry. 317 318 Args: 319 globals: The global objects that should be added. 320 """ 321 # Keep the registry lightweight when the graph is not needed 322 if not "graph" in self.packages: 323 globals.pop("flat_graph", None) 324 self.global_objs.update(**self._validate_global_objs(globals)) 325 326 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 327 """Builds a Python callable for a macro with the given reference. 328 329 Args: 330 reference: The macro reference. 331 Returns: 332 The macro as a Python callable or None if not found. 333 """ 334 env: Environment = self.build_environment(**kwargs) 335 if reference.package is not None: 336 package = env.globals.get(reference.package, {}) 337 return package.get(reference.name) # type: ignore 338 return env.globals.get(reference.name) # type: ignore 339 340 def build_environment(self, **kwargs: t.Any) -> Environment: 341 """Builds a new Jinja environment based on this registry.""" 342 343 context: t.Dict[str, t.Any] = {} 344 345 root_macros = { 346 name: self._MacroWrapper(name, None, self, context) 347 for name, macro in self.root_macros.items() 348 } 349 350 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 351 for package_name, macros in self.packages.items(): 352 for macro_name, macro in macros.items(): 353 macro_wrapper = self._MacroWrapper(macro_name, package_name, self, context) 354 package_macros[package_name][macro_name] = macro_wrapper 355 if macro.is_top_level and macro_name not in root_macros: 356 root_macros[macro_name] = macro_wrapper 357 358 top_level_packages = self.top_level_packages.copy() 359 360 if self.root_package_name is not None: 361 package_macros[self.root_package_name].update(root_macros) 362 top_level_packages.append(self.root_package_name) 363 364 env = environment() 365 366 builtin_globals = self._create_builtin_globals(kwargs) 367 for top_level_package_name in top_level_packages: 368 # Make sure that the top-level package doesn't fully override the same builtin package. 369 package_macros[top_level_package_name] = AttributeDict( 370 { 371 **(builtin_globals.pop(top_level_package_name, None) or {}), 372 **(package_macros.get(top_level_package_name) or {}), 373 } 374 ) 375 root_macros.update(package_macros[top_level_package_name]) 376 377 context.update(builtin_globals) 378 context.update(root_macros) 379 context.update(package_macros) 380 context["render"] = lambda input: env.from_string(input).render() 381 382 env.globals.update(context) 383 env.filters.update(self._environment.filters) 384 return env 385 386 def trim( 387 self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None 388 ) -> JinjaMacroRegistry: 389 """Trims the registry by keeping only macros with given references and their transitive dependencies. 390 391 Args: 392 dependencies: References to macros that should be kept. 393 package: The name of the package in the context of which the trimming should be performed. 394 395 Returns: 396 A new trimmed registry. 397 """ 398 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 399 for dep in dependencies: 400 dependencies_by_package[dep.package or package].add(dep.name) 401 402 top_level_packages = self.top_level_packages.copy() 403 if package is not None: 404 top_level_packages.append(package) 405 406 result = JinjaMacroRegistry( 407 global_objs=self.global_objs.copy(), 408 create_builtins_module=self.create_builtins_module, 409 root_package_name=self.root_package_name, 410 top_level_packages=top_level_packages, 411 ) 412 for package, names in dependencies_by_package.items(): 413 result = result.merge(self._trim_macros(names, package)) 414 415 result._trimmed = True 416 417 return result 418 419 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 420 """Returns a copy of the registry which contains macros from both this and `other` instances. 421 422 Args: 423 other: The other registry instance. 424 425 Returns: 426 A new merged registry. 427 """ 428 429 root_macros = { 430 **self.root_macros, 431 **other.root_macros, 432 } 433 434 packages = {} 435 for package in {*self.packages, *other.packages}: 436 packages[package] = { 437 **self.packages.get(package, {}), 438 **other.packages.get(package, {}), 439 } 440 441 global_objs = { 442 **self.global_objs, 443 **other.global_objs, 444 } 445 446 return JinjaMacroRegistry( 447 packages=packages, 448 root_macros=root_macros, 449 global_objs=global_objs, 450 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 451 root_package_name=self.root_package_name or other.root_package_name, 452 top_level_packages=[*self.top_level_packages, *other.top_level_packages], 453 ) 454 455 def to_expressions(self) -> t.List[Expression]: 456 output: t.List[Expression] = [] 457 458 filtered_objs = { 459 k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars") 460 } 461 if filtered_objs: 462 output.append( 463 d.PythonCode( 464 expressions=[ 465 f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" 466 for k, v in sort_dict_recursive(filtered_objs).items() 467 ] 468 ) 469 ) 470 471 for macro_name, macro_info in sorted(self.root_macros.items()): 472 output.append(d.jinja_statement(macro_info.definition)) 473 474 for _, package in sorted(self.packages.items()): 475 for macro_name, macro_info in sorted(package.items()): 476 output.append(d.jinja_statement(macro_info.definition)) 477 478 return output 479 480 @property 481 def data_hash_values(self) -> t.List[str]: 482 data = [] 483 484 for macro_name, macro in sorted(self.root_macros.items()): 485 data.append(macro_name) 486 data.append(macro.definition) 487 488 for _, package in sorted(self.packages.items()): 489 for macro_name, macro in sorted(package.items()): 490 data.append(macro_name) 491 data.append(macro.definition) 492 493 trimmed_global_objs = { 494 k: self.global_objs[k] for k in ("refs", "sources", "vars") if k in self.global_objs 495 } 496 data.append(json.dumps(trimmed_global_objs, sort_keys=True)) 497 498 return data 499 500 def __deepcopy__(self, memo: t.Optional[t.Dict[int, t.Any]] = None) -> JinjaMacroRegistry: 501 return JinjaMacroRegistry.parse_obj(self.dict()) 502 503 def _parse_macro(self, name: str, package: t.Optional[str]) -> Template: 504 cache_key = (package, name) 505 if cache_key not in self._parser_cache: 506 macro = self._get_macro(name, package) 507 508 definition: nodes.Template = self._environment.parse(macro.definition) 509 if _is_private_macro(name): 510 # A workaround to expose private jinja macros. 511 definition = self._to_non_private_macro_def(name, definition) 512 513 self._parser_cache[cache_key] = self._environment.from_string(definition) 514 return self._parser_cache[cache_key] 515 516 @property 517 def _environment(self) -> Environment: 518 if self.__environment is None: 519 self.__environment = environment() 520 self.__environment.filters.update(self._create_builtin_filters()) 521 return self.__environment 522 523 def _trim_macros( 524 self, 525 names: t.Set[str], 526 package: t.Optional[str] = None, 527 visited: t.Optional[t.Dict[t.Optional[str], t.Set[str]]] = None, 528 ) -> JinjaMacroRegistry: 529 if visited is None: 530 visited = defaultdict(set) 531 532 macros = self.packages.get(package, {}) if package is not None else self.root_macros 533 trimmed_macros = {} 534 535 dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 536 537 for name in names: 538 if name in macros and name not in visited[package]: 539 macro = macros[name] 540 trimmed_macros[name] = macro 541 for dependency in macro.depends_on: 542 dependencies[dependency.package or package].add(dependency.name) 543 visited[package].add(name) 544 545 if package is not None: 546 result = JinjaMacroRegistry(packages={package: trimmed_macros}) 547 else: 548 result = JinjaMacroRegistry(root_macros=trimmed_macros) 549 550 for upstream_package, upstream_names in dependencies.items(): 551 result = result.merge( 552 self._trim_macros(upstream_names, upstream_package, visited=visited) 553 ) 554 555 return result 556 557 def _macro_exists(self, name: str, package: t.Optional[str]) -> bool: 558 return ( 559 name in self.packages.get(package, {}) 560 if package is not None 561 else name in self.root_macros 562 ) 563 564 def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo: 565 return self.packages[package][name] if package is not None else self.root_macros[name] 566 567 def _to_non_private_macro_def(self, name: str, template: nodes.Template) -> nodes.Template: 568 for node in template.find_all((nodes.Macro, nodes.Call)): 569 if isinstance(node, nodes.Macro): 570 node.name = _non_private_name(name) 571 elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name): 572 node.node.name = _non_private_name(name) 573 574 return template 575 576 def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 577 """Creates Jinja builtin globals using a factory function defined in the provided module.""" 578 engine_adapter = global_vars.pop("engine_adapter", None) 579 global_vars = {**self.global_objs, **global_vars} 580 if self.create_builtins_module is not None: 581 module = importlib.import_module(self.create_builtins_module) 582 if hasattr(module, "create_builtin_globals"): 583 return module.create_builtin_globals(self, global_vars, engine_adapter) 584 return global_vars 585 586 def _create_builtin_filters(self) -> t.Dict[str, t.Any]: 587 """Creates Jinja builtin filters using a factory function defined in the provided module.""" 588 if self.create_builtins_module is not None: 589 module = importlib.import_module(self.create_builtins_module) 590 if hasattr(module, "create_builtin_filters"): 591 return module.create_builtin_filters() 592 return {} 593 594 class _MacroWrapper: 595 def __init__( 596 self, 597 name: str, 598 package: t.Optional[str], 599 registry: JinjaMacroRegistry, 600 context: t.Dict[str, t.Any], 601 ): 602 self.name = name 603 self.package = package 604 self.context = context 605 self.registry = registry 606 607 def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: 608 context = self.context.copy() 609 if self.package is not None and self.package in context: 610 context.update(context[self.package]) 611 612 template = self.registry._parse_macro(self.name, self.package) 613 macro_callable = getattr( 614 template.make_module(vars=context), _non_private_name(self.name) 615 ) 616 try: 617 return macro_callable(*args, **kwargs) 618 except MacroReturnVal as ret: 619 return ret.value 620 621 622def _is_private_macro(name: str) -> bool: 623 return name.startswith("_") 624 625 626def _non_private_name(name: str) -> str: 627 return name.lstrip("_") 628 629 630JINJA_REGEX = re.compile(r"({{|{%)") 631 632 633def has_jinja(value: str) -> bool: 634 return JINJA_REGEX.search(value) is not None 635 636 637def jinja_call_arg_name(node: nodes.Node) -> str: 638 if isinstance(node, nodes.Const): 639 return node.value 640 return "" 641 642 643def create_var(variables: t.Dict[str, t.Any]) -> t.Callable: 644 def _var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 645 value = variables.get(var_name.lower(), default) 646 if isinstance(value, SqlValue): 647 return value.sql 648 return value 649 650 return _var 651 652 653def create_builtin_globals( 654 jinja_macros: JinjaMacroRegistry, global_vars: t.Dict[str, t.Any], *args: t.Any, **kwargs: t.Any 655) -> t.Dict[str, t.Any]: 656 global_vars.pop(c.GATEWAY, None) 657 variables = global_vars.pop(c.SQLMESH_VARS, None) or {} 658 blueprint_variables = global_vars.pop(c.SQLMESH_BLUEPRINT_VARS, None) or {} 659 return { 660 **global_vars, 661 c.VAR: create_var(variables), 662 c.GATEWAY: lambda: variables.get(c.GATEWAY, None), 663 c.BLUEPRINT_VAR: create_var(blueprint_variables), 664 } 665 666 667def make_jinja_registry( 668 jinja_macros: JinjaMacroRegistry, package_name: str, jinja_references: t.Set[MacroReference] 669) -> JinjaMacroRegistry: 670 """ 671 Creates a Jinja macro registry for a specific package. 672 673 This function takes an existing Jinja macro registry and returns a new 674 registry that includes only the macros associated with the specified 675 package and trims the registry to include only the macros referenced 676 in the provided set of macro references. 677 678 Args: 679 jinja_macros: The original Jinja macro registry containing all macros. 680 package_name: The name of the package for which to create the registry. 681 jinja_references: A set of macro references to retain in the new registry. 682 683 Returns: 684 A new JinjaMacroRegistry containing only the macros for the specified 685 package and the referenced macros. 686 """ 687 688 jinja_registry = jinja_macros.copy() 689 jinja_registry.root_macros = jinja_registry.packages.get(package_name) or {} 690 jinja_registry = jinja_registry.trim(jinja_references) 691 692 return jinja_registry 693 694 695def extract_error_details(ex: Exception) -> str: 696 """Extracts a readable message from a Jinja2 error, to include missing name and macro.""" 697 698 error_details = "" 699 if isinstance(ex, UndefinedError): 700 if match := re.search(r"'(\w+)'", str(ex)): 701 error_details += f"\nUndefined macro/variable: '{match.group(1)}'" 702 try: 703 _, _, exc_traceback = exc_info() 704 for frame, _ in walk_tb(exc_traceback): 705 if frame.f_code.co_name == "_invoke": 706 macro = frame.f_locals.get("self") 707 if isinstance(macro, Macro): 708 error_details += f" in macro: '{macro.name}'\n" 709 break 710 except: 711 # to fall back to the generic error message if frame analysis fails 712 pass 713 return error_details or str(ex)
41class MacroReference(PydanticModel, frozen=True): 42 package: t.Optional[str] = None 43 name: str 44 45 @property 46 def reference(self) -> str: 47 if self.package is None: 48 return self.name 49 return ".".join((self.package, self.name)) 50 51 def __str__(self) -> str: 52 return self.reference
!!! abstract "Usage Documentation" Models
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__[Signature][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-coreSchemaSerializerused to dump instances of the model. - __pydantic_validator__: The
pydantic-coreSchemaValidatorused to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
55class MacroInfo(PydanticModel): 56 """Class to hold macro and its calls""" 57 58 definition: str 59 depends_on: t.List[MacroReference] 60 is_top_level: bool = False
Class to hold macro and its calls
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
Common base class for all non-exit exceptions.
Inherited Members
- builtins.BaseException
- with_traceback
- args
68class MacroExtractor(Parser): 69 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 70 """Extract a dictionary of macro definitions from a jinja string. 71 72 Args: 73 jinja: The jinja string to extract from. 74 dialect: The dialect of SQL. 75 76 Returns: 77 A dictionary of macro name to macro definition. 78 """ 79 self.reset() 80 self.sql = jinja 81 self._tokens = Dialect.get_or_raise(dialect).tokenize(jinja) 82 self._index = -1 83 self._advance() 84 85 macros: t.Dict[str, MacroInfo] = {} 86 87 while self._curr: 88 if self._curr.token_type == TokenType.BLOCK_START: 89 macro_start = self._curr 90 elif self._tag == "MACRO" and self._next: 91 name = self._next.text 92 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 93 self._advance() 94 95 while self._curr and self._tag != "ENDMACRO": 96 self._advance() 97 98 macro_str = self._find_sql(macro_start, self._next) 99 macros[name] = MacroInfo( 100 definition=macro_str, 101 depends_on=list(extract_macro_references_and_variables(macro_str)[0]), 102 ) 103 104 self._advance() 105 106 return macros 107 108 def _advance(self, times: int = 1) -> None: 109 super()._advance(times) 110 self._tag = ( 111 self._curr.text.upper() 112 if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START 113 else "" 114 )
Parser consumes a list of tokens produced by the Tokenizer and produces a parsed syntax tree.
Arguments:
- error_level: The desired error level. Default: ErrorLevel.IMMEDIATE
- error_message_context: The amount of context to capture from a query string when displaying the error message (in number of characters). Default: 100
- max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. Default: 3
69 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 70 """Extract a dictionary of macro definitions from a jinja string. 71 72 Args: 73 jinja: The jinja string to extract from. 74 dialect: The dialect of SQL. 75 76 Returns: 77 A dictionary of macro name to macro definition. 78 """ 79 self.reset() 80 self.sql = jinja 81 self._tokens = Dialect.get_or_raise(dialect).tokenize(jinja) 82 self._index = -1 83 self._advance() 84 85 macros: t.Dict[str, MacroInfo] = {} 86 87 while self._curr: 88 if self._curr.token_type == TokenType.BLOCK_START: 89 macro_start = self._curr 90 elif self._tag == "MACRO" and self._next: 91 name = self._next.text 92 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 93 self._advance() 94 95 while self._curr and self._tag != "ENDMACRO": 96 self._advance() 97 98 macro_str = self._find_sql(macro_start, self._next) 99 macros[name] = MacroInfo( 100 definition=macro_str, 101 depends_on=list(extract_macro_references_and_variables(macro_str)[0]), 102 ) 103 104 self._advance() 105 106 return macros
Extract a dictionary of macro definitions from a jinja string.
Arguments:
- jinja: The jinja string to extract from.
- dialect: The dialect of SQL.
Returns:
A dictionary of macro name to macro definition.
Inherited Members
- sqlglot.parser.Parser
- Parser
- FUNCTIONS
- NO_PAREN_FUNCTIONS
- STRUCT_TYPE_TOKENS
- NESTED_TYPE_TOKENS
- ENUM_TYPE_TOKENS
- AGGREGATE_TYPE_TOKENS
- TYPE_TOKENS
- SIGNED_TO_UNSIGNED_TYPE_TOKEN
- SUBQUERY_PREDICATES
- RESERVED_TOKENS
- DB_CREATABLES
- CREATABLES
- ALTERABLES
- ID_VAR_TOKENS
- TABLE_ALIAS_TOKENS
- ALIAS_TOKENS
- COLON_PLACEHOLDER_TOKENS
- ARRAY_CONSTRUCTORS
- COMMENT_TABLE_ALIAS_TOKENS
- UPDATE_ALIAS_TOKENS
- TRIM_TYPES
- FUNC_TOKENS
- CONJUNCTION
- ASSIGNMENT
- DISJUNCTION
- EQUALITY
- COMPARISON
- BITWISE
- TERM
- FACTOR
- EXPONENT
- TIMES
- TIMESTAMPS
- SET_OPERATIONS
- JOIN_METHODS
- JOIN_SIDES
- JOIN_KINDS
- JOIN_HINTS
- LAMBDAS
- COLUMN_OPERATORS
- CAST_COLUMN_OPERATORS
- EXPRESSION_PARSERS
- STATEMENT_PARSERS
- UNARY_PARSERS
- STRING_PARSERS
- NUMERIC_PARSERS
- PRIMARY_PARSERS
- PLACEHOLDER_PARSERS
- RANGE_PARSERS
- PIPE_SYNTAX_TRANSFORM_PARSERS
- PROPERTY_PARSERS
- CONSTRAINT_PARSERS
- ALTER_PARSERS
- ALTER_ALTER_PARSERS
- SCHEMA_UNNAMED_CONSTRAINTS
- NO_PAREN_FUNCTION_PARSERS
- INVALID_FUNC_NAME_TOKENS
- FUNCTIONS_WITH_ALIASED_ARGS
- KEY_VALUE_DEFINITIONS
- FUNCTION_PARSERS
- QUERY_MODIFIER_PARSERS
- QUERY_MODIFIER_TOKENS
- SET_PARSERS
- SHOW_PARSERS
- TYPE_LITERAL_PARSERS
- TYPE_CONVERTERS
- DDL_SELECT_TOKENS
- PRE_VOLATILE_TOKENS
- TRANSACTION_KIND
- TRANSACTION_CHARACTERISTICS
- CONFLICT_ACTIONS
- CREATE_SEQUENCE
- ISOLATED_LOADING_OPTIONS
- USABLES
- CAST_ACTIONS
- SCHEMA_BINDING_OPTIONS
- PROCEDURE_OPTIONS
- EXECUTE_AS_OPTIONS
- KEY_CONSTRAINT_OPTIONS
- WINDOW_EXCLUDE_OPTIONS
- INSERT_ALTERNATIVES
- CLONE_KEYWORDS
- HISTORICAL_DATA_PREFIX
- HISTORICAL_DATA_KIND
- OPCLASS_FOLLOW_KEYWORDS
- OPTYPE_FOLLOW_TOKENS
- TABLE_INDEX_HINT_TOKENS
- VIEW_ATTRIBUTES
- WINDOW_ALIAS_TOKENS
- WINDOW_BEFORE_PAREN_TOKENS
- WINDOW_SIDES
- JSON_KEY_VALUE_SEPARATOR_TOKENS
- FETCH_TOKENS
- ADD_CONSTRAINT_TOKENS
- DISTINCT_TOKENS
- UNNEST_OFFSET_ALIAS_TOKENS
- SELECT_START_TOKENS
- COPY_INTO_VARLEN_OPTIONS
- IS_JSON_PREDICATE_KIND
- ODBC_DATETIME_LITERALS
- ON_CONDITION_TOKENS
- PRIVILEGE_FOLLOW_TOKENS
- DESCRIBE_STYLES
- SET_ASSIGNMENT_DELIMITERS
- ANALYZE_STYLES
- ANALYZE_EXPRESSION_PARSERS
- PARTITION_KEYWORDS
- AMBIGUOUS_ALIAS_TOKENS
- OPERATION_MODIFIERS
- RECURSIVE_CTE_SEARCH_KIND
- MODIFIABLES
- STRICT_CAST
- PREFIXED_PIVOT_COLUMNS
- IDENTIFY_PIVOT_STRINGS
- LOG_DEFAULTS_TO_LN
- TABLESAMPLE_CSV
- DEFAULT_SAMPLING_METHOD
- SET_REQUIRES_ASSIGNMENT_DELIMITER
- TRIM_PATTERN_FIRST
- STRING_ALIASES
- MODIFIERS_ATTACHED_TO_SET_OP
- SET_OP_MODIFIERS
- NO_PAREN_IF_COMMANDS
- JSON_ARROWS_REQUIRE_JSON_TYPE
- COLON_IS_VARIANT_EXTRACT
- VALUES_FOLLOWED_BY_PAREN
- SUPPORTS_IMPLICIT_UNNEST
- INTERVAL_SPANS
- SUPPORTS_PARTITION_SELECTION
- WRAPPED_TRANSFORM_COLUMN_CONSTRAINT
- OPTIONAL_ALIAS_TOKEN_CTE
- ALTER_RENAME_REQUIRES_COLUMN
- ALTER_TABLE_PARTITIONS
- JOINS_HAVE_EQUAL_PRECEDENCE
- ZONE_AWARE_TIMESTAMP_CONSTRUCTOR
- MAP_KEYS_ARE_ARBITRARY_EXPRESSIONS
- JSON_EXTRACT_REQUIRES_JSON_EXPRESSION
- ADD_JOIN_ON_TRUE
- SUPPORTS_OMITTED_INTERVAL_SPAN_UNIT
- error_level
- error_message_context
- max_errors
- dialect
- reset
- parse
- parse_into
- check_errors
- raise_error
- expression
- validate_expression
- parse_set_operation
- build_cast
- errors
- sql
117def call_name(node: nodes.Expr) -> t.Tuple[str, ...]: 118 if isinstance(node, nodes.Name): 119 return (node.name,) 120 if isinstance(node, nodes.Const): 121 return (f"'{node.value}'",) 122 if isinstance(node, nodes.Getattr): 123 return call_name(node.node) + (node.attr,) 124 if isinstance(node, (nodes.Getitem, nodes.Call)): 125 return call_name(node.node) 126 return ()
133def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[CallNames]: 134 vars_in_scope = vars_in_scope.copy() 135 for child_node in node.iter_child_nodes(): 136 if "target" in child_node.fields: 137 # For nodes with assignment targets (Assign, AssignBlock, For, Import), 138 # the target name could shadow a reference in the right hand side. 139 # So we need to process the RHS before adding the target to scope. 140 # For example: {% set model = model.path %} should track model.path. 141 yield from find_call_names(child_node, vars_in_scope) 142 143 target = getattr(child_node, "target") 144 if isinstance(target, nodes.Name): 145 vars_in_scope.add(target.name) 146 elif isinstance(target, nodes.Tuple): 147 for item in target.items: 148 if isinstance(item, nodes.Name): 149 vars_in_scope.add(item.name) 150 elif isinstance(child_node, nodes.Macro): 151 for arg in child_node.args: 152 vars_in_scope.add(arg.name) 153 elif isinstance(child_node, nodes.Call) or ( 154 isinstance(child_node, nodes.Getattr) and not isinstance(child_node.node, nodes.Getattr) 155 ): 156 name = call_name(child_node) 157 if name[0][0] != "'" and name[0] not in vars_in_scope: 158 yield (name, child_node) 159 160 if "target" not in child_node.fields: 161 yield from find_call_names(child_node, vars_in_scope)
164def extract_call_names( 165 jinja_str: str, cache: t.Optional[t.Dict[str, t.Tuple[t.List[CallNames], bool]]] = None 166) -> t.List[CallNames]: 167 def parse() -> t.List[CallNames]: 168 return list(find_call_names(ENVIRONMENT.parse(jinja_str), set())) 169 170 if cache is not None: 171 key = str(zlib.crc32(jinja_str.encode("utf-8"))) 172 if key in cache: 173 names = cache[key][0] 174 else: 175 names = parse() 176 cache[key] = (names, True) 177 return names 178 return parse()
189def extract_macro_references_and_variables( 190 *jinja_strs: str, 191) -> t.Tuple[t.Set[MacroReference], t.Set[str]]: 192 macro_references = set() 193 variables = set() 194 for jinja_str in jinja_strs: 195 for call_name, node in extract_call_names(jinja_str): 196 if call_name[0] in (c.VAR, c.BLUEPRINT_VAR): 197 if not is_variable_node(node): 198 # Find the variable node which could be nested 199 for n in node.find_all(nodes.Call): 200 if is_variable_node(n): 201 node = n 202 break 203 else: 204 raise ValueError(f"Could not find variable name in {jinja_str}") 205 node = t.cast(nodes.Call, node) 206 args = [jinja_call_arg_name(arg) for arg in node.args] 207 if args and args[0]: 208 variables.add(args[0].lower()) 209 elif call_name[0] == c.GATEWAY: 210 variables.add(c.GATEWAY) 211 elif len(call_name) == 1: 212 macro_references.add(MacroReference(name=call_name[0])) 213 elif len(call_name) == 2: 214 macro_references.add(MacroReference(package=call_name[0], name=call_name[1])) 215 return macro_references, variables
218def sort_dict_recursive( 219 item: t.Dict[str, t.Any], 220) -> t.Dict[str, t.Any]: 221 sorted_dict: t.Dict[str, t.Any] = {} 222 for k, v in sorted(item.items()): 223 if isinstance(v, list): 224 sorted_dict[k] = sorted(v) 225 elif isinstance(v, dict): 226 sorted_dict[k] = sort_dict_recursive(v) 227 else: 228 sorted_dict[k] = v 229 return sorted_dict
235class JinjaMacroRegistry(PydanticModel): 236 """Registry for Jinja macros. 237 238 Args: 239 packages: The mapping from package name to a collection of macro definitions. 240 root_macros: The collection of top-level macro definitions. 241 global_objs: The global objects. 242 create_builtins_module: The name of a module which defines the `create_builtins` factory 243 function that will be used to construct builtin variables and functions. 244 root_package_name: The name of the root package. If specified root macros will be available 245 as both `root_package_name.macro_name` and `macro_name`. 246 top_level_packages: The list of top-level packages. Macros in this packages will be available 247 as both `package_name.macro_name` and `macro_name`. 248 """ 249 250 packages: t.Dict[str, t.Dict[str, MacroInfo]] = {} 251 root_macros: t.Dict[str, MacroInfo] = {} 252 global_objs: t.Dict[str, JinjaGlobalAttribute] = {} 253 create_builtins_module: t.Optional[str] = SQLMESH_JINJA_PACKAGE 254 root_package_name: t.Optional[str] = None 255 top_level_packages: t.List[str] = [] 256 257 _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {} 258 _trimmed: bool = False 259 __environment: t.Optional[Environment] = None 260 261 def __getstate__(self) -> t.Dict[t.Any, t.Any]: 262 state = super().__getstate__() 263 private = state[PRIVATE_FIELDS] 264 private["_parser_cache"] = {} 265 private["_JinjaMacroRegistry__environment"] = None 266 return state 267 268 @field_validator("global_objs", mode="before") 269 @classmethod 270 def _validate_global_objs(cls, value: t.Any) -> t.Any: 271 def _normalize(val: t.Any) -> t.Any: 272 if isinstance(val, dict): 273 return AttributeDict({k: _normalize(v) for k, v in val.items()}) 274 if isinstance(val, list): 275 return [_normalize(v) for v in val] 276 if isinstance(val, set): 277 return [_normalize(v) for v in sorted(val)] 278 if isinstance(val, Enum): 279 return val.value 280 return val 281 282 return _normalize(value) 283 284 @field_serializer("global_objs") 285 def _serialize_attribute_dict( 286 self, value: t.Dict[str, JinjaGlobalAttribute] 287 ) -> t.Dict[str, t.Any]: 288 # NOTE: This is called only when used with Pydantic V2. 289 def _convert( 290 val: t.Union[t.Dict[str, JinjaGlobalAttribute], t.Dict[str, t.Any]], 291 ) -> t.Dict[str, t.Any]: 292 return {k: _convert(v) if isinstance(v, AttributeDict) else v for k, v in val.items()} 293 294 return _convert(value) 295 296 @property 297 def trimmed(self) -> bool: 298 return self._trimmed 299 300 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 301 """Adds macros to the target package. 302 303 Args: 304 macros: Macros that should be added. 305 package: The name of the package the given macros belong to. If not specified, the provided 306 macros will be added to the root namespace. 307 """ 308 309 if package is not None: 310 package_macros = self.packages.get(package, {}) 311 package_macros.update(macros) 312 self.packages[package] = package_macros 313 else: 314 self.root_macros.update(macros) 315 316 def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None: 317 """Adds global objects to the registry. 318 319 Args: 320 globals: The global objects that should be added. 321 """ 322 # Keep the registry lightweight when the graph is not needed 323 if not "graph" in self.packages: 324 globals.pop("flat_graph", None) 325 self.global_objs.update(**self._validate_global_objs(globals)) 326 327 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 328 """Builds a Python callable for a macro with the given reference. 329 330 Args: 331 reference: The macro reference. 332 Returns: 333 The macro as a Python callable or None if not found. 334 """ 335 env: Environment = self.build_environment(**kwargs) 336 if reference.package is not None: 337 package = env.globals.get(reference.package, {}) 338 return package.get(reference.name) # type: ignore 339 return env.globals.get(reference.name) # type: ignore 340 341 def build_environment(self, **kwargs: t.Any) -> Environment: 342 """Builds a new Jinja environment based on this registry.""" 343 344 context: t.Dict[str, t.Any] = {} 345 346 root_macros = { 347 name: self._MacroWrapper(name, None, self, context) 348 for name, macro in self.root_macros.items() 349 } 350 351 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 352 for package_name, macros in self.packages.items(): 353 for macro_name, macro in macros.items(): 354 macro_wrapper = self._MacroWrapper(macro_name, package_name, self, context) 355 package_macros[package_name][macro_name] = macro_wrapper 356 if macro.is_top_level and macro_name not in root_macros: 357 root_macros[macro_name] = macro_wrapper 358 359 top_level_packages = self.top_level_packages.copy() 360 361 if self.root_package_name is not None: 362 package_macros[self.root_package_name].update(root_macros) 363 top_level_packages.append(self.root_package_name) 364 365 env = environment() 366 367 builtin_globals = self._create_builtin_globals(kwargs) 368 for top_level_package_name in top_level_packages: 369 # Make sure that the top-level package doesn't fully override the same builtin package. 370 package_macros[top_level_package_name] = AttributeDict( 371 { 372 **(builtin_globals.pop(top_level_package_name, None) or {}), 373 **(package_macros.get(top_level_package_name) or {}), 374 } 375 ) 376 root_macros.update(package_macros[top_level_package_name]) 377 378 context.update(builtin_globals) 379 context.update(root_macros) 380 context.update(package_macros) 381 context["render"] = lambda input: env.from_string(input).render() 382 383 env.globals.update(context) 384 env.filters.update(self._environment.filters) 385 return env 386 387 def trim( 388 self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None 389 ) -> JinjaMacroRegistry: 390 """Trims the registry by keeping only macros with given references and their transitive dependencies. 391 392 Args: 393 dependencies: References to macros that should be kept. 394 package: The name of the package in the context of which the trimming should be performed. 395 396 Returns: 397 A new trimmed registry. 398 """ 399 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 400 for dep in dependencies: 401 dependencies_by_package[dep.package or package].add(dep.name) 402 403 top_level_packages = self.top_level_packages.copy() 404 if package is not None: 405 top_level_packages.append(package) 406 407 result = JinjaMacroRegistry( 408 global_objs=self.global_objs.copy(), 409 create_builtins_module=self.create_builtins_module, 410 root_package_name=self.root_package_name, 411 top_level_packages=top_level_packages, 412 ) 413 for package, names in dependencies_by_package.items(): 414 result = result.merge(self._trim_macros(names, package)) 415 416 result._trimmed = True 417 418 return result 419 420 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 421 """Returns a copy of the registry which contains macros from both this and `other` instances. 422 423 Args: 424 other: The other registry instance. 425 426 Returns: 427 A new merged registry. 428 """ 429 430 root_macros = { 431 **self.root_macros, 432 **other.root_macros, 433 } 434 435 packages = {} 436 for package in {*self.packages, *other.packages}: 437 packages[package] = { 438 **self.packages.get(package, {}), 439 **other.packages.get(package, {}), 440 } 441 442 global_objs = { 443 **self.global_objs, 444 **other.global_objs, 445 } 446 447 return JinjaMacroRegistry( 448 packages=packages, 449 root_macros=root_macros, 450 global_objs=global_objs, 451 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 452 root_package_name=self.root_package_name or other.root_package_name, 453 top_level_packages=[*self.top_level_packages, *other.top_level_packages], 454 ) 455 456 def to_expressions(self) -> t.List[Expression]: 457 output: t.List[Expression] = [] 458 459 filtered_objs = { 460 k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars") 461 } 462 if filtered_objs: 463 output.append( 464 d.PythonCode( 465 expressions=[ 466 f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" 467 for k, v in sort_dict_recursive(filtered_objs).items() 468 ] 469 ) 470 ) 471 472 for macro_name, macro_info in sorted(self.root_macros.items()): 473 output.append(d.jinja_statement(macro_info.definition)) 474 475 for _, package in sorted(self.packages.items()): 476 for macro_name, macro_info in sorted(package.items()): 477 output.append(d.jinja_statement(macro_info.definition)) 478 479 return output 480 481 @property 482 def data_hash_values(self) -> t.List[str]: 483 data = [] 484 485 for macro_name, macro in sorted(self.root_macros.items()): 486 data.append(macro_name) 487 data.append(macro.definition) 488 489 for _, package in sorted(self.packages.items()): 490 for macro_name, macro in sorted(package.items()): 491 data.append(macro_name) 492 data.append(macro.definition) 493 494 trimmed_global_objs = { 495 k: self.global_objs[k] for k in ("refs", "sources", "vars") if k in self.global_objs 496 } 497 data.append(json.dumps(trimmed_global_objs, sort_keys=True)) 498 499 return data 500 501 def __deepcopy__(self, memo: t.Optional[t.Dict[int, t.Any]] = None) -> JinjaMacroRegistry: 502 return JinjaMacroRegistry.parse_obj(self.dict()) 503 504 def _parse_macro(self, name: str, package: t.Optional[str]) -> Template: 505 cache_key = (package, name) 506 if cache_key not in self._parser_cache: 507 macro = self._get_macro(name, package) 508 509 definition: nodes.Template = self._environment.parse(macro.definition) 510 if _is_private_macro(name): 511 # A workaround to expose private jinja macros. 512 definition = self._to_non_private_macro_def(name, definition) 513 514 self._parser_cache[cache_key] = self._environment.from_string(definition) 515 return self._parser_cache[cache_key] 516 517 @property 518 def _environment(self) -> Environment: 519 if self.__environment is None: 520 self.__environment = environment() 521 self.__environment.filters.update(self._create_builtin_filters()) 522 return self.__environment 523 524 def _trim_macros( 525 self, 526 names: t.Set[str], 527 package: t.Optional[str] = None, 528 visited: t.Optional[t.Dict[t.Optional[str], t.Set[str]]] = None, 529 ) -> JinjaMacroRegistry: 530 if visited is None: 531 visited = defaultdict(set) 532 533 macros = self.packages.get(package, {}) if package is not None else self.root_macros 534 trimmed_macros = {} 535 536 dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 537 538 for name in names: 539 if name in macros and name not in visited[package]: 540 macro = macros[name] 541 trimmed_macros[name] = macro 542 for dependency in macro.depends_on: 543 dependencies[dependency.package or package].add(dependency.name) 544 visited[package].add(name) 545 546 if package is not None: 547 result = JinjaMacroRegistry(packages={package: trimmed_macros}) 548 else: 549 result = JinjaMacroRegistry(root_macros=trimmed_macros) 550 551 for upstream_package, upstream_names in dependencies.items(): 552 result = result.merge( 553 self._trim_macros(upstream_names, upstream_package, visited=visited) 554 ) 555 556 return result 557 558 def _macro_exists(self, name: str, package: t.Optional[str]) -> bool: 559 return ( 560 name in self.packages.get(package, {}) 561 if package is not None 562 else name in self.root_macros 563 ) 564 565 def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo: 566 return self.packages[package][name] if package is not None else self.root_macros[name] 567 568 def _to_non_private_macro_def(self, name: str, template: nodes.Template) -> nodes.Template: 569 for node in template.find_all((nodes.Macro, nodes.Call)): 570 if isinstance(node, nodes.Macro): 571 node.name = _non_private_name(name) 572 elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name): 573 node.node.name = _non_private_name(name) 574 575 return template 576 577 def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 578 """Creates Jinja builtin globals using a factory function defined in the provided module.""" 579 engine_adapter = global_vars.pop("engine_adapter", None) 580 global_vars = {**self.global_objs, **global_vars} 581 if self.create_builtins_module is not None: 582 module = importlib.import_module(self.create_builtins_module) 583 if hasattr(module, "create_builtin_globals"): 584 return module.create_builtin_globals(self, global_vars, engine_adapter) 585 return global_vars 586 587 def _create_builtin_filters(self) -> t.Dict[str, t.Any]: 588 """Creates Jinja builtin filters using a factory function defined in the provided module.""" 589 if self.create_builtins_module is not None: 590 module = importlib.import_module(self.create_builtins_module) 591 if hasattr(module, "create_builtin_filters"): 592 return module.create_builtin_filters() 593 return {} 594 595 class _MacroWrapper: 596 def __init__( 597 self, 598 name: str, 599 package: t.Optional[str], 600 registry: JinjaMacroRegistry, 601 context: t.Dict[str, t.Any], 602 ): 603 self.name = name 604 self.package = package 605 self.context = context 606 self.registry = registry 607 608 def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: 609 context = self.context.copy() 610 if self.package is not None and self.package in context: 611 context.update(context[self.package]) 612 613 template = self.registry._parse_macro(self.name, self.package) 614 macro_callable = getattr( 615 template.make_module(vars=context), _non_private_name(self.name) 616 ) 617 try: 618 return macro_callable(*args, **kwargs) 619 except MacroReturnVal as ret: 620 return ret.value
Registry for Jinja macros.
Arguments:
- packages: The mapping from package name to a collection of macro definitions.
- root_macros: The collection of top-level macro definitions.
- global_objs: The global objects.
- create_builtins_module: The name of a module which defines the
create_builtinsfactory function that will be used to construct builtin variables and functions. - root_package_name: The name of the root package. If specified root macros will be available
as both
root_package_name.macro_nameandmacro_name. - top_level_packages: The list of top-level packages. Macros in this packages will be available
as both
package_name.macro_nameandmacro_name.
300 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 301 """Adds macros to the target package. 302 303 Args: 304 macros: Macros that should be added. 305 package: The name of the package the given macros belong to. If not specified, the provided 306 macros will be added to the root namespace. 307 """ 308 309 if package is not None: 310 package_macros = self.packages.get(package, {}) 311 package_macros.update(macros) 312 self.packages[package] = package_macros 313 else: 314 self.root_macros.update(macros)
Adds macros to the target package.
Arguments:
- macros: Macros that should be added.
- package: The name of the package the given macros belong to. If not specified, the provided
- macros will be added to the root namespace.
316 def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None: 317 """Adds global objects to the registry. 318 319 Args: 320 globals: The global objects that should be added. 321 """ 322 # Keep the registry lightweight when the graph is not needed 323 if not "graph" in self.packages: 324 globals.pop("flat_graph", None) 325 self.global_objs.update(**self._validate_global_objs(globals))
Adds global objects to the registry.
Arguments:
- globals: The global objects that should be added.
327 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 328 """Builds a Python callable for a macro with the given reference. 329 330 Args: 331 reference: The macro reference. 332 Returns: 333 The macro as a Python callable or None if not found. 334 """ 335 env: Environment = self.build_environment(**kwargs) 336 if reference.package is not None: 337 package = env.globals.get(reference.package, {}) 338 return package.get(reference.name) # type: ignore 339 return env.globals.get(reference.name) # type: ignore
Builds a Python callable for a macro with the given reference.
Arguments:
- reference: The macro reference.
Returns:
The macro as a Python callable or None if not found.
341 def build_environment(self, **kwargs: t.Any) -> Environment: 342 """Builds a new Jinja environment based on this registry.""" 343 344 context: t.Dict[str, t.Any] = {} 345 346 root_macros = { 347 name: self._MacroWrapper(name, None, self, context) 348 for name, macro in self.root_macros.items() 349 } 350 351 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 352 for package_name, macros in self.packages.items(): 353 for macro_name, macro in macros.items(): 354 macro_wrapper = self._MacroWrapper(macro_name, package_name, self, context) 355 package_macros[package_name][macro_name] = macro_wrapper 356 if macro.is_top_level and macro_name not in root_macros: 357 root_macros[macro_name] = macro_wrapper 358 359 top_level_packages = self.top_level_packages.copy() 360 361 if self.root_package_name is not None: 362 package_macros[self.root_package_name].update(root_macros) 363 top_level_packages.append(self.root_package_name) 364 365 env = environment() 366 367 builtin_globals = self._create_builtin_globals(kwargs) 368 for top_level_package_name in top_level_packages: 369 # Make sure that the top-level package doesn't fully override the same builtin package. 370 package_macros[top_level_package_name] = AttributeDict( 371 { 372 **(builtin_globals.pop(top_level_package_name, None) or {}), 373 **(package_macros.get(top_level_package_name) or {}), 374 } 375 ) 376 root_macros.update(package_macros[top_level_package_name]) 377 378 context.update(builtin_globals) 379 context.update(root_macros) 380 context.update(package_macros) 381 context["render"] = lambda input: env.from_string(input).render() 382 383 env.globals.update(context) 384 env.filters.update(self._environment.filters) 385 return env
Builds a new Jinja environment based on this registry.
387 def trim( 388 self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None 389 ) -> JinjaMacroRegistry: 390 """Trims the registry by keeping only macros with given references and their transitive dependencies. 391 392 Args: 393 dependencies: References to macros that should be kept. 394 package: The name of the package in the context of which the trimming should be performed. 395 396 Returns: 397 A new trimmed registry. 398 """ 399 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 400 for dep in dependencies: 401 dependencies_by_package[dep.package or package].add(dep.name) 402 403 top_level_packages = self.top_level_packages.copy() 404 if package is not None: 405 top_level_packages.append(package) 406 407 result = JinjaMacroRegistry( 408 global_objs=self.global_objs.copy(), 409 create_builtins_module=self.create_builtins_module, 410 root_package_name=self.root_package_name, 411 top_level_packages=top_level_packages, 412 ) 413 for package, names in dependencies_by_package.items(): 414 result = result.merge(self._trim_macros(names, package)) 415 416 result._trimmed = True 417 418 return result
Trims the registry by keeping only macros with given references and their transitive dependencies.
Arguments:
- dependencies: References to macros that should be kept.
- package: The name of the package in the context of which the trimming should be performed.
Returns:
A new trimmed registry.
420 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 421 """Returns a copy of the registry which contains macros from both this and `other` instances. 422 423 Args: 424 other: The other registry instance. 425 426 Returns: 427 A new merged registry. 428 """ 429 430 root_macros = { 431 **self.root_macros, 432 **other.root_macros, 433 } 434 435 packages = {} 436 for package in {*self.packages, *other.packages}: 437 packages[package] = { 438 **self.packages.get(package, {}), 439 **other.packages.get(package, {}), 440 } 441 442 global_objs = { 443 **self.global_objs, 444 **other.global_objs, 445 } 446 447 return JinjaMacroRegistry( 448 packages=packages, 449 root_macros=root_macros, 450 global_objs=global_objs, 451 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 452 root_package_name=self.root_package_name or other.root_package_name, 453 top_level_packages=[*self.top_level_packages, *other.top_level_packages], 454 )
Returns a copy of the registry which contains macros from both this and other instances.
Arguments:
- other: The other registry instance.
Returns:
A new merged registry.
456 def to_expressions(self) -> t.List[Expression]: 457 output: t.List[Expression] = [] 458 459 filtered_objs = { 460 k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars") 461 } 462 if filtered_objs: 463 output.append( 464 d.PythonCode( 465 expressions=[ 466 f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" 467 for k, v in sort_dict_recursive(filtered_objs).items() 468 ] 469 ) 470 ) 471 472 for macro_name, macro_info in sorted(self.root_macros.items()): 473 output.append(d.jinja_statement(macro_info.definition)) 474 475 for _, package in sorted(self.packages.items()): 476 for macro_name, macro_info in sorted(package.items()): 477 output.append(d.jinja_statement(macro_info.definition)) 478 479 return output
481 @property 482 def data_hash_values(self) -> t.List[str]: 483 data = [] 484 485 for macro_name, macro in sorted(self.root_macros.items()): 486 data.append(macro_name) 487 data.append(macro.definition) 488 489 for _, package in sorted(self.packages.items()): 490 for macro_name, macro in sorted(package.items()): 491 data.append(macro_name) 492 data.append(macro.definition) 493 494 trimmed_global_objs = { 495 k: self.global_objs[k] for k in ("refs", "sources", "vars") if k in self.global_objs 496 } 497 data.append(json.dumps(trimmed_global_objs, sort_keys=True)) 498 499 return data
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
358def init_private_attributes(self: BaseModel, context: Any, /) -> None: 359 """This function is meant to behave like a BaseModel method to initialise private attributes. 360 361 It takes context as an argument since that's what pydantic-core passes when calling it. 362 363 Args: 364 self: The BaseModel instance. 365 context: The context. 366 """ 367 if getattr(self, '__pydantic_private__', None) is None: 368 pydantic_private = {} 369 for name, private_attr in self.__private_attributes__.items(): 370 default = private_attr.get_default() 371 if default is not PydanticUndefined: 372 pydantic_private[name] = default 373 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Arguments:
- self: The BaseModel instance.
- context: The context.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
654def create_builtin_globals( 655 jinja_macros: JinjaMacroRegistry, global_vars: t.Dict[str, t.Any], *args: t.Any, **kwargs: t.Any 656) -> t.Dict[str, t.Any]: 657 global_vars.pop(c.GATEWAY, None) 658 variables = global_vars.pop(c.SQLMESH_VARS, None) or {} 659 blueprint_variables = global_vars.pop(c.SQLMESH_BLUEPRINT_VARS, None) or {} 660 return { 661 **global_vars, 662 c.VAR: create_var(variables), 663 c.GATEWAY: lambda: variables.get(c.GATEWAY, None), 664 c.BLUEPRINT_VAR: create_var(blueprint_variables), 665 }
668def make_jinja_registry( 669 jinja_macros: JinjaMacroRegistry, package_name: str, jinja_references: t.Set[MacroReference] 670) -> JinjaMacroRegistry: 671 """ 672 Creates a Jinja macro registry for a specific package. 673 674 This function takes an existing Jinja macro registry and returns a new 675 registry that includes only the macros associated with the specified 676 package and trims the registry to include only the macros referenced 677 in the provided set of macro references. 678 679 Args: 680 jinja_macros: The original Jinja macro registry containing all macros. 681 package_name: The name of the package for which to create the registry. 682 jinja_references: A set of macro references to retain in the new registry. 683 684 Returns: 685 A new JinjaMacroRegistry containing only the macros for the specified 686 package and the referenced macros. 687 """ 688 689 jinja_registry = jinja_macros.copy() 690 jinja_registry.root_macros = jinja_registry.packages.get(package_name) or {} 691 jinja_registry = jinja_registry.trim(jinja_references) 692 693 return jinja_registry
Creates a Jinja macro registry for a specific package.
This function takes an existing Jinja macro registry and returns a new registry that includes only the macros associated with the specified package and trims the registry to include only the macros referenced in the provided set of macro references.
Arguments:
- jinja_macros: The original Jinja macro registry containing all macros.
- package_name: The name of the package for which to create the registry.
- jinja_references: A set of macro references to retain in the new registry.
Returns:
A new JinjaMacroRegistry containing only the macros for the specified package and the referenced macros.
696def extract_error_details(ex: Exception) -> str: 697 """Extracts a readable message from a Jinja2 error, to include missing name and macro.""" 698 699 error_details = "" 700 if isinstance(ex, UndefinedError): 701 if match := re.search(r"'(\w+)'", str(ex)): 702 error_details += f"\nUndefined macro/variable: '{match.group(1)}'" 703 try: 704 _, _, exc_traceback = exc_info() 705 for frame, _ in walk_tb(exc_traceback): 706 if frame.f_code.co_name == "_invoke": 707 macro = frame.f_locals.get("self") 708 if isinstance(macro, Macro): 709 error_details += f" in macro: '{macro.name}'\n" 710 break 711 except: 712 # to fall back to the generic error message if frame analysis fails 713 pass 714 return error_details or str(ex)
Extracts a readable message from a Jinja2 error, to include missing name and macro.