Edit on GitHub

sqlmesh.core.dialect

   1from __future__ import annotations
   2
   3import functools
   4import logging
   5import re
   6import sys
   7import typing as t
   8from contextlib import contextmanager
   9from difflib import unified_diff
  10from enum import Enum, auto
  11from functools import lru_cache
  12
  13from sqlglot import Dialect, Generator, ParseError, Parser, Tokenizer, TokenType, exp
  14from sqlglot.dialects.dialect import DialectType
  15from sqlglot.dialects import DuckDB, Snowflake
  16import sqlglot.dialects.athena as athena
  17from sqlglot.helper import seq_get
  18from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
  19from sqlglot.optimizer.qualify_columns import quote_identifiers
  20from sqlglot.optimizer.qualify_tables import qualify_tables
  21from sqlglot.optimizer.scope import traverse_scope
  22from sqlglot.schema import MappingSchema
  23from sqlglot.tokens import Token
  24
  25from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE
  26from sqlmesh.utils import get_source_columns_to_types
  27from sqlmesh.utils.errors import SQLMeshError, ConfigError
  28from sqlmesh.utils.pandas import columns_to_types_from_df
  29
  30if t.TYPE_CHECKING:
  31    import pandas as pd
  32
  33    from sqlglot._typing import E
  34
  35
  36SQLMESH_MACRO_PREFIX = "@"
  37
  38TABLES_META = "sqlmesh.tables"
  39
  40logger = logging.getLogger(__name__)
  41
  42
  43class Model(exp.Expression):
  44    arg_types = {"expressions": True}
  45
  46
  47class Audit(exp.Expression):
  48    arg_types = {"expressions": True}
  49
  50
  51class Metric(exp.Expression):
  52    arg_types = {"expressions": True}
  53
  54
  55class Jinja(exp.Func):
  56    arg_types = {"this": True}
  57
  58
  59class JinjaQuery(Jinja):
  60    pass
  61
  62
  63class JinjaStatement(Jinja):
  64    pass
  65
  66
  67class VirtualUpdateStatement(exp.Expression):
  68    arg_types = {"expressions": True}
  69
  70
  71class ModelKind(exp.Expression):
  72    arg_types = {"this": True, "expressions": False}
  73
  74
  75class MacroVar(exp.Var):
  76    pass
  77
  78
  79class MacroFunc(exp.Func):
  80    @property
  81    def name(self) -> str:
  82        return self.this.name
  83
  84
  85class MacroDef(MacroFunc):
  86    arg_types = {"this": True, "expression": True}
  87
  88
  89class MacroSQL(MacroFunc):
  90    arg_types = {"this": True, "into": False}
  91
  92
  93class MacroStrReplace(MacroFunc):
  94    pass
  95
  96
  97class PythonCode(exp.Expression):
  98    arg_types = {"expressions": True}
  99
 100
 101class DColonCast(exp.Cast):
 102    pass
 103
 104
 105class MetricAgg(exp.AggFunc):
 106    """Used for computing metrics."""
 107
 108    arg_types = {"this": True}
 109
 110    @property
 111    def output_name(self) -> str:
 112        return self.this.name
 113
 114
 115class StagedFilePath(exp.Expression):
 116    """Represents paths to "staged files" in Snowflake."""
 117
 118    arg_types = exp.Table.arg_types.copy()
 119
 120
 121def _parse_statement(self: Parser) -> t.Optional[exp.Expression]:
 122    if self._curr is None:
 123        return None
 124
 125    parser = PARSERS.get(self._curr.text.upper())
 126    error_msg = None
 127
 128    if parser:
 129        # Capture any available description in the form of a comment
 130        comments = self._curr.comments
 131
 132        index = self._index
 133        try:
 134            self._advance()
 135            meta = self._parse_wrapped(lambda: t.cast(t.Callable, parser)(self))
 136        except ParseError as parse_error:
 137            error_msg = parse_error.args[0]
 138            self._retreat(index)
 139
 140        # Only return the DDL expression if we actually managed to parse one. This is
 141        # done in order to allow parsing standalone identifiers / function calls like
 142        # "metric", or "model(1, 2, 3)", which collide with SQLMesh's DDL syntax.
 143        if self._index != index:
 144            meta.comments = comments
 145            return meta
 146
 147    try:
 148        return self.__parse_statement()  # type: ignore
 149    except ParseError:
 150        if error_msg:
 151            raise ParseError(error_msg)
 152        raise
 153
 154
 155def _parse_lambda(self: Parser, alias: bool = False) -> t.Optional[exp.Expression]:
 156    node = self.__parse_lambda(alias=alias)  # type: ignore
 157    if isinstance(node, exp.Lambda):
 158        node.set("this", self._parse_alias(node.this))
 159    return node
 160
 161
 162def _parse_id_var(
 163    self: Parser,
 164    any_token: bool = True,
 165    tokens: t.Optional[t.Collection[TokenType]] = None,
 166) -> t.Optional[exp.Expression]:
 167    if self._prev and self._prev.text == SQLMESH_MACRO_PREFIX and self._match(TokenType.L_BRACE):
 168        identifier = self.__parse_id_var(any_token=any_token, tokens=tokens)  # type: ignore
 169        if not self._match(TokenType.R_BRACE):
 170            self.raise_error("Expecting }")
 171        identifier.args["this"] = f"@{{{identifier.name}}}"
 172    else:
 173        identifier = self.__parse_id_var(any_token=any_token, tokens=tokens)  # type: ignore
 174
 175    while (
 176        identifier
 177        and not identifier.args.get("quoted")
 178        and self._is_connected()
 179        and (
 180            self._match_texts(("{", SQLMESH_MACRO_PREFIX))
 181            or self._curr.token_type not in self.RESERVED_TOKENS
 182        )
 183    ):
 184        this = identifier.name
 185        brace = False
 186
 187        if self._prev.text == "{":
 188            this += "{"
 189            brace = True
 190        else:
 191            if self._prev.text == SQLMESH_MACRO_PREFIX:
 192                this += "@"
 193            if self._match(TokenType.L_BRACE):
 194                this += "{"
 195                brace = True
 196
 197        next_id = self._parse_id_var(any_token=False)
 198
 199        if next_id:
 200            this += next_id.name
 201        else:
 202            return identifier
 203
 204        if brace:
 205            if self._match(TokenType.R_BRACE):
 206                this += "}"
 207            else:
 208                self.raise_error("Expecting }")
 209
 210        identifier = self.expression(exp.Identifier, this=this, quoted=identifier.quoted)
 211
 212    return identifier
 213
 214
 215def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expression]:
 216    if self._prev.text != SQLMESH_MACRO_PREFIX:
 217        return self._parse_parameter()
 218
 219    comments = self._prev.comments
 220    index = self._index
 221    field = self._parse_primary() or self._parse_function(functions={}) or self._parse_id_var()
 222
 223    def _build_macro(field: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 224        if isinstance(field, exp.Func):
 225            macro_name = field.name.upper()
 226            if macro_name != keyword_macro and macro_name in KEYWORD_MACROS:
 227                self._retreat(index)
 228                return None
 229
 230            if isinstance(field, exp.Anonymous):
 231                if macro_name == "DEF":
 232                    return self.expression(
 233                        MacroDef,
 234                        this=field.expressions[0],
 235                        expression=field.expressions[1],
 236                        comments=comments,
 237                    )
 238                if macro_name == "SQL":
 239                    into = field.expressions[1].this.lower() if len(field.expressions) > 1 else None
 240                    return self.expression(
 241                        MacroSQL, this=field.expressions[0], into=into, comments=comments
 242                    )
 243            else:
 244                field = self.expression(
 245                    exp.Anonymous,
 246                    this=field.sql_name(),
 247                    expressions=list(field.args.values()),
 248                    comments=comments,
 249                )
 250
 251            return self.expression(MacroFunc, this=field, comments=comments)
 252
 253        if field is None:
 254            return None
 255
 256        if field.is_string or (isinstance(field, exp.Identifier) and field.quoted):
 257            return self.expression(
 258                MacroStrReplace, this=exp.Literal.string(field.this), comments=comments
 259            )
 260
 261        if "@" in field.this:
 262            return field
 263        return self.expression(MacroVar, this=field.this, comments=comments)
 264
 265    if isinstance(field, (exp.Window, exp.IgnoreNulls, exp.RespectNulls)):
 266        field.set("this", _build_macro(field.this))
 267    else:
 268        field = _build_macro(field)
 269
 270    return field
 271
 272
 273KEYWORD_MACROS = {"WITH", "JOIN", "WHERE", "GROUP_BY", "HAVING", "ORDER_BY", "LIMIT"}
 274
 275
 276def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expression]:
 277    if not self._match_pair(TokenType.PARAMETER, TokenType.VAR, advance=False) or (
 278        self._next and self._next.text.upper() != name.upper()
 279    ):
 280        return None
 281
 282    self._advance()
 283    return _parse_macro(self, keyword_macro=name)
 284
 285
 286def _parse_body_macro(self: Parser) -> t.Tuple[str, t.Optional[exp.Expression]]:
 287    name = self._next and self._next.text.upper()
 288
 289    if name == "JOIN":
 290        return ("joins", self._parse_join())
 291    if name == "WHERE":
 292        return ("where", self._parse_where())
 293    if name == "GROUP_BY":
 294        return ("group", self._parse_group())
 295    if name == "HAVING":
 296        return ("having", self._parse_having())
 297    if name == "ORDER_BY":
 298        return ("order", self._parse_order())
 299    if name == "LIMIT":
 300        return ("limit", self._parse_limit())
 301    return ("", None)
 302
 303
 304def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.Expression]:
 305    macro = _parse_matching_macro(self, "WITH")
 306    if not macro:
 307        return self.__parse_with(skip_with_token=skip_with_token)  # type: ignore
 308
 309    macro.this.append("expressions", self.__parse_with(skip_with_token=True))  # type: ignore
 310    return macro
 311
 312
 313def _parse_join(
 314    self: Parser, skip_join_token: bool = False, parse_bracket: bool = False
 315) -> t.Optional[exp.Expression]:
 316    index = self._index
 317    method, side, kind = self._parse_join_parts()
 318    macro = _parse_matching_macro(self, "JOIN")
 319    if not macro:
 320        self._retreat(index)
 321        return self.__parse_join(skip_join_token=skip_join_token, parse_bracket=parse_bracket)  # type: ignore
 322
 323    join = self.__parse_join(skip_join_token=True)  # type: ignore
 324    if method:
 325        join.set("method", method.text)
 326    if side:
 327        join.set("side", side.text)
 328    if kind:
 329        join.set("kind", kind.text)
 330
 331    macro.this.append("expressions", join)
 332    return macro
 333
 334
 335def _warn_unsupported(self: Parser) -> None:
 336    from sqlmesh.core.console import get_console
 337
 338    sql = self._find_sql(self._tokens[0], self._tokens[-1])[: self.error_message_context]
 339
 340    get_console().log_warning(
 341        f"'{sql}' could not be semantically understood as it contains unsupported syntax, SQLMesh will treat the command as is. Note that any references to the model's "
 342        "underlying physical table can't be resolved in this case, consider using Jinja as explained here https://sqlmesh.readthedocs.io/en/stable/concepts/macros/macro_variables/#audit-only-variables"
 343    )
 344
 345
 346def _parse_select(
 347    self: Parser,
 348    nested: bool = False,
 349    table: bool = False,
 350    parse_subquery_alias: bool = True,
 351    parse_set_operation: bool = True,
 352    consume_pipe: bool = True,
 353    from_: t.Optional[exp.From] = None,
 354) -> t.Optional[exp.Expression]:
 355    select = self.__parse_select(  # type: ignore
 356        nested=nested,
 357        table=table,
 358        parse_subquery_alias=parse_subquery_alias,
 359        parse_set_operation=parse_set_operation,
 360        consume_pipe=consume_pipe,
 361        from_=from_,
 362    )
 363
 364    if (
 365        not select
 366        and not parse_set_operation
 367        and self._match_pair(TokenType.PARAMETER, TokenType.VAR, advance=False)
 368    ):
 369        self._advance()
 370        return _parse_macro(self)
 371
 372    return select
 373
 374
 375def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
 376    macro = _parse_matching_macro(self, "WHERE")
 377    if not macro:
 378        return self.__parse_where(skip_where_token=skip_where_token)  # type: ignore
 379
 380    macro.this.append("expressions", self.__parse_where(skip_where_token=True))  # type: ignore
 381    return macro
 382
 383
 384def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]:
 385    macro = _parse_matching_macro(self, "GROUP_BY")
 386    if not macro:
 387        return self.__parse_group(skip_group_by_token=skip_group_by_token)  # type: ignore
 388
 389    macro.this.append("expressions", self.__parse_group(skip_group_by_token=True))  # type: ignore
 390    return macro
 391
 392
 393def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[exp.Expression]:
 394    macro = _parse_matching_macro(self, "HAVING")
 395    if not macro:
 396        return self.__parse_having(skip_having_token=skip_having_token)  # type: ignore
 397
 398    macro.this.append("expressions", self.__parse_having(skip_having_token=True))  # type: ignore
 399    return macro
 400
 401
 402def _parse_order(
 403    self: Parser, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
 404) -> t.Optional[exp.Expression]:
 405    macro = _parse_matching_macro(self, "ORDER_BY")
 406    if not macro:
 407        return self.__parse_order(this, skip_order_token=skip_order_token)  # type: ignore
 408
 409    macro.this.append("expressions", self.__parse_order(this, skip_order_token=True))  # type: ignore
 410    return macro
 411
 412
 413def _parse_limit(
 414    self: Parser,
 415    this: t.Optional[exp.Expression] = None,
 416    top: bool = False,
 417    skip_limit_token: bool = False,
 418) -> t.Optional[exp.Expression]:
 419    macro = _parse_matching_macro(self, "TOP" if top else "LIMIT")
 420    if not macro:
 421        return self.__parse_limit(this, top=top, skip_limit_token=skip_limit_token)  # type: ignore
 422
 423    macro.this.append("expressions", self.__parse_limit(this, top=top, skip_limit_token=True))  # type: ignore
 424    return macro
 425
 426
 427def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expression]:
 428    wrapped = self._match(TokenType.L_PAREN, advance=False)
 429
 430    # The base _parse_value method always constructs a Tuple instance. This is problematic when
 431    # generating values with a macro function, because it's impossible to tell whether the user's
 432    # intention was to construct a row or a column with the VALUES expression. To avoid this, we
 433    # amend the AST such that the Tuple is replaced by the macro function call itself.
 434    expr = self.__parse_value()  # type: ignore
 435    if expr and not wrapped and isinstance(seq_get(expr.expressions, 0), MacroFunc):
 436        return expr.expressions[0]
 437
 438    return expr
 439
 440
 441def _parse_macro_or_clause(self: Parser, parser: t.Callable) -> t.Optional[exp.Expression]:
 442    return _parse_macro(self) if self._match(TokenType.PARAMETER) else parser()
 443
 444
 445def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
 446    key = self._parse_id_var(any_token=True)
 447    if not key:
 448        return None
 449
 450    name = key.name.lower()
 451    if name == "time_data_type":
 452        # TODO: if we make *_data_type a convention to parse things into exp.DataType, we could make this more generic
 453        value = self._parse_types(schema=True)
 454    elif name == "when_matched":
 455        # Parentheses around the WHEN clauses can be used to disambiguate them from other properties
 456        value = self._parse_wrapped(
 457            lambda: _parse_macro_or_clause(self, self._parse_when_matched),
 458            optional=True,
 459        )
 460    elif name == "merge_filter":
 461        value = self._parse_conjunction()
 462    elif self._match(TokenType.L_PAREN):
 463        value = self.expression(exp.Tuple, expressions=self._parse_csv(self._parse_equality))
 464        self._match_r_paren()
 465    else:
 466        value = self._parse_bracket(self._parse_field(any_token=True))
 467
 468    if name == "path" and value:
 469        # Make sure if we get a windows path that it is converted to posix
 470        value = exp.Literal.string(value.this.replace("\\", "/"))  # type: ignore
 471
 472    return self.expression(exp.Property, this=name, value=value)
 473
 474
 475def _parse_types(
 476    self: Parser,
 477    check_func: bool = False,
 478    schema: bool = False,
 479    allow_identifiers: bool = True,
 480) -> t.Optional[exp.Expression]:
 481    start = self._curr
 482    parsed_type = self.__parse_types(  # type: ignore
 483        check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
 484    )
 485
 486    if schema and parsed_type:
 487        parsed_type.meta["sql"] = self._find_sql(start, self._prev)
 488
 489    return parsed_type
 490
 491
 492# Only needed for Snowflake: its "staged file" syntax (@<path>) clashes with our macro
 493# var syntax. By converting the Var representation to a MacroVar, we should be able to
 494# handle both use cases: if there's no value in the MacroEvaluator's context for that
 495# MacroVar, it'll render into @<path>, so it won't break staged file path references.
 496#
 497# See: https://docs.snowflake.com/en/user-guide/querying-stage
 498def _parse_table_parts(
 499    self: Parser, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
 500) -> exp.Table | StagedFilePath:
 501    index = self._index
 502    table = self.__parse_table_parts(  # type: ignore
 503        schema=schema, is_db_reference=is_db_reference, wildcard=wildcard
 504    )
 505
 506    table_arg = table.this
 507    name = table_arg.name if isinstance(table_arg, exp.Var) else ""
 508
 509    if name.startswith(SQLMESH_MACRO_PREFIX):
 510        # In these cases, we don't want to produce a `StagedFilePath` node:
 511        #
 512        # - @'...' needs to parsed as a string template
 513        # - @{foo}.bar needs to be parsed as a table with a macro var part
 514        # - @name(arg1 [, arg2 ...]) needs to be parsed as a macro function call
 515        #
 516        # These cases can unambiguously be parsed using the base `_parse_table_parts`, as there
 517        # is no overlap with staged files https://docs.snowflake.com/en/user-guide/querying-stage
 518        if (
 519            self._prev.token_type == TokenType.STRING
 520            or "{" in name
 521            or (
 522                self._curr
 523                and self._prev.token_type in (TokenType.L_PAREN, TokenType.R_PAREN)
 524                and self._curr.text.upper() not in ("FILE_FORMAT", "PATTERN")
 525                and not (table.args.get("format") or table.args.get("pattern"))
 526            )
 527        ):
 528            self._retreat(index)
 529            return Parser._parse_table_parts(self, schema=schema, is_db_reference=is_db_reference)
 530
 531        table_arg.replace(MacroVar(this=name[1:]))
 532        return StagedFilePath(**table.args)
 533
 534    return table
 535
 536
 537def _parse_if(self: Parser) -> t.Optional[exp.Expression]:
 538    # If we fail to parse an IF function with expressions as arguments, we then try
 539    # to parse a statement / command to support the macro @IF(condition, statement)
 540    index = self._index
 541    try:
 542        return self.__parse_if()  # type: ignore
 543    except ParseError:
 544        self._retreat(index)
 545        self._match_l_paren()
 546
 547        cond = self._parse_conjunction()
 548        self._match(TokenType.COMMA)
 549
 550        # Try to parse a known statement, otherwise fall back to parsing a command
 551        # Since the trailing `)` token is not expected by the statement parsers, we
 552        # remove it from the token stream before trying to parse the statement.
 553        last_token = self._tokens[-1]
 554        if last_token.token_type == TokenType.R_PAREN:
 555            self._tokens[-2].comments.extend(last_token.comments)
 556            self._tokens.pop()
 557        else:
 558            self.raise_error("Expecting )")
 559
 560        index = self._index
 561        stmt = self._parse_statement()
 562        if self._curr:
 563            self._retreat(index)
 564            stmt = self._parse_as_command(self._tokens[index])
 565
 566        return exp.Anonymous(this="IF", expressions=[cond, stmt])
 567
 568
 569def _create_parser(expression_type: t.Type[exp.Expression], table_keys: t.List[str]) -> t.Callable:
 570    def parse(self: Parser) -> t.Optional[exp.Expression]:
 571        from sqlmesh.core.model.kind import ModelKindName
 572
 573        expressions: t.List[exp.Expression] = []
 574
 575        while True:
 576            prev_property = seq_get(expressions, -1)
 577            if not self._match(TokenType.COMMA, expression=prev_property) and expressions:
 578                break
 579
 580            key_expression = self._parse_id_var(any_token=True)
 581            if not key_expression:
 582                break
 583
 584            # This allows macro functions that programmaticaly generate the property key-value pair
 585            if isinstance(key_expression, MacroFunc):
 586                expressions.append(key_expression)
 587                continue
 588
 589            key = key_expression.name.lower()
 590
 591            start = self._curr
 592            value: t.Optional[exp.Expression | str]
 593
 594            if key in table_keys:
 595                value = self._parse_table_parts()
 596                if value and self._prev.token_type == TokenType.STRING:
 597                    self.raise_error(
 598                        f"'{key}' property cannot be a string value: {value}. "
 599                        "Please use the identifier syntax instead, e.g. foo.bar instead of 'foo.bar'"
 600                    )
 601            elif key == "columns":
 602                value = self._parse_schema()
 603            elif key == "kind":
 604                field = _parse_macro_or_clause(self, lambda: self._parse_id_var(any_token=True))
 605
 606                if not field or isinstance(field, (MacroVar, MacroFunc)):
 607                    value = field
 608                else:
 609                    try:
 610                        kind = ModelKindName[field.name.upper()]
 611                    except KeyError:
 612                        raise SQLMeshError(
 613                            f"Model kind specified as '{field.name}', but that is not a valid model kind.\n\nPlease specify one of {', '.join(ModelKindName)}."
 614                        )
 615
 616                    if kind in (
 617                        ModelKindName.INCREMENTAL_BY_TIME_RANGE,
 618                        ModelKindName.INCREMENTAL_BY_UNIQUE_KEY,
 619                        ModelKindName.INCREMENTAL_BY_PARTITION,
 620                        ModelKindName.INCREMENTAL_UNMANAGED,
 621                        ModelKindName.SEED,
 622                        ModelKindName.VIEW,
 623                        ModelKindName.SCD_TYPE_2,
 624                        ModelKindName.SCD_TYPE_2_BY_TIME,
 625                        ModelKindName.SCD_TYPE_2_BY_COLUMN,
 626                        ModelKindName.CUSTOM,
 627                    ) and self._match(TokenType.L_PAREN, advance=False):
 628                        props = self._parse_wrapped_csv(functools.partial(_parse_props, self))
 629                    else:
 630                        props = None
 631
 632                    value = self.expression(ModelKind, this=kind.value, expressions=props)
 633            elif key == "expression":
 634                value = self._parse_conjunction()
 635            elif key == "partitioned_by":
 636                partitioned_by = self._parse_partitioned_by()
 637                if isinstance(partitioned_by.this, exp.Schema):
 638                    value = exp.tuple_(*partitioned_by.this.expressions)
 639                else:
 640                    value = partitioned_by.this
 641            else:
 642                value = self._parse_bracket(self._parse_field(any_token=True))
 643
 644            if isinstance(value, exp.Expression):
 645                value.meta["sql"] = self._find_sql(start, self._prev)
 646
 647            expressions.append(self.expression(exp.Property, this=key, value=value))
 648
 649        return self.expression(expression_type, expressions=expressions)
 650
 651    return parse
 652
 653
 654PARSERS = {
 655    "MODEL": _create_parser(Model, ["name"]),
 656    "AUDIT": _create_parser(Audit, ["model"]),
 657    "METRIC": _create_parser(Metric, ["name"]),
 658}
 659
 660
 661def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str:
 662    props = []
 663    size = len(expressions)
 664
 665    for i, prop in enumerate(expressions):
 666        if isinstance(prop, MacroFunc):
 667            sql = self.indent(self.sql(prop, comment=False))
 668        else:
 669            sql = self.indent(f"{prop.name} {self.sql(prop, 'value')}")
 670
 671        if i < size - 1:
 672            sql += ","
 673
 674        props.append(self.maybe_comment(sql, expression=prop))
 675
 676    return "\n".join(props)
 677
 678
 679def _on_virtual_update_sql(self: Generator, expressions: t.List[exp.Expression]) -> str:
 680    statements = "\n".join(
 681        self.sql(expression)
 682        if isinstance(expression, JinjaStatement)
 683        else f"{self.sql(expression)};"
 684        for expression in expressions
 685    )
 686    return f"{ON_VIRTUAL_UPDATE_BEGIN};\n{statements}\n{ON_VIRTUAL_UPDATE_END};"
 687
 688
 689def _sqlmesh_ddl_sql(self: Generator, expression: Model | Audit | Metric, name: str) -> str:
 690    return "\n".join([f"{name} (", _props_sql(self, expression.expressions), ")"])
 691
 692
 693def _model_kind_sql(self: Generator, expression: ModelKind) -> str:
 694    props = _props_sql(self, expression.expressions)
 695    if props:
 696        return "\n".join([f"{expression.this} (", props, ")"])
 697    return expression.name.upper()
 698
 699
 700def _macro_keyword_func_sql(self: Generator, expression: exp.Expression) -> str:
 701    name = expression.name
 702    keyword = name.replace("_", " ")
 703    *args, clause = expression.expressions
 704    macro = f"@{name}({self.format_args(*args)})"
 705    return self.sql(clause).replace(keyword, macro, 1)
 706
 707
 708def _macro_func_sql(self: Generator, expression: MacroFunc) -> str:
 709    expression = expression.this
 710    name = expression.name
 711    if name in KEYWORD_MACROS:
 712        sql = _macro_keyword_func_sql(self, expression)
 713    else:
 714        sql = f"@{name}({self.format_args(*expression.expressions)})"
 715    return self.maybe_comment(sql, expression)
 716
 717
 718def _whens_sql(self: Generator, expression: exp.Whens) -> str:
 719    if isinstance(expression.parent, exp.Merge):
 720        return self.whens_sql(expression)
 721
 722    # If the `WHEN` clauses aren't part of a MERGE statement (e.g. they
 723    # appear in the `MODEL` DDL), then we will wrap them with parentheses.
 724    return self.wrap(self.expressions(expression, sep=" ", indent=False))
 725
 726
 727def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None:
 728    name = func.__name__
 729    setattr(klass, f"_{name}", getattr(klass, name))
 730    setattr(klass, name, func)
 731
 732
 733def format_model_expressions(
 734    expressions: t.List[exp.Expression],
 735    dialect: t.Optional[str] = None,
 736    rewrite_casts: bool = True,
 737    **kwargs: t.Any,
 738) -> str:
 739    """Format a model's expressions into a standardized format.
 740
 741    Args:
 742        expressions: The model's expressions, must be at least model def + query.
 743        dialect: The dialect to render the expressions as.
 744        rewrite_casts: Whether to rewrite all casts to use the :: syntax.
 745        **kwargs: Additional keyword arguments to pass to the sql generator.
 746
 747    Returns:
 748        A string representing the formatted model.
 749    """
 750    if len(expressions) == 1 and is_meta_expression(expressions[0]):
 751        return expressions[0].sql(pretty=True, dialect=dialect)
 752
 753    if rewrite_casts:
 754
 755        def cast_to_colon(node: exp.Expression) -> exp.Expression:
 756            if isinstance(node, exp.Cast) and not any(
 757                # Only convert CAST into :: if it doesn't have additional args set, otherwise this
 758                # conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST)
 759                arg
 760                for name, arg in node.args.items()
 761                if name not in ("this", "to")
 762            ):
 763                this = node.this
 764
 765                if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren):
 766                    cast = DColonCast(this=this, to=node.to)
 767                    cast.comments = node.comments
 768                    node = cast
 769
 770            exp.replace_children(node, cast_to_colon)
 771            return node
 772
 773        new_expressions = []
 774        for expression in expressions:
 775            expression = expression.copy()
 776            exp.replace_children(expression, cast_to_colon)
 777            new_expressions.append(expression)
 778
 779        expressions = new_expressions
 780
 781    return ";\n\n".join(
 782        expression.sql(pretty=True, dialect=dialect, **kwargs) for expression in expressions
 783    ).strip()
 784
 785
 786def text_diff(
 787    a: t.List[exp.Expression],
 788    b: t.List[exp.Expression],
 789    a_dialect: t.Optional[str] = None,
 790    b_dialect: t.Optional[str] = None,
 791) -> str:
 792    """Find the unified text diff between two expressions."""
 793    a_sql = [
 794        line
 795        for expr in a
 796        for line in expr.sql(pretty=True, comments=False, dialect=a_dialect).split("\n")
 797    ]
 798    b_sql = [
 799        line
 800        for expr in b
 801        for line in expr.sql(pretty=True, comments=False, dialect=b_dialect).split("\n")
 802    ]
 803    return "\n".join(unified_diff(a_sql, b_sql))
 804
 805
 806WS_OR_COMMENT = r"(?:\s|--[^\n]*\n|/\*.*?\*/)"
 807HEADER = r"\b(?:model|audit)\b(?=\s*\()"
 808KEY_BOUNDARY = r"(?:\(|,)"  # key is preceded by either '(' or ','
 809DIALECT_VALUE = r"['\"]?(?P<dialect>[a-z][a-z0-9]*)['\"]?"
 810VALUE_BOUNDARY = r"(?=,|\))"  # value is followed by comma or closing paren
 811
 812DIALECT_PATTERN = re.compile(
 813    rf"{HEADER}.*?{KEY_BOUNDARY}{WS_OR_COMMENT}*dialect{WS_OR_COMMENT}+{DIALECT_VALUE}{WS_OR_COMMENT}*{VALUE_BOUNDARY}",
 814    re.IGNORECASE | re.DOTALL,
 815)
 816
 817
 818def _is_command_statement(command: str, tokens: t.List[Token], pos: int) -> bool:
 819    try:
 820        return (
 821            tokens[pos].text.upper() == command.upper()
 822            and tokens[pos + 1].token_type == TokenType.SEMICOLON
 823        )
 824    except IndexError:
 825        return False
 826
 827
 828JINJA_QUERY_BEGIN = "JINJA_QUERY_BEGIN"
 829JINJA_STATEMENT_BEGIN = "JINJA_STATEMENT_BEGIN"
 830JINJA_END = "JINJA_END"
 831ON_VIRTUAL_UPDATE_BEGIN = "ON_VIRTUAL_UPDATE_BEGIN"
 832ON_VIRTUAL_UPDATE_END = "ON_VIRTUAL_UPDATE_END"
 833
 834
 835def _is_jinja_statement_begin(tokens: t.List[Token], pos: int) -> bool:
 836    return _is_command_statement(JINJA_STATEMENT_BEGIN, tokens, pos)
 837
 838
 839def _is_jinja_query_begin(tokens: t.List[Token], pos: int) -> bool:
 840    return _is_command_statement(JINJA_QUERY_BEGIN, tokens, pos)
 841
 842
 843def _is_jinja_end(tokens: t.List[Token], pos: int) -> bool:
 844    return _is_command_statement(JINJA_END, tokens, pos)
 845
 846
 847def jinja_query(query: str) -> JinjaQuery:
 848    return JinjaQuery(this=exp.Literal.string(query.strip()))
 849
 850
 851def jinja_statement(statement: str) -> JinjaStatement:
 852    return JinjaStatement(this=exp.Literal.string(statement.strip()))
 853
 854
 855def _is_virtual_statement_begin(tokens: t.List[Token], pos: int) -> bool:
 856    return _is_command_statement(ON_VIRTUAL_UPDATE_BEGIN, tokens, pos)
 857
 858
 859def _is_virtual_statement_end(tokens: t.List[Token], pos: int) -> bool:
 860    return _is_command_statement(ON_VIRTUAL_UPDATE_END, tokens, pos)
 861
 862
 863def virtual_statement(statements: t.List[exp.Expression]) -> VirtualUpdateStatement:
 864    return VirtualUpdateStatement(expressions=statements)
 865
 866
 867class ChunkType(Enum):
 868    JINJA_QUERY = auto()
 869    JINJA_STATEMENT = auto()
 870    SQL = auto()
 871    VIRTUAL_STATEMENT = auto()
 872    VIRTUAL_JINJA_STATEMENT = auto()
 873
 874
 875def parse_one(
 876    sql: str, dialect: t.Optional[str] = None, into: t.Optional[exp.IntoType] = None
 877) -> exp.Expression:
 878    expressions = parse(sql, default_dialect=dialect, match_dialect=False, into=into)
 879    if not expressions:
 880        raise SQLMeshError(f"No expressions found in '{sql}'")
 881    elif len(expressions) > 1:
 882        raise SQLMeshError(f"Multiple expressions found in '{sql}'")
 883    return expressions[0]
 884
 885
 886def parse(
 887    sql: str,
 888    default_dialect: t.Optional[str] = None,
 889    match_dialect: bool = True,
 890    into: t.Optional[exp.IntoType] = None,
 891) -> t.List[exp.Expression]:
 892    """Parse a sql string.
 893
 894    Supports parsing model definition.
 895    If a jinja block is detected, the query is stored as raw string in a Jinja node.
 896
 897    Args:
 898        sql: The sql based definition.
 899        default_dialect: The dialect to use if the model does not specify one.
 900
 901    Returns:
 902        A list of the parsed expressions: [Model, *Statements, Query, *Statements]
 903    """
 904    match = match_dialect and DIALECT_PATTERN.search(sql[:MAX_MODEL_DEFINITION_SIZE])
 905    dialect_str = match.group("dialect") if match else None
 906    dialect = Dialect.get_or_raise(dialect_str or default_dialect)
 907
 908    tokens = dialect.tokenize(sql)
 909    chunks: t.List[t.Tuple[t.List[Token], ChunkType]] = [([], ChunkType.SQL)]
 910    total = len(tokens)
 911
 912    pos = 0
 913    virtual = False
 914    while pos < total:
 915        token = tokens[pos]
 916        if _is_virtual_statement_end(tokens, pos):
 917            chunks[-1][0].append(token)
 918            virtual = False
 919            chunks.append(([], ChunkType.SQL))
 920            pos += 2
 921        elif _is_jinja_end(tokens, pos) or (
 922            chunks[-1][1] == ChunkType.SQL
 923            and token.token_type == TokenType.SEMICOLON
 924            and pos < total - 1
 925        ):
 926            if token.token_type == TokenType.SEMICOLON:
 927                pos += 1
 928            else:
 929                # Jinja end statement
 930                chunks[-1][0].append(token)
 931                pos += 2
 932            chunks.append(
 933                (
 934                    [],
 935                    ChunkType.VIRTUAL_STATEMENT
 936                    if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END
 937                    else ChunkType.SQL,
 938                )
 939            )
 940        elif _is_jinja_query_begin(tokens, pos):
 941            chunks.append(([token], ChunkType.JINJA_QUERY))
 942            pos += 2
 943        elif _is_jinja_statement_begin(tokens, pos):
 944            chunks.append(([token], ChunkType.JINJA_STATEMENT))
 945            pos += 2
 946        elif _is_virtual_statement_begin(tokens, pos):
 947            chunks.append(([token], ChunkType.VIRTUAL_STATEMENT))
 948            pos += 2
 949            virtual = True
 950        else:
 951            chunks[-1][0].append(token)
 952            pos += 1
 953
 954    parser = dialect.parser()
 955    expressions: t.List[exp.Expression] = []
 956
 957    def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expression]:
 958        parsed_expressions: t.List[t.Optional[exp.Expression]] = (
 959            parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql)
 960        )
 961        expressions = []
 962        for expression in parsed_expressions:
 963            if expression:
 964                if meta_sql:
 965                    expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1])
 966                expressions.append(expression)
 967        return expressions
 968
 969    def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expression:
 970        start, *_, end = chunk
 971        segment = sql[start.end + 2 : end.start - 1]
 972        factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement
 973        expression = factory(segment.strip())
 974        if meta_sql:
 975            expression.meta["sql"] = sql[start.start : end.end + 1]
 976        return expression
 977
 978    def parse_virtual_statement(
 979        chunks: t.List[t.Tuple[t.List[Token], ChunkType]], pos: int
 980    ) -> t.Tuple[t.List[exp.Expression], int]:
 981        # For virtual statements we need to handle both SQL and Jinja nested blocks within the chunk
 982        virtual_update_statements = []
 983        start = chunks[pos][0][0].start
 984
 985        while (
 986            chunks[pos - 1][0] == [] or chunks[pos - 1][0][-1].text.upper() != ON_VIRTUAL_UPDATE_END
 987        ):
 988            chunk, chunk_type = chunks[pos]
 989            if chunk_type == ChunkType.JINJA_STATEMENT:
 990                virtual_update_statements.append(parse_jinja_chunk(chunk, False))
 991            else:
 992                virtual_update_statements.extend(
 993                    parse_sql_chunk(
 994                        chunk[int(chunk[0].text.upper() == ON_VIRTUAL_UPDATE_BEGIN) : -1], False
 995                    ),
 996                )
 997            pos += 1
 998
 999        if virtual_update_statements:
1000            statements = virtual_statement(virtual_update_statements)
1001            end = chunk[-1].end + 1
1002            statements.meta["sql"] = sql[start:end]
1003            return [statements], pos
1004
1005        return [], pos
1006
1007    pos = 0
1008    total_chunks = len(chunks)
1009    while pos < total_chunks:
1010        chunk, chunk_type = chunks[pos]
1011        if chunk_type == ChunkType.VIRTUAL_STATEMENT:
1012            virtual_expression, pos = parse_virtual_statement(chunks, pos)
1013            expressions.extend(virtual_expression)
1014        elif chunk_type == ChunkType.SQL:
1015            expressions.extend(parse_sql_chunk(chunk))
1016        else:
1017            expressions.append(parse_jinja_chunk(chunk))
1018        pos += 1
1019
1020    return expressions
1021
1022
1023def extend_sqlglot() -> None:
1024    """Extend SQLGlot with SQLMesh's custom macro aware dialect."""
1025    tokenizers = {Tokenizer}
1026    parsers = {Parser}
1027    generators = {Generator}
1028
1029    for dialect in Dialect.classes.values():
1030        # Athena picks a different Tokenizer / Parser / Generator depending on the query
1031        # so this ensures that the extra ones it defines are also extended
1032        if dialect == athena.Athena:
1033            tokenizers.add(athena._TrinoTokenizer)
1034            parsers.add(athena._TrinoParser)
1035            generators.add(athena._TrinoGenerator)
1036            generators.add(athena._HiveGenerator)
1037
1038        if hasattr(dialect, "Tokenizer"):
1039            tokenizers.add(dialect.Tokenizer)
1040        if hasattr(dialect, "Parser"):
1041            parsers.add(dialect.Parser)
1042        if hasattr(dialect, "Generator"):
1043            generators.add(dialect.Generator)
1044
1045    for tokenizer in tokenizers:
1046        tokenizer.VAR_SINGLE_TOKENS.update(SQLMESH_MACRO_PREFIX)
1047
1048    for parser in parsers:
1049        parser.FUNCTIONS.update({"JINJA": Jinja.from_arg_list, "METRIC": MetricAgg.from_arg_list})
1050        parser.PLACEHOLDER_PARSERS.update({TokenType.PARAMETER: _parse_macro})
1051        parser.QUERY_MODIFIER_PARSERS.update(
1052            {TokenType.PARAMETER: lambda self: _parse_body_macro(self)}
1053        )
1054
1055    for generator in generators:
1056        if MacroFunc not in generator.TRANSFORMS:
1057            generator.TRANSFORMS.update(
1058                {
1059                    Audit: lambda self, e: _sqlmesh_ddl_sql(self, e, "AUDIT"),
1060                    DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}",
1061                    Jinja: lambda self, e: e.name,
1062                    JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};",
1063                    JinjaStatement: lambda self,
1064                    e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};",
1065                    VirtualUpdateStatement: lambda self, e: _on_virtual_update_sql(self, e),
1066                    MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})",
1067                    MacroFunc: _macro_func_sql,
1068                    MacroStrReplace: lambda self, e: f"@{self.sql(e.this)}",
1069                    MacroSQL: lambda self, e: f"@SQL({self.sql(e.this)})",
1070                    MacroVar: lambda self, e: f"@{e.name}",
1071                    Metric: lambda self, e: _sqlmesh_ddl_sql(self, e, "METRIC"),
1072                    Model: lambda self, e: _sqlmesh_ddl_sql(self, e, "MODEL"),
1073                    ModelKind: _model_kind_sql,
1074                    PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False),
1075                    StagedFilePath: lambda self, e: self.table_sql(e),
1076                    exp.Whens: _whens_sql,
1077                }
1078            )
1079        if MacroDef not in generator.WITH_SEPARATED_COMMENTS:
1080            generator.WITH_SEPARATED_COMMENTS = (
1081                *generator.WITH_SEPARATED_COMMENTS,
1082                Model,
1083                MacroDef,
1084            )
1085
1086        generator.UNWRAPPED_INTERVAL_VALUES = (
1087            *generator.UNWRAPPED_INTERVAL_VALUES,
1088            MacroStrReplace,
1089            MacroVar,
1090        )
1091
1092    _override(Parser, _parse_select)
1093    _override(Parser, _parse_statement)
1094    _override(Parser, _parse_join)
1095    _override(Parser, _parse_order)
1096    _override(Parser, _parse_where)
1097    _override(Parser, _parse_group)
1098    _override(Parser, _parse_with)
1099    _override(Parser, _parse_having)
1100    _override(Parser, _parse_limit)
1101    _override(Parser, _parse_value)
1102    _override(Parser, _parse_lambda)
1103    _override(Parser, _parse_types)
1104    _override(Parser, _parse_if)
1105    _override(Parser, _parse_id_var)
1106    _override(Parser, _warn_unsupported)
1107    _override(Snowflake.Parser, _parse_table_parts)
1108
1109    # DuckDB's prefix absolute power operator `@` clashes with the macro syntax
1110    DuckDB.Parser.NO_PAREN_FUNCTION_PARSERS.pop("@", None)
1111
1112
1113def select_from_values(
1114    values: t.List[t.Tuple[t.Any, ...]],
1115    columns_to_types: t.Dict[str, exp.DataType],
1116    batch_size: int = 0,
1117    alias: str = "t",
1118) -> t.Iterator[exp.Select]:
1119    """Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types.
1120
1121    Args:
1122        values: List of values to use for the VALUES expression.
1123        columns_to_types: Mapping of column names to types to assign to the values.
1124        batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
1125        alias: The alias to assign to the values expression. If not provided then will default to "t"
1126
1127    Returns:
1128        This method operates as a generator and yields a VALUES expression.
1129    """
1130    if batch_size <= 0:
1131        batch_size = sys.maxsize
1132    num_rows = len(values)
1133    for i in range(0, num_rows, batch_size):
1134        yield select_from_values_for_batch_range(
1135            values=values,
1136            target_columns_to_types=columns_to_types,
1137            batch_start=i,
1138            batch_end=min(i + batch_size, num_rows),
1139            alias=alias,
1140        )
1141
1142
1143def select_from_values_for_batch_range(
1144    values: t.List[t.Tuple[t.Any, ...]],
1145    target_columns_to_types: t.Dict[str, exp.DataType],
1146    batch_start: int,
1147    batch_end: int,
1148    alias: str = "t",
1149    source_columns: t.Optional[t.List[str]] = None,
1150) -> exp.Select:
1151    source_columns = source_columns or list(target_columns_to_types)
1152    source_columns_to_types = get_source_columns_to_types(target_columns_to_types, source_columns)
1153
1154    if not values:
1155        # Ensures we don't generate an empty VALUES clause & forces a zero-row output
1156        where = exp.false()
1157        expressions = [
1158            tuple(exp.cast(exp.null(), to=kind) for kind in source_columns_to_types.values())
1159        ]
1160    else:
1161        where = None
1162        expressions = [
1163            tuple(transform_values(v, source_columns_to_types))
1164            for v in values[batch_start:batch_end]
1165        ]
1166
1167    values_exp = exp.values(expressions, alias=alias, columns=source_columns_to_types)
1168    if values:
1169        # BigQuery crashes on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([NULL]) AS x`, but not
1170        # on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([CAST(NULL AS TIMESTAMP)]) AS x`. This
1171        # ensures nulls under the `Values` expression are cast to avoid similar issues.
1172        for value, kind in zip(
1173            values_exp.expressions[0].expressions, source_columns_to_types.values()
1174        ):
1175            if isinstance(value, exp.Null):
1176                value.replace(exp.cast(value, to=kind))
1177
1178    casted_columns = [
1179        exp.alias_(
1180            exp.cast(
1181                exp.column(column) if column in source_columns_to_types else exp.Null(), to=kind
1182            ),
1183            column,
1184            copy=False,
1185        )
1186        for column, kind in target_columns_to_types.items()
1187    ]
1188    return exp.select(*casted_columns).from_(values_exp, copy=False).where(where, copy=False)
1189
1190
1191def pandas_to_sql(
1192    df: pd.DataFrame,
1193    columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
1194    batch_size: int = 0,
1195    alias: str = "t",
1196) -> t.Iterator[exp.Select]:
1197    """Convert a pandas dataframe into a VALUES sql statement.
1198
1199    Args:
1200        df: A pandas dataframe to convert.
1201        columns_to_types: Mapping of column names to types to assign to the values.
1202        batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
1203        alias: The alias to assign to the values expression. If not provided then will default to "t"
1204
1205    Returns:
1206        This method operates as a generator and yields a VALUES expression.
1207    """
1208    yield from select_from_values(
1209        values=list(df.itertuples(index=False, name=None)),
1210        columns_to_types=columns_to_types or columns_to_types_from_df(df),
1211        batch_size=batch_size,
1212        alias=alias,
1213    )
1214
1215
1216def set_default_catalog(
1217    table: str | exp.Table,
1218    default_catalog: t.Optional[str],
1219) -> exp.Table:
1220    table = exp.to_table(table)
1221
1222    if default_catalog and not table.catalog and table.db:
1223        table.set("catalog", exp.parse_identifier(default_catalog))
1224
1225    return table
1226
1227
1228@lru_cache(maxsize=16384)
1229def normalize_model_name(
1230    table: str | exp.Table | exp.Column,
1231    default_catalog: t.Optional[str],
1232    dialect: DialectType = None,
1233) -> str:
1234    if isinstance(table, exp.Column):
1235        table = exp.table_(table.this, db=table.args.get("table"), catalog=table.args.get("db"))
1236    else:
1237        # We are relying on sqlglot's flexible parsing here to accept quotes from other dialects.
1238        # Ex: I have a a normalized name of '"my_table"' but the dialect is spark and therefore we should
1239        # expect spark quotes to be backticks ('`') instead of double quotes ('"'). sqlglot today is flexible
1240        # and will still parse this correctly and we rely on that.
1241        table = exp.to_table(table, dialect=dialect)
1242
1243    table = set_default_catalog(table, default_catalog)
1244    # An alternative way to do this is the following: exp.table_name(table, dialect=dialect, identify=True)
1245    # This though would result in the names being normalized to the target dialect AND the quotes while the below
1246    # approach just normalizes the names.
1247    # By just normalizing names and using sqlglot dialect for quotes this makes it easier for dialects that have
1248    # compatible normalization strategies but incompatible quoting to still work together without user hassle
1249    return exp.table_name(normalize_identifiers(table, dialect=dialect), identify=True)
1250
1251
1252def find_tables(
1253    expression: exp.Expression, default_catalog: t.Optional[str], dialect: DialectType = None
1254) -> t.Set[str]:
1255    """Find all tables referenced in a query.
1256
1257    Caches the result in the meta field 'tables'.
1258
1259    Args:
1260        expressions: The query to find the tables in.
1261        dialect: The dialect to use for normalization of table names.
1262
1263    Returns:
1264        A Set of all the table names.
1265    """
1266    if TABLES_META not in expression.meta:
1267        expression.meta[TABLES_META] = {
1268            normalize_model_name(table, default_catalog=default_catalog, dialect=dialect)
1269            for scope in traverse_scope(expression)
1270            for table in scope.tables
1271            if table.name and table.name not in scope.cte_sources
1272        }
1273    return expression.meta[TABLES_META]
1274
1275
1276def add_table(node: exp.Expression, table: str) -> exp.Expression:
1277    """Add a table to all columns in an expression."""
1278
1279    def _transform(node: exp.Expression) -> exp.Expression:
1280        if isinstance(node, exp.Column) and not node.table:
1281            return exp.column(node.this, table=table)
1282        if isinstance(node, exp.Identifier):
1283            return exp.column(node, table=table)
1284        return node
1285
1286    return node.transform(_transform)
1287
1288
1289def transform_values(
1290    values: t.Tuple[t.Any, ...], columns_to_types: t.Dict[str, exp.DataType]
1291) -> t.Iterator[t.Any]:
1292    """Perform transformations on values given columns_to_types."""
1293
1294    def _transform_value(value: t.Any, dtype: exp.DataType) -> t.Any:
1295        if (
1296            isinstance(value, list)
1297            and dtype.is_type(*exp.DataType.ARRAY_TYPES)
1298            and len(dtype.expressions) == 1
1299        ):
1300            element_type = dtype.expressions[0]
1301            return exp.convert([_transform_value(v, element_type) for v in value])
1302
1303        if (
1304            isinstance(value, dict)
1305            and dtype.is_type(*exp.DataType.STRUCT_TYPES)
1306            and len(value) == len(dtype.expressions)
1307        ):
1308            expressions = []
1309            for (field_name, field_value), field_type in zip(value.items(), dtype.expressions):
1310                if isinstance(field_type, exp.ColumnDef):
1311                    field_type = field_type.kind
1312                else:
1313                    field_type = exp.DataType.build(exp.DataType.Type.UNKNOWN)
1314
1315                expressions.append(
1316                    exp.PropertyEQ(
1317                        this=exp.to_identifier(field_name),
1318                        expression=_transform_value(field_value, field_type),
1319                    )
1320                )
1321
1322            return exp.Struct(expressions=expressions)
1323
1324        if dtype.is_type(exp.DataType.Type.JSON):
1325            return exp.func("PARSE_JSON", f"'{value}'")
1326
1327        return exp.convert(value)
1328
1329    for col_value, col_type in zip(values, columns_to_types.values()):
1330        yield _transform_value(col_value, col_type)
1331
1332
1333def to_schema(sql_path: str | exp.Table, dialect: DialectType = None) -> exp.Table:
1334    if isinstance(sql_path, exp.Table) and sql_path.this is None:
1335        return sql_path
1336    table = exp.to_table(
1337        sql_path.copy() if isinstance(sql_path, exp.Table) else sql_path, dialect=dialect
1338    )
1339    table.set("catalog", table.args.get("db"))
1340    table.set("db", table.args.get("this"))
1341    table.set("this", None)
1342    return table
1343
1344
1345def schema_(
1346    db: exp.Identifier | str,
1347    catalog: t.Optional[exp.Identifier | str] = None,
1348    quoted: t.Optional[bool] = None,
1349) -> exp.Table:
1350    """Build a Schema.
1351
1352    Args:
1353        db: Database name.
1354        catalog: Catalog name.
1355        quoted: Whether to force quotes on the schema's identifiers.
1356
1357    Returns:
1358        The new Schema instance.
1359    """
1360    return exp.Table(
1361        this=None,
1362        db=exp.to_identifier(db, quoted=quoted) if db else None,
1363        catalog=exp.to_identifier(catalog, quoted=quoted) if catalog else None,
1364    )
1365
1366
1367def normalize_mapping_schema(schema: t.Dict, dialect: DialectType) -> MappingSchema:
1368    return MappingSchema(_unquote_schema(schema), dialect=dialect, normalize=False)
1369
1370
1371def _unquote_schema(schema: t.Dict) -> t.Dict:
1372    """SQLGlot schema expects unquoted normalized keys."""
1373    return {
1374        k.strip('"'): _unquote_schema(v) if isinstance(v, dict) else v for k, v in schema.items()
1375    }
1376
1377
1378@contextmanager
1379def normalize_and_quote(
1380    query: E, dialect: DialectType, default_catalog: t.Optional[str], quote: bool = True
1381) -> t.Iterator[E]:
1382    qualify_tables(query, catalog=default_catalog, dialect=dialect)
1383    normalize_identifiers(query, dialect=dialect)
1384    yield query
1385    if quote:
1386        quote_identifiers(query, dialect=dialect)
1387
1388
1389def interpret_expression(e: exp.Expression) -> exp.Expression | str | int | float | bool:
1390    if e.is_int:
1391        return int(e.this)
1392    if e.is_number:
1393        return float(e.this)
1394    if isinstance(e, (exp.Literal, exp.Boolean)):
1395        return e.this
1396    return e
1397
1398
1399def interpret_key_value_pairs(
1400    e: exp.Tuple,
1401) -> t.Dict[str, exp.Expression | str | int | float | bool]:
1402    return {i.this.name: interpret_expression(i.expression) for i in e.expressions}
1403
1404
1405def extract_func_call(
1406    v: exp.Expression, allow_tuples: bool = False
1407) -> t.Tuple[str, t.Dict[str, exp.Expression]]:
1408    kwargs = {}
1409
1410    if isinstance(v, exp.Anonymous):
1411        func = v.name
1412        args = v.expressions
1413    elif isinstance(v, exp.Func):
1414        func = v.sql_name()
1415        args = list(v.args.values())
1416    elif isinstance(v, exp.Paren):
1417        func = ""
1418        args = [v.this]
1419    elif isinstance(v, exp.Tuple):  # airflow only
1420        if not allow_tuples:
1421            raise ConfigError("Audit name is missing (eg. MY_AUDIT())")
1422
1423        func = ""
1424        args = v.expressions
1425    else:
1426        return v.name.lower(), {}
1427
1428    for arg in args:
1429        if not isinstance(arg, (exp.PropertyEQ, exp.EQ)):
1430            raise ConfigError(
1431                f"Function '{func}' must be called with key-value arguments like {func}(arg := value)."
1432            )
1433        kwargs[arg.left.name.lower()] = arg.right
1434    return func.lower(), kwargs
1435
1436
1437def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.Any:
1438    """Used for extracting function calls for signals or audits."""
1439
1440    if isinstance(func_calls, (exp.Tuple, exp.Array)):
1441        return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions]
1442    if isinstance(func_calls, exp.Paren):
1443        return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)]
1444    if isinstance(func_calls, exp.Expression):
1445        return [extract_func_call(func_calls, allow_tuples=allow_tuples)]
1446    if isinstance(func_calls, list):
1447        function_calls = []
1448        for entry in func_calls:
1449            if isinstance(entry, dict):
1450                args = entry
1451                name = "" if allow_tuples else entry.pop("name")
1452            elif isinstance(entry, (tuple, list)):
1453                name, args = entry
1454            else:
1455                raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.")
1456
1457            function_calls.append(
1458                (
1459                    name.lower(),
1460                    {
1461                        key: parse_one(value) if isinstance(value, str) else value
1462                        for key, value in args.items()
1463                    },
1464                )
1465            )
1466
1467        return function_calls
1468
1469    return func_calls or []
1470
1471
1472def is_meta_expression(v: t.Any) -> bool:
1473    return isinstance(v, (Audit, Metric, Model))
1474
1475
1476def replace_merge_table_aliases(
1477    expression: exp.Expression, dialect: t.Optional[str] = None
1478) -> exp.Expression:
1479    """
1480    Resolves references from the "source" and "target" tables (or their DBT equivalents)
1481    with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)
1482    """
1483    from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
1484
1485    if isinstance(expression, exp.Column) and (first_part := expression.parts[0]):
1486        if first_part.this.lower() in ("target", "dbt_internal_dest", "__merge_target__"):
1487            first_part.replace(exp.to_identifier(MERGE_TARGET_ALIAS, quoted=True))
1488        elif first_part.this.lower() in ("source", "dbt_internal_source", "__merge_source__"):
1489            first_part.replace(exp.to_identifier(MERGE_SOURCE_ALIAS, quoted=True))
1490
1491    return expression
SQLMESH_MACRO_PREFIX = '@'
TABLES_META = 'sqlmesh.tables'
logger = <Logger sqlmesh.core.dialect (WARNING)>
class Model(sqlglot.expressions.Expression):
44class Model(exp.Expression):
45    arg_types = {"expressions": True}
arg_types = {'expressions': True}
key = 'model'
required_args = {'expressions'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
class Audit(sqlglot.expressions.Expression):
48class Audit(exp.Expression):
49    arg_types = {"expressions": True}
arg_types = {'expressions': True}
key = 'audit'
required_args = {'expressions'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
class Metric(sqlglot.expressions.Expression):
52class Metric(exp.Expression):
53    arg_types = {"expressions": True}
arg_types = {'expressions': True}
key = 'metric'
required_args = {'expressions'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
class Jinja(sqlglot.expressions.Func):
56class Jinja(exp.Func):
57    arg_types = {"this": True}
arg_types = {'this': True}
key = 'jinja'
required_args = {'this'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
sqlglot.expressions.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class JinjaQuery(Jinja):
60class JinjaQuery(Jinja):
61    pass
key = 'jinjaquery'
required_args = {'this'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
Jinja
arg_types
sqlglot.expressions.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class JinjaStatement(Jinja):
64class JinjaStatement(Jinja):
65    pass
key = 'jinjastatement'
required_args = {'this'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
Jinja
arg_types
sqlglot.expressions.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class VirtualUpdateStatement(sqlglot.expressions.Expression):
68class VirtualUpdateStatement(exp.Expression):
69    arg_types = {"expressions": True}
arg_types = {'expressions': True}
key = 'virtualupdatestatement'
required_args = {'expressions'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
class ModelKind(sqlglot.expressions.Expression):
72class ModelKind(exp.Expression):
73    arg_types = {"this": True, "expressions": False}
arg_types = {'this': True, 'expressions': False}
key = 'modelkind'
required_args = {'this'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
class MacroVar(sqlglot.expressions.Var):
76class MacroVar(exp.Var):
77    pass
key = 'macrovar'
required_args = {'this'}
Inherited Members
sqlglot.expressions.Expression
Expression
arg_types
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
class MacroFunc(sqlglot.expressions.Func):
80class MacroFunc(exp.Func):
81    @property
82    def name(self) -> str:
83        return self.this.name
name: str
81    @property
82    def name(self) -> str:
83        return self.this.name
key = 'macrofunc'
required_args = {'this'}
Inherited Members
sqlglot.expressions.Expression
Expression
arg_types
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
sqlglot.expressions.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class MacroDef(MacroFunc):
86class MacroDef(MacroFunc):
87    arg_types = {"this": True, "expression": True}
arg_types = {'this': True, 'expression': True}
key = 'macrodef'
required_args = {'this', 'expression'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
MacroFunc
name
sqlglot.expressions.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class MacroSQL(MacroFunc):
90class MacroSQL(MacroFunc):
91    arg_types = {"this": True, "into": False}
arg_types = {'this': True, 'into': False}
key = 'macrosql'
required_args = {'this'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
MacroFunc
name
sqlglot.expressions.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class MacroStrReplace(MacroFunc):
94class MacroStrReplace(MacroFunc):
95    pass
key = 'macrostrreplace'
required_args = {'this'}
Inherited Members
sqlglot.expressions.Expression
Expression
arg_types
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
MacroFunc
name
sqlglot.expressions.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class PythonCode(sqlglot.expressions.Expression):
98class PythonCode(exp.Expression):
99    arg_types = {"expressions": True}
arg_types = {'expressions': True}
key = 'pythoncode'
required_args = {'expressions'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
class DColonCast(sqlglot.expressions.Cast):
102class DColonCast(exp.Cast):
103    pass
key = 'dcoloncast'
required_args = {'to', 'this'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
alias_or_name
type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
sqlglot.expressions.Cast
arg_types
name
to
output_name
is_type
sqlglot.expressions.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class MetricAgg(sqlglot.expressions.AggFunc):
106class MetricAgg(exp.AggFunc):
107    """Used for computing metrics."""
108
109    arg_types = {"this": True}
110
111    @property
112    def output_name(self) -> str:
113        return self.this.name

Used for computing metrics.

arg_types = {'this': True}
output_name: str
111    @property
112    def output_name(self) -> str:
113        return self.this.name

Name of the output column if this expression is a selection.

If the Expression has no output name, an empty string is returned.

Example:
>>> from sqlglot import parse_one
>>> parse_one("SELECT a").expressions[0].output_name
'a'
>>> parse_one("SELECT b AS c").expressions[0].output_name
'c'
>>> parse_one("SELECT 1 + 2").expressions[0].output_name
''
key = 'metricagg'
required_args = {'this'}
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
sqlglot.expressions.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class StagedFilePath(sqlglot.expressions.Expression):
116class StagedFilePath(exp.Expression):
117    """Represents paths to "staged files" in Snowflake."""
118
119    arg_types = exp.Table.arg_types.copy()

Represents paths to "staged files" in Snowflake.

arg_types = {'this': False, 'alias': False, 'db': False, 'catalog': False, 'laterals': False, 'joins': False, 'pivots': False, 'hints': False, 'system_time': False, 'version': False, 'format': False, 'pattern': False, 'ordinality': False, 'when': False, 'only': False, 'partition': False, 'changes': False, 'rows_from': False, 'sample': False, 'indexed': False}
key = 'stagedfilepath'
required_args = set()
Inherited Members
sqlglot.expressions.Expression
Expression
args
parent
arg_key
index
comments
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
dump
load
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
KEYWORD_MACROS = {'WITH', 'JOIN', 'LIMIT', 'HAVING', 'ORDER_BY', 'WHERE', 'GROUP_BY'}
PARSERS = {'MODEL': <function _create_parser.<locals>.parse>, 'AUDIT': <function _create_parser.<locals>.parse>, 'METRIC': <function _create_parser.<locals>.parse>}
def format_model_expressions( expressions: List[sqlglot.expressions.Expression], dialect: Optional[str] = None, rewrite_casts: bool = True, **kwargs: Any) -> str:
734def format_model_expressions(
735    expressions: t.List[exp.Expression],
736    dialect: t.Optional[str] = None,
737    rewrite_casts: bool = True,
738    **kwargs: t.Any,
739) -> str:
740    """Format a model's expressions into a standardized format.
741
742    Args:
743        expressions: The model's expressions, must be at least model def + query.
744        dialect: The dialect to render the expressions as.
745        rewrite_casts: Whether to rewrite all casts to use the :: syntax.
746        **kwargs: Additional keyword arguments to pass to the sql generator.
747
748    Returns:
749        A string representing the formatted model.
750    """
751    if len(expressions) == 1 and is_meta_expression(expressions[0]):
752        return expressions[0].sql(pretty=True, dialect=dialect)
753
754    if rewrite_casts:
755
756        def cast_to_colon(node: exp.Expression) -> exp.Expression:
757            if isinstance(node, exp.Cast) and not any(
758                # Only convert CAST into :: if it doesn't have additional args set, otherwise this
759                # conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST)
760                arg
761                for name, arg in node.args.items()
762                if name not in ("this", "to")
763            ):
764                this = node.this
765
766                if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren):
767                    cast = DColonCast(this=this, to=node.to)
768                    cast.comments = node.comments
769                    node = cast
770
771            exp.replace_children(node, cast_to_colon)
772            return node
773
774        new_expressions = []
775        for expression in expressions:
776            expression = expression.copy()
777            exp.replace_children(expression, cast_to_colon)
778            new_expressions.append(expression)
779
780        expressions = new_expressions
781
782    return ";\n\n".join(
783        expression.sql(pretty=True, dialect=dialect, **kwargs) for expression in expressions
784    ).strip()

Format a model's expressions into a standardized format.

Arguments:
  • expressions: The model's expressions, must be at least model def + query.
  • dialect: The dialect to render the expressions as.
  • rewrite_casts: Whether to rewrite all casts to use the :: syntax.
  • **kwargs: Additional keyword arguments to pass to the sql generator.
Returns:

A string representing the formatted model.

def text_diff( a: List[sqlglot.expressions.Expression], b: List[sqlglot.expressions.Expression], a_dialect: Optional[str] = None, b_dialect: Optional[str] = None) -> str:
787def text_diff(
788    a: t.List[exp.Expression],
789    b: t.List[exp.Expression],
790    a_dialect: t.Optional[str] = None,
791    b_dialect: t.Optional[str] = None,
792) -> str:
793    """Find the unified text diff between two expressions."""
794    a_sql = [
795        line
796        for expr in a
797        for line in expr.sql(pretty=True, comments=False, dialect=a_dialect).split("\n")
798    ]
799    b_sql = [
800        line
801        for expr in b
802        for line in expr.sql(pretty=True, comments=False, dialect=b_dialect).split("\n")
803    ]
804    return "\n".join(unified_diff(a_sql, b_sql))

Find the unified text diff between two expressions.

WS_OR_COMMENT = '(?:\\s|--[^\\n]*\\n|/\\*.*?\\*/)'
KEY_BOUNDARY = '(?:\\(|,)'
DIALECT_VALUE = '[\'\\"]?(?P<dialect>[a-z][a-z0-9]*)[\'\\"]?'
VALUE_BOUNDARY = '(?=,|\\))'
DIALECT_PATTERN = re.compile('\\b(?:model|audit)\\b(?=\\s*\\().*?(?:\\(|,)(?:\\s|--[^\\n]*\\n|/\\*.*?\\*/)*dialect(?:\\s|--[^\\n]*\\n|/\\*.*?\\*/)+[\'\\"]?(?P<dialect>[a-z][a-z0-9]*)[\'\\"]?(?:\\s|--[^\\n]*\\n|/\\*.*?\\*/)*(?=,|\, re.IGNORECASE|re.DOTALL)
JINJA_QUERY_BEGIN = 'JINJA_QUERY_BEGIN'
JINJA_STATEMENT_BEGIN = 'JINJA_STATEMENT_BEGIN'
JINJA_END = 'JINJA_END'
ON_VIRTUAL_UPDATE_BEGIN = 'ON_VIRTUAL_UPDATE_BEGIN'
ON_VIRTUAL_UPDATE_END = 'ON_VIRTUAL_UPDATE_END'
def jinja_query(query: str) -> JinjaQuery:
848def jinja_query(query: str) -> JinjaQuery:
849    return JinjaQuery(this=exp.Literal.string(query.strip()))
def jinja_statement(statement: str) -> JinjaStatement:
852def jinja_statement(statement: str) -> JinjaStatement:
853    return JinjaStatement(this=exp.Literal.string(statement.strip()))
def virtual_statement( statements: List[sqlglot.expressions.Expression]) -> VirtualUpdateStatement:
864def virtual_statement(statements: t.List[exp.Expression]) -> VirtualUpdateStatement:
865    return VirtualUpdateStatement(expressions=statements)
class ChunkType(enum.Enum):
868class ChunkType(Enum):
869    JINJA_QUERY = auto()
870    JINJA_STATEMENT = auto()
871    SQL = auto()
872    VIRTUAL_STATEMENT = auto()
873    VIRTUAL_JINJA_STATEMENT = auto()

An enumeration.

JINJA_QUERY = <ChunkType.JINJA_QUERY: 1>
JINJA_STATEMENT = <ChunkType.JINJA_STATEMENT: 2>
SQL = <ChunkType.SQL: 3>
VIRTUAL_STATEMENT = <ChunkType.VIRTUAL_STATEMENT: 4>
VIRTUAL_JINJA_STATEMENT = <ChunkType.VIRTUAL_JINJA_STATEMENT: 5>
Inherited Members
enum.Enum
name
value
def parse_one( sql: str, dialect: Optional[str] = None, into: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]], NoneType] = None) -> sqlglot.expressions.Expression:
876def parse_one(
877    sql: str, dialect: t.Optional[str] = None, into: t.Optional[exp.IntoType] = None
878) -> exp.Expression:
879    expressions = parse(sql, default_dialect=dialect, match_dialect=False, into=into)
880    if not expressions:
881        raise SQLMeshError(f"No expressions found in '{sql}'")
882    elif len(expressions) > 1:
883        raise SQLMeshError(f"Multiple expressions found in '{sql}'")
884    return expressions[0]
def parse( sql: str, default_dialect: Optional[str] = None, match_dialect: bool = True, into: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]], NoneType] = None) -> List[sqlglot.expressions.Expression]:
 887def parse(
 888    sql: str,
 889    default_dialect: t.Optional[str] = None,
 890    match_dialect: bool = True,
 891    into: t.Optional[exp.IntoType] = None,
 892) -> t.List[exp.Expression]:
 893    """Parse a sql string.
 894
 895    Supports parsing model definition.
 896    If a jinja block is detected, the query is stored as raw string in a Jinja node.
 897
 898    Args:
 899        sql: The sql based definition.
 900        default_dialect: The dialect to use if the model does not specify one.
 901
 902    Returns:
 903        A list of the parsed expressions: [Model, *Statements, Query, *Statements]
 904    """
 905    match = match_dialect and DIALECT_PATTERN.search(sql[:MAX_MODEL_DEFINITION_SIZE])
 906    dialect_str = match.group("dialect") if match else None
 907    dialect = Dialect.get_or_raise(dialect_str or default_dialect)
 908
 909    tokens = dialect.tokenize(sql)
 910    chunks: t.List[t.Tuple[t.List[Token], ChunkType]] = [([], ChunkType.SQL)]
 911    total = len(tokens)
 912
 913    pos = 0
 914    virtual = False
 915    while pos < total:
 916        token = tokens[pos]
 917        if _is_virtual_statement_end(tokens, pos):
 918            chunks[-1][0].append(token)
 919            virtual = False
 920            chunks.append(([], ChunkType.SQL))
 921            pos += 2
 922        elif _is_jinja_end(tokens, pos) or (
 923            chunks[-1][1] == ChunkType.SQL
 924            and token.token_type == TokenType.SEMICOLON
 925            and pos < total - 1
 926        ):
 927            if token.token_type == TokenType.SEMICOLON:
 928                pos += 1
 929            else:
 930                # Jinja end statement
 931                chunks[-1][0].append(token)
 932                pos += 2
 933            chunks.append(
 934                (
 935                    [],
 936                    ChunkType.VIRTUAL_STATEMENT
 937                    if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END
 938                    else ChunkType.SQL,
 939                )
 940            )
 941        elif _is_jinja_query_begin(tokens, pos):
 942            chunks.append(([token], ChunkType.JINJA_QUERY))
 943            pos += 2
 944        elif _is_jinja_statement_begin(tokens, pos):
 945            chunks.append(([token], ChunkType.JINJA_STATEMENT))
 946            pos += 2
 947        elif _is_virtual_statement_begin(tokens, pos):
 948            chunks.append(([token], ChunkType.VIRTUAL_STATEMENT))
 949            pos += 2
 950            virtual = True
 951        else:
 952            chunks[-1][0].append(token)
 953            pos += 1
 954
 955    parser = dialect.parser()
 956    expressions: t.List[exp.Expression] = []
 957
 958    def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expression]:
 959        parsed_expressions: t.List[t.Optional[exp.Expression]] = (
 960            parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql)
 961        )
 962        expressions = []
 963        for expression in parsed_expressions:
 964            if expression:
 965                if meta_sql:
 966                    expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1])
 967                expressions.append(expression)
 968        return expressions
 969
 970    def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expression:
 971        start, *_, end = chunk
 972        segment = sql[start.end + 2 : end.start - 1]
 973        factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement
 974        expression = factory(segment.strip())
 975        if meta_sql:
 976            expression.meta["sql"] = sql[start.start : end.end + 1]
 977        return expression
 978
 979    def parse_virtual_statement(
 980        chunks: t.List[t.Tuple[t.List[Token], ChunkType]], pos: int
 981    ) -> t.Tuple[t.List[exp.Expression], int]:
 982        # For virtual statements we need to handle both SQL and Jinja nested blocks within the chunk
 983        virtual_update_statements = []
 984        start = chunks[pos][0][0].start
 985
 986        while (
 987            chunks[pos - 1][0] == [] or chunks[pos - 1][0][-1].text.upper() != ON_VIRTUAL_UPDATE_END
 988        ):
 989            chunk, chunk_type = chunks[pos]
 990            if chunk_type == ChunkType.JINJA_STATEMENT:
 991                virtual_update_statements.append(parse_jinja_chunk(chunk, False))
 992            else:
 993                virtual_update_statements.extend(
 994                    parse_sql_chunk(
 995                        chunk[int(chunk[0].text.upper() == ON_VIRTUAL_UPDATE_BEGIN) : -1], False
 996                    ),
 997                )
 998            pos += 1
 999
1000        if virtual_update_statements:
1001            statements = virtual_statement(virtual_update_statements)
1002            end = chunk[-1].end + 1
1003            statements.meta["sql"] = sql[start:end]
1004            return [statements], pos
1005
1006        return [], pos
1007
1008    pos = 0
1009    total_chunks = len(chunks)
1010    while pos < total_chunks:
1011        chunk, chunk_type = chunks[pos]
1012        if chunk_type == ChunkType.VIRTUAL_STATEMENT:
1013            virtual_expression, pos = parse_virtual_statement(chunks, pos)
1014            expressions.extend(virtual_expression)
1015        elif chunk_type == ChunkType.SQL:
1016            expressions.extend(parse_sql_chunk(chunk))
1017        else:
1018            expressions.append(parse_jinja_chunk(chunk))
1019        pos += 1
1020
1021    return expressions

Parse a sql string.

Supports parsing model definition. If a jinja block is detected, the query is stored as raw string in a Jinja node.

Arguments:
  • sql: The sql based definition.
  • default_dialect: The dialect to use if the model does not specify one.
Returns:

A list of the parsed expressions: [Model, *Statements, Query, *Statements]

def extend_sqlglot() -> None:
1024def extend_sqlglot() -> None:
1025    """Extend SQLGlot with SQLMesh's custom macro aware dialect."""
1026    tokenizers = {Tokenizer}
1027    parsers = {Parser}
1028    generators = {Generator}
1029
1030    for dialect in Dialect.classes.values():
1031        # Athena picks a different Tokenizer / Parser / Generator depending on the query
1032        # so this ensures that the extra ones it defines are also extended
1033        if dialect == athena.Athena:
1034            tokenizers.add(athena._TrinoTokenizer)
1035            parsers.add(athena._TrinoParser)
1036            generators.add(athena._TrinoGenerator)
1037            generators.add(athena._HiveGenerator)
1038
1039        if hasattr(dialect, "Tokenizer"):
1040            tokenizers.add(dialect.Tokenizer)
1041        if hasattr(dialect, "Parser"):
1042            parsers.add(dialect.Parser)
1043        if hasattr(dialect, "Generator"):
1044            generators.add(dialect.Generator)
1045
1046    for tokenizer in tokenizers:
1047        tokenizer.VAR_SINGLE_TOKENS.update(SQLMESH_MACRO_PREFIX)
1048
1049    for parser in parsers:
1050        parser.FUNCTIONS.update({"JINJA": Jinja.from_arg_list, "METRIC": MetricAgg.from_arg_list})
1051        parser.PLACEHOLDER_PARSERS.update({TokenType.PARAMETER: _parse_macro})
1052        parser.QUERY_MODIFIER_PARSERS.update(
1053            {TokenType.PARAMETER: lambda self: _parse_body_macro(self)}
1054        )
1055
1056    for generator in generators:
1057        if MacroFunc not in generator.TRANSFORMS:
1058            generator.TRANSFORMS.update(
1059                {
1060                    Audit: lambda self, e: _sqlmesh_ddl_sql(self, e, "AUDIT"),
1061                    DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}",
1062                    Jinja: lambda self, e: e.name,
1063                    JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};",
1064                    JinjaStatement: lambda self,
1065                    e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};",
1066                    VirtualUpdateStatement: lambda self, e: _on_virtual_update_sql(self, e),
1067                    MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})",
1068                    MacroFunc: _macro_func_sql,
1069                    MacroStrReplace: lambda self, e: f"@{self.sql(e.this)}",
1070                    MacroSQL: lambda self, e: f"@SQL({self.sql(e.this)})",
1071                    MacroVar: lambda self, e: f"@{e.name}",
1072                    Metric: lambda self, e: _sqlmesh_ddl_sql(self, e, "METRIC"),
1073                    Model: lambda self, e: _sqlmesh_ddl_sql(self, e, "MODEL"),
1074                    ModelKind: _model_kind_sql,
1075                    PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False),
1076                    StagedFilePath: lambda self, e: self.table_sql(e),
1077                    exp.Whens: _whens_sql,
1078                }
1079            )
1080        if MacroDef not in generator.WITH_SEPARATED_COMMENTS:
1081            generator.WITH_SEPARATED_COMMENTS = (
1082                *generator.WITH_SEPARATED_COMMENTS,
1083                Model,
1084                MacroDef,
1085            )
1086
1087        generator.UNWRAPPED_INTERVAL_VALUES = (
1088            *generator.UNWRAPPED_INTERVAL_VALUES,
1089            MacroStrReplace,
1090            MacroVar,
1091        )
1092
1093    _override(Parser, _parse_select)
1094    _override(Parser, _parse_statement)
1095    _override(Parser, _parse_join)
1096    _override(Parser, _parse_order)
1097    _override(Parser, _parse_where)
1098    _override(Parser, _parse_group)
1099    _override(Parser, _parse_with)
1100    _override(Parser, _parse_having)
1101    _override(Parser, _parse_limit)
1102    _override(Parser, _parse_value)
1103    _override(Parser, _parse_lambda)
1104    _override(Parser, _parse_types)
1105    _override(Parser, _parse_if)
1106    _override(Parser, _parse_id_var)
1107    _override(Parser, _warn_unsupported)
1108    _override(Snowflake.Parser, _parse_table_parts)
1109
1110    # DuckDB's prefix absolute power operator `@` clashes with the macro syntax
1111    DuckDB.Parser.NO_PAREN_FUNCTION_PARSERS.pop("@", None)

Extend SQLGlot with SQLMesh's custom macro aware dialect.

def select_from_values( values: List[Tuple[Any, ...]], columns_to_types: Dict[str, sqlglot.expressions.DataType], batch_size: int = 0, alias: str = 't') -> Iterator[sqlglot.expressions.Select]:
1114def select_from_values(
1115    values: t.List[t.Tuple[t.Any, ...]],
1116    columns_to_types: t.Dict[str, exp.DataType],
1117    batch_size: int = 0,
1118    alias: str = "t",
1119) -> t.Iterator[exp.Select]:
1120    """Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types.
1121
1122    Args:
1123        values: List of values to use for the VALUES expression.
1124        columns_to_types: Mapping of column names to types to assign to the values.
1125        batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
1126        alias: The alias to assign to the values expression. If not provided then will default to "t"
1127
1128    Returns:
1129        This method operates as a generator and yields a VALUES expression.
1130    """
1131    if batch_size <= 0:
1132        batch_size = sys.maxsize
1133    num_rows = len(values)
1134    for i in range(0, num_rows, batch_size):
1135        yield select_from_values_for_batch_range(
1136            values=values,
1137            target_columns_to_types=columns_to_types,
1138            batch_start=i,
1139            batch_end=min(i + batch_size, num_rows),
1140            alias=alias,
1141        )

Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types.

Arguments:
  • values: List of values to use for the VALUES expression.
  • columns_to_types: Mapping of column names to types to assign to the values.
  • batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
  • alias: The alias to assign to the values expression. If not provided then will default to "t"
Returns:

This method operates as a generator and yields a VALUES expression.

def select_from_values_for_batch_range( values: List[Tuple[Any, ...]], target_columns_to_types: Dict[str, sqlglot.expressions.DataType], batch_start: int, batch_end: int, alias: str = 't', source_columns: Optional[List[str]] = None) -> sqlglot.expressions.Select:
1144def select_from_values_for_batch_range(
1145    values: t.List[t.Tuple[t.Any, ...]],
1146    target_columns_to_types: t.Dict[str, exp.DataType],
1147    batch_start: int,
1148    batch_end: int,
1149    alias: str = "t",
1150    source_columns: t.Optional[t.List[str]] = None,
1151) -> exp.Select:
1152    source_columns = source_columns or list(target_columns_to_types)
1153    source_columns_to_types = get_source_columns_to_types(target_columns_to_types, source_columns)
1154
1155    if not values:
1156        # Ensures we don't generate an empty VALUES clause & forces a zero-row output
1157        where = exp.false()
1158        expressions = [
1159            tuple(exp.cast(exp.null(), to=kind) for kind in source_columns_to_types.values())
1160        ]
1161    else:
1162        where = None
1163        expressions = [
1164            tuple(transform_values(v, source_columns_to_types))
1165            for v in values[batch_start:batch_end]
1166        ]
1167
1168    values_exp = exp.values(expressions, alias=alias, columns=source_columns_to_types)
1169    if values:
1170        # BigQuery crashes on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([NULL]) AS x`, but not
1171        # on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([CAST(NULL AS TIMESTAMP)]) AS x`. This
1172        # ensures nulls under the `Values` expression are cast to avoid similar issues.
1173        for value, kind in zip(
1174            values_exp.expressions[0].expressions, source_columns_to_types.values()
1175        ):
1176            if isinstance(value, exp.Null):
1177                value.replace(exp.cast(value, to=kind))
1178
1179    casted_columns = [
1180        exp.alias_(
1181            exp.cast(
1182                exp.column(column) if column in source_columns_to_types else exp.Null(), to=kind
1183            ),
1184            column,
1185            copy=False,
1186        )
1187        for column, kind in target_columns_to_types.items()
1188    ]
1189    return exp.select(*casted_columns).from_(values_exp, copy=False).where(where, copy=False)
def pandas_to_sql( df: pandas.core.frame.DataFrame, columns_to_types: Optional[Dict[str, sqlglot.expressions.DataType]] = None, batch_size: int = 0, alias: str = 't') -> Iterator[sqlglot.expressions.Select]:
1192def pandas_to_sql(
1193    df: pd.DataFrame,
1194    columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
1195    batch_size: int = 0,
1196    alias: str = "t",
1197) -> t.Iterator[exp.Select]:
1198    """Convert a pandas dataframe into a VALUES sql statement.
1199
1200    Args:
1201        df: A pandas dataframe to convert.
1202        columns_to_types: Mapping of column names to types to assign to the values.
1203        batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
1204        alias: The alias to assign to the values expression. If not provided then will default to "t"
1205
1206    Returns:
1207        This method operates as a generator and yields a VALUES expression.
1208    """
1209    yield from select_from_values(
1210        values=list(df.itertuples(index=False, name=None)),
1211        columns_to_types=columns_to_types or columns_to_types_from_df(df),
1212        batch_size=batch_size,
1213        alias=alias,
1214    )

Convert a pandas dataframe into a VALUES sql statement.

Arguments:
  • df: A pandas dataframe to convert.
  • columns_to_types: Mapping of column names to types to assign to the values.
  • batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
  • alias: The alias to assign to the values expression. If not provided then will default to "t"
Returns:

This method operates as a generator and yields a VALUES expression.

def set_default_catalog( table: str | sqlglot.expressions.Table, default_catalog: Optional[str]) -> sqlglot.expressions.Table:
1217def set_default_catalog(
1218    table: str | exp.Table,
1219    default_catalog: t.Optional[str],
1220) -> exp.Table:
1221    table = exp.to_table(table)
1222
1223    if default_catalog and not table.catalog and table.db:
1224        table.set("catalog", exp.parse_identifier(default_catalog))
1225
1226    return table
@lru_cache(maxsize=16384)
def normalize_model_name( table: str | sqlglot.expressions.Table | sqlglot.expressions.Column, default_catalog: Optional[str], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> str:
1229@lru_cache(maxsize=16384)
1230def normalize_model_name(
1231    table: str | exp.Table | exp.Column,
1232    default_catalog: t.Optional[str],
1233    dialect: DialectType = None,
1234) -> str:
1235    if isinstance(table, exp.Column):
1236        table = exp.table_(table.this, db=table.args.get("table"), catalog=table.args.get("db"))
1237    else:
1238        # We are relying on sqlglot's flexible parsing here to accept quotes from other dialects.
1239        # Ex: I have a a normalized name of '"my_table"' but the dialect is spark and therefore we should
1240        # expect spark quotes to be backticks ('`') instead of double quotes ('"'). sqlglot today is flexible
1241        # and will still parse this correctly and we rely on that.
1242        table = exp.to_table(table, dialect=dialect)
1243
1244    table = set_default_catalog(table, default_catalog)
1245    # An alternative way to do this is the following: exp.table_name(table, dialect=dialect, identify=True)
1246    # This though would result in the names being normalized to the target dialect AND the quotes while the below
1247    # approach just normalizes the names.
1248    # By just normalizing names and using sqlglot dialect for quotes this makes it easier for dialects that have
1249    # compatible normalization strategies but incompatible quoting to still work together without user hassle
1250    return exp.table_name(normalize_identifiers(table, dialect=dialect), identify=True)
def find_tables( expression: sqlglot.expressions.Expression, default_catalog: Optional[str], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> Set[str]:
1253def find_tables(
1254    expression: exp.Expression, default_catalog: t.Optional[str], dialect: DialectType = None
1255) -> t.Set[str]:
1256    """Find all tables referenced in a query.
1257
1258    Caches the result in the meta field 'tables'.
1259
1260    Args:
1261        expressions: The query to find the tables in.
1262        dialect: The dialect to use for normalization of table names.
1263
1264    Returns:
1265        A Set of all the table names.
1266    """
1267    if TABLES_META not in expression.meta:
1268        expression.meta[TABLES_META] = {
1269            normalize_model_name(table, default_catalog=default_catalog, dialect=dialect)
1270            for scope in traverse_scope(expression)
1271            for table in scope.tables
1272            if table.name and table.name not in scope.cte_sources
1273        }
1274    return expression.meta[TABLES_META]

Find all tables referenced in a query.

Caches the result in the meta field 'tables'.

Arguments:
  • expressions: The query to find the tables in.
  • dialect: The dialect to use for normalization of table names.
Returns:

A Set of all the table names.

def add_table( node: sqlglot.expressions.Expression, table: str) -> sqlglot.expressions.Expression:
1277def add_table(node: exp.Expression, table: str) -> exp.Expression:
1278    """Add a table to all columns in an expression."""
1279
1280    def _transform(node: exp.Expression) -> exp.Expression:
1281        if isinstance(node, exp.Column) and not node.table:
1282            return exp.column(node.this, table=table)
1283        if isinstance(node, exp.Identifier):
1284            return exp.column(node, table=table)
1285        return node
1286
1287    return node.transform(_transform)

Add a table to all columns in an expression.

def transform_values( values: Tuple[Any, ...], columns_to_types: Dict[str, sqlglot.expressions.DataType]) -> Iterator[Any]:
1290def transform_values(
1291    values: t.Tuple[t.Any, ...], columns_to_types: t.Dict[str, exp.DataType]
1292) -> t.Iterator[t.Any]:
1293    """Perform transformations on values given columns_to_types."""
1294
1295    def _transform_value(value: t.Any, dtype: exp.DataType) -> t.Any:
1296        if (
1297            isinstance(value, list)
1298            and dtype.is_type(*exp.DataType.ARRAY_TYPES)
1299            and len(dtype.expressions) == 1
1300        ):
1301            element_type = dtype.expressions[0]
1302            return exp.convert([_transform_value(v, element_type) for v in value])
1303
1304        if (
1305            isinstance(value, dict)
1306            and dtype.is_type(*exp.DataType.STRUCT_TYPES)
1307            and len(value) == len(dtype.expressions)
1308        ):
1309            expressions = []
1310            for (field_name, field_value), field_type in zip(value.items(), dtype.expressions):
1311                if isinstance(field_type, exp.ColumnDef):
1312                    field_type = field_type.kind
1313                else:
1314                    field_type = exp.DataType.build(exp.DataType.Type.UNKNOWN)
1315
1316                expressions.append(
1317                    exp.PropertyEQ(
1318                        this=exp.to_identifier(field_name),
1319                        expression=_transform_value(field_value, field_type),
1320                    )
1321                )
1322
1323            return exp.Struct(expressions=expressions)
1324
1325        if dtype.is_type(exp.DataType.Type.JSON):
1326            return exp.func("PARSE_JSON", f"'{value}'")
1327
1328        return exp.convert(value)
1329
1330    for col_value, col_type in zip(values, columns_to_types.values()):
1331        yield _transform_value(col_value, col_type)

Perform transformations on values given columns_to_types.

def to_schema( sql_path: str | sqlglot.expressions.Table, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.Table:
1334def to_schema(sql_path: str | exp.Table, dialect: DialectType = None) -> exp.Table:
1335    if isinstance(sql_path, exp.Table) and sql_path.this is None:
1336        return sql_path
1337    table = exp.to_table(
1338        sql_path.copy() if isinstance(sql_path, exp.Table) else sql_path, dialect=dialect
1339    )
1340    table.set("catalog", table.args.get("db"))
1341    table.set("db", table.args.get("this"))
1342    table.set("this", None)
1343    return table
def schema_( db: sqlglot.expressions.Identifier | str, catalog: Union[sqlglot.expressions.Identifier, str, NoneType] = None, quoted: Optional[bool] = None) -> sqlglot.expressions.Table:
1346def schema_(
1347    db: exp.Identifier | str,
1348    catalog: t.Optional[exp.Identifier | str] = None,
1349    quoted: t.Optional[bool] = None,
1350) -> exp.Table:
1351    """Build a Schema.
1352
1353    Args:
1354        db: Database name.
1355        catalog: Catalog name.
1356        quoted: Whether to force quotes on the schema's identifiers.
1357
1358    Returns:
1359        The new Schema instance.
1360    """
1361    return exp.Table(
1362        this=None,
1363        db=exp.to_identifier(db, quoted=quoted) if db else None,
1364        catalog=exp.to_identifier(catalog, quoted=quoted) if catalog else None,
1365    )

Build a Schema.

Arguments:
  • db: Database name.
  • catalog: Catalog name.
  • quoted: Whether to force quotes on the schema's identifiers.
Returns:

The new Schema instance.

def normalize_mapping_schema( schema: Dict, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> sqlglot.schema.MappingSchema:
1368def normalize_mapping_schema(schema: t.Dict, dialect: DialectType) -> MappingSchema:
1369    return MappingSchema(_unquote_schema(schema), dialect=dialect, normalize=False)
@contextmanager
def normalize_and_quote( query: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType], default_catalog: Optional[str], quote: bool = True) -> Iterator[~E]:
1379@contextmanager
1380def normalize_and_quote(
1381    query: E, dialect: DialectType, default_catalog: t.Optional[str], quote: bool = True
1382) -> t.Iterator[E]:
1383    qualify_tables(query, catalog=default_catalog, dialect=dialect)
1384    normalize_identifiers(query, dialect=dialect)
1385    yield query
1386    if quote:
1387        quote_identifiers(query, dialect=dialect)
def interpret_expression( e: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression | str | int | float | bool:
1390def interpret_expression(e: exp.Expression) -> exp.Expression | str | int | float | bool:
1391    if e.is_int:
1392        return int(e.this)
1393    if e.is_number:
1394        return float(e.this)
1395    if isinstance(e, (exp.Literal, exp.Boolean)):
1396        return e.this
1397    return e
def interpret_key_value_pairs( e: sqlglot.expressions.Tuple) -> Dict[str, sqlglot.expressions.Expression | str | int | float | bool]:
1400def interpret_key_value_pairs(
1401    e: exp.Tuple,
1402) -> t.Dict[str, exp.Expression | str | int | float | bool]:
1403    return {i.this.name: interpret_expression(i.expression) for i in e.expressions}
def extract_func_call( v: sqlglot.expressions.Expression, allow_tuples: bool = False) -> Tuple[str, Dict[str, sqlglot.expressions.Expression]]:
1406def extract_func_call(
1407    v: exp.Expression, allow_tuples: bool = False
1408) -> t.Tuple[str, t.Dict[str, exp.Expression]]:
1409    kwargs = {}
1410
1411    if isinstance(v, exp.Anonymous):
1412        func = v.name
1413        args = v.expressions
1414    elif isinstance(v, exp.Func):
1415        func = v.sql_name()
1416        args = list(v.args.values())
1417    elif isinstance(v, exp.Paren):
1418        func = ""
1419        args = [v.this]
1420    elif isinstance(v, exp.Tuple):  # airflow only
1421        if not allow_tuples:
1422            raise ConfigError("Audit name is missing (eg. MY_AUDIT())")
1423
1424        func = ""
1425        args = v.expressions
1426    else:
1427        return v.name.lower(), {}
1428
1429    for arg in args:
1430        if not isinstance(arg, (exp.PropertyEQ, exp.EQ)):
1431            raise ConfigError(
1432                f"Function '{func}' must be called with key-value arguments like {func}(arg := value)."
1433            )
1434        kwargs[arg.left.name.lower()] = arg.right
1435    return func.lower(), kwargs
def extract_function_calls(func_calls: Any, allow_tuples: bool = False) -> Any:
1438def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.Any:
1439    """Used for extracting function calls for signals or audits."""
1440
1441    if isinstance(func_calls, (exp.Tuple, exp.Array)):
1442        return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions]
1443    if isinstance(func_calls, exp.Paren):
1444        return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)]
1445    if isinstance(func_calls, exp.Expression):
1446        return [extract_func_call(func_calls, allow_tuples=allow_tuples)]
1447    if isinstance(func_calls, list):
1448        function_calls = []
1449        for entry in func_calls:
1450            if isinstance(entry, dict):
1451                args = entry
1452                name = "" if allow_tuples else entry.pop("name")
1453            elif isinstance(entry, (tuple, list)):
1454                name, args = entry
1455            else:
1456                raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.")
1457
1458            function_calls.append(
1459                (
1460                    name.lower(),
1461                    {
1462                        key: parse_one(value) if isinstance(value, str) else value
1463                        for key, value in args.items()
1464                    },
1465                )
1466            )
1467
1468        return function_calls
1469
1470    return func_calls or []

Used for extracting function calls for signals or audits.

def is_meta_expression(v: Any) -> bool:
1473def is_meta_expression(v: t.Any) -> bool:
1474    return isinstance(v, (Audit, Metric, Model))
def replace_merge_table_aliases( expression: sqlglot.expressions.Expression, dialect: Optional[str] = None) -> sqlglot.expressions.Expression:
1477def replace_merge_table_aliases(
1478    expression: exp.Expression, dialect: t.Optional[str] = None
1479) -> exp.Expression:
1480    """
1481    Resolves references from the "source" and "target" tables (or their DBT equivalents)
1482    with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)
1483    """
1484    from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
1485
1486    if isinstance(expression, exp.Column) and (first_part := expression.parts[0]):
1487        if first_part.this.lower() in ("target", "dbt_internal_dest", "__merge_target__"):
1488            first_part.replace(exp.to_identifier(MERGE_TARGET_ALIAS, quoted=True))
1489        elif first_part.this.lower() in ("source", "dbt_internal_source", "__merge_source__"):
1490            first_part.replace(exp.to_identifier(MERGE_SOURCE_ALIAS, quoted=True))
1491
1492    return expression

Resolves references from the "source" and "target" tables (or their DBT equivalents) with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)