Edit on GitHub

sqlmesh.core.engine_adapter.trino

  1from __future__ import annotations
  2
  3import contextlib
  4import re
  5import typing as t
  6from functools import lru_cache
  7
  8from sqlglot import exp
  9from sqlglot.helper import seq_get
 10from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_result
 11
 12from sqlmesh.core.dialect import schema_, to_schema
 13from sqlmesh.core.engine_adapter.mixins import (
 14    GetCurrentCatalogFromFunctionMixin,
 15    HiveMetastoreTablePropertiesMixin,
 16    PandasNativeFetchDFSupportMixin,
 17    RowDiffMixin,
 18)
 19from sqlmesh.core.engine_adapter.shared import (
 20    CatalogSupport,
 21    CommentCreationTable,
 22    CommentCreationView,
 23    DataObject,
 24    DataObjectType,
 25    InsertOverwriteStrategy,
 26    SourceQuery,
 27    set_catalog,
 28)
 29from sqlmesh.utils import get_source_columns_to_types
 30from sqlmesh.utils.errors import SQLMeshError
 31from sqlmesh.utils.date import TimeLike
 32
 33if t.TYPE_CHECKING:
 34    from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
 35    from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF
 36
 37CATALOG_TYPES_SUPPORTING_REPLACE_TABLE = {"iceberg", "delta_lake"}
 38
 39
 40@set_catalog()
 41class TrinoEngineAdapter(
 42    PandasNativeFetchDFSupportMixin,
 43    HiveMetastoreTablePropertiesMixin,
 44    GetCurrentCatalogFromFunctionMixin,
 45    RowDiffMixin,
 46):
 47    DIALECT = "trino"
 48    INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INTO_IS_OVERWRITE
 49    # Trino does technically support transactions but it doesn't work correctly with partition overwrite so we
 50    # disable transactions. If we need to get them enabled again then we would need to disable auto commit on the
 51    # connector and then figure out how to get insert/overwrite to work correctly without it.
 52    SUPPORTS_TRANSACTIONS = False
 53    CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog")
 54    COMMENT_CREATION_TABLE = CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS
 55    COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY
 56    SUPPORTS_REPLACE_TABLE = False
 57    SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"]
 58    DEFAULT_CATALOG_TYPE = "hive"
 59    QUOTE_IDENTIFIERS_IN_VIEWS = False
 60    SUPPORTS_QUERY_EXECUTION_TRACKING = True
 61    SCHEMA_DIFFER_KWARGS = {
 62        "parameterized_type_defaults": {
 63            # default decimal precision varies across backends
 64            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)],
 65            exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)],
 66            exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(3,)],
 67        },
 68    }
 69    # some catalogs support microsecond (precision 6) but it has to be specifically enabled (Hive) or just isnt available (Delta / TIMESTAMP WITH TIME ZONE)
 70    # and even if you have a TIMESTAMP(6) the date formatting functions still only support millisecond precision
 71    MAX_TIMESTAMP_PRECISION = 3
 72
 73    @property
 74    def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]:
 75        return self._extra_config.get("schema_location_mapping")
 76
 77    @property
 78    def timestamp_mapping(self) -> t.Optional[t.Dict[exp.DataType, exp.DataType]]:
 79        return self._extra_config.get("timestamp_mapping")
 80
 81    def _apply_timestamp_mapping(
 82        self, columns_to_types: t.Dict[str, exp.DataType]
 83    ) -> t.Tuple[t.Dict[str, exp.DataType], t.Set[str]]:
 84        """Apply custom timestamp mapping to column types.
 85
 86        Returns:
 87            A tuple of (mapped_columns_to_types, mapped_column_names) where mapped_column_names
 88            contains the names of columns that were found in the mapping.
 89        """
 90        if not self.timestamp_mapping:
 91            return columns_to_types, set()
 92
 93        result = {}
 94        mapped_columns: t.Set[str] = set()
 95        for column, column_type in columns_to_types.items():
 96            if column_type in self.timestamp_mapping:
 97                result[column] = self.timestamp_mapping[column_type]
 98                mapped_columns.add(column)
 99            else:
100                result[column] = column_type
101        return result, mapped_columns
102
103    @property
104    def catalog_support(self) -> CatalogSupport:
105        return CatalogSupport.FULL_SUPPORT
106
107    def set_current_catalog(self, catalog: str) -> None:
108        """Sets the catalog name of the current connection."""
109        self.execute(exp.Use(this=schema_(db="information_schema", catalog=catalog)))
110
111    @lru_cache()
112    def get_catalog_type(self, catalog: t.Optional[str]) -> str:
113        row: t.Tuple = tuple()
114        if catalog:
115            if catalog_type_override := self._catalog_type_overrides.get(catalog):
116                return catalog_type_override
117            row = (
118                self.fetchone(
119                    f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'"
120                )
121                or ()
122            )
123        return seq_get(row, 0) or self.DEFAULT_CATALOG_TYPE
124
125    @contextlib.contextmanager
126    def session(self, properties: SessionProperties) -> t.Iterator[None]:
127        authorization = properties.get("authorization")
128        if not authorization:
129            yield
130            return
131
132        if not isinstance(authorization, exp.Expr):
133            authorization = exp.Literal.string(authorization)
134
135        if not authorization.is_string:
136            raise SQLMeshError(
137                "Invalid value for `session_properties.authorization`. Must be a string literal."
138            )
139
140        authorization_sql = authorization.sql(dialect=self.dialect)
141
142        self.execute(f"SET SESSION AUTHORIZATION {authorization_sql}")
143        try:
144            yield
145        finally:
146            self.execute("RESET SESSION AUTHORIZATION")
147
148    def replace_query(
149        self,
150        table_name: TableName,
151        query_or_df: QueryOrDF,
152        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
153        table_description: t.Optional[str] = None,
154        column_descriptions: t.Optional[t.Dict[str, str]] = None,
155        source_columns: t.Optional[t.List[str]] = None,
156        supports_replace_table_override: t.Optional[bool] = None,
157        **kwargs: t.Any,
158    ) -> None:
159        catalog_type = self.get_catalog_type_from_table(table_name)
160        # User may have a custom catalog type name so we are assuming they keep the catalog type still in the name
161        # Ex: `acme_iceberg` would be identified as an iceberg catalog and therefore supports replace table
162        supports_replace_table_override = None
163        for replace_table_catalog_type in CATALOG_TYPES_SUPPORTING_REPLACE_TABLE:
164            if replace_table_catalog_type in catalog_type:
165                supports_replace_table_override = True
166                break
167
168        super().replace_query(
169            table_name=table_name,
170            query_or_df=query_or_df,
171            target_columns_to_types=target_columns_to_types,
172            table_description=table_description,
173            column_descriptions=column_descriptions,
174            source_columns=source_columns,
175            supports_replace_table_override=supports_replace_table_override,
176            **kwargs,
177        )
178
179    def _insert_overwrite_by_condition(
180        self,
181        table_name: TableName,
182        source_queries: t.List[SourceQuery],
183        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
184        where: t.Optional[exp.Condition] = None,
185        insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None,
186        **kwargs: t.Any,
187    ) -> None:
188        catalog = exp.to_table(table_name).catalog or self.get_current_catalog()
189
190        if where and self.get_catalog_type(catalog) == "hive":
191            # These session properties are only valid for the Trino Hive connector
192            # Attempting to set them on an Iceberg catalog will throw an error:
193            # "Session property 'catalog.insert_existing_partitions_behavior' does not exist"
194            self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='OVERWRITE'")
195            super()._insert_overwrite_by_condition(
196                table_name, source_queries, target_columns_to_types, where
197            )
198            self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='APPEND'")
199        else:
200            super()._insert_overwrite_by_condition(
201                table_name,
202                source_queries,
203                target_columns_to_types,
204                where,
205                insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT,
206            )
207
208    def _truncate_table(self, table_name: TableName) -> None:
209        table = exp.to_table(table_name)
210        # Some trino connectors don't support truncate so we use delete.
211        self.execute(f"DELETE FROM {table.sql(dialect=self.dialect, identify=True)}")
212
213    def _get_data_objects(
214        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
215    ) -> t.List[DataObject]:
216        """
217        Returns all the data objects that exist in the given schema and optionally catalog.
218        """
219        schema_name = to_schema(schema_name)
220        schema = schema_name.db
221        catalog = schema_name.catalog or self.get_current_catalog()
222        query = (
223            exp.select(
224                exp.column("table_catalog", table="t").as_("catalog"),
225                exp.column("table_schema", table="t").as_("schema"),
226                exp.column("table_name", table="t").as_("name"),
227                exp.case()
228                .when(
229                    exp.column("name", table="mv").is_(exp.null()).not_(),
230                    exp.Literal.string("materialized_view"),
231                )
232                .when(
233                    exp.column("table_type", table="t").eq("BASE TABLE"),
234                    exp.Literal.string("table"),
235                )
236                .else_(exp.column("table_type", table="t"))
237                .as_("type"),
238            )
239            .from_(exp.to_table(f"{catalog}.information_schema.tables", alias="t"))
240            .join(
241                exp.to_table("system.metadata.materialized_views", alias="mv"),
242                on=exp.and_(
243                    exp.column("catalog_name", table="mv").eq(
244                        exp.column("table_catalog", table="t")
245                    ),
246                    exp.column("schema_name", table="mv").eq(exp.column("table_schema", table="t")),
247                    exp.column("name", table="mv").eq(exp.column("table_name", table="t")),
248                ),
249                join_type="left",
250            )
251            .where(
252                exp.and_(
253                    exp.column("table_schema", table="t").eq(schema),
254                    exp.or_(
255                        exp.column("catalog_name", table="mv").is_(exp.null()),
256                        exp.column("catalog_name", table="mv").eq(catalog),
257                    ),
258                    exp.or_(
259                        exp.column("schema_name", table="mv").is_(exp.null()),
260                        exp.column("schema_name", table="mv").eq(schema),
261                    ),
262                )
263            )
264        )
265        if object_names:
266            query = query.where(exp.column("table_name", table="t").isin(*object_names))
267        df = self.fetchdf(query)
268        return [
269            DataObject(
270                catalog=row.catalog,  # type: ignore
271                schema=row.schema,  # type: ignore
272                name=row.name,  # type: ignore
273                type=DataObjectType.from_str(row.type),  # type: ignore
274            )
275            for row in df.itertuples()
276        ]
277
278    def _df_to_source_queries(
279        self,
280        df: DF,
281        target_columns_to_types: t.Dict[str, exp.DataType],
282        batch_size: int,
283        target_table: TableName,
284        source_columns: t.Optional[t.List[str]] = None,
285    ) -> t.List[SourceQuery]:
286        import pandas as pd
287        from pandas.api.types import is_datetime64_any_dtype  # type: ignore
288
289        assert isinstance(df, pd.DataFrame)
290        source_columns_to_types = get_source_columns_to_types(
291            target_columns_to_types, source_columns
292        )
293
294        # Trino does not accept timestamps in ISOFORMAT that include the "T". `execution_time` is stored in
295        # Pandas with that format, so we convert the column to a string with the proper format and CAST to
296        # timestamp in Trino.
297        for column, kind in source_columns_to_types.items():
298            dtype = df.dtypes[column]
299            if is_datetime64_any_dtype(dtype) and getattr(dtype, "tz", None) is not None:
300                df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" "))
301
302        return super()._df_to_source_queries(
303            df, target_columns_to_types, batch_size, target_table, source_columns=source_columns
304        )
305
306    def _build_schema_exp(
307        self,
308        table: exp.Table,
309        target_columns_to_types: t.Dict[str, exp.DataType],
310        column_descriptions: t.Optional[t.Dict[str, str]] = None,
311        expressions: t.Optional[t.List[exp.PrimaryKey]] = None,
312        is_view: bool = False,
313        materialized: bool = False,
314    ) -> exp.Schema:
315        target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
316            target_columns_to_types
317        )
318        if "delta_lake" in self.get_catalog_type_from_table(table):
319            target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)
320
321        return super()._build_schema_exp(
322            table, target_columns_to_types, column_descriptions, expressions, is_view
323        )
324
325    def _scd_type_2(
326        self,
327        target_table: TableName,
328        source_table: QueryOrDF,
329        unique_key: t.Sequence[exp.Expr],
330        valid_from_col: exp.Column,
331        valid_to_col: exp.Column,
332        execution_time: t.Union[TimeLike, exp.Column],
333        invalidate_hard_deletes: bool = True,
334        updated_at_col: t.Optional[exp.Column] = None,
335        check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expr]]] = None,
336        updated_at_as_valid_from: bool = False,
337        execution_time_as_valid_from: bool = False,
338        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
339        table_description: t.Optional[str] = None,
340        column_descriptions: t.Optional[t.Dict[str, str]] = None,
341        truncate: bool = False,
342        source_columns: t.Optional[t.List[str]] = None,
343        **kwargs: t.Any,
344    ) -> None:
345        mapped_columns: t.Set[str] = set()
346        if target_columns_to_types:
347            target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
348                target_columns_to_types
349            )
350        if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table(
351            target_table
352        ):
353            target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)
354
355        return super()._scd_type_2(
356            target_table,
357            source_table,
358            unique_key,
359            valid_from_col,
360            valid_to_col,
361            execution_time,
362            invalidate_hard_deletes,
363            updated_at_col,
364            check_columns,
365            updated_at_as_valid_from,
366            execution_time_as_valid_from,
367            target_columns_to_types,
368            table_description,
369            column_descriptions,
370            truncate,
371            source_columns,
372            **kwargs,
373        )
374
375    # delta_lake only supports two timestamp data types. This method converts other
376    # timestamp types to those two for use in DDL statements. Trino/delta automatically
377    # converts the data values to the correct type on write, so we only need to handle
378    # the column types in DDL.
379    # - `timestamp(6)` for non-timezone-aware
380    # - `timestamp(3) with time zone` for timezone-aware
381    # https://trino.io/docs/current/connector/delta-lake.html#delta-lake-to-trino-type-mapping
382    def _to_delta_ts(
383        self,
384        columns_to_types: t.Dict[str, exp.DataType],
385        skip_columns: t.Optional[t.Set[str]] = None,
386    ) -> t.Dict[str, exp.DataType]:
387        ts6 = exp.DataType.build("timestamp(6)")
388        ts3_tz = exp.DataType.build("timestamp(3) with time zone")
389        skip = skip_columns or set()
390
391        delta_columns_to_types = {
392            k: ts6 if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMP) else v
393            for k, v in columns_to_types.items()
394        }
395
396        delta_columns_to_types = {
397            k: ts3_tz if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v
398            for k, v in delta_columns_to_types.items()
399        }
400
401        return delta_columns_to_types
402
403    @retry(wait=wait_fixed(1), stop=stop_after_attempt(10), retry=retry_if_result(lambda v: not v))
404    def _block_until_table_exists(self, table_name: TableName) -> bool:
405        return self.table_exists(table_name)
406
407    def _create_schema(
408        self,
409        schema_name: SchemaName,
410        ignore_if_exists: bool,
411        warn_on_error: bool,
412        properties: t.List[exp.Expr],
413        kind: str,
414    ) -> None:
415        if mapped_location := self._schema_location(schema_name):
416            properties.append(exp.LocationProperty(this=exp.Literal.string(mapped_location)))
417
418        return super()._create_schema(
419            schema_name=schema_name,
420            ignore_if_exists=ignore_if_exists,
421            warn_on_error=warn_on_error,
422            properties=properties,
423            kind=kind,
424        )
425
426    def _create_table(
427        self,
428        table_name_or_schema: t.Union[exp.Schema, TableName],
429        expression: t.Optional[exp.Expr],
430        exists: bool = True,
431        replace: bool = False,
432        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
433        table_description: t.Optional[str] = None,
434        column_descriptions: t.Optional[t.Dict[str, str]] = None,
435        table_kind: t.Optional[str] = None,
436        track_rows_processed: bool = True,
437        **kwargs: t.Any,
438    ) -> None:
439        super()._create_table(
440            table_name_or_schema=table_name_or_schema,
441            expression=expression,
442            exists=exists,
443            replace=replace,
444            target_columns_to_types=target_columns_to_types,
445            table_description=table_description,
446            column_descriptions=column_descriptions,
447            table_kind=table_kind,
448            track_rows_processed=track_rows_processed,
449            **kwargs,
450        )
451
452        # extract the table name
453        if isinstance(table_name_or_schema, exp.Schema):
454            table_name = table_name_or_schema.this
455            assert isinstance(table_name, exp.Table)
456        else:
457            table_name = table_name_or_schema
458
459        if "hive" in self.get_catalog_type_from_table(table_name):
460            # the Trino Hive connector can take a few seconds for metadata changes to propagate to all internal threads
461            # (even if metadata TTL is set to 0s)
462            # Blocking until the table shows up means that subsequent code expecting it to exist immediately will not fail
463            self._block_until_table_exists(table_name)
464
465    def _schema_location(self, schema_name: SchemaName) -> t.Optional[str]:
466        if mapping := self.schema_location_mapping:
467            schema = to_schema(schema_name)
468            match_key = schema.db
469
470            # only consider the catalog if it is present
471            if schema.catalog:
472                match_key = f"{schema.catalog}.{match_key}"
473
474            for k, v in mapping.items():
475                if re.match(k, match_key):
476                    return v.replace("@{schema_name}", schema.db).replace(
477                        "@{catalog_name}", schema.catalog
478                    )
479        return None
CATALOG_TYPES_SUPPORTING_REPLACE_TABLE = {'iceberg', 'delta_lake'}
 41@set_catalog()
 42class TrinoEngineAdapter(
 43    PandasNativeFetchDFSupportMixin,
 44    HiveMetastoreTablePropertiesMixin,
 45    GetCurrentCatalogFromFunctionMixin,
 46    RowDiffMixin,
 47):
 48    DIALECT = "trino"
 49    INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INTO_IS_OVERWRITE
 50    # Trino does technically support transactions but it doesn't work correctly with partition overwrite so we
 51    # disable transactions. If we need to get them enabled again then we would need to disable auto commit on the
 52    # connector and then figure out how to get insert/overwrite to work correctly without it.
 53    SUPPORTS_TRANSACTIONS = False
 54    CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog")
 55    COMMENT_CREATION_TABLE = CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS
 56    COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY
 57    SUPPORTS_REPLACE_TABLE = False
 58    SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"]
 59    DEFAULT_CATALOG_TYPE = "hive"
 60    QUOTE_IDENTIFIERS_IN_VIEWS = False
 61    SUPPORTS_QUERY_EXECUTION_TRACKING = True
 62    SCHEMA_DIFFER_KWARGS = {
 63        "parameterized_type_defaults": {
 64            # default decimal precision varies across backends
 65            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)],
 66            exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)],
 67            exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(3,)],
 68        },
 69    }
 70    # some catalogs support microsecond (precision 6) but it has to be specifically enabled (Hive) or just isnt available (Delta / TIMESTAMP WITH TIME ZONE)
 71    # and even if you have a TIMESTAMP(6) the date formatting functions still only support millisecond precision
 72    MAX_TIMESTAMP_PRECISION = 3
 73
 74    @property
 75    def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]:
 76        return self._extra_config.get("schema_location_mapping")
 77
 78    @property
 79    def timestamp_mapping(self) -> t.Optional[t.Dict[exp.DataType, exp.DataType]]:
 80        return self._extra_config.get("timestamp_mapping")
 81
 82    def _apply_timestamp_mapping(
 83        self, columns_to_types: t.Dict[str, exp.DataType]
 84    ) -> t.Tuple[t.Dict[str, exp.DataType], t.Set[str]]:
 85        """Apply custom timestamp mapping to column types.
 86
 87        Returns:
 88            A tuple of (mapped_columns_to_types, mapped_column_names) where mapped_column_names
 89            contains the names of columns that were found in the mapping.
 90        """
 91        if not self.timestamp_mapping:
 92            return columns_to_types, set()
 93
 94        result = {}
 95        mapped_columns: t.Set[str] = set()
 96        for column, column_type in columns_to_types.items():
 97            if column_type in self.timestamp_mapping:
 98                result[column] = self.timestamp_mapping[column_type]
 99                mapped_columns.add(column)
100            else:
101                result[column] = column_type
102        return result, mapped_columns
103
104    @property
105    def catalog_support(self) -> CatalogSupport:
106        return CatalogSupport.FULL_SUPPORT
107
108    def set_current_catalog(self, catalog: str) -> None:
109        """Sets the catalog name of the current connection."""
110        self.execute(exp.Use(this=schema_(db="information_schema", catalog=catalog)))
111
112    @lru_cache()
113    def get_catalog_type(self, catalog: t.Optional[str]) -> str:
114        row: t.Tuple = tuple()
115        if catalog:
116            if catalog_type_override := self._catalog_type_overrides.get(catalog):
117                return catalog_type_override
118            row = (
119                self.fetchone(
120                    f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'"
121                )
122                or ()
123            )
124        return seq_get(row, 0) or self.DEFAULT_CATALOG_TYPE
125
126    @contextlib.contextmanager
127    def session(self, properties: SessionProperties) -> t.Iterator[None]:
128        authorization = properties.get("authorization")
129        if not authorization:
130            yield
131            return
132
133        if not isinstance(authorization, exp.Expr):
134            authorization = exp.Literal.string(authorization)
135
136        if not authorization.is_string:
137            raise SQLMeshError(
138                "Invalid value for `session_properties.authorization`. Must be a string literal."
139            )
140
141        authorization_sql = authorization.sql(dialect=self.dialect)
142
143        self.execute(f"SET SESSION AUTHORIZATION {authorization_sql}")
144        try:
145            yield
146        finally:
147            self.execute("RESET SESSION AUTHORIZATION")
148
149    def replace_query(
150        self,
151        table_name: TableName,
152        query_or_df: QueryOrDF,
153        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
154        table_description: t.Optional[str] = None,
155        column_descriptions: t.Optional[t.Dict[str, str]] = None,
156        source_columns: t.Optional[t.List[str]] = None,
157        supports_replace_table_override: t.Optional[bool] = None,
158        **kwargs: t.Any,
159    ) -> None:
160        catalog_type = self.get_catalog_type_from_table(table_name)
161        # User may have a custom catalog type name so we are assuming they keep the catalog type still in the name
162        # Ex: `acme_iceberg` would be identified as an iceberg catalog and therefore supports replace table
163        supports_replace_table_override = None
164        for replace_table_catalog_type in CATALOG_TYPES_SUPPORTING_REPLACE_TABLE:
165            if replace_table_catalog_type in catalog_type:
166                supports_replace_table_override = True
167                break
168
169        super().replace_query(
170            table_name=table_name,
171            query_or_df=query_or_df,
172            target_columns_to_types=target_columns_to_types,
173            table_description=table_description,
174            column_descriptions=column_descriptions,
175            source_columns=source_columns,
176            supports_replace_table_override=supports_replace_table_override,
177            **kwargs,
178        )
179
180    def _insert_overwrite_by_condition(
181        self,
182        table_name: TableName,
183        source_queries: t.List[SourceQuery],
184        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
185        where: t.Optional[exp.Condition] = None,
186        insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None,
187        **kwargs: t.Any,
188    ) -> None:
189        catalog = exp.to_table(table_name).catalog or self.get_current_catalog()
190
191        if where and self.get_catalog_type(catalog) == "hive":
192            # These session properties are only valid for the Trino Hive connector
193            # Attempting to set them on an Iceberg catalog will throw an error:
194            # "Session property 'catalog.insert_existing_partitions_behavior' does not exist"
195            self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='OVERWRITE'")
196            super()._insert_overwrite_by_condition(
197                table_name, source_queries, target_columns_to_types, where
198            )
199            self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='APPEND'")
200        else:
201            super()._insert_overwrite_by_condition(
202                table_name,
203                source_queries,
204                target_columns_to_types,
205                where,
206                insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT,
207            )
208
209    def _truncate_table(self, table_name: TableName) -> None:
210        table = exp.to_table(table_name)
211        # Some trino connectors don't support truncate so we use delete.
212        self.execute(f"DELETE FROM {table.sql(dialect=self.dialect, identify=True)}")
213
214    def _get_data_objects(
215        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
216    ) -> t.List[DataObject]:
217        """
218        Returns all the data objects that exist in the given schema and optionally catalog.
219        """
220        schema_name = to_schema(schema_name)
221        schema = schema_name.db
222        catalog = schema_name.catalog or self.get_current_catalog()
223        query = (
224            exp.select(
225                exp.column("table_catalog", table="t").as_("catalog"),
226                exp.column("table_schema", table="t").as_("schema"),
227                exp.column("table_name", table="t").as_("name"),
228                exp.case()
229                .when(
230                    exp.column("name", table="mv").is_(exp.null()).not_(),
231                    exp.Literal.string("materialized_view"),
232                )
233                .when(
234                    exp.column("table_type", table="t").eq("BASE TABLE"),
235                    exp.Literal.string("table"),
236                )
237                .else_(exp.column("table_type", table="t"))
238                .as_("type"),
239            )
240            .from_(exp.to_table(f"{catalog}.information_schema.tables", alias="t"))
241            .join(
242                exp.to_table("system.metadata.materialized_views", alias="mv"),
243                on=exp.and_(
244                    exp.column("catalog_name", table="mv").eq(
245                        exp.column("table_catalog", table="t")
246                    ),
247                    exp.column("schema_name", table="mv").eq(exp.column("table_schema", table="t")),
248                    exp.column("name", table="mv").eq(exp.column("table_name", table="t")),
249                ),
250                join_type="left",
251            )
252            .where(
253                exp.and_(
254                    exp.column("table_schema", table="t").eq(schema),
255                    exp.or_(
256                        exp.column("catalog_name", table="mv").is_(exp.null()),
257                        exp.column("catalog_name", table="mv").eq(catalog),
258                    ),
259                    exp.or_(
260                        exp.column("schema_name", table="mv").is_(exp.null()),
261                        exp.column("schema_name", table="mv").eq(schema),
262                    ),
263                )
264            )
265        )
266        if object_names:
267            query = query.where(exp.column("table_name", table="t").isin(*object_names))
268        df = self.fetchdf(query)
269        return [
270            DataObject(
271                catalog=row.catalog,  # type: ignore
272                schema=row.schema,  # type: ignore
273                name=row.name,  # type: ignore
274                type=DataObjectType.from_str(row.type),  # type: ignore
275            )
276            for row in df.itertuples()
277        ]
278
279    def _df_to_source_queries(
280        self,
281        df: DF,
282        target_columns_to_types: t.Dict[str, exp.DataType],
283        batch_size: int,
284        target_table: TableName,
285        source_columns: t.Optional[t.List[str]] = None,
286    ) -> t.List[SourceQuery]:
287        import pandas as pd
288        from pandas.api.types import is_datetime64_any_dtype  # type: ignore
289
290        assert isinstance(df, pd.DataFrame)
291        source_columns_to_types = get_source_columns_to_types(
292            target_columns_to_types, source_columns
293        )
294
295        # Trino does not accept timestamps in ISOFORMAT that include the "T". `execution_time` is stored in
296        # Pandas with that format, so we convert the column to a string with the proper format and CAST to
297        # timestamp in Trino.
298        for column, kind in source_columns_to_types.items():
299            dtype = df.dtypes[column]
300            if is_datetime64_any_dtype(dtype) and getattr(dtype, "tz", None) is not None:
301                df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" "))
302
303        return super()._df_to_source_queries(
304            df, target_columns_to_types, batch_size, target_table, source_columns=source_columns
305        )
306
307    def _build_schema_exp(
308        self,
309        table: exp.Table,
310        target_columns_to_types: t.Dict[str, exp.DataType],
311        column_descriptions: t.Optional[t.Dict[str, str]] = None,
312        expressions: t.Optional[t.List[exp.PrimaryKey]] = None,
313        is_view: bool = False,
314        materialized: bool = False,
315    ) -> exp.Schema:
316        target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
317            target_columns_to_types
318        )
319        if "delta_lake" in self.get_catalog_type_from_table(table):
320            target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)
321
322        return super()._build_schema_exp(
323            table, target_columns_to_types, column_descriptions, expressions, is_view
324        )
325
326    def _scd_type_2(
327        self,
328        target_table: TableName,
329        source_table: QueryOrDF,
330        unique_key: t.Sequence[exp.Expr],
331        valid_from_col: exp.Column,
332        valid_to_col: exp.Column,
333        execution_time: t.Union[TimeLike, exp.Column],
334        invalidate_hard_deletes: bool = True,
335        updated_at_col: t.Optional[exp.Column] = None,
336        check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expr]]] = None,
337        updated_at_as_valid_from: bool = False,
338        execution_time_as_valid_from: bool = False,
339        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
340        table_description: t.Optional[str] = None,
341        column_descriptions: t.Optional[t.Dict[str, str]] = None,
342        truncate: bool = False,
343        source_columns: t.Optional[t.List[str]] = None,
344        **kwargs: t.Any,
345    ) -> None:
346        mapped_columns: t.Set[str] = set()
347        if target_columns_to_types:
348            target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
349                target_columns_to_types
350            )
351        if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table(
352            target_table
353        ):
354            target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)
355
356        return super()._scd_type_2(
357            target_table,
358            source_table,
359            unique_key,
360            valid_from_col,
361            valid_to_col,
362            execution_time,
363            invalidate_hard_deletes,
364            updated_at_col,
365            check_columns,
366            updated_at_as_valid_from,
367            execution_time_as_valid_from,
368            target_columns_to_types,
369            table_description,
370            column_descriptions,
371            truncate,
372            source_columns,
373            **kwargs,
374        )
375
376    # delta_lake only supports two timestamp data types. This method converts other
377    # timestamp types to those two for use in DDL statements. Trino/delta automatically
378    # converts the data values to the correct type on write, so we only need to handle
379    # the column types in DDL.
380    # - `timestamp(6)` for non-timezone-aware
381    # - `timestamp(3) with time zone` for timezone-aware
382    # https://trino.io/docs/current/connector/delta-lake.html#delta-lake-to-trino-type-mapping
383    def _to_delta_ts(
384        self,
385        columns_to_types: t.Dict[str, exp.DataType],
386        skip_columns: t.Optional[t.Set[str]] = None,
387    ) -> t.Dict[str, exp.DataType]:
388        ts6 = exp.DataType.build("timestamp(6)")
389        ts3_tz = exp.DataType.build("timestamp(3) with time zone")
390        skip = skip_columns or set()
391
392        delta_columns_to_types = {
393            k: ts6 if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMP) else v
394            for k, v in columns_to_types.items()
395        }
396
397        delta_columns_to_types = {
398            k: ts3_tz if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v
399            for k, v in delta_columns_to_types.items()
400        }
401
402        return delta_columns_to_types
403
404    @retry(wait=wait_fixed(1), stop=stop_after_attempt(10), retry=retry_if_result(lambda v: not v))
405    def _block_until_table_exists(self, table_name: TableName) -> bool:
406        return self.table_exists(table_name)
407
408    def _create_schema(
409        self,
410        schema_name: SchemaName,
411        ignore_if_exists: bool,
412        warn_on_error: bool,
413        properties: t.List[exp.Expr],
414        kind: str,
415    ) -> None:
416        if mapped_location := self._schema_location(schema_name):
417            properties.append(exp.LocationProperty(this=exp.Literal.string(mapped_location)))
418
419        return super()._create_schema(
420            schema_name=schema_name,
421            ignore_if_exists=ignore_if_exists,
422            warn_on_error=warn_on_error,
423            properties=properties,
424            kind=kind,
425        )
426
427    def _create_table(
428        self,
429        table_name_or_schema: t.Union[exp.Schema, TableName],
430        expression: t.Optional[exp.Expr],
431        exists: bool = True,
432        replace: bool = False,
433        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
434        table_description: t.Optional[str] = None,
435        column_descriptions: t.Optional[t.Dict[str, str]] = None,
436        table_kind: t.Optional[str] = None,
437        track_rows_processed: bool = True,
438        **kwargs: t.Any,
439    ) -> None:
440        super()._create_table(
441            table_name_or_schema=table_name_or_schema,
442            expression=expression,
443            exists=exists,
444            replace=replace,
445            target_columns_to_types=target_columns_to_types,
446            table_description=table_description,
447            column_descriptions=column_descriptions,
448            table_kind=table_kind,
449            track_rows_processed=track_rows_processed,
450            **kwargs,
451        )
452
453        # extract the table name
454        if isinstance(table_name_or_schema, exp.Schema):
455            table_name = table_name_or_schema.this
456            assert isinstance(table_name, exp.Table)
457        else:
458            table_name = table_name_or_schema
459
460        if "hive" in self.get_catalog_type_from_table(table_name):
461            # the Trino Hive connector can take a few seconds for metadata changes to propagate to all internal threads
462            # (even if metadata TTL is set to 0s)
463            # Blocking until the table shows up means that subsequent code expecting it to exist immediately will not fail
464            self._block_until_table_exists(table_name)
465
466    def _schema_location(self, schema_name: SchemaName) -> t.Optional[str]:
467        if mapping := self.schema_location_mapping:
468            schema = to_schema(schema_name)
469            match_key = schema.db
470
471            # only consider the catalog if it is present
472            if schema.catalog:
473                match_key = f"{schema.catalog}.{match_key}"
474
475            for k, v in mapping.items():
476                if re.match(k, match_key):
477                    return v.replace("@{schema_name}", schema.db).replace(
478                        "@{catalog_name}", schema.catalog
479                    )
480        return None

Base class wrapping a Database API compliant connection.

The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.

Arguments:
  • connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
  • dialect: The dialect with which this adapter is associated.
  • multithreaded: Indicates whether this adapter will be used by more than one thread.
DIALECT = 'trino'
INSERT_OVERWRITE_STRATEGY = <InsertOverwriteStrategy.INTO_IS_OVERWRITE: 4>
SUPPORTS_TRANSACTIONS = False
CURRENT_CATALOG_EXPRESSION = Column( this=Identifier(this=current_catalog, quoted=False))
COMMENT_CREATION_TABLE = <CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS: 3>
COMMENT_CREATION_VIEW = <CommentCreationView.COMMENT_COMMAND_ONLY: 4>
SUPPORTS_REPLACE_TABLE = False
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ['SCHEMA']
DEFAULT_CATALOG_TYPE = 'hive'
QUOTE_IDENTIFIERS_IN_VIEWS = False
SUPPORTS_QUERY_EXECUTION_TRACKING = True
SCHEMA_DIFFER_KWARGS = {'parameterized_type_defaults': {<DType.DECIMAL: 'DECIMAL'>: [(), (0,)], <DType.CHAR: 'CHAR'>: [(1,)], <DType.TIMESTAMP: 'TIMESTAMP'>: [(3,)]}}
MAX_TIMESTAMP_PRECISION = 3
schema_location_mapping: Optional[Dict[re.Pattern, str]]
74    @property
75    def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]:
76        return self._extra_config.get("schema_location_mapping")
timestamp_mapping: Optional[Dict[sqlglot.expressions.datatypes.DataType, sqlglot.expressions.datatypes.DataType]]
78    @property
79    def timestamp_mapping(self) -> t.Optional[t.Dict[exp.DataType, exp.DataType]]:
80        return self._extra_config.get("timestamp_mapping")
104    @property
105    def catalog_support(self) -> CatalogSupport:
106        return CatalogSupport.FULL_SUPPORT
def set_current_catalog(self, catalog: str) -> None:
108    def set_current_catalog(self, catalog: str) -> None:
109        """Sets the catalog name of the current connection."""
110        self.execute(exp.Use(this=schema_(db="information_schema", catalog=catalog)))

Sets the catalog name of the current connection.

@lru_cache()
def get_catalog_type(self, catalog: Optional[str]) -> str:
112    @lru_cache()
113    def get_catalog_type(self, catalog: t.Optional[str]) -> str:
114        row: t.Tuple = tuple()
115        if catalog:
116            if catalog_type_override := self._catalog_type_overrides.get(catalog):
117                return catalog_type_override
118            row = (
119                self.fetchone(
120                    f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'"
121                )
122                or ()
123            )
124        return seq_get(row, 0) or self.DEFAULT_CATALOG_TYPE

Intended to be overridden for data virtualization systems like Trino that, depending on the target catalog, require slightly different properties to be set when creating / updating tables

@contextlib.contextmanager
def session( self, properties: Dict[str, sqlglot.expressions.core.Expr | str | int | float | bool]) -> Iterator[NoneType]:
126    @contextlib.contextmanager
127    def session(self, properties: SessionProperties) -> t.Iterator[None]:
128        authorization = properties.get("authorization")
129        if not authorization:
130            yield
131            return
132
133        if not isinstance(authorization, exp.Expr):
134            authorization = exp.Literal.string(authorization)
135
136        if not authorization.is_string:
137            raise SQLMeshError(
138                "Invalid value for `session_properties.authorization`. Must be a string literal."
139            )
140
141        authorization_sql = authorization.sql(dialect=self.dialect)
142
143        self.execute(f"SET SESSION AUTHORIZATION {authorization_sql}")
144        try:
145            yield
146        finally:
147            self.execute("RESET SESSION AUTHORIZATION")

A session context manager.

def replace_query( self, table_name: Union[str, sqlglot.expressions.query.Table], query_or_df: <MagicMock id='132726885001120'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]] = None, table_description: Optional[str] = None, column_descriptions: Optional[Dict[str, str]] = None, source_columns: Optional[List[str]] = None, supports_replace_table_override: Optional[bool] = None, **kwargs: Any) -> None:
149    def replace_query(
150        self,
151        table_name: TableName,
152        query_or_df: QueryOrDF,
153        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
154        table_description: t.Optional[str] = None,
155        column_descriptions: t.Optional[t.Dict[str, str]] = None,
156        source_columns: t.Optional[t.List[str]] = None,
157        supports_replace_table_override: t.Optional[bool] = None,
158        **kwargs: t.Any,
159    ) -> None:
160        catalog_type = self.get_catalog_type_from_table(table_name)
161        # User may have a custom catalog type name so we are assuming they keep the catalog type still in the name
162        # Ex: `acme_iceberg` would be identified as an iceberg catalog and therefore supports replace table
163        supports_replace_table_override = None
164        for replace_table_catalog_type in CATALOG_TYPES_SUPPORTING_REPLACE_TABLE:
165            if replace_table_catalog_type in catalog_type:
166                supports_replace_table_override = True
167                break
168
169        super().replace_query(
170            table_name=table_name,
171            query_or_df=query_or_df,
172            target_columns_to_types=target_columns_to_types,
173            table_description=table_description,
174            column_descriptions=column_descriptions,
175            source_columns=source_columns,
176            supports_replace_table_override=supports_replace_table_override,
177            **kwargs,
178        )

Replaces an existing table with a query.

For partition based engines (hive, spark), insert override is used. For other systems, create or replace is used.

Arguments:
  • table_name: The name of the table (eg. prod.table)
  • query_or_df: The SQL query to run or a dataframe.
  • target_columns_to_types: Only used if a dataframe is provided. A mapping between the column name and its data type. Expected to be ordered to match the order of values in the dataframe.
  • kwargs: Optional create table properties.
Inherited Members
sqlmesh.core.engine_adapter.base.EngineAdapter
EngineAdapter
DEFAULT_BATCH_SIZE
DATA_OBJECT_FILTER_BATCH_SIZE
SUPPORTS_INDEXES
SUPPORTS_MATERIALIZED_VIEWS
SUPPORTS_MATERIALIZED_VIEW_SCHEMA
SUPPORTS_VIEW_SCHEMA
SUPPORTS_CLONING
SUPPORTS_MANAGED_MODELS
SUPPORTS_CREATE_DROP_CATALOG
SUPPORTS_TUPLE_IN
HAS_VIEW_BINDING
SUPPORTS_GRANTS
MAX_IDENTIFIER_LENGTH
ATTACH_CORRELATION_ID
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
dialect
correlation_id
with_settings
cursor
connection
spark
snowpark
bigframe
comments_enabled
schema_differ
default_catalog
engine_run_mode
recycle
close
get_catalog_type_from_table
current_catalog_type
create_index
create_table
create_managed_table
ctas
create_state_table
create_table_like
clone_table
drop_data_object
drop_table
drop_managed_table
get_alter_operations
alter_table
create_view
create_schema
drop_schema
drop_view
create_catalog
drop_catalog
columns
table_exists
delete_from
insert_append
insert_overwrite_by_partition
insert_overwrite_by_time_partition
update_table
scd_type_2_by_time
scd_type_2_by_column
merge
rename_table
get_data_object
get_data_objects
fetchone
fetchall
fetchdf
fetch_pyspark_df
wap_enabled
wap_supported
wap_table_name
wap_prepare
wap_publish
sync_grants_config
transaction
execute
temp_table
drop_data_object_on_type_mismatch
ensure_nulls_for_unmatched_after_join
use_server_nulls_for_unmatched_after_join
ping
get_table_last_modified_ts
sqlmesh.core.engine_adapter.mixins.HiveMetastoreTablePropertiesMixin
MAX_TABLE_COMMENT_LENGTH
MAX_COLUMN_COMMENT_LENGTH
sqlmesh.core.engine_adapter.mixins.GetCurrentCatalogFromFunctionMixin
get_current_catalog
sqlmesh.core.engine_adapter.mixins.RowDiffMixin
concat_columns
normalize_value