sqlmesh.utils.jinja
1from __future__ import annotations 2 3import importlib 4import json 5import re 6import typing as t 7from collections import defaultdict 8from enum import Enum 9 10from jinja2 import Environment, Template, nodes 11from sqlglot import Dialect, Expression, Parser, TokenType 12 13from sqlmesh.core import constants as c 14from sqlmesh.core import dialect as d 15from sqlmesh.utils import AttributeDict 16from sqlmesh.utils.pydantic import PydanticModel, field_serializer, field_validator 17 18SQLMESH_JINJA_PACKAGE = "sqlmesh.utils.jinja" 19 20 21def environment(**kwargs: t.Any) -> Environment: 22 extensions = kwargs.pop("extensions", []) 23 extensions.append("jinja2.ext.do") 24 extensions.append("jinja2.ext.loopcontrols") 25 return Environment(extensions=extensions, **kwargs) 26 27 28ENVIRONMENT = environment() 29 30 31class MacroReference(PydanticModel, frozen=True): 32 package: t.Optional[str] = None 33 name: str 34 35 @property 36 def reference(self) -> str: 37 if self.package is None: 38 return self.name 39 return ".".join((self.package, self.name)) 40 41 def __str__(self) -> str: 42 return self.reference 43 44 45class MacroInfo(PydanticModel): 46 """Class to hold macro and its calls""" 47 48 definition: str 49 depends_on: t.List[MacroReference] 50 51 52class MacroReturnVal(Exception): 53 def __init__(self, val: t.Any): 54 self.value = val 55 56 57class MacroExtractor(Parser): 58 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 59 """Extract a dictionary of macro definitions from a jinja string. 60 61 Args: 62 jinja: The jinja string to extract from. 63 dialect: The dialect of SQL. 64 65 Returns: 66 A dictionary of macro name to macro definition. 67 """ 68 self.reset() 69 self.sql = jinja 70 self._tokens = Dialect.get_or_raise(dialect).tokenizer.tokenize(jinja) 71 self._index = -1 72 self._advance() 73 74 macros: t.Dict[str, MacroInfo] = {} 75 76 while self._curr: 77 if self._curr.token_type == TokenType.BLOCK_START: 78 macro_start = self._curr 79 elif self._tag == "MACRO" and self._next: 80 name = self._next.text 81 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 82 self._advance() 83 84 while self._curr and self._tag != "ENDMACRO": 85 self._advance() 86 87 macro_str = self._find_sql(macro_start, self._next) 88 macros[name] = MacroInfo( 89 definition=macro_str, 90 depends_on=list(extract_macro_references_and_variables(macro_str)[0]), 91 ) 92 93 self._advance() 94 95 return macros 96 97 def _advance(self, times: int = 1) -> None: 98 super()._advance(times) 99 self._tag = ( 100 self._curr.text.upper() 101 if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START 102 else "" 103 ) 104 105 106def call_name(node: nodes.Expr) -> t.Tuple[str, ...]: 107 if isinstance(node, nodes.Name): 108 return (node.name,) 109 if isinstance(node, nodes.Const): 110 return (f"'{node.value}'",) 111 if isinstance(node, nodes.Getattr): 112 return call_name(node.node) + (node.attr,) 113 if isinstance(node, (nodes.Getitem, nodes.Call)): 114 return call_name(node.node) 115 return () 116 117 118def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) -> str: 119 return ENVIRONMENT.from_string(query).render(methods or {}) 120 121 122def find_call_names( 123 node: nodes.Node, vars_in_scope: t.Set[str] 124) -> t.Iterator[t.Tuple[t.Tuple[str, ...], nodes.Call]]: 125 vars_in_scope = vars_in_scope.copy() 126 for child_node in node.iter_child_nodes(): 127 if "target" in child_node.fields: 128 target = getattr(child_node, "target") 129 if isinstance(target, nodes.Name): 130 vars_in_scope.add(target.name) 131 elif isinstance(target, nodes.Tuple): 132 for item in target.items: 133 if isinstance(item, nodes.Name): 134 vars_in_scope.add(item.name) 135 elif isinstance(child_node, nodes.Macro): 136 for arg in child_node.args: 137 vars_in_scope.add(arg.name) 138 elif isinstance(child_node, nodes.Call): 139 name = call_name(child_node) 140 if name[0][0] != "'" and name[0] not in vars_in_scope: 141 yield (name, child_node) 142 yield from find_call_names(child_node, vars_in_scope) 143 144 145def extract_call_names(jinja_str: str) -> t.List[t.Tuple[t.Tuple[str, ...], nodes.Call]]: 146 return list(find_call_names(ENVIRONMENT.parse(jinja_str), set())) 147 148 149def extract_macro_references_and_variables( 150 *jinja_strs: str, 151) -> t.Tuple[t.Set[MacroReference], t.Set[str]]: 152 macro_references = set() 153 variables = set() 154 for jinja_str in jinja_strs: 155 for call_name, node in extract_call_names(jinja_str): 156 if call_name[0] == c.VAR: 157 args = [jinja_call_arg_name(arg) for arg in node.args] 158 if args and args[0]: 159 variables.add(args[0].lower()) 160 elif call_name[0] == c.GATEWAY: 161 variables.add(c.GATEWAY) 162 elif len(call_name) == 1: 163 macro_references.add(MacroReference(name=call_name[0])) 164 elif len(call_name) == 2: 165 macro_references.add(MacroReference(package=call_name[0], name=call_name[1])) 166 return macro_references, variables 167 168 169JinjaGlobalAttribute = t.Union[str, int, float, bool, AttributeDict] 170 171 172class JinjaMacroRegistry(PydanticModel): 173 """Registry for Jinja macros. 174 175 Args: 176 packages: The mapping from package name to a collection of macro definitions. 177 root_macros: The collection of top-level macro definitions. 178 global_objs: The global objects. 179 create_builtins_module: The name of a module which defines the `create_builtins` factory 180 function that will be used to construct builtin variables and functions. 181 root_package_name: The name of the root package. If specified root macros will be available 182 as both `root_package_name.macro_name` and `macro_name`. 183 top_level_packages: The list of top-level packages. Macros in this packages will be available 184 as both `package_name.macro_name` and `macro_name`. 185 """ 186 187 packages: t.Dict[str, t.Dict[str, MacroInfo]] = {} 188 root_macros: t.Dict[str, MacroInfo] = {} 189 global_objs: t.Dict[str, JinjaGlobalAttribute] = {} 190 create_builtins_module: t.Optional[str] = SQLMESH_JINJA_PACKAGE 191 root_package_name: t.Optional[str] = None 192 top_level_packages: t.List[str] = [] 193 194 _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {} 195 __environment: t.Optional[Environment] = None 196 197 @field_validator("global_objs", mode="before") 198 @classmethod 199 def _validate_global_objs(cls, value: t.Any) -> t.Any: 200 def _normalize(val: t.Any) -> t.Any: 201 if isinstance(val, dict): 202 return AttributeDict({k: _normalize(v) for k, v in val.items()}) 203 if isinstance(val, list): 204 return [_normalize(v) for v in val] 205 if isinstance(val, set): 206 return [_normalize(v) for v in sorted(val)] 207 if isinstance(val, Enum): 208 return val.value 209 return val 210 211 return _normalize(value) 212 213 @field_serializer("global_objs") 214 def _serialize_attribute_dict( 215 self, value: t.Dict[str, JinjaGlobalAttribute] 216 ) -> t.Dict[str, t.Any]: 217 # NOTE: This is called only when used with Pydantic V2. 218 def _convert( 219 val: t.Union[t.Dict[str, JinjaGlobalAttribute], t.Dict[str, t.Any]], 220 ) -> t.Dict[str, t.Any]: 221 return {k: _convert(v) if isinstance(v, AttributeDict) else v for k, v in val.items()} 222 223 return _convert(value) 224 225 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 226 """Adds macros to the target package. 227 228 Args: 229 macros: Macros that should be added. 230 package: The name of the package the given macros belong to. If not specified, the provided 231 macros will be added to the root namespace. 232 """ 233 234 if package is not None: 235 package_macros = self.packages.get(package, {}) 236 package_macros.update(macros) 237 self.packages[package] = package_macros 238 else: 239 self.root_macros.update(macros) 240 241 def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None: 242 """Adds global objects to the registry. 243 244 Args: 245 globals: The global objects that should be added. 246 """ 247 self.global_objs.update(**self._validate_global_objs(globals)) 248 249 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 250 """Builds a Python callable for a macro with the given reference. 251 252 Args: 253 reference: The macro reference. 254 Returns: 255 The macro as a Python callable or None if not found. 256 """ 257 env: Environment = self.build_environment(**kwargs) 258 if reference.package is not None: 259 package = env.globals.get(reference.package, {}) 260 return package.get(reference.name) # type: ignore 261 return env.globals.get(reference.name) # type: ignore 262 263 def build_environment(self, **kwargs: t.Any) -> Environment: 264 """Builds a new Jinja environment based on this registry.""" 265 266 context: t.Dict[str, t.Any] = {} 267 268 root_macros = { 269 name: self._MacroWrapper(name, None, self, context) 270 for name, macro in self.root_macros.items() 271 } 272 273 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 274 for package_name, macros in self.packages.items(): 275 for macro_name in macros: 276 package_macros[package_name][macro_name] = self._MacroWrapper( 277 macro_name, package_name, self, context 278 ) 279 280 if self.root_package_name is not None: 281 package_macros[self.root_package_name].update(root_macros) 282 283 env = environment() 284 285 builtin_globals = self._create_builtin_globals(kwargs) 286 for top_level_package_name in self.top_level_packages: 287 # Make sure that the top-level package doesn't fully override the same builtin package. 288 package_macros[top_level_package_name] = AttributeDict( 289 { 290 **(builtin_globals.pop(top_level_package_name, None) or {}), 291 **(package_macros.get(top_level_package_name) or {}), 292 } 293 ) 294 root_macros.update(package_macros[top_level_package_name]) 295 296 context.update(builtin_globals) 297 context.update(root_macros) 298 context.update(package_macros) 299 300 env.globals.update(context) 301 env.filters.update(self._environment.filters) 302 return env 303 304 def trim( 305 self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None 306 ) -> JinjaMacroRegistry: 307 """Trims the registry by keeping only macros with given references and their transitive dependencies. 308 309 Args: 310 dependencies: References to macros that should be kept. 311 package: The name of the package in the context of which the trimming should be performed. 312 313 Returns: 314 A new trimmed registry. 315 """ 316 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 317 for dep in dependencies: 318 dependencies_by_package[dep.package or package].add(dep.name) 319 320 top_level_packages = self.top_level_packages.copy() 321 if package is not None: 322 top_level_packages.append(package) 323 324 result = JinjaMacroRegistry( 325 global_objs=self.global_objs.copy(), 326 create_builtins_module=self.create_builtins_module, 327 root_package_name=self.root_package_name, 328 top_level_packages=top_level_packages, 329 ) 330 for package, names in dependencies_by_package.items(): 331 result = result.merge(self._trim_macros(names, package)) 332 333 return result 334 335 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 336 """Returns a copy of the registry which contains macros from both this and `other` instances. 337 338 Args: 339 other: The other registry instance. 340 341 Returns: 342 A new merged registry. 343 """ 344 345 root_macros = { 346 **self.root_macros, 347 **other.root_macros, 348 } 349 350 packages = {} 351 for package in {*self.packages, *other.packages}: 352 packages[package] = { 353 **self.packages.get(package, {}), 354 **other.packages.get(package, {}), 355 } 356 357 global_objs = { 358 **self.global_objs, 359 **other.global_objs, 360 } 361 362 return JinjaMacroRegistry( 363 packages=packages, 364 root_macros=root_macros, 365 global_objs=global_objs, 366 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 367 root_package_name=self.root_package_name or other.root_package_name, 368 top_level_packages=[*self.top_level_packages, *other.top_level_packages], 369 ) 370 371 def to_expressions(self) -> t.List[Expression]: 372 output: t.List[Expression] = [] 373 374 filtered_objs = { 375 k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars") 376 } 377 if filtered_objs: 378 output.append( 379 d.PythonCode( 380 expressions=[ 381 f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" 382 for k, v in sorted(filtered_objs.items()) 383 ] 384 ) 385 ) 386 387 for macro_name, macro_info in sorted(self.root_macros.items()): 388 output.append(d.jinja_statement(macro_info.definition)) 389 390 for _, package in sorted(self.packages.items()): 391 for macro_name, macro_info in sorted(package.items()): 392 output.append(d.jinja_statement(macro_info.definition)) 393 394 return output 395 396 @property 397 def data_hash_values(self) -> t.List[str]: 398 data = [] 399 400 for macro_name, macro in sorted(self.root_macros.items()): 401 data.append(macro_name) 402 data.append(macro.definition) 403 404 for _, package in sorted(self.packages.items()): 405 for macro_name, macro in sorted(package.items()): 406 data.append(macro_name) 407 data.append(macro.definition) 408 409 trimmed_global_objs = { 410 k: self.global_objs[k] for k in ("refs", "sources", "vars") if k in self.global_objs 411 } 412 data.append(json.dumps(trimmed_global_objs, sort_keys=True)) 413 414 return data 415 416 def __deepcopy__(self, memo: t.Optional[t.Dict[int, t.Any]] = None) -> JinjaMacroRegistry: 417 return JinjaMacroRegistry.parse_obj(self.dict()) 418 419 def _parse_macro(self, name: str, package: t.Optional[str]) -> Template: 420 cache_key = (package, name) 421 if cache_key not in self._parser_cache: 422 macro = self._get_macro(name, package) 423 424 definition: nodes.Template = self._environment.parse(macro.definition) 425 if _is_private_macro(name): 426 # A workaround to expose private jinja macros. 427 definition = self._to_non_private_macro_def(name, definition) 428 429 self._parser_cache[cache_key] = self._environment.from_string(definition) 430 return self._parser_cache[cache_key] 431 432 @property 433 def _environment(self) -> Environment: 434 if self.__environment is None: 435 self.__environment = environment() 436 self.__environment.filters.update(self._create_builtin_filters()) 437 return self.__environment 438 439 def _trim_macros( 440 self, 441 names: t.Set[str], 442 package: t.Optional[str] = None, 443 visited: t.Optional[t.Dict[t.Optional[str], t.Set[str]]] = None, 444 ) -> JinjaMacroRegistry: 445 if visited is None: 446 visited = defaultdict(set) 447 448 macros = self.packages.get(package, {}) if package is not None else self.root_macros 449 trimmed_macros = {} 450 451 dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 452 453 for name in names: 454 if name in macros and name not in visited[package]: 455 macro = macros[name] 456 trimmed_macros[name] = macro 457 for dependency in macro.depends_on: 458 dependencies[dependency.package or package].add(dependency.name) 459 visited[package].add(name) 460 461 if package is not None: 462 result = JinjaMacroRegistry(packages={package: trimmed_macros}) 463 else: 464 result = JinjaMacroRegistry(root_macros=trimmed_macros) 465 466 for upstream_package, upstream_names in dependencies.items(): 467 result = result.merge( 468 self._trim_macros(upstream_names, upstream_package, visited=visited) 469 ) 470 471 return result 472 473 def _macro_exists(self, name: str, package: t.Optional[str]) -> bool: 474 return ( 475 name in self.packages.get(package, {}) 476 if package is not None 477 else name in self.root_macros 478 ) 479 480 def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo: 481 return self.packages[package][name] if package is not None else self.root_macros[name] 482 483 def _to_non_private_macro_def(self, name: str, template: nodes.Template) -> nodes.Template: 484 for node in template.find_all((nodes.Macro, nodes.Call)): 485 if isinstance(node, nodes.Macro): 486 node.name = _non_private_name(name) 487 elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name): 488 node.node.name = _non_private_name(name) 489 490 return template 491 492 def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 493 """Creates Jinja builtin globals using a factory function defined in the provided module.""" 494 engine_adapter = global_vars.pop("engine_adapter", None) 495 global_vars = {**self.global_objs, **global_vars} 496 if self.create_builtins_module is not None: 497 module = importlib.import_module(self.create_builtins_module) 498 if hasattr(module, "create_builtin_globals"): 499 return module.create_builtin_globals(self, global_vars, engine_adapter) 500 return global_vars 501 502 def _create_builtin_filters(self) -> t.Dict[str, t.Any]: 503 """Creates Jinja builtin filters using a factory function defined in the provided module.""" 504 if self.create_builtins_module is not None: 505 module = importlib.import_module(self.create_builtins_module) 506 if hasattr(module, "create_builtin_filters"): 507 return module.create_builtin_filters() 508 return {} 509 510 class _MacroWrapper: 511 def __init__( 512 self, 513 name: str, 514 package: t.Optional[str], 515 registry: JinjaMacroRegistry, 516 context: t.Dict[str, t.Any], 517 ): 518 self.name = name 519 self.package = package 520 self.context = context 521 self.registry = registry 522 523 def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: 524 context = self.context.copy() 525 if self.package is not None and self.package in context: 526 context.update(context[self.package]) 527 528 template = self.registry._parse_macro(self.name, self.package) 529 macro_callable = getattr( 530 template.make_module(vars=context), _non_private_name(self.name) 531 ) 532 try: 533 return macro_callable(*args, **kwargs) 534 except MacroReturnVal as ret: 535 return ret.value 536 537 538def _is_private_macro(name: str) -> bool: 539 return name.startswith("_") 540 541 542def _non_private_name(name: str) -> str: 543 return name.lstrip("_") 544 545 546JINJA_REGEX = re.compile(r"({{|{%)") 547 548 549def has_jinja(value: str) -> bool: 550 return JINJA_REGEX.search(value) is not None 551 552 553def jinja_call_arg_name(node: nodes.Node) -> str: 554 if isinstance(node, nodes.Const): 555 return node.value 556 return "" 557 558 559def create_var(variables: t.Dict[str, t.Any]) -> t.Callable: 560 def _var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 561 return variables.get(var_name.lower(), default) 562 563 return _var 564 565 566def create_builtin_globals( 567 jinja_macros: JinjaMacroRegistry, global_vars: t.Dict[str, t.Any], *args: t.Any, **kwargs: t.Any 568) -> t.Dict[str, t.Any]: 569 global_vars.pop(c.GATEWAY, None) 570 variables = global_vars.pop(c.SQLMESH_VARS, None) or {} 571 return { 572 c.VAR: create_var(variables), 573 c.GATEWAY: lambda: variables.get(c.GATEWAY, None), 574 **global_vars, 575 }
32class MacroReference(PydanticModel, frozen=True): 33 package: t.Optional[str] = None 34 name: str 35 36 @property 37 def reference(self) -> str: 38 if self.package is None: 39 return self.name 40 return ".".join((self.package, self.name)) 41 42 def __str__(self) -> str: 43 return self.reference
Usage docs: https://docs.pydantic.dev/2.7/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of classvars defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The signature for instantiating the model.
- __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a
RootModel
. - __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
- __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
- __pydantic_extra__: An instance attribute with the values of extra fields from validation when
model_config['extra'] == 'allow'
. - __pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
- __pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
46class MacroInfo(PydanticModel): 47 """Class to hold macro and its calls""" 48 49 definition: str 50 depends_on: t.List[MacroReference]
Class to hold macro and its calls
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
Common base class for all non-exit exceptions.
Inherited Members
- builtins.BaseException
- with_traceback
58class MacroExtractor(Parser): 59 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 60 """Extract a dictionary of macro definitions from a jinja string. 61 62 Args: 63 jinja: The jinja string to extract from. 64 dialect: The dialect of SQL. 65 66 Returns: 67 A dictionary of macro name to macro definition. 68 """ 69 self.reset() 70 self.sql = jinja 71 self._tokens = Dialect.get_or_raise(dialect).tokenizer.tokenize(jinja) 72 self._index = -1 73 self._advance() 74 75 macros: t.Dict[str, MacroInfo] = {} 76 77 while self._curr: 78 if self._curr.token_type == TokenType.BLOCK_START: 79 macro_start = self._curr 80 elif self._tag == "MACRO" and self._next: 81 name = self._next.text 82 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 83 self._advance() 84 85 while self._curr and self._tag != "ENDMACRO": 86 self._advance() 87 88 macro_str = self._find_sql(macro_start, self._next) 89 macros[name] = MacroInfo( 90 definition=macro_str, 91 depends_on=list(extract_macro_references_and_variables(macro_str)[0]), 92 ) 93 94 self._advance() 95 96 return macros 97 98 def _advance(self, times: int = 1) -> None: 99 super()._advance(times) 100 self._tag = ( 101 self._curr.text.upper() 102 if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START 103 else "" 104 )
Parser consumes a list of tokens produced by the Tokenizer and produces a parsed syntax tree.
Arguments:
- error_level: The desired error level. Default: ErrorLevel.IMMEDIATE
- error_message_context: The amount of context to capture from a query string when displaying the error message (in number of characters). Default: 100
- max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. Default: 3
59 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 60 """Extract a dictionary of macro definitions from a jinja string. 61 62 Args: 63 jinja: The jinja string to extract from. 64 dialect: The dialect of SQL. 65 66 Returns: 67 A dictionary of macro name to macro definition. 68 """ 69 self.reset() 70 self.sql = jinja 71 self._tokens = Dialect.get_or_raise(dialect).tokenizer.tokenize(jinja) 72 self._index = -1 73 self._advance() 74 75 macros: t.Dict[str, MacroInfo] = {} 76 77 while self._curr: 78 if self._curr.token_type == TokenType.BLOCK_START: 79 macro_start = self._curr 80 elif self._tag == "MACRO" and self._next: 81 name = self._next.text 82 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 83 self._advance() 84 85 while self._curr and self._tag != "ENDMACRO": 86 self._advance() 87 88 macro_str = self._find_sql(macro_start, self._next) 89 macros[name] = MacroInfo( 90 definition=macro_str, 91 depends_on=list(extract_macro_references_and_variables(macro_str)[0]), 92 ) 93 94 self._advance() 95 96 return macros
Extract a dictionary of macro definitions from a jinja string.
Arguments:
- jinja: The jinja string to extract from.
- dialect: The dialect of SQL.
Returns:
A dictionary of macro name to macro definition.
Inherited Members
- sqlglot.parser.Parser
- Parser
- reset
- parse
- parse_into
- check_errors
- raise_error
- expression
- validate_expression
107def call_name(node: nodes.Expr) -> t.Tuple[str, ...]: 108 if isinstance(node, nodes.Name): 109 return (node.name,) 110 if isinstance(node, nodes.Const): 111 return (f"'{node.value}'",) 112 if isinstance(node, nodes.Getattr): 113 return call_name(node.node) + (node.attr,) 114 if isinstance(node, (nodes.Getitem, nodes.Call)): 115 return call_name(node.node) 116 return ()
123def find_call_names( 124 node: nodes.Node, vars_in_scope: t.Set[str] 125) -> t.Iterator[t.Tuple[t.Tuple[str, ...], nodes.Call]]: 126 vars_in_scope = vars_in_scope.copy() 127 for child_node in node.iter_child_nodes(): 128 if "target" in child_node.fields: 129 target = getattr(child_node, "target") 130 if isinstance(target, nodes.Name): 131 vars_in_scope.add(target.name) 132 elif isinstance(target, nodes.Tuple): 133 for item in target.items: 134 if isinstance(item, nodes.Name): 135 vars_in_scope.add(item.name) 136 elif isinstance(child_node, nodes.Macro): 137 for arg in child_node.args: 138 vars_in_scope.add(arg.name) 139 elif isinstance(child_node, nodes.Call): 140 name = call_name(child_node) 141 if name[0][0] != "'" and name[0] not in vars_in_scope: 142 yield (name, child_node) 143 yield from find_call_names(child_node, vars_in_scope)
150def extract_macro_references_and_variables( 151 *jinja_strs: str, 152) -> t.Tuple[t.Set[MacroReference], t.Set[str]]: 153 macro_references = set() 154 variables = set() 155 for jinja_str in jinja_strs: 156 for call_name, node in extract_call_names(jinja_str): 157 if call_name[0] == c.VAR: 158 args = [jinja_call_arg_name(arg) for arg in node.args] 159 if args and args[0]: 160 variables.add(args[0].lower()) 161 elif call_name[0] == c.GATEWAY: 162 variables.add(c.GATEWAY) 163 elif len(call_name) == 1: 164 macro_references.add(MacroReference(name=call_name[0])) 165 elif len(call_name) == 2: 166 macro_references.add(MacroReference(package=call_name[0], name=call_name[1])) 167 return macro_references, variables
173class JinjaMacroRegistry(PydanticModel): 174 """Registry for Jinja macros. 175 176 Args: 177 packages: The mapping from package name to a collection of macro definitions. 178 root_macros: The collection of top-level macro definitions. 179 global_objs: The global objects. 180 create_builtins_module: The name of a module which defines the `create_builtins` factory 181 function that will be used to construct builtin variables and functions. 182 root_package_name: The name of the root package. If specified root macros will be available 183 as both `root_package_name.macro_name` and `macro_name`. 184 top_level_packages: The list of top-level packages. Macros in this packages will be available 185 as both `package_name.macro_name` and `macro_name`. 186 """ 187 188 packages: t.Dict[str, t.Dict[str, MacroInfo]] = {} 189 root_macros: t.Dict[str, MacroInfo] = {} 190 global_objs: t.Dict[str, JinjaGlobalAttribute] = {} 191 create_builtins_module: t.Optional[str] = SQLMESH_JINJA_PACKAGE 192 root_package_name: t.Optional[str] = None 193 top_level_packages: t.List[str] = [] 194 195 _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {} 196 __environment: t.Optional[Environment] = None 197 198 @field_validator("global_objs", mode="before") 199 @classmethod 200 def _validate_global_objs(cls, value: t.Any) -> t.Any: 201 def _normalize(val: t.Any) -> t.Any: 202 if isinstance(val, dict): 203 return AttributeDict({k: _normalize(v) for k, v in val.items()}) 204 if isinstance(val, list): 205 return [_normalize(v) for v in val] 206 if isinstance(val, set): 207 return [_normalize(v) for v in sorted(val)] 208 if isinstance(val, Enum): 209 return val.value 210 return val 211 212 return _normalize(value) 213 214 @field_serializer("global_objs") 215 def _serialize_attribute_dict( 216 self, value: t.Dict[str, JinjaGlobalAttribute] 217 ) -> t.Dict[str, t.Any]: 218 # NOTE: This is called only when used with Pydantic V2. 219 def _convert( 220 val: t.Union[t.Dict[str, JinjaGlobalAttribute], t.Dict[str, t.Any]], 221 ) -> t.Dict[str, t.Any]: 222 return {k: _convert(v) if isinstance(v, AttributeDict) else v for k, v in val.items()} 223 224 return _convert(value) 225 226 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 227 """Adds macros to the target package. 228 229 Args: 230 macros: Macros that should be added. 231 package: The name of the package the given macros belong to. If not specified, the provided 232 macros will be added to the root namespace. 233 """ 234 235 if package is not None: 236 package_macros = self.packages.get(package, {}) 237 package_macros.update(macros) 238 self.packages[package] = package_macros 239 else: 240 self.root_macros.update(macros) 241 242 def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None: 243 """Adds global objects to the registry. 244 245 Args: 246 globals: The global objects that should be added. 247 """ 248 self.global_objs.update(**self._validate_global_objs(globals)) 249 250 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 251 """Builds a Python callable for a macro with the given reference. 252 253 Args: 254 reference: The macro reference. 255 Returns: 256 The macro as a Python callable or None if not found. 257 """ 258 env: Environment = self.build_environment(**kwargs) 259 if reference.package is not None: 260 package = env.globals.get(reference.package, {}) 261 return package.get(reference.name) # type: ignore 262 return env.globals.get(reference.name) # type: ignore 263 264 def build_environment(self, **kwargs: t.Any) -> Environment: 265 """Builds a new Jinja environment based on this registry.""" 266 267 context: t.Dict[str, t.Any] = {} 268 269 root_macros = { 270 name: self._MacroWrapper(name, None, self, context) 271 for name, macro in self.root_macros.items() 272 } 273 274 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 275 for package_name, macros in self.packages.items(): 276 for macro_name in macros: 277 package_macros[package_name][macro_name] = self._MacroWrapper( 278 macro_name, package_name, self, context 279 ) 280 281 if self.root_package_name is not None: 282 package_macros[self.root_package_name].update(root_macros) 283 284 env = environment() 285 286 builtin_globals = self._create_builtin_globals(kwargs) 287 for top_level_package_name in self.top_level_packages: 288 # Make sure that the top-level package doesn't fully override the same builtin package. 289 package_macros[top_level_package_name] = AttributeDict( 290 { 291 **(builtin_globals.pop(top_level_package_name, None) or {}), 292 **(package_macros.get(top_level_package_name) or {}), 293 } 294 ) 295 root_macros.update(package_macros[top_level_package_name]) 296 297 context.update(builtin_globals) 298 context.update(root_macros) 299 context.update(package_macros) 300 301 env.globals.update(context) 302 env.filters.update(self._environment.filters) 303 return env 304 305 def trim( 306 self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None 307 ) -> JinjaMacroRegistry: 308 """Trims the registry by keeping only macros with given references and their transitive dependencies. 309 310 Args: 311 dependencies: References to macros that should be kept. 312 package: The name of the package in the context of which the trimming should be performed. 313 314 Returns: 315 A new trimmed registry. 316 """ 317 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 318 for dep in dependencies: 319 dependencies_by_package[dep.package or package].add(dep.name) 320 321 top_level_packages = self.top_level_packages.copy() 322 if package is not None: 323 top_level_packages.append(package) 324 325 result = JinjaMacroRegistry( 326 global_objs=self.global_objs.copy(), 327 create_builtins_module=self.create_builtins_module, 328 root_package_name=self.root_package_name, 329 top_level_packages=top_level_packages, 330 ) 331 for package, names in dependencies_by_package.items(): 332 result = result.merge(self._trim_macros(names, package)) 333 334 return result 335 336 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 337 """Returns a copy of the registry which contains macros from both this and `other` instances. 338 339 Args: 340 other: The other registry instance. 341 342 Returns: 343 A new merged registry. 344 """ 345 346 root_macros = { 347 **self.root_macros, 348 **other.root_macros, 349 } 350 351 packages = {} 352 for package in {*self.packages, *other.packages}: 353 packages[package] = { 354 **self.packages.get(package, {}), 355 **other.packages.get(package, {}), 356 } 357 358 global_objs = { 359 **self.global_objs, 360 **other.global_objs, 361 } 362 363 return JinjaMacroRegistry( 364 packages=packages, 365 root_macros=root_macros, 366 global_objs=global_objs, 367 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 368 root_package_name=self.root_package_name or other.root_package_name, 369 top_level_packages=[*self.top_level_packages, *other.top_level_packages], 370 ) 371 372 def to_expressions(self) -> t.List[Expression]: 373 output: t.List[Expression] = [] 374 375 filtered_objs = { 376 k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars") 377 } 378 if filtered_objs: 379 output.append( 380 d.PythonCode( 381 expressions=[ 382 f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" 383 for k, v in sorted(filtered_objs.items()) 384 ] 385 ) 386 ) 387 388 for macro_name, macro_info in sorted(self.root_macros.items()): 389 output.append(d.jinja_statement(macro_info.definition)) 390 391 for _, package in sorted(self.packages.items()): 392 for macro_name, macro_info in sorted(package.items()): 393 output.append(d.jinja_statement(macro_info.definition)) 394 395 return output 396 397 @property 398 def data_hash_values(self) -> t.List[str]: 399 data = [] 400 401 for macro_name, macro in sorted(self.root_macros.items()): 402 data.append(macro_name) 403 data.append(macro.definition) 404 405 for _, package in sorted(self.packages.items()): 406 for macro_name, macro in sorted(package.items()): 407 data.append(macro_name) 408 data.append(macro.definition) 409 410 trimmed_global_objs = { 411 k: self.global_objs[k] for k in ("refs", "sources", "vars") if k in self.global_objs 412 } 413 data.append(json.dumps(trimmed_global_objs, sort_keys=True)) 414 415 return data 416 417 def __deepcopy__(self, memo: t.Optional[t.Dict[int, t.Any]] = None) -> JinjaMacroRegistry: 418 return JinjaMacroRegistry.parse_obj(self.dict()) 419 420 def _parse_macro(self, name: str, package: t.Optional[str]) -> Template: 421 cache_key = (package, name) 422 if cache_key not in self._parser_cache: 423 macro = self._get_macro(name, package) 424 425 definition: nodes.Template = self._environment.parse(macro.definition) 426 if _is_private_macro(name): 427 # A workaround to expose private jinja macros. 428 definition = self._to_non_private_macro_def(name, definition) 429 430 self._parser_cache[cache_key] = self._environment.from_string(definition) 431 return self._parser_cache[cache_key] 432 433 @property 434 def _environment(self) -> Environment: 435 if self.__environment is None: 436 self.__environment = environment() 437 self.__environment.filters.update(self._create_builtin_filters()) 438 return self.__environment 439 440 def _trim_macros( 441 self, 442 names: t.Set[str], 443 package: t.Optional[str] = None, 444 visited: t.Optional[t.Dict[t.Optional[str], t.Set[str]]] = None, 445 ) -> JinjaMacroRegistry: 446 if visited is None: 447 visited = defaultdict(set) 448 449 macros = self.packages.get(package, {}) if package is not None else self.root_macros 450 trimmed_macros = {} 451 452 dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 453 454 for name in names: 455 if name in macros and name not in visited[package]: 456 macro = macros[name] 457 trimmed_macros[name] = macro 458 for dependency in macro.depends_on: 459 dependencies[dependency.package or package].add(dependency.name) 460 visited[package].add(name) 461 462 if package is not None: 463 result = JinjaMacroRegistry(packages={package: trimmed_macros}) 464 else: 465 result = JinjaMacroRegistry(root_macros=trimmed_macros) 466 467 for upstream_package, upstream_names in dependencies.items(): 468 result = result.merge( 469 self._trim_macros(upstream_names, upstream_package, visited=visited) 470 ) 471 472 return result 473 474 def _macro_exists(self, name: str, package: t.Optional[str]) -> bool: 475 return ( 476 name in self.packages.get(package, {}) 477 if package is not None 478 else name in self.root_macros 479 ) 480 481 def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo: 482 return self.packages[package][name] if package is not None else self.root_macros[name] 483 484 def _to_non_private_macro_def(self, name: str, template: nodes.Template) -> nodes.Template: 485 for node in template.find_all((nodes.Macro, nodes.Call)): 486 if isinstance(node, nodes.Macro): 487 node.name = _non_private_name(name) 488 elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name): 489 node.node.name = _non_private_name(name) 490 491 return template 492 493 def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 494 """Creates Jinja builtin globals using a factory function defined in the provided module.""" 495 engine_adapter = global_vars.pop("engine_adapter", None) 496 global_vars = {**self.global_objs, **global_vars} 497 if self.create_builtins_module is not None: 498 module = importlib.import_module(self.create_builtins_module) 499 if hasattr(module, "create_builtin_globals"): 500 return module.create_builtin_globals(self, global_vars, engine_adapter) 501 return global_vars 502 503 def _create_builtin_filters(self) -> t.Dict[str, t.Any]: 504 """Creates Jinja builtin filters using a factory function defined in the provided module.""" 505 if self.create_builtins_module is not None: 506 module = importlib.import_module(self.create_builtins_module) 507 if hasattr(module, "create_builtin_filters"): 508 return module.create_builtin_filters() 509 return {} 510 511 class _MacroWrapper: 512 def __init__( 513 self, 514 name: str, 515 package: t.Optional[str], 516 registry: JinjaMacroRegistry, 517 context: t.Dict[str, t.Any], 518 ): 519 self.name = name 520 self.package = package 521 self.context = context 522 self.registry = registry 523 524 def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: 525 context = self.context.copy() 526 if self.package is not None and self.package in context: 527 context.update(context[self.package]) 528 529 template = self.registry._parse_macro(self.name, self.package) 530 macro_callable = getattr( 531 template.make_module(vars=context), _non_private_name(self.name) 532 ) 533 try: 534 return macro_callable(*args, **kwargs) 535 except MacroReturnVal as ret: 536 return ret.value
Registry for Jinja macros.
Arguments:
- packages: The mapping from package name to a collection of macro definitions.
- root_macros: The collection of top-level macro definitions.
- global_objs: The global objects.
- create_builtins_module: The name of a module which defines the
create_builtins
factory function that will be used to construct builtin variables and functions. - root_package_name: The name of the root package. If specified root macros will be available
as both
root_package_name.macro_name
andmacro_name
. - top_level_packages: The list of top-level packages. Macros in this packages will be available
as both
package_name.macro_name
andmacro_name
.
226 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 227 """Adds macros to the target package. 228 229 Args: 230 macros: Macros that should be added. 231 package: The name of the package the given macros belong to. If not specified, the provided 232 macros will be added to the root namespace. 233 """ 234 235 if package is not None: 236 package_macros = self.packages.get(package, {}) 237 package_macros.update(macros) 238 self.packages[package] = package_macros 239 else: 240 self.root_macros.update(macros)
Adds macros to the target package.
Arguments:
- macros: Macros that should be added.
- package: The name of the package the given macros belong to. If not specified, the provided
- macros will be added to the root namespace.
242 def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None: 243 """Adds global objects to the registry. 244 245 Args: 246 globals: The global objects that should be added. 247 """ 248 self.global_objs.update(**self._validate_global_objs(globals))
Adds global objects to the registry.
Arguments:
- globals: The global objects that should be added.
250 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 251 """Builds a Python callable for a macro with the given reference. 252 253 Args: 254 reference: The macro reference. 255 Returns: 256 The macro as a Python callable or None if not found. 257 """ 258 env: Environment = self.build_environment(**kwargs) 259 if reference.package is not None: 260 package = env.globals.get(reference.package, {}) 261 return package.get(reference.name) # type: ignore 262 return env.globals.get(reference.name) # type: ignore
Builds a Python callable for a macro with the given reference.
Arguments:
- reference: The macro reference.
Returns:
The macro as a Python callable or None if not found.
264 def build_environment(self, **kwargs: t.Any) -> Environment: 265 """Builds a new Jinja environment based on this registry.""" 266 267 context: t.Dict[str, t.Any] = {} 268 269 root_macros = { 270 name: self._MacroWrapper(name, None, self, context) 271 for name, macro in self.root_macros.items() 272 } 273 274 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 275 for package_name, macros in self.packages.items(): 276 for macro_name in macros: 277 package_macros[package_name][macro_name] = self._MacroWrapper( 278 macro_name, package_name, self, context 279 ) 280 281 if self.root_package_name is not None: 282 package_macros[self.root_package_name].update(root_macros) 283 284 env = environment() 285 286 builtin_globals = self._create_builtin_globals(kwargs) 287 for top_level_package_name in self.top_level_packages: 288 # Make sure that the top-level package doesn't fully override the same builtin package. 289 package_macros[top_level_package_name] = AttributeDict( 290 { 291 **(builtin_globals.pop(top_level_package_name, None) or {}), 292 **(package_macros.get(top_level_package_name) or {}), 293 } 294 ) 295 root_macros.update(package_macros[top_level_package_name]) 296 297 context.update(builtin_globals) 298 context.update(root_macros) 299 context.update(package_macros) 300 301 env.globals.update(context) 302 env.filters.update(self._environment.filters) 303 return env
Builds a new Jinja environment based on this registry.
305 def trim( 306 self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None 307 ) -> JinjaMacroRegistry: 308 """Trims the registry by keeping only macros with given references and their transitive dependencies. 309 310 Args: 311 dependencies: References to macros that should be kept. 312 package: The name of the package in the context of which the trimming should be performed. 313 314 Returns: 315 A new trimmed registry. 316 """ 317 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 318 for dep in dependencies: 319 dependencies_by_package[dep.package or package].add(dep.name) 320 321 top_level_packages = self.top_level_packages.copy() 322 if package is not None: 323 top_level_packages.append(package) 324 325 result = JinjaMacroRegistry( 326 global_objs=self.global_objs.copy(), 327 create_builtins_module=self.create_builtins_module, 328 root_package_name=self.root_package_name, 329 top_level_packages=top_level_packages, 330 ) 331 for package, names in dependencies_by_package.items(): 332 result = result.merge(self._trim_macros(names, package)) 333 334 return result
Trims the registry by keeping only macros with given references and their transitive dependencies.
Arguments:
- dependencies: References to macros that should be kept.
- package: The name of the package in the context of which the trimming should be performed.
Returns:
A new trimmed registry.
336 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 337 """Returns a copy of the registry which contains macros from both this and `other` instances. 338 339 Args: 340 other: The other registry instance. 341 342 Returns: 343 A new merged registry. 344 """ 345 346 root_macros = { 347 **self.root_macros, 348 **other.root_macros, 349 } 350 351 packages = {} 352 for package in {*self.packages, *other.packages}: 353 packages[package] = { 354 **self.packages.get(package, {}), 355 **other.packages.get(package, {}), 356 } 357 358 global_objs = { 359 **self.global_objs, 360 **other.global_objs, 361 } 362 363 return JinjaMacroRegistry( 364 packages=packages, 365 root_macros=root_macros, 366 global_objs=global_objs, 367 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 368 root_package_name=self.root_package_name or other.root_package_name, 369 top_level_packages=[*self.top_level_packages, *other.top_level_packages], 370 )
Returns a copy of the registry which contains macros from both this and other
instances.
Arguments:
- other: The other registry instance.
Returns:
A new merged registry.
372 def to_expressions(self) -> t.List[Expression]: 373 output: t.List[Expression] = [] 374 375 filtered_objs = { 376 k: v for k, v in self.global_objs.items() if k in ("refs", "sources", "vars") 377 } 378 if filtered_objs: 379 output.append( 380 d.PythonCode( 381 expressions=[ 382 f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" 383 for k, v in sorted(filtered_objs.items()) 384 ] 385 ) 386 ) 387 388 for macro_name, macro_info in sorted(self.root_macros.items()): 389 output.append(d.jinja_statement(macro_info.definition)) 390 391 for _, package in sorted(self.packages.items()): 392 for macro_name, macro_info in sorted(package.items()): 393 output.append(d.jinja_statement(macro_info.definition)) 394 395 return output
102 def wrapped_model_post_init(self: BaseModel, __context: Any) -> None: 103 """We need to both initialize private attributes and call the user-defined model_post_init 104 method. 105 """ 106 init_private_attributes(self, __context) 107 original_model_post_init(self, __context)
Override this method to perform additional initialization after __init__
and model_construct
.
This is useful if you want to do some validation that requires the entire model to be initialized.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
567def create_builtin_globals( 568 jinja_macros: JinjaMacroRegistry, global_vars: t.Dict[str, t.Any], *args: t.Any, **kwargs: t.Any 569) -> t.Dict[str, t.Any]: 570 global_vars.pop(c.GATEWAY, None) 571 variables = global_vars.pop(c.SQLMESH_VARS, None) or {} 572 return { 573 c.VAR: create_var(variables), 574 c.GATEWAY: lambda: variables.get(c.GATEWAY, None), 575 **global_vars, 576 }