Edit on GitHub

sqlmesh.core.engine_adapter.snowflake

  1from __future__ import annotations
  2
  3import contextlib
  4import logging
  5import typing as t
  6
  7from sqlglot import exp
  8from sqlglot.helper import ensure_list
  9from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
 10from sqlglot.optimizer.qualify_columns import quote_identifiers
 11
 12import sqlmesh.core.constants as c
 13from sqlmesh.core.dialect import to_schema
 14from sqlmesh.core.engine_adapter.mixins import (
 15    GetCurrentCatalogFromFunctionMixin,
 16    ClusteredByMixin,
 17    RowDiffMixin,
 18    GrantsFromInfoSchemaMixin,
 19)
 20from sqlmesh.core.engine_adapter.shared import (
 21    CatalogSupport,
 22    DataObject,
 23    DataObjectType,
 24    SourceQuery,
 25    set_catalog,
 26)
 27from sqlmesh.utils import optional_import, get_source_columns_to_types
 28from sqlmesh.utils.errors import SQLMeshError
 29from sqlmesh.utils.pandas import columns_to_types_from_dtypes
 30
 31logger = logging.getLogger(__name__)
 32snowpark = optional_import("snowflake.snowpark")
 33
 34if t.TYPE_CHECKING:
 35    import pandas as pd
 36
 37    from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
 38    from sqlmesh.core.engine_adapter._typing import (
 39        DF,
 40        Query,
 41        QueryOrDF,
 42        SnowparkSession,
 43    )
 44    from sqlmesh.core.node import IntervalUnit
 45
 46
 47@set_catalog(
 48    override_mapping={
 49        "_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG,
 50        "create_schema": CatalogSupport.REQUIRES_SET_CATALOG,
 51        "drop_schema": CatalogSupport.REQUIRES_SET_CATALOG,
 52        "drop_catalog": CatalogSupport.REQUIRES_SET_CATALOG,  # needs a catalog to issue a query to information_schema.databases even though the result is global
 53    }
 54)
 55class SnowflakeEngineAdapter(
 56    GetCurrentCatalogFromFunctionMixin, ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchemaMixin
 57):
 58    DIALECT = "snowflake"
 59    SUPPORTS_MATERIALIZED_VIEWS = True
 60    SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
 61    SUPPORTS_CLONING = True
 62    SUPPORTS_MANAGED_MODELS = True
 63    CURRENT_CATALOG_EXPRESSION = exp.func("current_database")
 64    SUPPORTS_CREATE_DROP_CATALOG = True
 65    SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = True
 66    SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA", "TABLE"]
 67    SCHEMA_DIFFER_KWARGS = {
 68        "parameterized_type_defaults": {
 69            exp.DataType.build("BINARY", dialect=DIALECT).this: [(8388608,)],
 70            exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(8388608,)],
 71            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 0), (0,)],
 72            exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)],
 73            exp.DataType.build("NCHAR", dialect=DIALECT).this: [(1,)],
 74            exp.DataType.build("VARCHAR", dialect=DIALECT).this: [(16777216,)],
 75            exp.DataType.build("TIME", dialect=DIALECT).this: [(9,)],
 76            exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(9,)],
 77            exp.DataType.build("TIMESTAMP_LTZ", dialect=DIALECT).this: [(9,)],
 78            exp.DataType.build("TIMESTAMP_NTZ", dialect=DIALECT).this: [(9,)],
 79            exp.DataType.build("TIMESTAMP_TZ", dialect=DIALECT).this: [(9,)],
 80        },
 81    }
 82    MANAGED_TABLE_KIND = "DYNAMIC TABLE"
 83    SNOWPARK = "snowpark"
 84    SUPPORTS_QUERY_EXECUTION_TRACKING = True
 85    SUPPORTS_GRANTS = True
 86    CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.func("CURRENT_ROLE")
 87    USE_CATALOG_IN_GRANTS = True
 88
 89    @contextlib.contextmanager
 90    def session(self, properties: SessionProperties) -> t.Iterator[None]:
 91        warehouse = properties.get("warehouse")
 92        if not warehouse:
 93            yield
 94            return
 95
 96        if isinstance(warehouse, str):
 97            warehouse = exp.to_identifier(warehouse)
 98        if not isinstance(warehouse, exp.Expr):
 99            raise SQLMeshError(f"Invalid warehouse: '{warehouse}'")
100
101        warehouse_exp = quote_identifiers(
102            normalize_identifiers(warehouse, dialect=self.dialect), dialect=self.dialect
103        )
104        warehouse_sql = warehouse_exp.sql(dialect=self.dialect)
105        current_warehouse_sql = self._current_warehouse.sql(dialect=self.dialect)
106
107        if warehouse_sql == current_warehouse_sql:
108            yield
109            return
110
111        self.execute(f"USE WAREHOUSE {warehouse_sql}")
112        try:
113            yield
114        finally:
115            self.execute(f"USE WAREHOUSE {current_warehouse_sql}")
116
117    @property
118    def _current_warehouse(self) -> exp.Identifier:
119        current_warehouse_str = self.fetchone("SELECT CURRENT_WAREHOUSE()")[0]  # type: ignore
120        # The warehouse value returned by Snowflake is already normalized, so only quoting is needed.
121        return quote_identifiers(exp.to_identifier(current_warehouse_str), dialect=self.dialect)
122
123    @property
124    def snowpark(self) -> t.Optional[SnowparkSession]:
125        if snowpark:
126            if not self._connection_pool.get_attribute(self.SNOWPARK):
127                # Snowpark sessions are not thread safe so we create a session per thread to prevent them from interfering with each other
128                # The sessions are cleaned up when close() is called
129                new_session = snowpark.Session.builder.configs(
130                    {"connection": self._connection_pool.get()}
131                ).create()
132                self._connection_pool.set_attribute(self.SNOWPARK, new_session)
133
134            return self._connection_pool.get_attribute(self.SNOWPARK)
135
136        return None
137
138    @property
139    def catalog_support(self) -> CatalogSupport:
140        return CatalogSupport.FULL_SUPPORT
141
142    @staticmethod
143    def _grant_object_kind(table_type: DataObjectType) -> str:
144        if table_type == DataObjectType.VIEW:
145            return "VIEW"
146        if table_type == DataObjectType.MATERIALIZED_VIEW:
147            return "MATERIALIZED VIEW"
148        if table_type == DataObjectType.MANAGED_TABLE:
149            return "DYNAMIC TABLE"
150        return "TABLE"
151
152    def _get_current_schema(self) -> str:
153        """Returns the current default schema for the connection."""
154        result = self.fetchone("SELECT CURRENT_SCHEMA()")
155        if not result or not result[0]:
156            raise SQLMeshError("Unable to determine current schema")
157        return str(result[0])
158
159    def _create_catalog(self, catalog_name: exp.Identifier) -> None:
160        props = exp.Properties(
161            expressions=[exp.SchemaCommentProperty(this=exp.Literal.string(c.SQLMESH_MANAGED))]
162        )
163        self.execute(
164            exp.Create(
165                this=exp.Table(this=catalog_name), kind="DATABASE", exists=True, properties=props
166            )
167        )
168
169    def _drop_catalog(self, catalog_name: exp.Identifier) -> None:
170        # only drop the catalog if it was created by SQLMesh, which is indicated by its comment matching {c.SQLMESH_MANAGED}
171        exists_check = (
172            exp.select(exp.Literal.number(1))
173            .from_(exp.to_table("information_schema.databases"))
174            .where(
175                exp.and_(
176                    exp.column("database_name").eq(exp.Literal.string(catalog_name)),
177                    exp.column("comment").eq(exp.Literal.string(c.SQLMESH_MANAGED)),
178                )
179            )
180        )
181        normalize_identifiers(exists_check, dialect=self.dialect)
182        if self.fetchone(exists_check, quote_identifiers=True) is not None:
183            self.execute(exp.Drop(this=exp.Table(this=catalog_name), kind="DATABASE", exists=True))
184        else:
185            logger.warning(
186                f"Not dropping database {catalog_name.sql(dialect=self.dialect)} because there is no indication it is '{c.SQLMESH_MANAGED}'"
187            )
188
189    def _create_table(
190        self,
191        table_name_or_schema: t.Union[exp.Schema, TableName],
192        expression: t.Optional[exp.Expr],
193        exists: bool = True,
194        replace: bool = False,
195        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
196        table_description: t.Optional[str] = None,
197        column_descriptions: t.Optional[t.Dict[str, str]] = None,
198        table_kind: t.Optional[str] = None,
199        track_rows_processed: bool = True,
200        **kwargs: t.Any,
201    ) -> None:
202        table_format = kwargs.get("table_format")
203        if table_format and isinstance(table_format, str):
204            table_format = table_format.upper()
205            if not table_kind:
206                table_kind = f"{table_format} TABLE"
207            elif table_kind == self.MANAGED_TABLE_KIND:
208                table_kind = f"DYNAMIC {table_format} TABLE"
209
210        super()._create_table(
211            table_name_or_schema=table_name_or_schema,
212            expression=expression,
213            exists=exists,
214            replace=replace,
215            target_columns_to_types=target_columns_to_types,
216            table_description=table_description,
217            column_descriptions=column_descriptions,
218            table_kind=table_kind,
219            track_rows_processed=False,  # snowflake tracks CTAS row counts incorrectly
220            **kwargs,
221        )
222
223    def create_managed_table(
224        self,
225        table_name: TableName,
226        query: Query,
227        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
228        partitioned_by: t.Optional[t.List[exp.Expr]] = None,
229        clustered_by: t.Optional[t.List[exp.Expr]] = None,
230        table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
231        table_description: t.Optional[str] = None,
232        column_descriptions: t.Optional[t.Dict[str, str]] = None,
233        source_columns: t.Optional[t.List[str]] = None,
234        **kwargs: t.Any,
235    ) -> None:
236        target_table = exp.to_table(table_name)
237
238        # Snowflake defaults to uppercase and it also makes the property presence checks
239        # easier
240        table_properties = {k.upper(): v for k, v in (table_properties or {}).items()}
241
242        # the WAREHOUSE property is required for a Dynamic Table
243        if "WAREHOUSE" not in table_properties:
244            table_properties["WAREHOUSE"] = self._current_warehouse
245
246        # so is TARGET_LAG
247        # ref: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table
248        if "TARGET_LAG" not in table_properties:
249            raise SQLMeshError(
250                "`target_lag` must be specified in the model physical_properties for a Snowflake Dynamic Table"
251            )
252
253        source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types(
254            query, target_columns_to_types, target_table=target_table, source_columns=source_columns
255        )
256
257        self._create_table_from_source_queries(
258            target_table,
259            source_queries,
260            target_columns_to_types,
261            replace=self.SUPPORTS_REPLACE_TABLE,
262            partitioned_by=partitioned_by,
263            clustered_by=clustered_by,
264            table_properties=table_properties,
265            table_description=table_description,
266            column_descriptions=column_descriptions,
267            table_kind=self.MANAGED_TABLE_KIND,
268            **kwargs,
269        )
270
271    def create_view(
272        self,
273        view_name: TableName,
274        query_or_df: QueryOrDF,
275        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
276        replace: bool = True,
277        materialized: bool = False,
278        materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
279        table_description: t.Optional[str] = None,
280        column_descriptions: t.Optional[t.Dict[str, str]] = None,
281        view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
282        source_columns: t.Optional[t.List[str]] = None,
283        **create_kwargs: t.Any,
284    ) -> None:
285        properties = create_kwargs.pop("properties", None)
286        if not properties:
287            properties = exp.Properties(expressions=[])
288        if replace:
289            properties.append("expressions", exp.CopyGrantsProperty())
290
291        super().create_view(
292            view_name=view_name,
293            query_or_df=query_or_df,
294            target_columns_to_types=target_columns_to_types,
295            replace=replace,
296            materialized=materialized,
297            materialized_properties=materialized_properties,
298            table_description=table_description,
299            column_descriptions=column_descriptions,
300            view_properties=view_properties,
301            properties=properties,
302            source_columns=source_columns,
303            **create_kwargs,
304        )
305
306    def drop_managed_table(self, table_name: TableName, exists: bool = True) -> None:
307        self._drop_object(table_name, exists, kind=self.MANAGED_TABLE_KIND)
308
309    def _build_table_properties_exp(
310        self,
311        catalog_name: t.Optional[str] = None,
312        table_format: t.Optional[str] = None,
313        storage_format: t.Optional[str] = None,
314        partitioned_by: t.Optional[t.List[exp.Expr]] = None,
315        partition_interval_unit: t.Optional[IntervalUnit] = None,
316        clustered_by: t.Optional[t.List[exp.Expr]] = None,
317        table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
318        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
319        table_description: t.Optional[str] = None,
320        table_kind: t.Optional[str] = None,
321        **kwargs: t.Any,
322    ) -> t.Optional[exp.Properties]:
323        properties: t.List[exp.Expr] = []
324
325        # TODO: there is some overlap with the base class and other engine adapters
326        # we need a better way of filtering table properties relevent to the current engine
327        # and using those to build the expression
328        if table_description:
329            properties.append(
330                exp.SchemaCommentProperty(
331                    this=exp.Literal.string(self._truncate_table_comment(table_description))
332                )
333            )
334
335        if (
336            clustered_by
337            and (clustered_by_prop := self._build_clustered_by_exp(clustered_by)) is not None
338        ):
339            properties.append(clustered_by_prop)
340
341        if table_properties:
342            table_properties = {k.upper(): v for k, v in table_properties.items()}
343            # if we are creating a non-dynamic table; remove any properties that are only valid for dynamic tables
344            # this is necessary because we create "normal" tables from the same managed model definition for dev previews and the "normal" tables dont support these parameters
345            if "DYNAMIC" not in (table_kind or "").upper():
346                for prop in {"WAREHOUSE", "TARGET_LAG", "REFRESH_MODE", "INITIALIZE"}:
347                    table_properties.pop(prop, None)
348
349            table_type = self._pop_creatable_type_from_properties(table_properties)
350            properties.extend(ensure_list(table_type))
351
352            properties.extend(self._table_or_view_properties_to_expressions(table_properties))
353
354        return exp.Properties(expressions=properties) if properties else None
355
356    def _df_to_source_queries(
357        self,
358        df: DF,
359        target_columns_to_types: t.Dict[str, exp.DataType],
360        batch_size: int,
361        target_table: TableName,
362        source_columns: t.Optional[t.List[str]] = None,
363    ) -> t.List[SourceQuery]:
364        import pandas as pd
365        from pandas.api.types import is_datetime64_any_dtype
366
367        source_columns_to_types = get_source_columns_to_types(
368            target_columns_to_types, source_columns
369        )
370
371        temp_table = self._get_temp_table(
372            target_table or "pandas", quoted=False
373        )  # write_pandas() re-quotes everything without checking if its already quoted
374
375        is_snowpark_dataframe = snowpark and isinstance(df, snowpark.dataframe.DataFrame)
376
377        def query_factory() -> Query:
378            # The catalog needs to be normalized before being passed to Snowflake's library functions because they
379            # just wrap whatever they are given in quotes without checking if its already quoted
380            database = (
381                normalize_identifiers(temp_table.catalog, dialect=self.dialect)
382                if temp_table.catalog
383                else None
384            )
385
386            if is_snowpark_dataframe:
387                temp_table.set("catalog", database)
388
389                # only quote columns if they arent already quoted
390                # if the Snowpark dataframe was created from a Pandas dataframe via snowpark.create_dataframe(pandas_df),
391                # then they will be quoted already. But if the Snowpark dataframe was created manually by the user, then the
392                # columns may not be quoted
393                columns_already_quoted = all(
394                    col.startswith('"') and col.endswith('"') for col in df.columns
395                )
396                local_df = df
397                if not columns_already_quoted:
398                    local_df = df.rename(
399                        {
400                            col: exp.to_identifier(col).sql(dialect=self.dialect, identify=True)
401                            for col in source_columns_to_types
402                        }
403                    )  # type: ignore
404                local_df.createOrReplaceTempView(
405                    temp_table.sql(dialect=self.dialect, identify=True)
406                )  # type: ignore
407            elif isinstance(df, pd.DataFrame):
408                from snowflake.connector.pandas_tools import write_pandas
409
410                ordered_df = df[list(source_columns_to_types)]
411
412                # Workaround for https://github.com/snowflakedb/snowflake-connector-python/issues/1034
413                # The above issue has already been fixed upstream, but we keep the following
414                # line anyway in order to support a wider range of Snowflake versions.
415                schema = temp_table.db
416                if temp_table.catalog:
417                    schema = f"{temp_table.catalog}.{schema}"
418                self.set_current_schema(schema)
419
420                # See: https://stackoverflow.com/a/75627721
421                for column, kind in source_columns_to_types.items():
422                    if is_datetime64_any_dtype(ordered_df.dtypes[column]):
423                        if kind.is_type("date"):  # type: ignore
424                            ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.date  # type: ignore
425                        elif getattr(ordered_df.dtypes[column], "tz", None) is not None:  # type: ignore
426                            ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.strftime(
427                                "%Y-%m-%d %H:%M:%S.%f%z"
428                            )  # type: ignore
429                        # https://github.com/snowflakedb/snowflake-connector-python/issues/1677
430                        else:  # type: ignore
431                            ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.strftime(
432                                "%Y-%m-%d %H:%M:%S.%f"
433                            )  # type: ignore
434
435                # create the table first using our usual method ensure the column datatypes match what we parsed with sqlglot
436                # otherwise we would be trusting `write_pandas()` from the snowflake lib to do this correctly
437                self.create_table(temp_table, source_columns_to_types, table_kind="TEMPORARY TABLE")
438
439                write_pandas(
440                    self._connection_pool.get(),
441                    ordered_df,
442                    temp_table.name,
443                    schema=temp_table.db or None,
444                    database=database.sql(dialect=self.dialect) if database else None,
445                    chunk_size=self.DEFAULT_BATCH_SIZE,
446                    overwrite=True,
447                    table_type="temp",
448                )
449            else:
450                raise SQLMeshError(
451                    f"Unknown dataframe type: {type(df)} for {target_table}. Expecting pandas or snowpark."
452                )
453
454            return exp.select(
455                *self._casted_columns(target_columns_to_types, source_columns=source_columns)
456            ).from_(temp_table)
457
458        def cleanup() -> None:
459            if is_snowpark_dataframe:
460                if hasattr(df, "table_name"):
461                    if isinstance(df.table_name, str):
462                        # created by the Snowpark library if the Snowpark DataFrame was created from a Pandas DataFrame
463                        # (if the Snowpark DataFrame was created via native means then there is no 'table_name' property and no temp table)
464                        self.drop_table(df.table_name)
465                self.drop_view(temp_table)
466            else:
467                self.drop_table(temp_table)
468
469        # the cleanup_func technically isnt needed because the temp table gets dropped when the session ends
470        # but boy does it make our multi-adapter integration tests easier to write
471        return [SourceQuery(query_factory=query_factory, cleanup_func=cleanup)]
472
473    def _fetch_native_df(
474        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
475    ) -> DF:
476        import pandas as pd
477        from snowflake.connector.errors import NotSupportedError
478
479        self.execute(query, quote_identifiers=quote_identifiers)
480
481        try:
482            return self.cursor.fetch_pandas_all()
483        except NotSupportedError:
484            # Sometimes Snowflake will not return results as an Arrow result and the fetch from
485            # pandas will fail (Ex: `SHOW TERSE OBJECTS IN SCHEMA`). Therefore we manually convert
486            # the result into a DataFrame when this happens.
487            rows = self.cursor.fetchall()
488            columns = self.cursor._result_set.batches[0].column_names
489            return pd.DataFrame([dict(zip(columns, row)) for row in rows])
490
491    def _native_df_to_pandas_df(
492        self,
493        query_or_df: QueryOrDF,
494    ) -> t.Union[Query, pd.DataFrame]:
495        if snowpark and isinstance(query_or_df, snowpark.DataFrame):
496            return query_or_df.to_pandas()
497
498        return super()._native_df_to_pandas_df(query_or_df)
499
500    def _get_data_objects(
501        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
502    ) -> t.List[DataObject]:
503        """
504        Returns all the data objects that exist in the given schema and optionally catalog.
505        """
506
507        schema = to_schema(schema_name)
508        catalog_name = schema.catalog or self.get_current_catalog()
509
510        query = (
511            exp.select(
512                exp.column("TABLE_CATALOG").as_("catalog"),
513                exp.column("TABLE_NAME").as_("name"),
514                exp.column("TABLE_SCHEMA").as_("schema_name"),
515                exp.case()
516                .when(
517                    exp.And(
518                        this=exp.column("TABLE_TYPE").eq("BASE TABLE"),
519                        expression=exp.column("IS_DYNAMIC").eq("YES"),
520                    ),
521                    exp.Literal.string("MANAGED_TABLE"),
522                )
523                .when(exp.column("TABLE_TYPE").eq("BASE TABLE"), exp.Literal.string("TABLE"))
524                .when(exp.column("TABLE_TYPE").eq("TEMPORARY TABLE"), exp.Literal.string("TABLE"))
525                .when(exp.column("TABLE_TYPE").eq("LOCAL TEMPORARY"), exp.Literal.string("TABLE"))
526                .when(exp.column("TABLE_TYPE").eq("EXTERNAL TABLE"), exp.Literal.string("TABLE"))
527                .when(exp.column("TABLE_TYPE").eq("EVENT TABLE"), exp.Literal.string("TABLE"))
528                .when(exp.column("TABLE_TYPE").eq("VIEW"), exp.Literal.string("VIEW"))
529                .when(
530                    exp.column("TABLE_TYPE").eq("MATERIALIZED VIEW"),
531                    exp.Literal.string("MATERIALIZED_VIEW"),
532                )
533                .else_(exp.column("TABLE_TYPE"))
534                .as_("type"),
535                exp.column("CLUSTERING_KEY").as_("clustering_key"),
536            )
537            .from_(exp.table_("TABLES", db="INFORMATION_SCHEMA", catalog=catalog_name))
538            .where(exp.column("TABLE_SCHEMA").eq(schema.db))
539            # Snowflake seems to have delayed internal metadata updates and will sometimes return duplicates
540            .distinct()
541        )
542        if object_names:
543            query = query.where(exp.column("TABLE_NAME").isin(*object_names))
544
545        # exclude SNOWPARK_TEMP_TABLE tables that are managed by the Snowpark library and are an implementation
546        # detail of dealing with DataFrame's
547        query = query.where(exp.column("TABLE_NAME").like("SNOWPARK_TEMP_TABLE%").not_())
548
549        df = self.fetchdf(query, quote_identifiers=True)
550        if df.empty:
551            return []
552        return [
553            DataObject(
554                catalog=row.catalog,  # type: ignore
555                schema=row.schema_name,  # type: ignore
556                name=row.name,  # type: ignore
557                type=DataObjectType.from_str(row.type),  # type: ignore
558                clustering_key=row.clustering_key,  # type: ignore
559            )
560            # lowercase the column names for cases where Snowflake might return uppercase column names for certain catalogs
561            for row in df.rename(columns={col: col.lower() for col in df.columns}).itertuples()
562        ]
563
564    def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
565        # Upon execute the catalog in table expressions are properly normalized to handle the case where a user provides
566        # the default catalog in their connection config. This doesn't though update catalogs in strings like when querying
567        # the information schema. So we need to manually replace those here.
568        expression = super()._get_grant_expression(table)
569        for col_exp in expression.find_all(exp.Column):
570            if col_exp.this.name == "table_catalog":
571                and_exp = col_exp.parent
572                assert and_exp is not None, "Expected column expression to have a parent"
573                assert and_exp.expression, "Expected AND expression to have an expression"
574                normalized_catalog = self._normalize_catalog(
575                    exp.table_("placeholder", db="placeholder", catalog=and_exp.expression.this)
576                )
577                and_exp.set(
578                    "expression",
579                    exp.Literal.string(normalized_catalog.args["catalog"].alias_or_name),
580                )
581        return expression
582
583    def set_current_catalog(self, catalog: str) -> None:
584        self.execute(exp.Use(this=exp.to_identifier(catalog)))
585
586    def set_current_schema(self, schema: str) -> None:
587        self.execute(exp.Use(kind="SCHEMA", this=to_schema(schema)))
588
589    def _normalize_catalog(self, expression: exp.Expr) -> exp.Expr:
590        # note: important to use self._default_catalog instead of the self.default_catalog property
591        # otherwise we get RecursionError: maximum recursion depth exceeded
592        # because it calls get_current_catalog(), which executes a query, which needs the default catalog, which calls get_current_catalog()... etc
593        if self._default_catalog:
594            # the purpose of this function is to identify instances where the default catalog is being used
595            # (so that we can replace it with the actual catalog as specified in the gateway)
596            #
597            # we can't do a direct string comparison because the catalog value on the model
598            # gets changed when it's normalized as part of generating `model.fqn`
599            def unquote_and_lower(identifier: str) -> str:
600                return exp.parse_identifier(identifier).name.lower()
601
602            default_catalog_unquoted = unquote_and_lower(self._default_catalog)
603            default_catalog_normalized = normalize_identifiers(
604                self._default_catalog, dialect=self.dialect
605            )
606
607            def catalog_rewriter(node: exp.Expr) -> exp.Expr:
608                if isinstance(node, exp.Table):
609                    if node.catalog:
610                        # only replace the catalog on the model with the target catalog if the two are functionally equivalent
611                        if unquote_and_lower(node.catalog) == default_catalog_unquoted:
612                            node.set("catalog", default_catalog_normalized)
613                elif isinstance(node, exp.Use) and isinstance(node.this, exp.Identifier):
614                    if unquote_and_lower(node.this.output_name) == default_catalog_unquoted:
615                        node.set("this", default_catalog_normalized)
616                return node
617
618            # Rewrite whatever default catalog is present on the query to be compatible with what the user supplied in the
619            # Snowflake connection config. This is because the catalog present on the model gets normalized and quoted to match
620            # the source dialect, which isnt always compatible with Snowflake
621            expression = expression.transform(catalog_rewriter)
622        return expression
623
624    def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str:
625        return super()._to_sql(
626            expression=self._normalize_catalog(expression), quote=quote, **kwargs
627        )
628
629    def _create_column_comments(
630        self,
631        table_name: TableName,
632        column_comments: t.Dict[str, str],
633        table_kind: str = "TABLE",
634        materialized_view: bool = False,
635    ) -> None:
636        """
637        Reference: https://docs.snowflake.com/en/sql-reference/sql/alter-table-column#syntax
638        """
639        if not column_comments:
640            return
641
642        table = exp.to_table(table_name)
643        table_sql = self._to_sql(table)
644
645        list_comment_sql = []
646        for column_name, column_comment in column_comments.items():
647            column_sql = exp.column(column_name).sql(dialect=self.dialect, identify=True)
648
649            truncated_comment = self._truncate_column_comment(column_comment)
650            comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect)
651
652            list_comment_sql.append(f"COLUMN {column_sql} COMMENT {comment_sql}")
653
654        combined_sql = f"ALTER {table_kind} {table_sql} ALTER {', '.join(list_comment_sql)}"
655        try:
656            self.execute(combined_sql)
657        except Exception:
658            logger.warning(
659                f"Column comments for table '{table.alias_or_name}' not registered - this may be due to limited permissions.",
660                exc_info=True,
661            )
662
663    def clone_table(
664        self,
665        target_table_name: TableName,
666        source_table_name: TableName,
667        replace: bool = False,
668        exists: bool = True,
669        clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
670        **kwargs: t.Any,
671    ) -> None:
672        # The Snowflake adapter should use the transient property to clone transient tables
673        if physical_properties := kwargs.get("rendered_physical_properties"):
674            table_type = self._pop_creatable_type_from_properties(physical_properties)
675            if isinstance(table_type, exp.TransientProperty):
676                kwargs["properties"] = exp.Properties(expressions=[table_type])
677
678        super().clone_table(
679            target_table_name,
680            source_table_name,
681            replace=replace,
682            clone_kwargs=clone_kwargs,
683            **kwargs,
684        )
685
686    @t.overload
687    def _columns_to_types(
688        self,
689        query_or_df: DF,
690        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
691        source_columns: t.Optional[t.List[str]] = None,
692    ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ...
693
694    @t.overload
695    def _columns_to_types(
696        self,
697        query_or_df: Query,
698        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
699        source_columns: t.Optional[t.List[str]] = None,
700    ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ...
701
702    def _columns_to_types(
703        self,
704        query_or_df: QueryOrDF,
705        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
706        source_columns: t.Optional[t.List[str]] = None,
707    ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]:
708        if not target_columns_to_types and snowpark and isinstance(query_or_df, snowpark.DataFrame):
709            target_columns_to_types = columns_to_types_from_dtypes(
710                query_or_df.sample(n=1).to_pandas().dtypes.items()
711            )
712            return target_columns_to_types, list(source_columns or target_columns_to_types)
713
714        return super()._columns_to_types(
715            query_or_df, target_columns_to_types, source_columns=source_columns
716        )
717
718    def close(self) -> t.Any:
719        if snowpark_session := self._connection_pool.get_attribute(self.SNOWPARK):
720            snowpark_session.close()  # type: ignore
721            self._connection_pool.set_attribute(self.SNOWPARK, None)
722
723        return super().close()
724
725    def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
726        from sqlmesh.utils.date import to_timestamp
727
728        num_tables = len(table_names)
729
730        query = "SELECT LAST_ALTERED FROM INFORMATION_SCHEMA.TABLES WHERE"
731        for i, table_name in enumerate(table_names):
732            table = exp.to_table(table_name)
733            query += f"""(TABLE_NAME = '{table.name}' AND TABLE_SCHEMA = '{table.db}' AND TABLE_CATALOG = '{table.catalog}')"""
734            if i < num_tables - 1:
735                query += " OR "
736
737        result = self.fetchall(query)
738        return [to_timestamp(row[0]) for row in result]
logger = <Logger sqlmesh.core.engine_adapter.snowflake (WARNING)>
@set_catalog(override_mapping={'_get_data_objects': CatalogSupport.REQUIRES_SET_CATALOG, 'create_schema': CatalogSupport.REQUIRES_SET_CATALOG, 'drop_schema': CatalogSupport.REQUIRES_SET_CATALOG, 'drop_catalog': CatalogSupport.REQUIRES_SET_CATALOG})
class SnowflakeEngineAdapter(sqlmesh.core.engine_adapter.mixins.GetCurrentCatalogFromFunctionMixin, sqlmesh.core.engine_adapter.mixins.ClusteredByMixin, sqlmesh.core.engine_adapter.mixins.RowDiffMixin, sqlmesh.core.engine_adapter.mixins.GrantsFromInfoSchemaMixin):
 48@set_catalog(
 49    override_mapping={
 50        "_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG,
 51        "create_schema": CatalogSupport.REQUIRES_SET_CATALOG,
 52        "drop_schema": CatalogSupport.REQUIRES_SET_CATALOG,
 53        "drop_catalog": CatalogSupport.REQUIRES_SET_CATALOG,  # needs a catalog to issue a query to information_schema.databases even though the result is global
 54    }
 55)
 56class SnowflakeEngineAdapter(
 57    GetCurrentCatalogFromFunctionMixin, ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchemaMixin
 58):
 59    DIALECT = "snowflake"
 60    SUPPORTS_MATERIALIZED_VIEWS = True
 61    SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
 62    SUPPORTS_CLONING = True
 63    SUPPORTS_MANAGED_MODELS = True
 64    CURRENT_CATALOG_EXPRESSION = exp.func("current_database")
 65    SUPPORTS_CREATE_DROP_CATALOG = True
 66    SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = True
 67    SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA", "TABLE"]
 68    SCHEMA_DIFFER_KWARGS = {
 69        "parameterized_type_defaults": {
 70            exp.DataType.build("BINARY", dialect=DIALECT).this: [(8388608,)],
 71            exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(8388608,)],
 72            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 0), (0,)],
 73            exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)],
 74            exp.DataType.build("NCHAR", dialect=DIALECT).this: [(1,)],
 75            exp.DataType.build("VARCHAR", dialect=DIALECT).this: [(16777216,)],
 76            exp.DataType.build("TIME", dialect=DIALECT).this: [(9,)],
 77            exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(9,)],
 78            exp.DataType.build("TIMESTAMP_LTZ", dialect=DIALECT).this: [(9,)],
 79            exp.DataType.build("TIMESTAMP_NTZ", dialect=DIALECT).this: [(9,)],
 80            exp.DataType.build("TIMESTAMP_TZ", dialect=DIALECT).this: [(9,)],
 81        },
 82    }
 83    MANAGED_TABLE_KIND = "DYNAMIC TABLE"
 84    SNOWPARK = "snowpark"
 85    SUPPORTS_QUERY_EXECUTION_TRACKING = True
 86    SUPPORTS_GRANTS = True
 87    CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.func("CURRENT_ROLE")
 88    USE_CATALOG_IN_GRANTS = True
 89
 90    @contextlib.contextmanager
 91    def session(self, properties: SessionProperties) -> t.Iterator[None]:
 92        warehouse = properties.get("warehouse")
 93        if not warehouse:
 94            yield
 95            return
 96
 97        if isinstance(warehouse, str):
 98            warehouse = exp.to_identifier(warehouse)
 99        if not isinstance(warehouse, exp.Expr):
100            raise SQLMeshError(f"Invalid warehouse: '{warehouse}'")
101
102        warehouse_exp = quote_identifiers(
103            normalize_identifiers(warehouse, dialect=self.dialect), dialect=self.dialect
104        )
105        warehouse_sql = warehouse_exp.sql(dialect=self.dialect)
106        current_warehouse_sql = self._current_warehouse.sql(dialect=self.dialect)
107
108        if warehouse_sql == current_warehouse_sql:
109            yield
110            return
111
112        self.execute(f"USE WAREHOUSE {warehouse_sql}")
113        try:
114            yield
115        finally:
116            self.execute(f"USE WAREHOUSE {current_warehouse_sql}")
117
118    @property
119    def _current_warehouse(self) -> exp.Identifier:
120        current_warehouse_str = self.fetchone("SELECT CURRENT_WAREHOUSE()")[0]  # type: ignore
121        # The warehouse value returned by Snowflake is already normalized, so only quoting is needed.
122        return quote_identifiers(exp.to_identifier(current_warehouse_str), dialect=self.dialect)
123
124    @property
125    def snowpark(self) -> t.Optional[SnowparkSession]:
126        if snowpark:
127            if not self._connection_pool.get_attribute(self.SNOWPARK):
128                # Snowpark sessions are not thread safe so we create a session per thread to prevent them from interfering with each other
129                # The sessions are cleaned up when close() is called
130                new_session = snowpark.Session.builder.configs(
131                    {"connection": self._connection_pool.get()}
132                ).create()
133                self._connection_pool.set_attribute(self.SNOWPARK, new_session)
134
135            return self._connection_pool.get_attribute(self.SNOWPARK)
136
137        return None
138
139    @property
140    def catalog_support(self) -> CatalogSupport:
141        return CatalogSupport.FULL_SUPPORT
142
143    @staticmethod
144    def _grant_object_kind(table_type: DataObjectType) -> str:
145        if table_type == DataObjectType.VIEW:
146            return "VIEW"
147        if table_type == DataObjectType.MATERIALIZED_VIEW:
148            return "MATERIALIZED VIEW"
149        if table_type == DataObjectType.MANAGED_TABLE:
150            return "DYNAMIC TABLE"
151        return "TABLE"
152
153    def _get_current_schema(self) -> str:
154        """Returns the current default schema for the connection."""
155        result = self.fetchone("SELECT CURRENT_SCHEMA()")
156        if not result or not result[0]:
157            raise SQLMeshError("Unable to determine current schema")
158        return str(result[0])
159
160    def _create_catalog(self, catalog_name: exp.Identifier) -> None:
161        props = exp.Properties(
162            expressions=[exp.SchemaCommentProperty(this=exp.Literal.string(c.SQLMESH_MANAGED))]
163        )
164        self.execute(
165            exp.Create(
166                this=exp.Table(this=catalog_name), kind="DATABASE", exists=True, properties=props
167            )
168        )
169
170    def _drop_catalog(self, catalog_name: exp.Identifier) -> None:
171        # only drop the catalog if it was created by SQLMesh, which is indicated by its comment matching {c.SQLMESH_MANAGED}
172        exists_check = (
173            exp.select(exp.Literal.number(1))
174            .from_(exp.to_table("information_schema.databases"))
175            .where(
176                exp.and_(
177                    exp.column("database_name").eq(exp.Literal.string(catalog_name)),
178                    exp.column("comment").eq(exp.Literal.string(c.SQLMESH_MANAGED)),
179                )
180            )
181        )
182        normalize_identifiers(exists_check, dialect=self.dialect)
183        if self.fetchone(exists_check, quote_identifiers=True) is not None:
184            self.execute(exp.Drop(this=exp.Table(this=catalog_name), kind="DATABASE", exists=True))
185        else:
186            logger.warning(
187                f"Not dropping database {catalog_name.sql(dialect=self.dialect)} because there is no indication it is '{c.SQLMESH_MANAGED}'"
188            )
189
190    def _create_table(
191        self,
192        table_name_or_schema: t.Union[exp.Schema, TableName],
193        expression: t.Optional[exp.Expr],
194        exists: bool = True,
195        replace: bool = False,
196        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
197        table_description: t.Optional[str] = None,
198        column_descriptions: t.Optional[t.Dict[str, str]] = None,
199        table_kind: t.Optional[str] = None,
200        track_rows_processed: bool = True,
201        **kwargs: t.Any,
202    ) -> None:
203        table_format = kwargs.get("table_format")
204        if table_format and isinstance(table_format, str):
205            table_format = table_format.upper()
206            if not table_kind:
207                table_kind = f"{table_format} TABLE"
208            elif table_kind == self.MANAGED_TABLE_KIND:
209                table_kind = f"DYNAMIC {table_format} TABLE"
210
211        super()._create_table(
212            table_name_or_schema=table_name_or_schema,
213            expression=expression,
214            exists=exists,
215            replace=replace,
216            target_columns_to_types=target_columns_to_types,
217            table_description=table_description,
218            column_descriptions=column_descriptions,
219            table_kind=table_kind,
220            track_rows_processed=False,  # snowflake tracks CTAS row counts incorrectly
221            **kwargs,
222        )
223
224    def create_managed_table(
225        self,
226        table_name: TableName,
227        query: Query,
228        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
229        partitioned_by: t.Optional[t.List[exp.Expr]] = None,
230        clustered_by: t.Optional[t.List[exp.Expr]] = None,
231        table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
232        table_description: t.Optional[str] = None,
233        column_descriptions: t.Optional[t.Dict[str, str]] = None,
234        source_columns: t.Optional[t.List[str]] = None,
235        **kwargs: t.Any,
236    ) -> None:
237        target_table = exp.to_table(table_name)
238
239        # Snowflake defaults to uppercase and it also makes the property presence checks
240        # easier
241        table_properties = {k.upper(): v for k, v in (table_properties or {}).items()}
242
243        # the WAREHOUSE property is required for a Dynamic Table
244        if "WAREHOUSE" not in table_properties:
245            table_properties["WAREHOUSE"] = self._current_warehouse
246
247        # so is TARGET_LAG
248        # ref: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table
249        if "TARGET_LAG" not in table_properties:
250            raise SQLMeshError(
251                "`target_lag` must be specified in the model physical_properties for a Snowflake Dynamic Table"
252            )
253
254        source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types(
255            query, target_columns_to_types, target_table=target_table, source_columns=source_columns
256        )
257
258        self._create_table_from_source_queries(
259            target_table,
260            source_queries,
261            target_columns_to_types,
262            replace=self.SUPPORTS_REPLACE_TABLE,
263            partitioned_by=partitioned_by,
264            clustered_by=clustered_by,
265            table_properties=table_properties,
266            table_description=table_description,
267            column_descriptions=column_descriptions,
268            table_kind=self.MANAGED_TABLE_KIND,
269            **kwargs,
270        )
271
272    def create_view(
273        self,
274        view_name: TableName,
275        query_or_df: QueryOrDF,
276        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
277        replace: bool = True,
278        materialized: bool = False,
279        materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
280        table_description: t.Optional[str] = None,
281        column_descriptions: t.Optional[t.Dict[str, str]] = None,
282        view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
283        source_columns: t.Optional[t.List[str]] = None,
284        **create_kwargs: t.Any,
285    ) -> None:
286        properties = create_kwargs.pop("properties", None)
287        if not properties:
288            properties = exp.Properties(expressions=[])
289        if replace:
290            properties.append("expressions", exp.CopyGrantsProperty())
291
292        super().create_view(
293            view_name=view_name,
294            query_or_df=query_or_df,
295            target_columns_to_types=target_columns_to_types,
296            replace=replace,
297            materialized=materialized,
298            materialized_properties=materialized_properties,
299            table_description=table_description,
300            column_descriptions=column_descriptions,
301            view_properties=view_properties,
302            properties=properties,
303            source_columns=source_columns,
304            **create_kwargs,
305        )
306
307    def drop_managed_table(self, table_name: TableName, exists: bool = True) -> None:
308        self._drop_object(table_name, exists, kind=self.MANAGED_TABLE_KIND)
309
310    def _build_table_properties_exp(
311        self,
312        catalog_name: t.Optional[str] = None,
313        table_format: t.Optional[str] = None,
314        storage_format: t.Optional[str] = None,
315        partitioned_by: t.Optional[t.List[exp.Expr]] = None,
316        partition_interval_unit: t.Optional[IntervalUnit] = None,
317        clustered_by: t.Optional[t.List[exp.Expr]] = None,
318        table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
319        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
320        table_description: t.Optional[str] = None,
321        table_kind: t.Optional[str] = None,
322        **kwargs: t.Any,
323    ) -> t.Optional[exp.Properties]:
324        properties: t.List[exp.Expr] = []
325
326        # TODO: there is some overlap with the base class and other engine adapters
327        # we need a better way of filtering table properties relevent to the current engine
328        # and using those to build the expression
329        if table_description:
330            properties.append(
331                exp.SchemaCommentProperty(
332                    this=exp.Literal.string(self._truncate_table_comment(table_description))
333                )
334            )
335
336        if (
337            clustered_by
338            and (clustered_by_prop := self._build_clustered_by_exp(clustered_by)) is not None
339        ):
340            properties.append(clustered_by_prop)
341
342        if table_properties:
343            table_properties = {k.upper(): v for k, v in table_properties.items()}
344            # if we are creating a non-dynamic table; remove any properties that are only valid for dynamic tables
345            # this is necessary because we create "normal" tables from the same managed model definition for dev previews and the "normal" tables dont support these parameters
346            if "DYNAMIC" not in (table_kind or "").upper():
347                for prop in {"WAREHOUSE", "TARGET_LAG", "REFRESH_MODE", "INITIALIZE"}:
348                    table_properties.pop(prop, None)
349
350            table_type = self._pop_creatable_type_from_properties(table_properties)
351            properties.extend(ensure_list(table_type))
352
353            properties.extend(self._table_or_view_properties_to_expressions(table_properties))
354
355        return exp.Properties(expressions=properties) if properties else None
356
357    def _df_to_source_queries(
358        self,
359        df: DF,
360        target_columns_to_types: t.Dict[str, exp.DataType],
361        batch_size: int,
362        target_table: TableName,
363        source_columns: t.Optional[t.List[str]] = None,
364    ) -> t.List[SourceQuery]:
365        import pandas as pd
366        from pandas.api.types import is_datetime64_any_dtype
367
368        source_columns_to_types = get_source_columns_to_types(
369            target_columns_to_types, source_columns
370        )
371
372        temp_table = self._get_temp_table(
373            target_table or "pandas", quoted=False
374        )  # write_pandas() re-quotes everything without checking if its already quoted
375
376        is_snowpark_dataframe = snowpark and isinstance(df, snowpark.dataframe.DataFrame)
377
378        def query_factory() -> Query:
379            # The catalog needs to be normalized before being passed to Snowflake's library functions because they
380            # just wrap whatever they are given in quotes without checking if its already quoted
381            database = (
382                normalize_identifiers(temp_table.catalog, dialect=self.dialect)
383                if temp_table.catalog
384                else None
385            )
386
387            if is_snowpark_dataframe:
388                temp_table.set("catalog", database)
389
390                # only quote columns if they arent already quoted
391                # if the Snowpark dataframe was created from a Pandas dataframe via snowpark.create_dataframe(pandas_df),
392                # then they will be quoted already. But if the Snowpark dataframe was created manually by the user, then the
393                # columns may not be quoted
394                columns_already_quoted = all(
395                    col.startswith('"') and col.endswith('"') for col in df.columns
396                )
397                local_df = df
398                if not columns_already_quoted:
399                    local_df = df.rename(
400                        {
401                            col: exp.to_identifier(col).sql(dialect=self.dialect, identify=True)
402                            for col in source_columns_to_types
403                        }
404                    )  # type: ignore
405                local_df.createOrReplaceTempView(
406                    temp_table.sql(dialect=self.dialect, identify=True)
407                )  # type: ignore
408            elif isinstance(df, pd.DataFrame):
409                from snowflake.connector.pandas_tools import write_pandas
410
411                ordered_df = df[list(source_columns_to_types)]
412
413                # Workaround for https://github.com/snowflakedb/snowflake-connector-python/issues/1034
414                # The above issue has already been fixed upstream, but we keep the following
415                # line anyway in order to support a wider range of Snowflake versions.
416                schema = temp_table.db
417                if temp_table.catalog:
418                    schema = f"{temp_table.catalog}.{schema}"
419                self.set_current_schema(schema)
420
421                # See: https://stackoverflow.com/a/75627721
422                for column, kind in source_columns_to_types.items():
423                    if is_datetime64_any_dtype(ordered_df.dtypes[column]):
424                        if kind.is_type("date"):  # type: ignore
425                            ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.date  # type: ignore
426                        elif getattr(ordered_df.dtypes[column], "tz", None) is not None:  # type: ignore
427                            ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.strftime(
428                                "%Y-%m-%d %H:%M:%S.%f%z"
429                            )  # type: ignore
430                        # https://github.com/snowflakedb/snowflake-connector-python/issues/1677
431                        else:  # type: ignore
432                            ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.strftime(
433                                "%Y-%m-%d %H:%M:%S.%f"
434                            )  # type: ignore
435
436                # create the table first using our usual method ensure the column datatypes match what we parsed with sqlglot
437                # otherwise we would be trusting `write_pandas()` from the snowflake lib to do this correctly
438                self.create_table(temp_table, source_columns_to_types, table_kind="TEMPORARY TABLE")
439
440                write_pandas(
441                    self._connection_pool.get(),
442                    ordered_df,
443                    temp_table.name,
444                    schema=temp_table.db or None,
445                    database=database.sql(dialect=self.dialect) if database else None,
446                    chunk_size=self.DEFAULT_BATCH_SIZE,
447                    overwrite=True,
448                    table_type="temp",
449                )
450            else:
451                raise SQLMeshError(
452                    f"Unknown dataframe type: {type(df)} for {target_table}. Expecting pandas or snowpark."
453                )
454
455            return exp.select(
456                *self._casted_columns(target_columns_to_types, source_columns=source_columns)
457            ).from_(temp_table)
458
459        def cleanup() -> None:
460            if is_snowpark_dataframe:
461                if hasattr(df, "table_name"):
462                    if isinstance(df.table_name, str):
463                        # created by the Snowpark library if the Snowpark DataFrame was created from a Pandas DataFrame
464                        # (if the Snowpark DataFrame was created via native means then there is no 'table_name' property and no temp table)
465                        self.drop_table(df.table_name)
466                self.drop_view(temp_table)
467            else:
468                self.drop_table(temp_table)
469
470        # the cleanup_func technically isnt needed because the temp table gets dropped when the session ends
471        # but boy does it make our multi-adapter integration tests easier to write
472        return [SourceQuery(query_factory=query_factory, cleanup_func=cleanup)]
473
474    def _fetch_native_df(
475        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
476    ) -> DF:
477        import pandas as pd
478        from snowflake.connector.errors import NotSupportedError
479
480        self.execute(query, quote_identifiers=quote_identifiers)
481
482        try:
483            return self.cursor.fetch_pandas_all()
484        except NotSupportedError:
485            # Sometimes Snowflake will not return results as an Arrow result and the fetch from
486            # pandas will fail (Ex: `SHOW TERSE OBJECTS IN SCHEMA`). Therefore we manually convert
487            # the result into a DataFrame when this happens.
488            rows = self.cursor.fetchall()
489            columns = self.cursor._result_set.batches[0].column_names
490            return pd.DataFrame([dict(zip(columns, row)) for row in rows])
491
492    def _native_df_to_pandas_df(
493        self,
494        query_or_df: QueryOrDF,
495    ) -> t.Union[Query, pd.DataFrame]:
496        if snowpark and isinstance(query_or_df, snowpark.DataFrame):
497            return query_or_df.to_pandas()
498
499        return super()._native_df_to_pandas_df(query_or_df)
500
501    def _get_data_objects(
502        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
503    ) -> t.List[DataObject]:
504        """
505        Returns all the data objects that exist in the given schema and optionally catalog.
506        """
507
508        schema = to_schema(schema_name)
509        catalog_name = schema.catalog or self.get_current_catalog()
510
511        query = (
512            exp.select(
513                exp.column("TABLE_CATALOG").as_("catalog"),
514                exp.column("TABLE_NAME").as_("name"),
515                exp.column("TABLE_SCHEMA").as_("schema_name"),
516                exp.case()
517                .when(
518                    exp.And(
519                        this=exp.column("TABLE_TYPE").eq("BASE TABLE"),
520                        expression=exp.column("IS_DYNAMIC").eq("YES"),
521                    ),
522                    exp.Literal.string("MANAGED_TABLE"),
523                )
524                .when(exp.column("TABLE_TYPE").eq("BASE TABLE"), exp.Literal.string("TABLE"))
525                .when(exp.column("TABLE_TYPE").eq("TEMPORARY TABLE"), exp.Literal.string("TABLE"))
526                .when(exp.column("TABLE_TYPE").eq("LOCAL TEMPORARY"), exp.Literal.string("TABLE"))
527                .when(exp.column("TABLE_TYPE").eq("EXTERNAL TABLE"), exp.Literal.string("TABLE"))
528                .when(exp.column("TABLE_TYPE").eq("EVENT TABLE"), exp.Literal.string("TABLE"))
529                .when(exp.column("TABLE_TYPE").eq("VIEW"), exp.Literal.string("VIEW"))
530                .when(
531                    exp.column("TABLE_TYPE").eq("MATERIALIZED VIEW"),
532                    exp.Literal.string("MATERIALIZED_VIEW"),
533                )
534                .else_(exp.column("TABLE_TYPE"))
535                .as_("type"),
536                exp.column("CLUSTERING_KEY").as_("clustering_key"),
537            )
538            .from_(exp.table_("TABLES", db="INFORMATION_SCHEMA", catalog=catalog_name))
539            .where(exp.column("TABLE_SCHEMA").eq(schema.db))
540            # Snowflake seems to have delayed internal metadata updates and will sometimes return duplicates
541            .distinct()
542        )
543        if object_names:
544            query = query.where(exp.column("TABLE_NAME").isin(*object_names))
545
546        # exclude SNOWPARK_TEMP_TABLE tables that are managed by the Snowpark library and are an implementation
547        # detail of dealing with DataFrame's
548        query = query.where(exp.column("TABLE_NAME").like("SNOWPARK_TEMP_TABLE%").not_())
549
550        df = self.fetchdf(query, quote_identifiers=True)
551        if df.empty:
552            return []
553        return [
554            DataObject(
555                catalog=row.catalog,  # type: ignore
556                schema=row.schema_name,  # type: ignore
557                name=row.name,  # type: ignore
558                type=DataObjectType.from_str(row.type),  # type: ignore
559                clustering_key=row.clustering_key,  # type: ignore
560            )
561            # lowercase the column names for cases where Snowflake might return uppercase column names for certain catalogs
562            for row in df.rename(columns={col: col.lower() for col in df.columns}).itertuples()
563        ]
564
565    def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
566        # Upon execute the catalog in table expressions are properly normalized to handle the case where a user provides
567        # the default catalog in their connection config. This doesn't though update catalogs in strings like when querying
568        # the information schema. So we need to manually replace those here.
569        expression = super()._get_grant_expression(table)
570        for col_exp in expression.find_all(exp.Column):
571            if col_exp.this.name == "table_catalog":
572                and_exp = col_exp.parent
573                assert and_exp is not None, "Expected column expression to have a parent"
574                assert and_exp.expression, "Expected AND expression to have an expression"
575                normalized_catalog = self._normalize_catalog(
576                    exp.table_("placeholder", db="placeholder", catalog=and_exp.expression.this)
577                )
578                and_exp.set(
579                    "expression",
580                    exp.Literal.string(normalized_catalog.args["catalog"].alias_or_name),
581                )
582        return expression
583
584    def set_current_catalog(self, catalog: str) -> None:
585        self.execute(exp.Use(this=exp.to_identifier(catalog)))
586
587    def set_current_schema(self, schema: str) -> None:
588        self.execute(exp.Use(kind="SCHEMA", this=to_schema(schema)))
589
590    def _normalize_catalog(self, expression: exp.Expr) -> exp.Expr:
591        # note: important to use self._default_catalog instead of the self.default_catalog property
592        # otherwise we get RecursionError: maximum recursion depth exceeded
593        # because it calls get_current_catalog(), which executes a query, which needs the default catalog, which calls get_current_catalog()... etc
594        if self._default_catalog:
595            # the purpose of this function is to identify instances where the default catalog is being used
596            # (so that we can replace it with the actual catalog as specified in the gateway)
597            #
598            # we can't do a direct string comparison because the catalog value on the model
599            # gets changed when it's normalized as part of generating `model.fqn`
600            def unquote_and_lower(identifier: str) -> str:
601                return exp.parse_identifier(identifier).name.lower()
602
603            default_catalog_unquoted = unquote_and_lower(self._default_catalog)
604            default_catalog_normalized = normalize_identifiers(
605                self._default_catalog, dialect=self.dialect
606            )
607
608            def catalog_rewriter(node: exp.Expr) -> exp.Expr:
609                if isinstance(node, exp.Table):
610                    if node.catalog:
611                        # only replace the catalog on the model with the target catalog if the two are functionally equivalent
612                        if unquote_and_lower(node.catalog) == default_catalog_unquoted:
613                            node.set("catalog", default_catalog_normalized)
614                elif isinstance(node, exp.Use) and isinstance(node.this, exp.Identifier):
615                    if unquote_and_lower(node.this.output_name) == default_catalog_unquoted:
616                        node.set("this", default_catalog_normalized)
617                return node
618
619            # Rewrite whatever default catalog is present on the query to be compatible with what the user supplied in the
620            # Snowflake connection config. This is because the catalog present on the model gets normalized and quoted to match
621            # the source dialect, which isnt always compatible with Snowflake
622            expression = expression.transform(catalog_rewriter)
623        return expression
624
625    def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str:
626        return super()._to_sql(
627            expression=self._normalize_catalog(expression), quote=quote, **kwargs
628        )
629
630    def _create_column_comments(
631        self,
632        table_name: TableName,
633        column_comments: t.Dict[str, str],
634        table_kind: str = "TABLE",
635        materialized_view: bool = False,
636    ) -> None:
637        """
638        Reference: https://docs.snowflake.com/en/sql-reference/sql/alter-table-column#syntax
639        """
640        if not column_comments:
641            return
642
643        table = exp.to_table(table_name)
644        table_sql = self._to_sql(table)
645
646        list_comment_sql = []
647        for column_name, column_comment in column_comments.items():
648            column_sql = exp.column(column_name).sql(dialect=self.dialect, identify=True)
649
650            truncated_comment = self._truncate_column_comment(column_comment)
651            comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect)
652
653            list_comment_sql.append(f"COLUMN {column_sql} COMMENT {comment_sql}")
654
655        combined_sql = f"ALTER {table_kind} {table_sql} ALTER {', '.join(list_comment_sql)}"
656        try:
657            self.execute(combined_sql)
658        except Exception:
659            logger.warning(
660                f"Column comments for table '{table.alias_or_name}' not registered - this may be due to limited permissions.",
661                exc_info=True,
662            )
663
664    def clone_table(
665        self,
666        target_table_name: TableName,
667        source_table_name: TableName,
668        replace: bool = False,
669        exists: bool = True,
670        clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
671        **kwargs: t.Any,
672    ) -> None:
673        # The Snowflake adapter should use the transient property to clone transient tables
674        if physical_properties := kwargs.get("rendered_physical_properties"):
675            table_type = self._pop_creatable_type_from_properties(physical_properties)
676            if isinstance(table_type, exp.TransientProperty):
677                kwargs["properties"] = exp.Properties(expressions=[table_type])
678
679        super().clone_table(
680            target_table_name,
681            source_table_name,
682            replace=replace,
683            clone_kwargs=clone_kwargs,
684            **kwargs,
685        )
686
687    @t.overload
688    def _columns_to_types(
689        self,
690        query_or_df: DF,
691        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
692        source_columns: t.Optional[t.List[str]] = None,
693    ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ...
694
695    @t.overload
696    def _columns_to_types(
697        self,
698        query_or_df: Query,
699        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
700        source_columns: t.Optional[t.List[str]] = None,
701    ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ...
702
703    def _columns_to_types(
704        self,
705        query_or_df: QueryOrDF,
706        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
707        source_columns: t.Optional[t.List[str]] = None,
708    ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]:
709        if not target_columns_to_types and snowpark and isinstance(query_or_df, snowpark.DataFrame):
710            target_columns_to_types = columns_to_types_from_dtypes(
711                query_or_df.sample(n=1).to_pandas().dtypes.items()
712            )
713            return target_columns_to_types, list(source_columns or target_columns_to_types)
714
715        return super()._columns_to_types(
716            query_or_df, target_columns_to_types, source_columns=source_columns
717        )
718
719    def close(self) -> t.Any:
720        if snowpark_session := self._connection_pool.get_attribute(self.SNOWPARK):
721            snowpark_session.close()  # type: ignore
722            self._connection_pool.set_attribute(self.SNOWPARK, None)
723
724        return super().close()
725
726    def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
727        from sqlmesh.utils.date import to_timestamp
728
729        num_tables = len(table_names)
730
731        query = "SELECT LAST_ALTERED FROM INFORMATION_SCHEMA.TABLES WHERE"
732        for i, table_name in enumerate(table_names):
733            table = exp.to_table(table_name)
734            query += f"""(TABLE_NAME = '{table.name}' AND TABLE_SCHEMA = '{table.db}' AND TABLE_CATALOG = '{table.catalog}')"""
735            if i < num_tables - 1:
736                query += " OR "
737
738        result = self.fetchall(query)
739        return [to_timestamp(row[0]) for row in result]

Base class wrapping a Database API compliant connection.

The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.

Arguments:
  • connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
  • dialect: The dialect with which this adapter is associated.
  • multithreaded: Indicates whether this adapter will be used by more than one thread.
DIALECT = 'snowflake'
SUPPORTS_MATERIALIZED_VIEWS = True
SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
SUPPORTS_CLONING = True
SUPPORTS_MANAGED_MODELS = True
CURRENT_CATALOG_EXPRESSION = CurrentDatabase()
SUPPORTS_CREATE_DROP_CATALOG = True
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = True
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ['DATABASE', 'SCHEMA', 'TABLE']
SCHEMA_DIFFER_KWARGS = {'parameterized_type_defaults': {<DType.BINARY: 'BINARY'>: [(8388608,)], <DType.VARBINARY: 'VARBINARY'>: [(8388608,)], <DType.DECIMAL: 'DECIMAL'>: [(38, 0), (0,)], <DType.CHAR: 'CHAR'>: [(1,)], <DType.NCHAR: 'NCHAR'>: [(1,)], <DType.VARCHAR: 'VARCHAR'>: [(16777216,)], <DType.TIME: 'TIME'>: [(9,)], <DType.TIMESTAMP: 'TIMESTAMP'>: [(9,)], <DType.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: [(9,)], <DType.TIMESTAMPNTZ: 'TIMESTAMPNTZ'>: [(9,)], <DType.TIMESTAMPTZ: 'TIMESTAMPTZ'>: [(9,)]}}
MANAGED_TABLE_KIND = 'DYNAMIC TABLE'
SNOWPARK = 'snowpark'
SUPPORTS_QUERY_EXECUTION_TRACKING = True
SUPPORTS_GRANTS = True
CURRENT_USER_OR_ROLE_EXPRESSION: sqlglot.expressions.core.Expr = CurrentRole()
USE_CATALOG_IN_GRANTS = True
@contextlib.contextmanager
def session( self, properties: Dict[str, sqlglot.expressions.core.Expr | str | int | float | bool]) -> Iterator[NoneType]:
 90    @contextlib.contextmanager
 91    def session(self, properties: SessionProperties) -> t.Iterator[None]:
 92        warehouse = properties.get("warehouse")
 93        if not warehouse:
 94            yield
 95            return
 96
 97        if isinstance(warehouse, str):
 98            warehouse = exp.to_identifier(warehouse)
 99        if not isinstance(warehouse, exp.Expr):
100            raise SQLMeshError(f"Invalid warehouse: '{warehouse}'")
101
102        warehouse_exp = quote_identifiers(
103            normalize_identifiers(warehouse, dialect=self.dialect), dialect=self.dialect
104        )
105        warehouse_sql = warehouse_exp.sql(dialect=self.dialect)
106        current_warehouse_sql = self._current_warehouse.sql(dialect=self.dialect)
107
108        if warehouse_sql == current_warehouse_sql:
109            yield
110            return
111
112        self.execute(f"USE WAREHOUSE {warehouse_sql}")
113        try:
114            yield
115        finally:
116            self.execute(f"USE WAREHOUSE {current_warehouse_sql}")

A session context manager.

snowpark: Optional[<MagicMock id='132726886342592'>]
124    @property
125    def snowpark(self) -> t.Optional[SnowparkSession]:
126        if snowpark:
127            if not self._connection_pool.get_attribute(self.SNOWPARK):
128                # Snowpark sessions are not thread safe so we create a session per thread to prevent them from interfering with each other
129                # The sessions are cleaned up when close() is called
130                new_session = snowpark.Session.builder.configs(
131                    {"connection": self._connection_pool.get()}
132                ).create()
133                self._connection_pool.set_attribute(self.SNOWPARK, new_session)
134
135            return self._connection_pool.get_attribute(self.SNOWPARK)
136
137        return None
139    @property
140    def catalog_support(self) -> CatalogSupport:
141        return CatalogSupport.FULL_SUPPORT
def create_managed_table( self, table_name: Union[str, sqlglot.expressions.query.Table], query: <MagicMock id='132726885539536'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]] = None, partitioned_by: Optional[List[sqlglot.expressions.core.Expr]] = None, clustered_by: Optional[List[sqlglot.expressions.core.Expr]] = None, table_properties: Optional[Dict[str, sqlglot.expressions.core.Expr]] = None, table_description: Optional[str] = None, column_descriptions: Optional[Dict[str, str]] = None, source_columns: Optional[List[str]] = None, **kwargs: Any) -> None:
224    def create_managed_table(
225        self,
226        table_name: TableName,
227        query: Query,
228        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
229        partitioned_by: t.Optional[t.List[exp.Expr]] = None,
230        clustered_by: t.Optional[t.List[exp.Expr]] = None,
231        table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
232        table_description: t.Optional[str] = None,
233        column_descriptions: t.Optional[t.Dict[str, str]] = None,
234        source_columns: t.Optional[t.List[str]] = None,
235        **kwargs: t.Any,
236    ) -> None:
237        target_table = exp.to_table(table_name)
238
239        # Snowflake defaults to uppercase and it also makes the property presence checks
240        # easier
241        table_properties = {k.upper(): v for k, v in (table_properties or {}).items()}
242
243        # the WAREHOUSE property is required for a Dynamic Table
244        if "WAREHOUSE" not in table_properties:
245            table_properties["WAREHOUSE"] = self._current_warehouse
246
247        # so is TARGET_LAG
248        # ref: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table
249        if "TARGET_LAG" not in table_properties:
250            raise SQLMeshError(
251                "`target_lag` must be specified in the model physical_properties for a Snowflake Dynamic Table"
252            )
253
254        source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types(
255            query, target_columns_to_types, target_table=target_table, source_columns=source_columns
256        )
257
258        self._create_table_from_source_queries(
259            target_table,
260            source_queries,
261            target_columns_to_types,
262            replace=self.SUPPORTS_REPLACE_TABLE,
263            partitioned_by=partitioned_by,
264            clustered_by=clustered_by,
265            table_properties=table_properties,
266            table_description=table_description,
267            column_descriptions=column_descriptions,
268            table_kind=self.MANAGED_TABLE_KIND,
269            **kwargs,
270        )

Create a managed table using a query.

"Managed" means that once the table is created, the data is kept up to date by the underlying database engine and not SQLMesh.

Arguments:
  • table_name: The name of the table to create. Can be fully qualified or just table name.
  • query: The SQL query for the engine to base the managed table on
  • target_columns_to_types: A mapping between the column name and its data type.
  • partitioned_by: The partition columns or engine specific expressions, only applicable in certain engines. (eg. (ds, hour))
  • clustered_by: The cluster columns or engine specific expressions, only applicable in certain engines. (eg. (ds, hour))
  • table_properties: Optional mapping of engine-specific properties to be set on the managed table
  • table_description: Optional table description from MODEL DDL.
  • column_descriptions: Optional column descriptions from model query.
  • kwargs: Optional create table properties.
def create_view( self, view_name: Union[str, sqlglot.expressions.query.Table], query_or_df: <MagicMock id='132726885974128'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]] = None, replace: bool = True, materialized: bool = False, materialized_properties: Optional[Dict[str, Any]] = None, table_description: Optional[str] = None, column_descriptions: Optional[Dict[str, str]] = None, view_properties: Optional[Dict[str, sqlglot.expressions.core.Expr]] = None, source_columns: Optional[List[str]] = None, **create_kwargs: Any) -> None:
272    def create_view(
273        self,
274        view_name: TableName,
275        query_or_df: QueryOrDF,
276        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
277        replace: bool = True,
278        materialized: bool = False,
279        materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
280        table_description: t.Optional[str] = None,
281        column_descriptions: t.Optional[t.Dict[str, str]] = None,
282        view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
283        source_columns: t.Optional[t.List[str]] = None,
284        **create_kwargs: t.Any,
285    ) -> None:
286        properties = create_kwargs.pop("properties", None)
287        if not properties:
288            properties = exp.Properties(expressions=[])
289        if replace:
290            properties.append("expressions", exp.CopyGrantsProperty())
291
292        super().create_view(
293            view_name=view_name,
294            query_or_df=query_or_df,
295            target_columns_to_types=target_columns_to_types,
296            replace=replace,
297            materialized=materialized,
298            materialized_properties=materialized_properties,
299            table_description=table_description,
300            column_descriptions=column_descriptions,
301            view_properties=view_properties,
302            properties=properties,
303            source_columns=source_columns,
304            **create_kwargs,
305        )

Create a view with a query or dataframe.

If a dataframe is passed in, it will be converted into a literal values statement. This should only be done if the dataframe is very small!

Arguments:
  • view_name: The view name.
  • query_or_df: A query or dataframe.
  • target_columns_to_types: Columns to use in the view statement.
  • replace: Whether or not to replace an existing view defaults to True.
  • materialized: Whether to create a a materialized view. Only used for engines that support this feature.
  • materialized_properties: Optional materialized view properties to add to the view.
  • table_description: Optional table description from MODEL DDL.
  • column_descriptions: Optional column descriptions from model query.
  • view_properties: Optional view properties to add to the view.
  • create_kwargs: Additional kwargs to pass into the Create expression
def drop_managed_table( self, table_name: Union[str, sqlglot.expressions.query.Table], exists: bool = True) -> None:
307    def drop_managed_table(self, table_name: TableName, exists: bool = True) -> None:
308        self._drop_object(table_name, exists, kind=self.MANAGED_TABLE_KIND)

Drops a managed table.

Arguments:
  • table_name: The name of the table to drop.
  • exists: If exists, defaults to True.
def set_current_catalog(self, catalog: str) -> None:
584    def set_current_catalog(self, catalog: str) -> None:
585        self.execute(exp.Use(this=exp.to_identifier(catalog)))

Sets the catalog name of the current connection.

def set_current_schema(self, schema: str) -> None:
587    def set_current_schema(self, schema: str) -> None:
588        self.execute(exp.Use(kind="SCHEMA", this=to_schema(schema)))
def clone_table( self, target_table_name: Union[str, sqlglot.expressions.query.Table], source_table_name: Union[str, sqlglot.expressions.query.Table], replace: bool = False, exists: bool = True, clone_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
664    def clone_table(
665        self,
666        target_table_name: TableName,
667        source_table_name: TableName,
668        replace: bool = False,
669        exists: bool = True,
670        clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
671        **kwargs: t.Any,
672    ) -> None:
673        # The Snowflake adapter should use the transient property to clone transient tables
674        if physical_properties := kwargs.get("rendered_physical_properties"):
675            table_type = self._pop_creatable_type_from_properties(physical_properties)
676            if isinstance(table_type, exp.TransientProperty):
677                kwargs["properties"] = exp.Properties(expressions=[table_type])
678
679        super().clone_table(
680            target_table_name,
681            source_table_name,
682            replace=replace,
683            clone_kwargs=clone_kwargs,
684            **kwargs,
685        )

Creates a table with the target name by cloning the source table.

Arguments:
  • target_table_name: The name of the table that should be created.
  • source_table_name: The name of the source table that should be cloned.
  • replace: Whether or not to replace an existing table.
  • exists: Indicates whether to include the IF NOT EXISTS check.
def close(self) -> Any:
719    def close(self) -> t.Any:
720        if snowpark_session := self._connection_pool.get_attribute(self.SNOWPARK):
721            snowpark_session.close()  # type: ignore
722            self._connection_pool.set_attribute(self.SNOWPARK, None)
723
724        return super().close()

Closes all open connections and releases all allocated resources.

def get_table_last_modified_ts( self, table_names: List[Union[str, sqlglot.expressions.query.Table]]) -> List[int]:
726    def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
727        from sqlmesh.utils.date import to_timestamp
728
729        num_tables = len(table_names)
730
731        query = "SELECT LAST_ALTERED FROM INFORMATION_SCHEMA.TABLES WHERE"
732        for i, table_name in enumerate(table_names):
733            table = exp.to_table(table_name)
734            query += f"""(TABLE_NAME = '{table.name}' AND TABLE_SCHEMA = '{table.db}' AND TABLE_CATALOG = '{table.catalog}')"""
735            if i < num_tables - 1:
736                query += " OR "
737
738        result = self.fetchall(query)
739        return [to_timestamp(row[0]) for row in result]
def get_alter_operations( self, current_table_name: Union[str, sqlglot.expressions.query.Table], target_table_name: Union[str, sqlglot.expressions.query.Table], *, ignore_destructive: bool = False, ignore_additive: bool = False) -> List[sqlmesh.core.schema_diff.TableAlterOperation]:
358    def get_alter_operations(
359        self,
360        current_table_name: TableName,
361        target_table_name: TableName,
362        *,
363        ignore_destructive: bool = False,
364        ignore_additive: bool = False,
365    ) -> t.List[TableAlterOperation]:
366        operations = super().get_alter_operations(
367            current_table_name,
368            target_table_name,
369            ignore_destructive=ignore_destructive,
370            ignore_additive=ignore_additive,
371        )
372
373        # check for a change in clustering
374        current_table = exp.to_table(current_table_name)
375        target_table = exp.to_table(target_table_name)
376
377        current_table_schema = schema_(current_table.db, catalog=current_table.catalog)
378        target_table_schema = schema_(target_table.db, catalog=target_table.catalog)
379
380        current_table_info = seq_get(
381            self.get_data_objects(current_table_schema, {current_table.name}), 0
382        )
383        target_table_info = seq_get(
384            self.get_data_objects(target_table_schema, {target_table.name}), 0
385        )
386
387        if current_table_info and target_table_info:
388            if target_table_info.is_clustered:
389                if target_table_info.clustering_key and (
390                    current_table_info.clustering_key != target_table_info.clustering_key
391                ):
392                    operations.append(
393                        TableAlterChangeClusterKeyOperation(
394                            target_table=current_table,
395                            clustering_key=target_table_info.clustering_key,
396                            dialect=self.dialect,
397                        )
398                    )
399            elif current_table_info.is_clustered:
400                operations.append(TableAlterDropClusterKeyOperation(target_table=current_table))
401
402        return operations

Determines the alter statements needed to change the current table into the structure of the target table.

Inherited Members
sqlmesh.core.engine_adapter.base.EngineAdapter
EngineAdapter
DEFAULT_BATCH_SIZE
DATA_OBJECT_FILTER_BATCH_SIZE
SUPPORTS_TRANSACTIONS
SUPPORTS_INDEXES
COMMENT_CREATION_TABLE
COMMENT_CREATION_VIEW
MAX_TABLE_COMMENT_LENGTH
MAX_COLUMN_COMMENT_LENGTH
INSERT_OVERWRITE_STRATEGY
SUPPORTS_VIEW_SCHEMA
SUPPORTS_TUPLE_IN
HAS_VIEW_BINDING
SUPPORTS_REPLACE_TABLE
DEFAULT_CATALOG_TYPE
QUOTE_IDENTIFIERS_IN_VIEWS
MAX_IDENTIFIER_LENGTH
ATTACH_CORRELATION_ID
dialect
correlation_id
with_settings
cursor
connection
spark
bigframe
comments_enabled
schema_differ
default_catalog
engine_run_mode
recycle
get_catalog_type
get_catalog_type_from_table
current_catalog_type
replace_query
create_index
create_table
ctas
create_state_table
create_table_like
drop_data_object
drop_table
alter_table
create_schema
drop_schema
drop_view
create_catalog
drop_catalog
columns
table_exists
delete_from
insert_append
insert_overwrite_by_partition
insert_overwrite_by_time_partition
update_table
scd_type_2_by_time
scd_type_2_by_column
merge
rename_table
get_data_object
get_data_objects
fetchone
fetchall
fetchdf
fetch_pyspark_df
wap_enabled
wap_supported
wap_table_name
wap_prepare
wap_publish
sync_grants_config
transaction
execute
temp_table
drop_data_object_on_type_mismatch
ensure_nulls_for_unmatched_after_join
use_server_nulls_for_unmatched_after_join
ping
sqlmesh.core.engine_adapter.mixins.GetCurrentCatalogFromFunctionMixin
get_current_catalog
sqlmesh.core.engine_adapter.mixins.RowDiffMixin
MAX_TIMESTAMP_PRECISION
concat_columns
normalize_value
sqlmesh.core.engine_adapter.mixins.GrantsFromInfoSchemaMixin
SUPPORTS_MULTIPLE_GRANT_PRINCIPALS
GRANT_INFORMATION_SCHEMA_TABLE_NAME