Edit on GitHub

sqlmesh.core.dialect

   1from __future__ import annotations
   2
   3import functools
   4import logging
   5import re
   6import sys
   7import typing as t
   8from contextlib import contextmanager
   9from difflib import unified_diff
  10from enum import Enum, auto
  11from functools import lru_cache
  12
  13from sqlglot import Dialect, Generator, ParseError, Parser, Tokenizer, TokenType, exp
  14from sqlglot.dialects.dialect import DialectType
  15from sqlglot.dialects import DuckDB, Snowflake, TSQL
  16import sqlglot.dialects.athena as athena
  17from sqlglot.parsers.athena import AthenaTrinoParser
  18from sqlglot.helper import seq_get
  19from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
  20from sqlglot.optimizer.qualify_columns import quote_identifiers
  21from sqlglot.optimizer.qualify_tables import qualify_tables
  22from sqlglot.optimizer.scope import traverse_scope
  23from sqlglot.schema import MappingSchema
  24from sqlglot.tokens import Token
  25
  26from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE
  27from sqlmesh.utils import get_source_columns_to_types
  28from sqlmesh.utils.errors import SQLMeshError, ConfigError
  29from sqlmesh.utils.pandas import columns_to_types_from_df
  30
  31if t.TYPE_CHECKING:
  32    import pandas as pd
  33
  34    from sqlglot._typing import E
  35
  36
  37SQLMESH_MACRO_PREFIX = "@"
  38
  39TABLES_META = "sqlmesh.tables"
  40
  41logger = logging.getLogger(__name__)
  42
  43
  44class Model(exp.Expression):
  45    arg_types = {"expressions": True}
  46
  47
  48class Audit(exp.Expression):
  49    arg_types = {"expressions": True}
  50
  51
  52class Metric(exp.Expression):
  53    arg_types = {"expressions": True}
  54
  55
  56class Jinja(exp.Expression, exp.Func):
  57    arg_types = {"this": True}
  58
  59
  60class JinjaQuery(Jinja):
  61    pass
  62
  63
  64class JinjaStatement(Jinja):
  65    pass
  66
  67
  68class VirtualUpdateStatement(exp.Expression):
  69    arg_types = {"expressions": True}
  70
  71
  72class ModelKind(exp.Expression):
  73    arg_types = {"this": True, "expressions": False}
  74
  75
  76class MacroVar(exp.Var):
  77    pass
  78
  79
  80class MacroFunc(exp.Expression, exp.Func):
  81    @property
  82    def name(self) -> str:
  83        return self.this.name
  84
  85
  86class MacroDef(MacroFunc):
  87    arg_types = {"this": True, "expression": True}
  88
  89
  90class MacroSQL(MacroFunc):
  91    arg_types = {"this": True, "into": False}
  92
  93
  94class MacroStrReplace(MacroFunc):
  95    pass
  96
  97
  98class PythonCode(exp.Expression):
  99    arg_types = {"expressions": True}
 100
 101
 102class DColonCast(exp.Cast):
 103    pass
 104
 105
 106class MetricAgg(exp.Expression, exp.AggFunc):
 107    """Used for computing metrics."""
 108
 109    arg_types = {"this": True}
 110
 111    @property
 112    def output_name(self) -> str:
 113        return self.this.name
 114
 115
 116class StagedFilePath(exp.Expression):
 117    """Represents paths to "staged files" in Snowflake."""
 118
 119    arg_types = exp.Table.arg_types.copy()
 120
 121
 122def _parse_statement(self: Parser) -> t.Optional[exp.Expr]:
 123    if self._curr is None:
 124        return None
 125
 126    parser = PARSERS.get(self._curr.text.upper())
 127    error_msg = None
 128
 129    if parser:
 130        # Capture any available description in the form of a comment
 131        comments = self._curr.comments
 132
 133        index = self._index
 134        try:
 135            self._advance()
 136            meta = self._parse_wrapped(lambda: t.cast(t.Callable, parser)(self))
 137        except ParseError as parse_error:
 138            error_msg = parse_error.args[0]
 139            self._retreat(index)
 140
 141        # Only return the DDL expression if we actually managed to parse one. This is
 142        # done in order to allow parsing standalone identifiers / function calls like
 143        # "metric", or "model(1, 2, 3)", which collide with SQLMesh's DDL syntax.
 144        if self._index != index:
 145            meta.comments = comments
 146            return meta
 147
 148    try:
 149        return self.__parse_statement()  # type: ignore
 150    except ParseError:
 151        if error_msg:
 152            raise ParseError(error_msg)
 153        raise
 154
 155
 156def _parse_lambda(self: Parser, alias: bool = False) -> t.Optional[exp.Expr]:
 157    node = self.__parse_lambda(alias=alias)  # type: ignore
 158    if isinstance(node, exp.Lambda):
 159        node.set("this", self._parse_alias(node.this))
 160    return node
 161
 162
 163def _parse_id_var(
 164    self: Parser,
 165    any_token: bool = True,
 166    tokens: t.Optional[t.Collection[TokenType]] = None,
 167) -> t.Optional[exp.Expr]:
 168    if self._prev and self._prev.text == SQLMESH_MACRO_PREFIX and self._match(TokenType.L_BRACE):
 169        identifier = self.__parse_id_var(any_token=any_token, tokens=tokens)  # type: ignore
 170        if not self._match(TokenType.R_BRACE):
 171            self.raise_error("Expecting }")
 172        identifier.args["this"] = f"@{{{identifier.name}}}"
 173    else:
 174        identifier = self.__parse_id_var(any_token=any_token, tokens=tokens)  # type: ignore
 175
 176    while (
 177        identifier
 178        and not identifier.args.get("quoted")
 179        and self._is_connected()
 180        and (
 181            self._match_texts(("{", SQLMESH_MACRO_PREFIX))
 182            or self._curr.token_type not in self.RESERVED_TOKENS
 183        )
 184    ):
 185        this = identifier.name
 186        brace = False
 187
 188        if self._prev.text == "{":
 189            this += "{"
 190            brace = True
 191        else:
 192            if self._prev.text == SQLMESH_MACRO_PREFIX:
 193                this += "@"
 194            if self._match(TokenType.L_BRACE):
 195                this += "{"
 196                brace = True
 197
 198        next_id = self._parse_id_var(any_token=False)
 199
 200        if next_id:
 201            this += next_id.name
 202        else:
 203            return identifier
 204
 205        if brace:
 206            if self._match(TokenType.R_BRACE):
 207                this += "}"
 208            else:
 209                self.raise_error("Expecting }")
 210
 211        identifier = self.expression(exp.Identifier(this=this, quoted=identifier.quoted))
 212
 213    return identifier
 214
 215
 216def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expr]:
 217    if self._prev.text != SQLMESH_MACRO_PREFIX:
 218        return self._parse_parameter()
 219
 220    comments = self._prev.comments
 221    index = self._index
 222    field = self._parse_primary() or self._parse_function(functions={}) or self._parse_id_var()
 223
 224    def _build_macro(field: t.Optional[exp.Expr]) -> t.Optional[exp.Expr]:
 225        if isinstance(field, exp.Func):
 226            macro_name = field.name.upper()
 227            if macro_name != keyword_macro and macro_name in KEYWORD_MACROS:
 228                self._retreat(index)
 229                return None
 230
 231            if isinstance(field, exp.Anonymous):
 232                if macro_name == "DEF":
 233                    return self.expression(
 234                        MacroDef(
 235                            this=field.expressions[0],
 236                            expression=field.expressions[1],
 237                        ),
 238                        comments=comments,
 239                    )
 240                if macro_name == "SQL":
 241                    into = field.expressions[1].this.lower() if len(field.expressions) > 1 else None
 242                    return self.expression(
 243                        MacroSQL(this=field.expressions[0], into=into), comments=comments
 244                    )
 245            else:
 246                field = self.expression(
 247                    exp.Anonymous(
 248                        this=field.sql_name(),
 249                        expressions=list(field.args.values()),
 250                    ),
 251                    comments=comments,
 252                )
 253
 254            return self.expression(MacroFunc(this=field), comments=comments)
 255
 256        if field is None:
 257            return None
 258
 259        if field.is_string or (isinstance(field, exp.Identifier) and field.quoted):
 260            return self.expression(
 261                MacroStrReplace(this=exp.Literal.string(field.this)), comments=comments
 262            )
 263
 264        if "@" in field.this:
 265            return field  # type: ignore[return-value]
 266        return self.expression(MacroVar(this=field.this), comments=comments)
 267
 268    if isinstance(field, (exp.Window, exp.IgnoreNulls, exp.RespectNulls)):
 269        field.set("this", _build_macro(field.this))
 270    else:
 271        field = _build_macro(field)
 272
 273    return field
 274
 275
 276KEYWORD_MACROS = {"WITH", "JOIN", "WHERE", "GROUP_BY", "HAVING", "ORDER_BY", "LIMIT"}
 277
 278
 279def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expr]:
 280    if not self._match_pair(TokenType.PARAMETER, TokenType.VAR, advance=False) or (
 281        self._next and self._next.text.upper() != name.upper()
 282    ):
 283        return None
 284
 285    self._advance()
 286    return _parse_macro(self, keyword_macro=name)
 287
 288
 289def _parse_body_macro(self: Parser) -> t.Tuple[str, t.Optional[exp.Expr]]:
 290    name = self._next and self._next.text.upper()
 291
 292    if name == "JOIN":
 293        return ("joins", self._parse_join())
 294    if name == "WHERE":
 295        return ("where", self._parse_where())
 296    if name == "GROUP_BY":
 297        return ("group", self._parse_group())
 298    if name == "HAVING":
 299        return ("having", self._parse_having())
 300    if name == "ORDER_BY":
 301        return ("order", self._parse_order())
 302    if name == "LIMIT":
 303        return ("limit", self._parse_limit())
 304    return ("", None)
 305
 306
 307def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.Expr]:
 308    macro = _parse_matching_macro(self, "WITH")
 309    if not macro:
 310        return self.__parse_with(skip_with_token=skip_with_token)  # type: ignore
 311
 312    macro.this.append("expressions", self.__parse_with(skip_with_token=True))  # type: ignore
 313    return macro
 314
 315
 316def _parse_join(
 317    self: Parser, skip_join_token: bool = False, parse_bracket: bool = False
 318) -> t.Optional[exp.Expr]:
 319    index = self._index
 320    method, side, kind = self._parse_join_parts()
 321    macro = _parse_matching_macro(self, "JOIN")
 322    if not macro:
 323        self._retreat(index)
 324        return self.__parse_join(skip_join_token=skip_join_token, parse_bracket=parse_bracket)  # type: ignore
 325
 326    join = self.__parse_join(skip_join_token=True)  # type: ignore
 327    if method:
 328        join.set("method", method.text)
 329    if side:
 330        join.set("side", side.text)
 331    if kind:
 332        join.set("kind", kind.text)
 333
 334    macro.this.append("expressions", join)
 335    return macro
 336
 337
 338def _warn_unsupported(self: Parser) -> None:
 339    from sqlmesh.core.console import get_console
 340
 341    sql = self._find_sql(self._tokens[0], self._tokens[-1])[: self.error_message_context]
 342
 343    get_console().log_warning(
 344        f"'{sql}' could not be semantically understood as it contains unsupported syntax, SQLMesh will treat the command as is. Note that any references to the model's "
 345        "underlying physical table can't be resolved in this case, consider using Jinja as explained here https://sqlmesh.readthedocs.io/en/stable/concepts/macros/macro_variables/#audit-only-variables"
 346    )
 347
 348
 349def _parse_select(
 350    self: Parser,
 351    nested: bool = False,
 352    table: bool = False,
 353    parse_subquery_alias: bool = True,
 354    parse_set_operation: bool = True,
 355    consume_pipe: bool = True,
 356    from_: t.Optional[exp.From] = None,
 357) -> t.Optional[exp.Expr]:
 358    select = self.__parse_select(  # type: ignore
 359        nested=nested,
 360        table=table,
 361        parse_subquery_alias=parse_subquery_alias,
 362        parse_set_operation=parse_set_operation,
 363        consume_pipe=consume_pipe,
 364        from_=from_,
 365    )
 366
 367    if (
 368        not select
 369        and not parse_set_operation
 370        and self._match_pair(TokenType.PARAMETER, TokenType.VAR, advance=False)
 371    ):
 372        self._advance()
 373        return _parse_macro(self)
 374
 375    return select
 376
 377
 378def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp.Expr]:
 379    macro = _parse_matching_macro(self, "WHERE")
 380    if not macro:
 381        return self.__parse_where(skip_where_token=skip_where_token)  # type: ignore
 382
 383    macro.this.append("expressions", self.__parse_where(skip_where_token=True))  # type: ignore
 384    return macro
 385
 386
 387def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[exp.Expr]:
 388    macro = _parse_matching_macro(self, "GROUP_BY")
 389    if not macro:
 390        return self.__parse_group(skip_group_by_token=skip_group_by_token)  # type: ignore
 391
 392    macro.this.append("expressions", self.__parse_group(skip_group_by_token=True))  # type: ignore
 393    return macro
 394
 395
 396def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[exp.Expr]:
 397    macro = _parse_matching_macro(self, "HAVING")
 398    if not macro:
 399        return self.__parse_having(skip_having_token=skip_having_token)  # type: ignore
 400
 401    macro.this.append("expressions", self.__parse_having(skip_having_token=True))  # type: ignore
 402    return macro
 403
 404
 405def _parse_order(
 406    self: Parser, this: t.Optional[exp.Expr] = None, skip_order_token: bool = False
 407) -> t.Optional[exp.Expr]:
 408    macro = _parse_matching_macro(self, "ORDER_BY")
 409    if not macro:
 410        return self.__parse_order(this, skip_order_token=skip_order_token)  # type: ignore
 411
 412    macro.this.append("expressions", self.__parse_order(this, skip_order_token=True))  # type: ignore
 413    return macro
 414
 415
 416def _parse_limit(
 417    self: Parser,
 418    this: t.Optional[exp.Expr] = None,
 419    top: bool = False,
 420    skip_limit_token: bool = False,
 421) -> t.Optional[exp.Expr]:
 422    macro = _parse_matching_macro(self, "TOP" if top else "LIMIT")
 423    if not macro:
 424        return self.__parse_limit(this, top=top, skip_limit_token=skip_limit_token)  # type: ignore
 425
 426    macro.this.append("expressions", self.__parse_limit(this, top=top, skip_limit_token=True))  # type: ignore
 427    return macro
 428
 429
 430def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expr]:
 431    wrapped = self._match(TokenType.L_PAREN, advance=False)
 432
 433    # The base _parse_value method always constructs a Tuple instance. This is problematic when
 434    # generating values with a macro function, because it's impossible to tell whether the user's
 435    # intention was to construct a row or a column with the VALUES expression. To avoid this, we
 436    # amend the AST such that the Tuple is replaced by the macro function call itself.
 437    expr = self.__parse_value()  # type: ignore
 438    if expr and not wrapped and isinstance(seq_get(expr.expressions, 0), MacroFunc):
 439        return expr.expressions[0]
 440
 441    return expr
 442
 443
 444def _parse_macro_or_clause(self: Parser, parser: t.Callable) -> t.Optional[exp.Expr]:
 445    return _parse_macro(self) if self._match(TokenType.PARAMETER) else parser()
 446
 447
 448def _parse_props(self: Parser) -> t.Optional[exp.Expr]:
 449    key = self._parse_id_var(any_token=True)
 450    if not key:
 451        return None
 452
 453    name = key.name.lower()
 454    if name == "time_data_type":
 455        # TODO: if we make *_data_type a convention to parse things into exp.DataType, we could make this more generic
 456        value = self._parse_types(schema=True)
 457    elif name == "when_matched":
 458        # Parentheses around the WHEN clauses can be used to disambiguate them from other properties
 459        value = self._parse_wrapped(
 460            lambda: _parse_macro_or_clause(self, self._parse_when_matched),
 461            optional=True,
 462        )
 463    elif name == "merge_filter":
 464        value = self._parse_conjunction()
 465    elif self._match(TokenType.L_PAREN):
 466        value = self.expression(exp.Tuple(expressions=self._parse_csv(self._parse_equality)))
 467        self._match_r_paren()
 468    else:
 469        value = self._parse_bracket(self._parse_field(any_token=True))
 470
 471    if name == "path" and value:
 472        # Make sure if we get a windows path that it is converted to posix
 473        value = exp.Literal.string(value.this.replace("\\", "/"))  # type: ignore
 474
 475    return self.expression(exp.Property(this=name, value=value))
 476
 477
 478def _parse_types(
 479    self: Parser,
 480    check_func: bool = False,
 481    schema: bool = False,
 482    allow_identifiers: bool = True,
 483) -> t.Optional[exp.Expr]:
 484    start = self._curr
 485    parsed_type = self.__parse_types(  # type: ignore
 486        check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
 487    )
 488
 489    if schema and parsed_type:
 490        parsed_type.meta["sql"] = self._find_sql(start, self._prev)
 491
 492    return parsed_type
 493
 494
 495# Only needed for Snowflake: its "staged file" syntax (@<path>) clashes with our macro
 496# var syntax. By converting the Var representation to a MacroVar, we should be able to
 497# handle both use cases: if there's no value in the MacroEvaluator's context for that
 498# MacroVar, it'll render into @<path>, so it won't break staged file path references.
 499#
 500# See: https://docs.snowflake.com/en/user-guide/querying-stage
 501def _parse_table_parts(
 502    self: Parser,
 503    schema: bool = False,
 504    is_db_reference: bool = False,
 505    wildcard: bool = False,
 506    fast: bool = False,
 507) -> exp.Table | StagedFilePath:
 508    index = self._index
 509    table = self.__parse_table_parts(  # type: ignore
 510        schema=schema, is_db_reference=is_db_reference, wildcard=wildcard, fast=fast
 511    )
 512
 513    if table is None:
 514        return table  # type: ignore[return-value]
 515
 516    table_arg = table.this
 517    name = table_arg.name if isinstance(table_arg, exp.Var) else ""
 518
 519    if name.startswith(SQLMESH_MACRO_PREFIX):
 520        # In these cases, we don't want to produce a `StagedFilePath` node:
 521        #
 522        # - @'...' needs to parsed as a string template
 523        # - @{foo}.bar needs to be parsed as a table with a macro var part
 524        # - @name(arg1 [, arg2 ...]) needs to be parsed as a macro function call
 525        #
 526        # These cases can unambiguously be parsed using the base `_parse_table_parts`, as there
 527        # is no overlap with staged files https://docs.snowflake.com/en/user-guide/querying-stage
 528        if (
 529            self._prev.token_type == TokenType.STRING
 530            or "{" in name
 531            or (
 532                self._curr
 533                and self._prev.token_type in (TokenType.L_PAREN, TokenType.R_PAREN)
 534                and self._curr.text.upper() not in ("FILE_FORMAT", "PATTERN")
 535                and not (table.args.get("format") or table.args.get("pattern"))
 536            )
 537        ):
 538            self._retreat(index)
 539            return Parser._parse_table_parts(
 540                self, schema=schema, is_db_reference=is_db_reference, fast=fast
 541            )  # type: ignore[return-value]
 542
 543        table_arg.replace(MacroVar(this=name[1:]))
 544        return StagedFilePath(**table.args)
 545
 546    return table
 547
 548
 549def _parse_if(self: Parser) -> t.Optional[exp.Expr]:
 550    # If we fail to parse an IF function with expressions as arguments, we then try
 551    # to parse a statement / command to support the macro @IF(condition, statement)
 552    index = self._index
 553    try:
 554        return self.__parse_if()  # type: ignore
 555    except ParseError:
 556        self._retreat(index)
 557        self._match_l_paren()
 558
 559        cond = self._parse_conjunction()
 560        self._match(TokenType.COMMA)
 561
 562        # Try to parse a known statement, otherwise fall back to parsing a command
 563        # Since the trailing `)` token is not expected by the statement parsers, we
 564        # remove it from the token stream before trying to parse the statement.
 565        last_token = self._tokens[-1]
 566        if last_token.token_type == TokenType.R_PAREN:
 567            self._tokens[-2].comments.extend(last_token.comments)
 568            self._tokens.pop()
 569            if hasattr(self, "_tokens_size"):
 570                # keep _tokens_size in sync sqlglot 30.0.3 caches len(_tokens)
 571                # _advance() tries to read tokens[index + 1] past the new end
 572                self._tokens_size -= 1
 573        else:
 574            self.raise_error("Expecting )")
 575
 576        index = self._index
 577        stmt = self._parse_statement()
 578        if self._curr:
 579            self._retreat(index)
 580            stmt = self._parse_as_command(self._tokens[index])
 581
 582        return exp.Anonymous(this="IF", expressions=[cond, stmt])
 583
 584
 585def _create_parser(expression_type: t.Type[exp.Expr], table_keys: t.List[str]) -> t.Callable:
 586    def parse(self: Parser) -> t.Optional[exp.Expr]:
 587        from sqlmesh.core.model.kind import ModelKindName
 588
 589        expressions: t.List[exp.Expr] = []
 590
 591        while True:
 592            prev_property = seq_get(expressions, -1)
 593            if not self._match(TokenType.COMMA, expression=prev_property) and expressions:
 594                break
 595
 596            key_expression = self._parse_id_var(any_token=True)
 597            if not key_expression:
 598                break
 599
 600            # This allows macro functions that programmaticaly generate the property key-value pair
 601            if isinstance(key_expression, MacroFunc):
 602                expressions.append(key_expression)
 603                continue
 604
 605            key = key_expression.name.lower()
 606
 607            start = self._curr
 608            value: t.Optional[exp.Expr | str]
 609
 610            if key in table_keys:
 611                value = self._parse_table_parts()
 612                if value and self._prev.token_type == TokenType.STRING:
 613                    self.raise_error(
 614                        f"'{key}' property cannot be a string value: {value}. "
 615                        "Please use the identifier syntax instead, e.g. foo.bar instead of 'foo.bar'"
 616                    )
 617            elif key == "columns":
 618                value = self._parse_schema()
 619            elif key == "kind":
 620                field = _parse_macro_or_clause(self, lambda: self._parse_id_var(any_token=True))
 621
 622                if not field or isinstance(field, (MacroVar, MacroFunc)):
 623                    value = field
 624                else:
 625                    try:
 626                        kind = ModelKindName[field.name.upper()]
 627                    except KeyError:
 628                        raise SQLMeshError(
 629                            f"Model kind specified as '{field.name}', but that is not a valid model kind.\n\nPlease specify one of {', '.join(ModelKindName)}."
 630                        )
 631
 632                    if kind in (
 633                        ModelKindName.INCREMENTAL_BY_TIME_RANGE,
 634                        ModelKindName.INCREMENTAL_BY_UNIQUE_KEY,
 635                        ModelKindName.INCREMENTAL_BY_PARTITION,
 636                        ModelKindName.INCREMENTAL_UNMANAGED,
 637                        ModelKindName.SEED,
 638                        ModelKindName.VIEW,
 639                        ModelKindName.SCD_TYPE_2,
 640                        ModelKindName.SCD_TYPE_2_BY_TIME,
 641                        ModelKindName.SCD_TYPE_2_BY_COLUMN,
 642                        ModelKindName.CUSTOM,
 643                    ) and self._match(TokenType.L_PAREN, advance=False):
 644                        props = self._parse_wrapped_csv(functools.partial(_parse_props, self))
 645                    else:
 646                        props = None
 647
 648                    value = self.expression(ModelKind(this=kind.value, expressions=props))
 649            elif key == "expression":
 650                value = self._parse_conjunction()
 651            elif key == "partitioned_by":
 652                partitioned_by = self._parse_partitioned_by()
 653                if isinstance(partitioned_by.this, exp.Schema):
 654                    value = exp.tuple_(*partitioned_by.this.expressions)
 655                else:
 656                    value = partitioned_by.this
 657            else:
 658                value = self._parse_bracket(self._parse_field(any_token=True))
 659
 660            if isinstance(value, exp.Expr):
 661                value.meta["sql"] = self._find_sql(start, self._prev)
 662
 663            expressions.append(self.expression(exp.Property(this=key, value=value)))
 664
 665        return self.expression(expression_type(expressions=expressions))
 666
 667    return parse
 668
 669
 670PARSERS = {
 671    "MODEL": _create_parser(Model, ["name"]),
 672    "AUDIT": _create_parser(Audit, ["model"]),
 673    "METRIC": _create_parser(Metric, ["name"]),
 674}
 675
 676
 677def _props_sql(self: Generator, expressions: t.List[exp.Expr]) -> str:
 678    props = []
 679    size = len(expressions)
 680
 681    for i, prop in enumerate(expressions):
 682        if isinstance(prop, MacroFunc):
 683            sql = self.indent(self.sql(prop, comment=False))
 684        else:
 685            sql = self.indent(f"{prop.name} {self.sql(prop, 'value')}")
 686
 687        if i < size - 1:
 688            sql += ","
 689
 690        props.append(self.maybe_comment(sql, expression=prop))
 691
 692    return "\n".join(props)
 693
 694
 695def _on_virtual_update_sql(self: Generator, expressions: t.List[exp.Expr]) -> str:
 696    statements = "\n".join(
 697        self.sql(expression)
 698        if isinstance(expression, JinjaStatement)
 699        else f"{self.sql(expression)};"
 700        for expression in expressions
 701    )
 702    return f"{ON_VIRTUAL_UPDATE_BEGIN};\n{statements}\n{ON_VIRTUAL_UPDATE_END};"
 703
 704
 705def _sqlmesh_ddl_sql(self: Generator, expression: Model | Audit | Metric, name: str) -> str:
 706    return "\n".join([f"{name} (", _props_sql(self, expression.expressions), ")"])
 707
 708
 709def _model_kind_sql(self: Generator, expression: ModelKind) -> str:
 710    props = _props_sql(self, expression.expressions)
 711    if props:
 712        return "\n".join([f"{expression.this} (", props, ")"])
 713    return expression.name.upper()
 714
 715
 716def _macro_keyword_func_sql(self: Generator, expression: exp.Expr) -> str:
 717    name = expression.name
 718    keyword = name.replace("_", " ")
 719    *args, clause = expression.expressions
 720    macro = f"@{name}({self.format_args(*args)})"
 721    return self.sql(clause).replace(keyword, macro, 1)
 722
 723
 724def _macro_func_sql(self: Generator, expression: MacroFunc) -> str:
 725    expression = expression.this
 726    name = expression.name
 727    if name in KEYWORD_MACROS:
 728        sql = _macro_keyword_func_sql(self, expression)
 729    else:
 730        sql = f"@{name}({self.format_args(*expression.expressions)})"
 731    return self.maybe_comment(sql, expression)
 732
 733
 734def _whens_sql(self: Generator, expression: exp.Whens) -> str:
 735    if isinstance(expression.parent, exp.Merge):
 736        return self.whens_sql(expression)
 737
 738    # If the `WHEN` clauses aren't part of a MERGE statement (e.g. they
 739    # appear in the `MODEL` DDL), then we will wrap them with parentheses.
 740    return self.wrap(self.expressions(expression, sep=" ", indent=False))
 741
 742
 743def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None:
 744    name = func.__name__
 745    setattr(klass, f"_{name}", getattr(klass, name))
 746    setattr(klass, name, func)
 747
 748
 749def format_model_expressions(
 750    expressions: t.List[exp.Expr],
 751    dialect: t.Optional[str] = None,
 752    rewrite_casts: bool = True,
 753    **kwargs: t.Any,
 754) -> str:
 755    """Format a model's expressions into a standardized format.
 756
 757    Args:
 758        expressions: The model's expressions, must be at least model def + query.
 759        dialect: The dialect to render the expressions as.
 760        rewrite_casts: Whether to rewrite all casts to use the :: syntax.
 761        **kwargs: Additional keyword arguments to pass to the sql generator.
 762
 763    Returns:
 764        A string representing the formatted model.
 765    """
 766    if len(expressions) == 1 and is_meta_expression(expressions[0]):
 767        return expressions[0].sql(pretty=True, dialect=dialect)
 768
 769    if rewrite_casts:
 770
 771        def cast_to_colon(node: exp.Expr) -> exp.Expr:
 772            if isinstance(node, exp.Cast) and not any(
 773                # Only convert CAST into :: if it doesn't have additional args set, otherwise this
 774                # conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST)
 775                arg
 776                for name, arg in node.args.items()
 777                if name not in ("this", "to")
 778            ):
 779                this = node.this
 780
 781                if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren):
 782                    cast = DColonCast(this=this, to=node.to)
 783                    cast.comments = node.comments
 784                    node = cast
 785
 786            exp.replace_children(node, cast_to_colon)
 787            return node
 788
 789        new_expressions = []
 790        for expression in expressions:
 791            expression = expression.copy()
 792            exp.replace_children(expression, cast_to_colon)
 793            new_expressions.append(expression)
 794
 795        expressions = new_expressions
 796
 797    return ";\n\n".join(
 798        expression.sql(pretty=True, dialect=dialect, **kwargs) for expression in expressions
 799    ).strip()
 800
 801
 802def text_diff(
 803    a: t.List[exp.Expr],
 804    b: t.List[exp.Expr],
 805    a_dialect: t.Optional[str] = None,
 806    b_dialect: t.Optional[str] = None,
 807) -> str:
 808    """Find the unified text diff between two expressions."""
 809    a_sql = [
 810        line
 811        for expr in a
 812        for line in expr.sql(pretty=True, comments=False, dialect=a_dialect).split("\n")
 813    ]
 814    b_sql = [
 815        line
 816        for expr in b
 817        for line in expr.sql(pretty=True, comments=False, dialect=b_dialect).split("\n")
 818    ]
 819    return "\n".join(unified_diff(a_sql, b_sql))
 820
 821
 822WS_OR_COMMENT = r"(?:\s|--[^\n]*\n|/\*.*?\*/)"
 823HEADER = r"\b(?:model|audit)\b(?=\s*\()"
 824KEY_BOUNDARY = r"(?:\(|,)"  # key is preceded by either '(' or ','
 825DIALECT_VALUE = r"['\"]?(?P<dialect>[a-z][a-z0-9]*)['\"]?"
 826VALUE_BOUNDARY = r"(?=,|\))"  # value is followed by comma or closing paren
 827
 828DIALECT_PATTERN = re.compile(
 829    rf"{HEADER}.*?{KEY_BOUNDARY}{WS_OR_COMMENT}*dialect{WS_OR_COMMENT}+{DIALECT_VALUE}{WS_OR_COMMENT}*{VALUE_BOUNDARY}",
 830    re.IGNORECASE | re.DOTALL,
 831)
 832
 833
 834def _is_command_statement(command: str, tokens: t.List[Token], pos: int) -> bool:
 835    try:
 836        return (
 837            tokens[pos].text.upper() == command.upper()
 838            and tokens[pos + 1].token_type == TokenType.SEMICOLON
 839        )
 840    except IndexError:
 841        return False
 842
 843
 844JINJA_QUERY_BEGIN = "JINJA_QUERY_BEGIN"
 845JINJA_STATEMENT_BEGIN = "JINJA_STATEMENT_BEGIN"
 846JINJA_END = "JINJA_END"
 847ON_VIRTUAL_UPDATE_BEGIN = "ON_VIRTUAL_UPDATE_BEGIN"
 848ON_VIRTUAL_UPDATE_END = "ON_VIRTUAL_UPDATE_END"
 849
 850
 851def _is_jinja_statement_begin(tokens: t.List[Token], pos: int) -> bool:
 852    return _is_command_statement(JINJA_STATEMENT_BEGIN, tokens, pos)
 853
 854
 855def _is_jinja_query_begin(tokens: t.List[Token], pos: int) -> bool:
 856    return _is_command_statement(JINJA_QUERY_BEGIN, tokens, pos)
 857
 858
 859def _is_jinja_end(tokens: t.List[Token], pos: int) -> bool:
 860    return _is_command_statement(JINJA_END, tokens, pos)
 861
 862
 863def jinja_query(query: str) -> JinjaQuery:
 864    return JinjaQuery(this=exp.Literal.string(query.strip()))
 865
 866
 867def jinja_statement(statement: str) -> JinjaStatement:
 868    return JinjaStatement(this=exp.Literal.string(statement.strip()))
 869
 870
 871def _is_virtual_statement_begin(tokens: t.List[Token], pos: int) -> bool:
 872    return _is_command_statement(ON_VIRTUAL_UPDATE_BEGIN, tokens, pos)
 873
 874
 875def _is_virtual_statement_end(tokens: t.List[Token], pos: int) -> bool:
 876    return _is_command_statement(ON_VIRTUAL_UPDATE_END, tokens, pos)
 877
 878
 879def virtual_statement(statements: t.List[exp.Expr]) -> VirtualUpdateStatement:
 880    return VirtualUpdateStatement(expressions=statements)
 881
 882
 883class ChunkType(Enum):
 884    JINJA_QUERY = auto()
 885    JINJA_STATEMENT = auto()
 886    SQL = auto()
 887    VIRTUAL_STATEMENT = auto()
 888    VIRTUAL_JINJA_STATEMENT = auto()
 889
 890
 891def parse_one(
 892    sql: str, dialect: t.Optional[str] = None, into: t.Optional[exp.IntoType] = None
 893) -> exp.Expr:
 894    expressions = parse(sql, default_dialect=dialect, match_dialect=False, into=into)
 895    if not expressions:
 896        raise SQLMeshError(f"No expressions found in '{sql}'")
 897    elif len(expressions) > 1:
 898        raise SQLMeshError(f"Multiple expressions found in '{sql}'")
 899    return expressions[0]
 900
 901
 902def parse(
 903    sql: str,
 904    default_dialect: t.Optional[str] = None,
 905    match_dialect: bool = True,
 906    into: t.Optional[exp.IntoType] = None,
 907) -> t.List[exp.Expr]:
 908    """Parse a sql string.
 909
 910    Supports parsing model definition.
 911    If a jinja block is detected, the query is stored as raw string in a Jinja node.
 912
 913    Args:
 914        sql: The sql based definition.
 915        default_dialect: The dialect to use if the model does not specify one.
 916
 917    Returns:
 918        A list of the parsed expressions: [Model, *Statements, Query, *Statements]
 919    """
 920    match = match_dialect and DIALECT_PATTERN.search(sql[:MAX_MODEL_DEFINITION_SIZE])
 921    dialect_str = match.group("dialect") if match else None
 922    dialect = Dialect.get_or_raise(dialect_str or default_dialect)
 923
 924    tokens = dialect.tokenize(sql)
 925    chunks: t.List[t.Tuple[t.List[Token], ChunkType]] = [([], ChunkType.SQL)]
 926    total = len(tokens)
 927
 928    pos = 0
 929    virtual = False
 930    while pos < total:
 931        token = tokens[pos]
 932        if _is_virtual_statement_end(tokens, pos):
 933            chunks[-1][0].append(token)
 934            virtual = False
 935            chunks.append(([], ChunkType.SQL))
 936            pos += 2
 937        elif _is_jinja_end(tokens, pos) or (
 938            chunks[-1][1] == ChunkType.SQL
 939            and token.token_type == TokenType.SEMICOLON
 940            and pos < total - 1
 941        ):
 942            if token.token_type == TokenType.SEMICOLON:
 943                pos += 1
 944            else:
 945                # Jinja end statement
 946                chunks[-1][0].append(token)
 947                pos += 2
 948            chunks.append(
 949                (
 950                    [],
 951                    ChunkType.VIRTUAL_STATEMENT
 952                    if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END
 953                    else ChunkType.SQL,
 954                )
 955            )
 956        elif _is_jinja_query_begin(tokens, pos):
 957            chunks.append(([token], ChunkType.JINJA_QUERY))
 958            pos += 2
 959        elif _is_jinja_statement_begin(tokens, pos):
 960            chunks.append(([token], ChunkType.JINJA_STATEMENT))
 961            pos += 2
 962        elif _is_virtual_statement_begin(tokens, pos):
 963            chunks.append(([token], ChunkType.VIRTUAL_STATEMENT))
 964            pos += 2
 965            virtual = True
 966        else:
 967            chunks[-1][0].append(token)
 968            pos += 1
 969
 970    parser = dialect.parser()
 971    expressions: t.List[exp.Expr] = []
 972
 973    def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expr]:
 974        parsed_expressions: t.List[t.Optional[exp.Expr]] = (
 975            parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql)
 976        )
 977        expressions = []
 978        for expression in parsed_expressions:
 979            if expression:
 980                if meta_sql:
 981                    expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1])
 982                expressions.append(expression)
 983        return expressions
 984
 985    def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expr:
 986        start, *_, end = chunk
 987        segment = sql[start.end + 2 : end.start - 1]
 988        factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement
 989        expression = factory(segment.strip())
 990        if meta_sql:
 991            expression.meta["sql"] = sql[start.start : end.end + 1]
 992        return expression
 993
 994    def parse_virtual_statement(
 995        chunks: t.List[t.Tuple[t.List[Token], ChunkType]], pos: int
 996    ) -> t.Tuple[t.List[exp.Expr], int]:
 997        # For virtual statements we need to handle both SQL and Jinja nested blocks within the chunk
 998        virtual_update_statements: t.List[exp.Expr] = []
 999        start = chunks[pos][0][0].start
1000
1001        while (
1002            chunks[pos - 1][0] == [] or chunks[pos - 1][0][-1].text.upper() != ON_VIRTUAL_UPDATE_END
1003        ):
1004            chunk, chunk_type = chunks[pos]
1005            if chunk_type == ChunkType.JINJA_STATEMENT:
1006                virtual_update_statements.append(parse_jinja_chunk(chunk, False))
1007            else:
1008                virtual_update_statements.extend(
1009                    parse_sql_chunk(
1010                        chunk[int(chunk[0].text.upper() == ON_VIRTUAL_UPDATE_BEGIN) : -1], False
1011                    ),
1012                )
1013            pos += 1
1014
1015        if virtual_update_statements:
1016            statements = virtual_statement(virtual_update_statements)
1017            end = chunk[-1].end + 1
1018            statements.meta["sql"] = sql[start:end]
1019            return [statements], pos
1020
1021        return [], pos
1022
1023    pos = 0
1024    total_chunks = len(chunks)
1025    while pos < total_chunks:
1026        chunk, chunk_type = chunks[pos]
1027        if chunk_type == ChunkType.VIRTUAL_STATEMENT:
1028            virtual_expression, pos = parse_virtual_statement(chunks, pos)
1029            expressions.extend(virtual_expression)
1030        elif chunk_type == ChunkType.SQL:
1031            expressions.extend(parse_sql_chunk(chunk))
1032        else:
1033            expressions.append(parse_jinja_chunk(chunk))
1034        pos += 1
1035
1036    return expressions
1037
1038
1039def extend_sqlglot() -> None:
1040    """Extend SQLGlot with SQLMesh's custom macro aware dialect."""
1041    tokenizers = {Tokenizer}
1042    parsers = {Parser}
1043    generators = {Generator}
1044
1045    for dialect in Dialect.classes.values():
1046        # Athena picks a different Tokenizer / Parser / Generator depending on the query
1047        # so this ensures that the extra ones it defines are also extended
1048        if dialect == athena.Athena:
1049            tokenizers.add(athena._TrinoTokenizer)
1050            parsers.add(AthenaTrinoParser)
1051            generators.add(athena._TrinoGenerator)
1052            generators.add(athena._HiveGenerator)
1053
1054        if hasattr(dialect, "Tokenizer"):
1055            tokenizers.add(dialect.Tokenizer)
1056        if hasattr(dialect, "Parser"):
1057            parsers.add(dialect.Parser)
1058        if hasattr(dialect, "Generator"):
1059            generators.add(dialect.Generator)
1060
1061    for tokenizer in tokenizers:
1062        tokenizer.VAR_SINGLE_TOKENS.update(SQLMESH_MACRO_PREFIX)
1063
1064    for parser in parsers:
1065        parser.FUNCTIONS.update({"JINJA": Jinja.from_arg_list, "METRIC": MetricAgg.from_arg_list})
1066        parser.PLACEHOLDER_PARSERS.update({TokenType.PARAMETER: _parse_macro})
1067        parser.QUERY_MODIFIER_PARSERS.update(
1068            {TokenType.PARAMETER: lambda self: _parse_body_macro(self)}
1069        )
1070
1071    for generator in generators:
1072        if MacroFunc not in generator.TRANSFORMS:
1073            generator.TRANSFORMS.update(
1074                {
1075                    Audit: lambda self, e: _sqlmesh_ddl_sql(self, e, "AUDIT"),
1076                    DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}",
1077                    Jinja: lambda self, e: e.name,
1078                    JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};",
1079                    JinjaStatement: lambda self,
1080                    e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};",
1081                    VirtualUpdateStatement: lambda self, e: _on_virtual_update_sql(self, e),
1082                    MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})",
1083                    MacroFunc: _macro_func_sql,
1084                    MacroStrReplace: lambda self, e: f"@{self.sql(e.this)}",
1085                    MacroSQL: lambda self, e: f"@SQL({self.sql(e.this)})",
1086                    MacroVar: lambda self, e: f"@{e.name}",
1087                    Metric: lambda self, e: _sqlmesh_ddl_sql(self, e, "METRIC"),
1088                    Model: lambda self, e: _sqlmesh_ddl_sql(self, e, "MODEL"),
1089                    ModelKind: _model_kind_sql,
1090                    PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False),
1091                    StagedFilePath: lambda self, e: self.table_sql(e),
1092                    exp.Whens: _whens_sql,
1093                }
1094            )
1095        if MacroDef not in generator.WITH_SEPARATED_COMMENTS:
1096            generator.WITH_SEPARATED_COMMENTS = (
1097                *generator.WITH_SEPARATED_COMMENTS,
1098                Model,
1099                MacroDef,
1100            )
1101
1102        generator.UNWRAPPED_INTERVAL_VALUES = (
1103            *generator.UNWRAPPED_INTERVAL_VALUES,
1104            MacroStrReplace,
1105            MacroVar,
1106        )
1107
1108    _override(Parser, _parse_select)
1109    _override(Parser, _parse_statement)
1110    _override(Parser, _parse_join)
1111    _override(Parser, _parse_order)
1112    _override(Parser, _parse_where)
1113    _override(Parser, _parse_group)
1114    _override(Parser, _parse_with)
1115    _override(Parser, _parse_having)
1116    _override(Parser, _parse_limit)
1117    _override(Parser, _parse_value)
1118    _override(Parser, _parse_lambda)
1119    _override(Parser, _parse_types)
1120    _override(TSQL.Parser, Parser._parse_if)
1121    _override(Parser, _parse_if)
1122    _override(Parser, _parse_id_var)
1123    _override(Parser, _warn_unsupported)
1124    _override(Snowflake.Parser, _parse_table_parts)
1125
1126    # DuckDB's prefix absolute power operator `@` clashes with the macro syntax
1127    DuckDB.Parser.NO_PAREN_FUNCTION_PARSERS.pop("@", None)
1128
1129
1130def select_from_values(
1131    values: t.List[t.Tuple[t.Any, ...]],
1132    columns_to_types: t.Dict[str, exp.DataType],
1133    batch_size: int = 0,
1134    alias: str = "t",
1135) -> t.Iterator[exp.Select]:
1136    """Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types.
1137
1138    Args:
1139        values: List of values to use for the VALUES expression.
1140        columns_to_types: Mapping of column names to types to assign to the values.
1141        batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
1142        alias: The alias to assign to the values expression. If not provided then will default to "t"
1143
1144    Returns:
1145        This method operates as a generator and yields a VALUES expression.
1146    """
1147    if batch_size <= 0:
1148        batch_size = sys.maxsize
1149    num_rows = len(values)
1150    for i in range(0, num_rows, batch_size):
1151        yield select_from_values_for_batch_range(
1152            values=values,
1153            target_columns_to_types=columns_to_types,
1154            batch_start=i,
1155            batch_end=min(i + batch_size, num_rows),
1156            alias=alias,
1157        )
1158
1159
1160def select_from_values_for_batch_range(
1161    values: t.List[t.Tuple[t.Any, ...]],
1162    target_columns_to_types: t.Dict[str, exp.DataType],
1163    batch_start: int,
1164    batch_end: int,
1165    alias: str = "t",
1166    source_columns: t.Optional[t.List[str]] = None,
1167) -> exp.Select:
1168    source_columns = source_columns or list(target_columns_to_types)
1169    source_columns_to_types = get_source_columns_to_types(target_columns_to_types, source_columns)
1170
1171    if not values:
1172        # Ensures we don't generate an empty VALUES clause & forces a zero-row output
1173        where = exp.false()
1174        expressions = [
1175            tuple(exp.cast(exp.null(), to=kind) for kind in source_columns_to_types.values())
1176        ]
1177    else:
1178        where = None
1179        expressions = [
1180            tuple(transform_values(v, source_columns_to_types))
1181            for v in values[batch_start:batch_end]
1182        ]
1183
1184    values_exp = exp.values(expressions, alias=alias, columns=source_columns_to_types)
1185    if values:
1186        # BigQuery crashes on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([NULL]) AS x`, but not
1187        # on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([CAST(NULL AS TIMESTAMP)]) AS x`. This
1188        # ensures nulls under the `Values` expression are cast to avoid similar issues.
1189        for value, kind in zip(
1190            values_exp.expressions[0].expressions, source_columns_to_types.values()
1191        ):
1192            if isinstance(value, exp.Null):
1193                value.replace(exp.cast(value, to=kind))
1194
1195    casted_columns = [
1196        exp.alias_(
1197            exp.cast(
1198                exp.column(column) if column in source_columns_to_types else exp.Null(), to=kind
1199            ),
1200            column,
1201            copy=False,
1202        )
1203        for column, kind in target_columns_to_types.items()
1204    ]
1205    return exp.select(*casted_columns).from_(values_exp, copy=False).where(where, copy=False)
1206
1207
1208def pandas_to_sql(
1209    df: pd.DataFrame,
1210    columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
1211    batch_size: int = 0,
1212    alias: str = "t",
1213) -> t.Iterator[exp.Select]:
1214    """Convert a pandas dataframe into a VALUES sql statement.
1215
1216    Args:
1217        df: A pandas dataframe to convert.
1218        columns_to_types: Mapping of column names to types to assign to the values.
1219        batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
1220        alias: The alias to assign to the values expression. If not provided then will default to "t"
1221
1222    Returns:
1223        This method operates as a generator and yields a VALUES expression.
1224    """
1225    yield from select_from_values(
1226        values=list(df.itertuples(index=False, name=None)),
1227        columns_to_types=columns_to_types or columns_to_types_from_df(df),
1228        batch_size=batch_size,
1229        alias=alias,
1230    )
1231
1232
1233def set_default_catalog(
1234    table: str | exp.Table,
1235    default_catalog: t.Optional[str],
1236) -> exp.Table:
1237    table = exp.to_table(table)
1238
1239    if default_catalog and not table.catalog and table.db:
1240        table.set("catalog", exp.parse_identifier(default_catalog))
1241
1242    return table
1243
1244
1245@lru_cache(maxsize=16384)
1246def normalize_model_name(
1247    table: str | exp.Table | exp.Column,
1248    default_catalog: t.Optional[str],
1249    dialect: DialectType = None,
1250) -> str:
1251    if isinstance(table, exp.Column):
1252        table = exp.table_(table.this, db=table.args.get("table"), catalog=table.args.get("db"))
1253    else:
1254        # We are relying on sqlglot's flexible parsing here to accept quotes from other dialects.
1255        # Ex: I have a a normalized name of '"my_table"' but the dialect is spark and therefore we should
1256        # expect spark quotes to be backticks ('`') instead of double quotes ('"'). sqlglot today is flexible
1257        # and will still parse this correctly and we rely on that.
1258        table = exp.to_table(table, dialect=dialect)
1259
1260    table = set_default_catalog(table, default_catalog)
1261    # An alternative way to do this is the following: exp.table_name(table, dialect=dialect, identify=True)
1262    # This though would result in the names being normalized to the target dialect AND the quotes while the below
1263    # approach just normalizes the names.
1264    # By just normalizing names and using sqlglot dialect for quotes this makes it easier for dialects that have
1265    # compatible normalization strategies but incompatible quoting to still work together without user hassle
1266    return exp.table_name(normalize_identifiers(table, dialect=dialect), identify=True)
1267
1268
1269def find_tables(
1270    expression: exp.Expr, default_catalog: t.Optional[str], dialect: DialectType = None
1271) -> t.Set[str]:
1272    """Find all tables referenced in a query.
1273
1274    Caches the result in the meta field 'tables'.
1275
1276    Args:
1277        expressions: The query to find the tables in.
1278        dialect: The dialect to use for normalization of table names.
1279
1280    Returns:
1281        A Set of all the table names.
1282    """
1283    if TABLES_META not in expression.meta:
1284        expression.meta[TABLES_META] = {
1285            normalize_model_name(table, default_catalog=default_catalog, dialect=dialect)
1286            for scope in traverse_scope(expression)
1287            for table in scope.tables
1288            if table.name and table.name not in scope.cte_sources
1289        }
1290    return expression.meta[TABLES_META]
1291
1292
1293def add_table(node: exp.Expr, table: str) -> exp.Expr:
1294    """Add a table to all columns in an expression."""
1295
1296    def _transform(node: exp.Expr) -> exp.Expr:
1297        if isinstance(node, exp.Column) and not node.table:
1298            return exp.column(node.this, table=table)
1299        if isinstance(node, exp.Identifier):
1300            return exp.column(node, table=table)
1301        return node
1302
1303    return node.transform(_transform)
1304
1305
1306def transform_values(
1307    values: t.Tuple[t.Any, ...], columns_to_types: t.Dict[str, exp.DataType]
1308) -> t.Iterator[t.Any]:
1309    """Perform transformations on values given columns_to_types."""
1310
1311    def _transform_value(value: t.Any, dtype: exp.DataType) -> t.Any:
1312        if (
1313            isinstance(value, list)
1314            and dtype.is_type(*exp.DataType.ARRAY_TYPES)
1315            and len(dtype.expressions) == 1
1316        ):
1317            element_type = dtype.expressions[0]
1318            return exp.convert([_transform_value(v, element_type) for v in value])
1319
1320        if (
1321            isinstance(value, dict)
1322            and dtype.is_type(*exp.DataType.STRUCT_TYPES)
1323            and len(value) == len(dtype.expressions)
1324        ):
1325            expressions = []
1326            for (field_name, field_value), field_type in zip(value.items(), dtype.expressions):
1327                if isinstance(field_type, exp.ColumnDef):
1328                    field_type = field_type.kind
1329                else:
1330                    field_type = exp.DataType.build(exp.DataType.Type.UNKNOWN)
1331
1332                expressions.append(
1333                    exp.PropertyEQ(
1334                        this=exp.to_identifier(field_name),
1335                        expression=_transform_value(field_value, field_type),
1336                    )
1337                )
1338
1339            return exp.Struct(expressions=expressions)
1340
1341        if dtype.is_type(exp.DataType.Type.JSON):
1342            return exp.func("PARSE_JSON", f"'{value}'")
1343
1344        return exp.convert(value)
1345
1346    for col_value, col_type in zip(values, columns_to_types.values()):
1347        yield _transform_value(col_value, col_type)
1348
1349
1350def to_schema(sql_path: str | exp.Table, dialect: DialectType = None) -> exp.Table:
1351    if isinstance(sql_path, exp.Table) and sql_path.this is None:
1352        return sql_path
1353    table = exp.to_table(
1354        sql_path.copy() if isinstance(sql_path, exp.Table) else sql_path, dialect=dialect
1355    )
1356    table.set("catalog", table.args.get("db"))
1357    table.set("db", table.args.get("this"))
1358    table.set("this", None)
1359    return table
1360
1361
1362def schema_(
1363    db: exp.Identifier | str,
1364    catalog: t.Optional[exp.Identifier | str] = None,
1365    quoted: t.Optional[bool] = None,
1366) -> exp.Table:
1367    """Build a Schema.
1368
1369    Args:
1370        db: Database name.
1371        catalog: Catalog name.
1372        quoted: Whether to force quotes on the schema's identifiers.
1373
1374    Returns:
1375        The new Schema instance.
1376    """
1377    return exp.Table(
1378        this=None,
1379        db=exp.to_identifier(db, quoted=quoted) if db else None,
1380        catalog=exp.to_identifier(catalog, quoted=quoted) if catalog else None,
1381    )
1382
1383
1384def normalize_mapping_schema(schema: t.Dict, dialect: DialectType) -> MappingSchema:
1385    return MappingSchema(_unquote_schema(schema), dialect=dialect, normalize=False)
1386
1387
1388def _unquote_schema(schema: t.Dict) -> t.Dict:
1389    """SQLGlot schema expects unquoted normalized keys."""
1390    return {
1391        k.strip('"'): _unquote_schema(v) if isinstance(v, dict) else v for k, v in schema.items()
1392    }
1393
1394
1395@contextmanager
1396def normalize_and_quote(
1397    query: E, dialect: DialectType, default_catalog: t.Optional[str], quote: bool = True
1398) -> t.Iterator[E]:
1399    qualify_tables(query, catalog=default_catalog, dialect=dialect)
1400    normalize_identifiers(query, dialect=dialect)
1401    yield query
1402    if quote:
1403        quote_identifiers(query, dialect=dialect)
1404
1405
1406def interpret_expression(e: exp.Expr) -> exp.Expr | str | int | float | bool:
1407    if e.is_int:
1408        return int(e.this)
1409    if e.is_number:
1410        return float(e.this)
1411    if isinstance(e, (exp.Literal, exp.Boolean)):
1412        return e.this
1413    return e
1414
1415
1416def interpret_key_value_pairs(
1417    e: exp.Tuple,
1418) -> t.Dict[str, exp.Expr | str | int | float | bool]:
1419    return {i.this.name: interpret_expression(i.expression) for i in e.expressions}
1420
1421
1422def extract_func_call(
1423    v: exp.Expr, allow_tuples: bool = False
1424) -> t.Tuple[str, t.Dict[str, exp.Expr]]:
1425    kwargs = {}
1426
1427    if isinstance(v, exp.Anonymous):
1428        func = v.name
1429        args = v.expressions
1430    elif isinstance(v, exp.Func):
1431        func = v.sql_name()
1432        args = list(v.args.values())
1433    elif isinstance(v, exp.Paren):
1434        func = ""
1435        args = [v.this]
1436    elif isinstance(v, exp.Tuple):  # airflow only
1437        if not allow_tuples:
1438            raise ConfigError("Audit name is missing (eg. MY_AUDIT())")
1439
1440        func = ""
1441        args = v.expressions
1442    else:
1443        return v.name.lower(), {}
1444
1445    for arg in args:
1446        if not isinstance(arg, (exp.PropertyEQ, exp.EQ)):
1447            raise ConfigError(
1448                f"Function '{func}' must be called with key-value arguments like {func}(arg := value)."
1449            )
1450        kwargs[arg.left.name.lower()] = arg.right
1451    return func.lower(), kwargs
1452
1453
1454def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.Any:
1455    """Used for extracting function calls for signals or audits."""
1456
1457    if isinstance(func_calls, (exp.Tuple, exp.Array)):
1458        return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions]
1459    if isinstance(func_calls, exp.Paren):
1460        return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)]
1461    if isinstance(func_calls, exp.Expr):
1462        return [extract_func_call(func_calls, allow_tuples=allow_tuples)]
1463    if isinstance(func_calls, list):
1464        function_calls = []
1465        for entry in func_calls:
1466            if isinstance(entry, dict):
1467                args = entry
1468                name = "" if allow_tuples else entry.pop("name")
1469            elif isinstance(entry, (tuple, list)):
1470                name, args = entry
1471            else:
1472                raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.")
1473
1474            function_calls.append(
1475                (
1476                    name.lower(),
1477                    {
1478                        key: parse_one(value) if isinstance(value, str) else value
1479                        for key, value in args.items()
1480                    },
1481                )
1482            )
1483
1484        return function_calls
1485
1486    return func_calls or []
1487
1488
1489def is_meta_expression(v: t.Any) -> bool:
1490    return isinstance(v, (Audit, Metric, Model))
1491
1492
1493def replace_merge_table_aliases(expression: exp.Expr, dialect: t.Optional[str] = None) -> exp.Expr:
1494    """
1495    Resolves references from the "source" and "target" tables (or their DBT equivalents)
1496    with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)
1497    """
1498    from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
1499
1500    if isinstance(expression, exp.Column) and (first_part := expression.parts[0]):
1501        if first_part.this.lower() in ("target", "dbt_internal_dest", "__merge_target__"):
1502            first_part.replace(exp.to_identifier(MERGE_TARGET_ALIAS, quoted=True))
1503        elif first_part.this.lower() in ("source", "dbt_internal_source", "__merge_source__"):
1504            first_part.replace(exp.to_identifier(MERGE_SOURCE_ALIAS, quoted=True))
1505
1506    return expression
SQLMESH_MACRO_PREFIX = '@'
TABLES_META = 'sqlmesh.tables'
logger = <Logger sqlmesh.core.dialect (WARNING)>
class Model(sqlglot.expressions.core.Expression):
45class Model(exp.Expression):
46    arg_types = {"expressions": True}
arg_types = {'expressions': True}
key: ClassVar[str] = 'model'
required_args: ClassVar[Set[str]] = {'expressions'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_var_len_args
is_subquery
is_cast
is_primitive
dump
load
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
class Audit(sqlglot.expressions.core.Expression):
49class Audit(exp.Expression):
50    arg_types = {"expressions": True}
arg_types = {'expressions': True}
key: ClassVar[str] = 'audit'
required_args: ClassVar[Set[str]] = {'expressions'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_var_len_args
is_subquery
is_cast
is_primitive
dump
load
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
class Metric(sqlglot.expressions.core.Expression):
53class Metric(exp.Expression):
54    arg_types = {"expressions": True}
arg_types = {'expressions': True}
key: ClassVar[str] = 'metric'
required_args: ClassVar[Set[str]] = {'expressions'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_var_len_args
is_subquery
is_cast
is_primitive
dump
load
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
class Jinja(sqlglot.expressions.core.Expression, sqlglot.expressions.core.Func):
57class Jinja(exp.Expression, exp.Func):
58    arg_types = {"this": True}
arg_types = {'this': True}
key: ClassVar[str] = 'jinja'
required_args: ClassVar[Set[str]] = {'this'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_subquery
is_cast
is_primitive
dump
load
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
sqlglot.expressions.core.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class JinjaQuery(Jinja):
61class JinjaQuery(Jinja):
62    pass
key: ClassVar[str] = 'jinjaquery'
required_args: ClassVar[Set[str]] = {'this'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_subquery
is_cast
is_primitive
dump
load
Jinja
arg_types
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
sqlglot.expressions.core.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class JinjaStatement(Jinja):
65class JinjaStatement(Jinja):
66    pass
key: ClassVar[str] = 'jinjastatement'
required_args: ClassVar[Set[str]] = {'this'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_subquery
is_cast
is_primitive
dump
load
Jinja
arg_types
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
sqlglot.expressions.core.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class VirtualUpdateStatement(sqlglot.expressions.core.Expression):
69class VirtualUpdateStatement(exp.Expression):
70    arg_types = {"expressions": True}
arg_types = {'expressions': True}
key: ClassVar[str] = 'virtualupdatestatement'
required_args: ClassVar[Set[str]] = {'expressions'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_var_len_args
is_subquery
is_cast
is_primitive
dump
load
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
class ModelKind(sqlglot.expressions.core.Expression):
73class ModelKind(exp.Expression):
74    arg_types = {"this": True, "expressions": False}
arg_types = {'this': True, 'expressions': False}
key: ClassVar[str] = 'modelkind'
required_args: ClassVar[Set[str]] = {'this'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_var_len_args
is_subquery
is_cast
is_primitive
dump
load
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
class MacroVar(sqlglot.expressions.core.Var):
77class MacroVar(exp.Var):
78    pass
key: ClassVar[str] = 'macrovar'
required_args: ClassVar[Set[str]] = {'this'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
arg_types
is_var_len_args
is_subquery
is_cast
dump
load
sqlglot.expressions.core.Var
is_primitive
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
class MacroFunc(sqlglot.expressions.core.Expression, sqlglot.expressions.core.Func):
81class MacroFunc(exp.Expression, exp.Func):
82    @property
83    def name(self) -> str:
84        return self.this.name
name: str
82    @property
83    def name(self) -> str:
84        return self.this.name
key: ClassVar[str] = 'macrofunc'
required_args: ClassVar[Set[str]] = {'this'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
arg_types
is_subquery
is_cast
is_primitive
dump
load
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
sqlglot.expressions.core.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class MacroDef(MacroFunc):
87class MacroDef(MacroFunc):
88    arg_types = {"this": True, "expression": True}
arg_types = {'this': True, 'expression': True}
key: ClassVar[str] = 'macrodef'
required_args: ClassVar[Set[str]] = {'expression', 'this'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_subquery
is_cast
is_primitive
dump
load
MacroFunc
name
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
sqlglot.expressions.core.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class MacroSQL(MacroFunc):
91class MacroSQL(MacroFunc):
92    arg_types = {"this": True, "into": False}
arg_types = {'this': True, 'into': False}
key: ClassVar[str] = 'macrosql'
required_args: ClassVar[Set[str]] = {'this'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_subquery
is_cast
is_primitive
dump
load
MacroFunc
name
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
sqlglot.expressions.core.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class MacroStrReplace(MacroFunc):
95class MacroStrReplace(MacroFunc):
96    pass
key: ClassVar[str] = 'macrostrreplace'
required_args: ClassVar[Set[str]] = {'this'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
arg_types
is_subquery
is_cast
is_primitive
dump
load
MacroFunc
name
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
sqlglot.expressions.core.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class PythonCode(sqlglot.expressions.core.Expression):
 99class PythonCode(exp.Expression):
100    arg_types = {"expressions": True}
arg_types = {'expressions': True}
key: ClassVar[str] = 'pythoncode'
required_args: ClassVar[Set[str]] = {'expressions'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_var_len_args
is_subquery
is_cast
is_primitive
dump
load
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
class DColonCast(sqlglot.expressions.functions.Cast):
103class DColonCast(exp.Cast):
104    pass
key: ClassVar[str] = 'dcoloncast'
required_args: ClassVar[Set[str]] = {'to', 'this'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_subquery
is_primitive
dump
load
sqlglot.expressions.functions.Cast
is_cast
arg_types
name
to
output_name
is_type
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
alias_or_name
type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
sqlglot.expressions.core.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class MetricAgg(sqlglot.expressions.core.Expression, sqlglot.expressions.core.AggFunc):
107class MetricAgg(exp.Expression, exp.AggFunc):
108    """Used for computing metrics."""
109
110    arg_types = {"this": True}
111
112    @property
113    def output_name(self) -> str:
114        return self.this.name

Used for computing metrics.

arg_types = {'this': True}
output_name: str
112    @property
113    def output_name(self) -> str:
114        return self.this.name

Name of the output column if this expression is a selection.

If the Expr has no output name, an empty string is returned.

Example:
>>> from sqlglot import parse_one
>>> parse_one("SELECT a").expressions[0].output_name
'a'
>>> parse_one("SELECT b AS c").expressions[0].output_name
'c'
>>> parse_one("SELECT 1 + 2").expressions[0].output_name
''
key: ClassVar[str] = 'metricagg'
required_args: ClassVar[Set[str]] = {'this'}
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_subquery
is_cast
is_primitive
dump
load
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
sqlglot.expressions.core.Func
is_var_len_args
from_arg_list
sql_names
sql_name
default_parser_mappings
class StagedFilePath(sqlglot.expressions.core.Expression):
117class StagedFilePath(exp.Expression):
118    """Represents paths to "staged files" in Snowflake."""
119
120    arg_types = exp.Table.arg_types.copy()

Represents paths to "staged files" in Snowflake.

arg_types = {'this': False, 'alias': False, 'db': False, 'catalog': False, 'laterals': False, 'joins': False, 'pivots': False, 'hints': False, 'system_time': False, 'version': False, 'format': False, 'pattern': False, 'ordinality': False, 'when': False, 'only': False, 'partition': False, 'changes': False, 'rows_from': False, 'sample': False, 'indexed': False}
key: ClassVar[str] = 'stagedfilepath'
required_args: ClassVar[Set[str]] = set()
Inherited Members
sqlglot.expressions.core.Expr
Expr
is_var_len_args
is_subquery
is_cast
is_primitive
dump
load
sqlglot.expressions.core.Expression
this
expression
expressions
text
is_string
is_number
to_py
is_int
is_star
alias
alias_column_names
name
alias_or_name
output_name
type
is_type
is_leaf
meta
copy
add_comments
pop_comments
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
to_s
sql
transform
replace
pop
assert_is
error_messages
and_
or_
not_
update_positions
as_
isin
between
is_
like
ilike
eq
neq
rlike
div
asc
desc
args
parent
arg_key
index
comments
KEYWORD_MACROS = {'LIMIT', 'ORDER_BY', 'WHERE', 'JOIN', 'WITH', 'GROUP_BY', 'HAVING'}
PARSERS = {'MODEL': <function _create_parser.<locals>.parse>, 'AUDIT': <function _create_parser.<locals>.parse>, 'METRIC': <function _create_parser.<locals>.parse>}
def format_model_expressions( expressions: List[sqlglot.expressions.core.Expr], dialect: Optional[str] = None, rewrite_casts: bool = True, **kwargs: Any) -> str:
750def format_model_expressions(
751    expressions: t.List[exp.Expr],
752    dialect: t.Optional[str] = None,
753    rewrite_casts: bool = True,
754    **kwargs: t.Any,
755) -> str:
756    """Format a model's expressions into a standardized format.
757
758    Args:
759        expressions: The model's expressions, must be at least model def + query.
760        dialect: The dialect to render the expressions as.
761        rewrite_casts: Whether to rewrite all casts to use the :: syntax.
762        **kwargs: Additional keyword arguments to pass to the sql generator.
763
764    Returns:
765        A string representing the formatted model.
766    """
767    if len(expressions) == 1 and is_meta_expression(expressions[0]):
768        return expressions[0].sql(pretty=True, dialect=dialect)
769
770    if rewrite_casts:
771
772        def cast_to_colon(node: exp.Expr) -> exp.Expr:
773            if isinstance(node, exp.Cast) and not any(
774                # Only convert CAST into :: if it doesn't have additional args set, otherwise this
775                # conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST)
776                arg
777                for name, arg in node.args.items()
778                if name not in ("this", "to")
779            ):
780                this = node.this
781
782                if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren):
783                    cast = DColonCast(this=this, to=node.to)
784                    cast.comments = node.comments
785                    node = cast
786
787            exp.replace_children(node, cast_to_colon)
788            return node
789
790        new_expressions = []
791        for expression in expressions:
792            expression = expression.copy()
793            exp.replace_children(expression, cast_to_colon)
794            new_expressions.append(expression)
795
796        expressions = new_expressions
797
798    return ";\n\n".join(
799        expression.sql(pretty=True, dialect=dialect, **kwargs) for expression in expressions
800    ).strip()

Format a model's expressions into a standardized format.

Arguments:
  • expressions: The model's expressions, must be at least model def + query.
  • dialect: The dialect to render the expressions as.
  • rewrite_casts: Whether to rewrite all casts to use the :: syntax.
  • **kwargs: Additional keyword arguments to pass to the sql generator.
Returns:

A string representing the formatted model.

def text_diff( a: List[sqlglot.expressions.core.Expr], b: List[sqlglot.expressions.core.Expr], a_dialect: Optional[str] = None, b_dialect: Optional[str] = None) -> str:
803def text_diff(
804    a: t.List[exp.Expr],
805    b: t.List[exp.Expr],
806    a_dialect: t.Optional[str] = None,
807    b_dialect: t.Optional[str] = None,
808) -> str:
809    """Find the unified text diff between two expressions."""
810    a_sql = [
811        line
812        for expr in a
813        for line in expr.sql(pretty=True, comments=False, dialect=a_dialect).split("\n")
814    ]
815    b_sql = [
816        line
817        for expr in b
818        for line in expr.sql(pretty=True, comments=False, dialect=b_dialect).split("\n")
819    ]
820    return "\n".join(unified_diff(a_sql, b_sql))

Find the unified text diff between two expressions.

WS_OR_COMMENT = '(?:\\s|--[^\\n]*\\n|/\\*.*?\\*/)'
KEY_BOUNDARY = '(?:\\(|,)'
DIALECT_VALUE = '[\'\\"]?(?P<dialect>[a-z][a-z0-9]*)[\'\\"]?'
VALUE_BOUNDARY = '(?=,|\\))'
DIALECT_PATTERN = re.compile('\\b(?:model|audit)\\b(?=\\s*\\().*?(?:\\(|,)(?:\\s|--[^\\n]*\\n|/\\*.*?\\*/)*dialect(?:\\s|--[^\\n]*\\n|/\\*.*?\\*/)+[\'\\"]?(?P<dialect>[a-z][a-z0-9]*)[\'\\"]?(?:\\s|--[^\\n]*\\n|/\\*.*?\\*/)*(?=,|\, re.IGNORECASE|re.DOTALL)
JINJA_QUERY_BEGIN = 'JINJA_QUERY_BEGIN'
JINJA_STATEMENT_BEGIN = 'JINJA_STATEMENT_BEGIN'
JINJA_END = 'JINJA_END'
ON_VIRTUAL_UPDATE_BEGIN = 'ON_VIRTUAL_UPDATE_BEGIN'
ON_VIRTUAL_UPDATE_END = 'ON_VIRTUAL_UPDATE_END'
def jinja_query(query: str) -> JinjaQuery:
864def jinja_query(query: str) -> JinjaQuery:
865    return JinjaQuery(this=exp.Literal.string(query.strip()))
def jinja_statement(statement: str) -> JinjaStatement:
868def jinja_statement(statement: str) -> JinjaStatement:
869    return JinjaStatement(this=exp.Literal.string(statement.strip()))
def virtual_statement( statements: List[sqlglot.expressions.core.Expr]) -> VirtualUpdateStatement:
880def virtual_statement(statements: t.List[exp.Expr]) -> VirtualUpdateStatement:
881    return VirtualUpdateStatement(expressions=statements)
class ChunkType(enum.Enum):
884class ChunkType(Enum):
885    JINJA_QUERY = auto()
886    JINJA_STATEMENT = auto()
887    SQL = auto()
888    VIRTUAL_STATEMENT = auto()
889    VIRTUAL_JINJA_STATEMENT = auto()

An enumeration.

JINJA_QUERY = <ChunkType.JINJA_QUERY: 1>
JINJA_STATEMENT = <ChunkType.JINJA_STATEMENT: 2>
SQL = <ChunkType.SQL: 3>
VIRTUAL_STATEMENT = <ChunkType.VIRTUAL_STATEMENT: 4>
VIRTUAL_JINJA_STATEMENT = <ChunkType.VIRTUAL_JINJA_STATEMENT: 5>
Inherited Members
enum.Enum
name
value
def parse_one( sql: str, dialect: Optional[str] = None, into: Union[Type[sqlglot.expressions.core.Expr], Collection[Type[sqlglot.expressions.core.Expr]], NoneType] = None) -> sqlglot.expressions.core.Expr:
892def parse_one(
893    sql: str, dialect: t.Optional[str] = None, into: t.Optional[exp.IntoType] = None
894) -> exp.Expr:
895    expressions = parse(sql, default_dialect=dialect, match_dialect=False, into=into)
896    if not expressions:
897        raise SQLMeshError(f"No expressions found in '{sql}'")
898    elif len(expressions) > 1:
899        raise SQLMeshError(f"Multiple expressions found in '{sql}'")
900    return expressions[0]
def parse( sql: str, default_dialect: Optional[str] = None, match_dialect: bool = True, into: Union[Type[sqlglot.expressions.core.Expr], Collection[Type[sqlglot.expressions.core.Expr]], NoneType] = None) -> List[sqlglot.expressions.core.Expr]:
 903def parse(
 904    sql: str,
 905    default_dialect: t.Optional[str] = None,
 906    match_dialect: bool = True,
 907    into: t.Optional[exp.IntoType] = None,
 908) -> t.List[exp.Expr]:
 909    """Parse a sql string.
 910
 911    Supports parsing model definition.
 912    If a jinja block is detected, the query is stored as raw string in a Jinja node.
 913
 914    Args:
 915        sql: The sql based definition.
 916        default_dialect: The dialect to use if the model does not specify one.
 917
 918    Returns:
 919        A list of the parsed expressions: [Model, *Statements, Query, *Statements]
 920    """
 921    match = match_dialect and DIALECT_PATTERN.search(sql[:MAX_MODEL_DEFINITION_SIZE])
 922    dialect_str = match.group("dialect") if match else None
 923    dialect = Dialect.get_or_raise(dialect_str or default_dialect)
 924
 925    tokens = dialect.tokenize(sql)
 926    chunks: t.List[t.Tuple[t.List[Token], ChunkType]] = [([], ChunkType.SQL)]
 927    total = len(tokens)
 928
 929    pos = 0
 930    virtual = False
 931    while pos < total:
 932        token = tokens[pos]
 933        if _is_virtual_statement_end(tokens, pos):
 934            chunks[-1][0].append(token)
 935            virtual = False
 936            chunks.append(([], ChunkType.SQL))
 937            pos += 2
 938        elif _is_jinja_end(tokens, pos) or (
 939            chunks[-1][1] == ChunkType.SQL
 940            and token.token_type == TokenType.SEMICOLON
 941            and pos < total - 1
 942        ):
 943            if token.token_type == TokenType.SEMICOLON:
 944                pos += 1
 945            else:
 946                # Jinja end statement
 947                chunks[-1][0].append(token)
 948                pos += 2
 949            chunks.append(
 950                (
 951                    [],
 952                    ChunkType.VIRTUAL_STATEMENT
 953                    if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END
 954                    else ChunkType.SQL,
 955                )
 956            )
 957        elif _is_jinja_query_begin(tokens, pos):
 958            chunks.append(([token], ChunkType.JINJA_QUERY))
 959            pos += 2
 960        elif _is_jinja_statement_begin(tokens, pos):
 961            chunks.append(([token], ChunkType.JINJA_STATEMENT))
 962            pos += 2
 963        elif _is_virtual_statement_begin(tokens, pos):
 964            chunks.append(([token], ChunkType.VIRTUAL_STATEMENT))
 965            pos += 2
 966            virtual = True
 967        else:
 968            chunks[-1][0].append(token)
 969            pos += 1
 970
 971    parser = dialect.parser()
 972    expressions: t.List[exp.Expr] = []
 973
 974    def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expr]:
 975        parsed_expressions: t.List[t.Optional[exp.Expr]] = (
 976            parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql)
 977        )
 978        expressions = []
 979        for expression in parsed_expressions:
 980            if expression:
 981                if meta_sql:
 982                    expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1])
 983                expressions.append(expression)
 984        return expressions
 985
 986    def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expr:
 987        start, *_, end = chunk
 988        segment = sql[start.end + 2 : end.start - 1]
 989        factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement
 990        expression = factory(segment.strip())
 991        if meta_sql:
 992            expression.meta["sql"] = sql[start.start : end.end + 1]
 993        return expression
 994
 995    def parse_virtual_statement(
 996        chunks: t.List[t.Tuple[t.List[Token], ChunkType]], pos: int
 997    ) -> t.Tuple[t.List[exp.Expr], int]:
 998        # For virtual statements we need to handle both SQL and Jinja nested blocks within the chunk
 999        virtual_update_statements: t.List[exp.Expr] = []
1000        start = chunks[pos][0][0].start
1001
1002        while (
1003            chunks[pos - 1][0] == [] or chunks[pos - 1][0][-1].text.upper() != ON_VIRTUAL_UPDATE_END
1004        ):
1005            chunk, chunk_type = chunks[pos]
1006            if chunk_type == ChunkType.JINJA_STATEMENT:
1007                virtual_update_statements.append(parse_jinja_chunk(chunk, False))
1008            else:
1009                virtual_update_statements.extend(
1010                    parse_sql_chunk(
1011                        chunk[int(chunk[0].text.upper() == ON_VIRTUAL_UPDATE_BEGIN) : -1], False
1012                    ),
1013                )
1014            pos += 1
1015
1016        if virtual_update_statements:
1017            statements = virtual_statement(virtual_update_statements)
1018            end = chunk[-1].end + 1
1019            statements.meta["sql"] = sql[start:end]
1020            return [statements], pos
1021
1022        return [], pos
1023
1024    pos = 0
1025    total_chunks = len(chunks)
1026    while pos < total_chunks:
1027        chunk, chunk_type = chunks[pos]
1028        if chunk_type == ChunkType.VIRTUAL_STATEMENT:
1029            virtual_expression, pos = parse_virtual_statement(chunks, pos)
1030            expressions.extend(virtual_expression)
1031        elif chunk_type == ChunkType.SQL:
1032            expressions.extend(parse_sql_chunk(chunk))
1033        else:
1034            expressions.append(parse_jinja_chunk(chunk))
1035        pos += 1
1036
1037    return expressions

Parse a sql string.

Supports parsing model definition. If a jinja block is detected, the query is stored as raw string in a Jinja node.

Arguments:
  • sql: The sql based definition.
  • default_dialect: The dialect to use if the model does not specify one.
Returns:

A list of the parsed expressions: [Model, *Statements, Query, *Statements]

def extend_sqlglot() -> None:
1040def extend_sqlglot() -> None:
1041    """Extend SQLGlot with SQLMesh's custom macro aware dialect."""
1042    tokenizers = {Tokenizer}
1043    parsers = {Parser}
1044    generators = {Generator}
1045
1046    for dialect in Dialect.classes.values():
1047        # Athena picks a different Tokenizer / Parser / Generator depending on the query
1048        # so this ensures that the extra ones it defines are also extended
1049        if dialect == athena.Athena:
1050            tokenizers.add(athena._TrinoTokenizer)
1051            parsers.add(AthenaTrinoParser)
1052            generators.add(athena._TrinoGenerator)
1053            generators.add(athena._HiveGenerator)
1054
1055        if hasattr(dialect, "Tokenizer"):
1056            tokenizers.add(dialect.Tokenizer)
1057        if hasattr(dialect, "Parser"):
1058            parsers.add(dialect.Parser)
1059        if hasattr(dialect, "Generator"):
1060            generators.add(dialect.Generator)
1061
1062    for tokenizer in tokenizers:
1063        tokenizer.VAR_SINGLE_TOKENS.update(SQLMESH_MACRO_PREFIX)
1064
1065    for parser in parsers:
1066        parser.FUNCTIONS.update({"JINJA": Jinja.from_arg_list, "METRIC": MetricAgg.from_arg_list})
1067        parser.PLACEHOLDER_PARSERS.update({TokenType.PARAMETER: _parse_macro})
1068        parser.QUERY_MODIFIER_PARSERS.update(
1069            {TokenType.PARAMETER: lambda self: _parse_body_macro(self)}
1070        )
1071
1072    for generator in generators:
1073        if MacroFunc not in generator.TRANSFORMS:
1074            generator.TRANSFORMS.update(
1075                {
1076                    Audit: lambda self, e: _sqlmesh_ddl_sql(self, e, "AUDIT"),
1077                    DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}",
1078                    Jinja: lambda self, e: e.name,
1079                    JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};",
1080                    JinjaStatement: lambda self,
1081                    e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};",
1082                    VirtualUpdateStatement: lambda self, e: _on_virtual_update_sql(self, e),
1083                    MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})",
1084                    MacroFunc: _macro_func_sql,
1085                    MacroStrReplace: lambda self, e: f"@{self.sql(e.this)}",
1086                    MacroSQL: lambda self, e: f"@SQL({self.sql(e.this)})",
1087                    MacroVar: lambda self, e: f"@{e.name}",
1088                    Metric: lambda self, e: _sqlmesh_ddl_sql(self, e, "METRIC"),
1089                    Model: lambda self, e: _sqlmesh_ddl_sql(self, e, "MODEL"),
1090                    ModelKind: _model_kind_sql,
1091                    PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False),
1092                    StagedFilePath: lambda self, e: self.table_sql(e),
1093                    exp.Whens: _whens_sql,
1094                }
1095            )
1096        if MacroDef not in generator.WITH_SEPARATED_COMMENTS:
1097            generator.WITH_SEPARATED_COMMENTS = (
1098                *generator.WITH_SEPARATED_COMMENTS,
1099                Model,
1100                MacroDef,
1101            )
1102
1103        generator.UNWRAPPED_INTERVAL_VALUES = (
1104            *generator.UNWRAPPED_INTERVAL_VALUES,
1105            MacroStrReplace,
1106            MacroVar,
1107        )
1108
1109    _override(Parser, _parse_select)
1110    _override(Parser, _parse_statement)
1111    _override(Parser, _parse_join)
1112    _override(Parser, _parse_order)
1113    _override(Parser, _parse_where)
1114    _override(Parser, _parse_group)
1115    _override(Parser, _parse_with)
1116    _override(Parser, _parse_having)
1117    _override(Parser, _parse_limit)
1118    _override(Parser, _parse_value)
1119    _override(Parser, _parse_lambda)
1120    _override(Parser, _parse_types)
1121    _override(TSQL.Parser, Parser._parse_if)
1122    _override(Parser, _parse_if)
1123    _override(Parser, _parse_id_var)
1124    _override(Parser, _warn_unsupported)
1125    _override(Snowflake.Parser, _parse_table_parts)
1126
1127    # DuckDB's prefix absolute power operator `@` clashes with the macro syntax
1128    DuckDB.Parser.NO_PAREN_FUNCTION_PARSERS.pop("@", None)

Extend SQLGlot with SQLMesh's custom macro aware dialect.

def select_from_values( values: List[Tuple[Any, ...]], columns_to_types: Dict[str, sqlglot.expressions.datatypes.DataType], batch_size: int = 0, alias: str = 't') -> Iterator[sqlglot.expressions.query.Select]:
1131def select_from_values(
1132    values: t.List[t.Tuple[t.Any, ...]],
1133    columns_to_types: t.Dict[str, exp.DataType],
1134    batch_size: int = 0,
1135    alias: str = "t",
1136) -> t.Iterator[exp.Select]:
1137    """Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types.
1138
1139    Args:
1140        values: List of values to use for the VALUES expression.
1141        columns_to_types: Mapping of column names to types to assign to the values.
1142        batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
1143        alias: The alias to assign to the values expression. If not provided then will default to "t"
1144
1145    Returns:
1146        This method operates as a generator and yields a VALUES expression.
1147    """
1148    if batch_size <= 0:
1149        batch_size = sys.maxsize
1150    num_rows = len(values)
1151    for i in range(0, num_rows, batch_size):
1152        yield select_from_values_for_batch_range(
1153            values=values,
1154            target_columns_to_types=columns_to_types,
1155            batch_start=i,
1156            batch_end=min(i + batch_size, num_rows),
1157            alias=alias,
1158        )

Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types.

Arguments:
  • values: List of values to use for the VALUES expression.
  • columns_to_types: Mapping of column names to types to assign to the values.
  • batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
  • alias: The alias to assign to the values expression. If not provided then will default to "t"
Returns:

This method operates as a generator and yields a VALUES expression.

def select_from_values_for_batch_range( values: List[Tuple[Any, ...]], target_columns_to_types: Dict[str, sqlglot.expressions.datatypes.DataType], batch_start: int, batch_end: int, alias: str = 't', source_columns: Optional[List[str]] = None) -> sqlglot.expressions.query.Select:
1161def select_from_values_for_batch_range(
1162    values: t.List[t.Tuple[t.Any, ...]],
1163    target_columns_to_types: t.Dict[str, exp.DataType],
1164    batch_start: int,
1165    batch_end: int,
1166    alias: str = "t",
1167    source_columns: t.Optional[t.List[str]] = None,
1168) -> exp.Select:
1169    source_columns = source_columns or list(target_columns_to_types)
1170    source_columns_to_types = get_source_columns_to_types(target_columns_to_types, source_columns)
1171
1172    if not values:
1173        # Ensures we don't generate an empty VALUES clause & forces a zero-row output
1174        where = exp.false()
1175        expressions = [
1176            tuple(exp.cast(exp.null(), to=kind) for kind in source_columns_to_types.values())
1177        ]
1178    else:
1179        where = None
1180        expressions = [
1181            tuple(transform_values(v, source_columns_to_types))
1182            for v in values[batch_start:batch_end]
1183        ]
1184
1185    values_exp = exp.values(expressions, alias=alias, columns=source_columns_to_types)
1186    if values:
1187        # BigQuery crashes on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([NULL]) AS x`, but not
1188        # on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([CAST(NULL AS TIMESTAMP)]) AS x`. This
1189        # ensures nulls under the `Values` expression are cast to avoid similar issues.
1190        for value, kind in zip(
1191            values_exp.expressions[0].expressions, source_columns_to_types.values()
1192        ):
1193            if isinstance(value, exp.Null):
1194                value.replace(exp.cast(value, to=kind))
1195
1196    casted_columns = [
1197        exp.alias_(
1198            exp.cast(
1199                exp.column(column) if column in source_columns_to_types else exp.Null(), to=kind
1200            ),
1201            column,
1202            copy=False,
1203        )
1204        for column, kind in target_columns_to_types.items()
1205    ]
1206    return exp.select(*casted_columns).from_(values_exp, copy=False).where(where, copy=False)
def pandas_to_sql( df: pandas.core.frame.DataFrame, columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]] = None, batch_size: int = 0, alias: str = 't') -> Iterator[sqlglot.expressions.query.Select]:
1209def pandas_to_sql(
1210    df: pd.DataFrame,
1211    columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
1212    batch_size: int = 0,
1213    alias: str = "t",
1214) -> t.Iterator[exp.Select]:
1215    """Convert a pandas dataframe into a VALUES sql statement.
1216
1217    Args:
1218        df: A pandas dataframe to convert.
1219        columns_to_types: Mapping of column names to types to assign to the values.
1220        batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
1221        alias: The alias to assign to the values expression. If not provided then will default to "t"
1222
1223    Returns:
1224        This method operates as a generator and yields a VALUES expression.
1225    """
1226    yield from select_from_values(
1227        values=list(df.itertuples(index=False, name=None)),
1228        columns_to_types=columns_to_types or columns_to_types_from_df(df),
1229        batch_size=batch_size,
1230        alias=alias,
1231    )

Convert a pandas dataframe into a VALUES sql statement.

Arguments:
  • df: A pandas dataframe to convert.
  • columns_to_types: Mapping of column names to types to assign to the values.
  • batch_size: The maximum number of tuples per batches. Defaults to sys.maxsize if <= 0.
  • alias: The alias to assign to the values expression. If not provided then will default to "t"
Returns:

This method operates as a generator and yields a VALUES expression.

def set_default_catalog( table: str | sqlglot.expressions.query.Table, default_catalog: Optional[str]) -> sqlglot.expressions.query.Table:
1234def set_default_catalog(
1235    table: str | exp.Table,
1236    default_catalog: t.Optional[str],
1237) -> exp.Table:
1238    table = exp.to_table(table)
1239
1240    if default_catalog and not table.catalog and table.db:
1241        table.set("catalog", exp.parse_identifier(default_catalog))
1242
1243    return table
@lru_cache(maxsize=16384)
def normalize_model_name( table: str | sqlglot.expressions.query.Table | sqlglot.expressions.core.Column, default_catalog: Optional[str], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> str:
1246@lru_cache(maxsize=16384)
1247def normalize_model_name(
1248    table: str | exp.Table | exp.Column,
1249    default_catalog: t.Optional[str],
1250    dialect: DialectType = None,
1251) -> str:
1252    if isinstance(table, exp.Column):
1253        table = exp.table_(table.this, db=table.args.get("table"), catalog=table.args.get("db"))
1254    else:
1255        # We are relying on sqlglot's flexible parsing here to accept quotes from other dialects.
1256        # Ex: I have a a normalized name of '"my_table"' but the dialect is spark and therefore we should
1257        # expect spark quotes to be backticks ('`') instead of double quotes ('"'). sqlglot today is flexible
1258        # and will still parse this correctly and we rely on that.
1259        table = exp.to_table(table, dialect=dialect)
1260
1261    table = set_default_catalog(table, default_catalog)
1262    # An alternative way to do this is the following: exp.table_name(table, dialect=dialect, identify=True)
1263    # This though would result in the names being normalized to the target dialect AND the quotes while the below
1264    # approach just normalizes the names.
1265    # By just normalizing names and using sqlglot dialect for quotes this makes it easier for dialects that have
1266    # compatible normalization strategies but incompatible quoting to still work together without user hassle
1267    return exp.table_name(normalize_identifiers(table, dialect=dialect), identify=True)
def find_tables( expression: sqlglot.expressions.core.Expr, default_catalog: Optional[str], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> Set[str]:
1270def find_tables(
1271    expression: exp.Expr, default_catalog: t.Optional[str], dialect: DialectType = None
1272) -> t.Set[str]:
1273    """Find all tables referenced in a query.
1274
1275    Caches the result in the meta field 'tables'.
1276
1277    Args:
1278        expressions: The query to find the tables in.
1279        dialect: The dialect to use for normalization of table names.
1280
1281    Returns:
1282        A Set of all the table names.
1283    """
1284    if TABLES_META not in expression.meta:
1285        expression.meta[TABLES_META] = {
1286            normalize_model_name(table, default_catalog=default_catalog, dialect=dialect)
1287            for scope in traverse_scope(expression)
1288            for table in scope.tables
1289            if table.name and table.name not in scope.cte_sources
1290        }
1291    return expression.meta[TABLES_META]

Find all tables referenced in a query.

Caches the result in the meta field 'tables'.

Arguments:
  • expressions: The query to find the tables in.
  • dialect: The dialect to use for normalization of table names.
Returns:

A Set of all the table names.

def add_table( node: sqlglot.expressions.core.Expr, table: str) -> sqlglot.expressions.core.Expr:
1294def add_table(node: exp.Expr, table: str) -> exp.Expr:
1295    """Add a table to all columns in an expression."""
1296
1297    def _transform(node: exp.Expr) -> exp.Expr:
1298        if isinstance(node, exp.Column) and not node.table:
1299            return exp.column(node.this, table=table)
1300        if isinstance(node, exp.Identifier):
1301            return exp.column(node, table=table)
1302        return node
1303
1304    return node.transform(_transform)

Add a table to all columns in an expression.

def transform_values( values: Tuple[Any, ...], columns_to_types: Dict[str, sqlglot.expressions.datatypes.DataType]) -> Iterator[Any]:
1307def transform_values(
1308    values: t.Tuple[t.Any, ...], columns_to_types: t.Dict[str, exp.DataType]
1309) -> t.Iterator[t.Any]:
1310    """Perform transformations on values given columns_to_types."""
1311
1312    def _transform_value(value: t.Any, dtype: exp.DataType) -> t.Any:
1313        if (
1314            isinstance(value, list)
1315            and dtype.is_type(*exp.DataType.ARRAY_TYPES)
1316            and len(dtype.expressions) == 1
1317        ):
1318            element_type = dtype.expressions[0]
1319            return exp.convert([_transform_value(v, element_type) for v in value])
1320
1321        if (
1322            isinstance(value, dict)
1323            and dtype.is_type(*exp.DataType.STRUCT_TYPES)
1324            and len(value) == len(dtype.expressions)
1325        ):
1326            expressions = []
1327            for (field_name, field_value), field_type in zip(value.items(), dtype.expressions):
1328                if isinstance(field_type, exp.ColumnDef):
1329                    field_type = field_type.kind
1330                else:
1331                    field_type = exp.DataType.build(exp.DataType.Type.UNKNOWN)
1332
1333                expressions.append(
1334                    exp.PropertyEQ(
1335                        this=exp.to_identifier(field_name),
1336                        expression=_transform_value(field_value, field_type),
1337                    )
1338                )
1339
1340            return exp.Struct(expressions=expressions)
1341
1342        if dtype.is_type(exp.DataType.Type.JSON):
1343            return exp.func("PARSE_JSON", f"'{value}'")
1344
1345        return exp.convert(value)
1346
1347    for col_value, col_type in zip(values, columns_to_types.values()):
1348        yield _transform_value(col_value, col_type)

Perform transformations on values given columns_to_types.

def to_schema( sql_path: str | sqlglot.expressions.query.Table, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.query.Table:
1351def to_schema(sql_path: str | exp.Table, dialect: DialectType = None) -> exp.Table:
1352    if isinstance(sql_path, exp.Table) and sql_path.this is None:
1353        return sql_path
1354    table = exp.to_table(
1355        sql_path.copy() if isinstance(sql_path, exp.Table) else sql_path, dialect=dialect
1356    )
1357    table.set("catalog", table.args.get("db"))
1358    table.set("db", table.args.get("this"))
1359    table.set("this", None)
1360    return table
def schema_( db: sqlglot.expressions.core.Identifier | str, catalog: Union[sqlglot.expressions.core.Identifier, str, NoneType] = None, quoted: Optional[bool] = None) -> sqlglot.expressions.query.Table:
1363def schema_(
1364    db: exp.Identifier | str,
1365    catalog: t.Optional[exp.Identifier | str] = None,
1366    quoted: t.Optional[bool] = None,
1367) -> exp.Table:
1368    """Build a Schema.
1369
1370    Args:
1371        db: Database name.
1372        catalog: Catalog name.
1373        quoted: Whether to force quotes on the schema's identifiers.
1374
1375    Returns:
1376        The new Schema instance.
1377    """
1378    return exp.Table(
1379        this=None,
1380        db=exp.to_identifier(db, quoted=quoted) if db else None,
1381        catalog=exp.to_identifier(catalog, quoted=quoted) if catalog else None,
1382    )

Build a Schema.

Arguments:
  • db: Database name.
  • catalog: Catalog name.
  • quoted: Whether to force quotes on the schema's identifiers.
Returns:

The new Schema instance.

def normalize_mapping_schema( schema: Dict, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> sqlglot.schema.MappingSchema:
1385def normalize_mapping_schema(schema: t.Dict, dialect: DialectType) -> MappingSchema:
1386    return MappingSchema(_unquote_schema(schema), dialect=dialect, normalize=False)
@contextmanager
def normalize_and_quote( query: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType], default_catalog: Optional[str], quote: bool = True) -> Iterator[~E]:
1396@contextmanager
1397def normalize_and_quote(
1398    query: E, dialect: DialectType, default_catalog: t.Optional[str], quote: bool = True
1399) -> t.Iterator[E]:
1400    qualify_tables(query, catalog=default_catalog, dialect=dialect)
1401    normalize_identifiers(query, dialect=dialect)
1402    yield query
1403    if quote:
1404        quote_identifiers(query, dialect=dialect)
def interpret_expression( e: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr | str | int | float | bool:
1407def interpret_expression(e: exp.Expr) -> exp.Expr | str | int | float | bool:
1408    if e.is_int:
1409        return int(e.this)
1410    if e.is_number:
1411        return float(e.this)
1412    if isinstance(e, (exp.Literal, exp.Boolean)):
1413        return e.this
1414    return e
def interpret_key_value_pairs( e: sqlglot.expressions.query.Tuple) -> Dict[str, sqlglot.expressions.core.Expr | str | int | float | bool]:
1417def interpret_key_value_pairs(
1418    e: exp.Tuple,
1419) -> t.Dict[str, exp.Expr | str | int | float | bool]:
1420    return {i.this.name: interpret_expression(i.expression) for i in e.expressions}
def extract_func_call( v: sqlglot.expressions.core.Expr, allow_tuples: bool = False) -> Tuple[str, Dict[str, sqlglot.expressions.core.Expr]]:
1423def extract_func_call(
1424    v: exp.Expr, allow_tuples: bool = False
1425) -> t.Tuple[str, t.Dict[str, exp.Expr]]:
1426    kwargs = {}
1427
1428    if isinstance(v, exp.Anonymous):
1429        func = v.name
1430        args = v.expressions
1431    elif isinstance(v, exp.Func):
1432        func = v.sql_name()
1433        args = list(v.args.values())
1434    elif isinstance(v, exp.Paren):
1435        func = ""
1436        args = [v.this]
1437    elif isinstance(v, exp.Tuple):  # airflow only
1438        if not allow_tuples:
1439            raise ConfigError("Audit name is missing (eg. MY_AUDIT())")
1440
1441        func = ""
1442        args = v.expressions
1443    else:
1444        return v.name.lower(), {}
1445
1446    for arg in args:
1447        if not isinstance(arg, (exp.PropertyEQ, exp.EQ)):
1448            raise ConfigError(
1449                f"Function '{func}' must be called with key-value arguments like {func}(arg := value)."
1450            )
1451        kwargs[arg.left.name.lower()] = arg.right
1452    return func.lower(), kwargs
def extract_function_calls(func_calls: Any, allow_tuples: bool = False) -> Any:
1455def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.Any:
1456    """Used for extracting function calls for signals or audits."""
1457
1458    if isinstance(func_calls, (exp.Tuple, exp.Array)):
1459        return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions]
1460    if isinstance(func_calls, exp.Paren):
1461        return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)]
1462    if isinstance(func_calls, exp.Expr):
1463        return [extract_func_call(func_calls, allow_tuples=allow_tuples)]
1464    if isinstance(func_calls, list):
1465        function_calls = []
1466        for entry in func_calls:
1467            if isinstance(entry, dict):
1468                args = entry
1469                name = "" if allow_tuples else entry.pop("name")
1470            elif isinstance(entry, (tuple, list)):
1471                name, args = entry
1472            else:
1473                raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.")
1474
1475            function_calls.append(
1476                (
1477                    name.lower(),
1478                    {
1479                        key: parse_one(value) if isinstance(value, str) else value
1480                        for key, value in args.items()
1481                    },
1482                )
1483            )
1484
1485        return function_calls
1486
1487    return func_calls or []

Used for extracting function calls for signals or audits.

def is_meta_expression(v: Any) -> bool:
1490def is_meta_expression(v: t.Any) -> bool:
1491    return isinstance(v, (Audit, Metric, Model))
def replace_merge_table_aliases( expression: sqlglot.expressions.core.Expr, dialect: Optional[str] = None) -> sqlglot.expressions.core.Expr:
1494def replace_merge_table_aliases(expression: exp.Expr, dialect: t.Optional[str] = None) -> exp.Expr:
1495    """
1496    Resolves references from the "source" and "target" tables (or their DBT equivalents)
1497    with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)
1498    """
1499    from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
1500
1501    if isinstance(expression, exp.Column) and (first_part := expression.parts[0]):
1502        if first_part.this.lower() in ("target", "dbt_internal_dest", "__merge_target__"):
1503            first_part.replace(exp.to_identifier(MERGE_TARGET_ALIAS, quoted=True))
1504        elif first_part.this.lower() in ("source", "dbt_internal_source", "__merge_source__"):
1505            first_part.replace(exp.to_identifier(MERGE_SOURCE_ALIAS, quoted=True))
1506
1507    return expression

Resolves references from the "source" and "target" tables (or their DBT equivalents) with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)