sqlmesh.core.engine_adapter.trino
1from __future__ import annotations 2 3import contextlib 4import re 5import typing as t 6from functools import lru_cache 7 8from sqlglot import exp 9from sqlglot.helper import seq_get 10from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_result 11 12from sqlmesh.core.dialect import schema_, to_schema 13from sqlmesh.core.engine_adapter.mixins import ( 14 GetCurrentCatalogFromFunctionMixin, 15 HiveMetastoreTablePropertiesMixin, 16 PandasNativeFetchDFSupportMixin, 17 RowDiffMixin, 18) 19from sqlmesh.core.engine_adapter.shared import ( 20 CatalogSupport, 21 CommentCreationTable, 22 CommentCreationView, 23 DataObject, 24 DataObjectType, 25 InsertOverwriteStrategy, 26 SourceQuery, 27 set_catalog, 28) 29from sqlmesh.utils import get_source_columns_to_types 30from sqlmesh.utils.errors import SQLMeshError 31from sqlmesh.utils.date import TimeLike 32 33if t.TYPE_CHECKING: 34 from sqlmesh.core._typing import SchemaName, SessionProperties, TableName 35 from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF 36 37CATALOG_TYPES_SUPPORTING_REPLACE_TABLE = {"iceberg", "delta_lake"} 38 39 40@set_catalog() 41class TrinoEngineAdapter( 42 PandasNativeFetchDFSupportMixin, 43 HiveMetastoreTablePropertiesMixin, 44 GetCurrentCatalogFromFunctionMixin, 45 RowDiffMixin, 46): 47 DIALECT = "trino" 48 INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INTO_IS_OVERWRITE 49 # Trino does technically support transactions but it doesn't work correctly with partition overwrite so we 50 # disable transactions. If we need to get them enabled again then we would need to disable auto commit on the 51 # connector and then figure out how to get insert/overwrite to work correctly without it. 52 SUPPORTS_TRANSACTIONS = False 53 CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") 54 COMMENT_CREATION_TABLE = CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS 55 COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY 56 SUPPORTS_REPLACE_TABLE = False 57 SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] 58 DEFAULT_CATALOG_TYPE = "hive" 59 QUOTE_IDENTIFIERS_IN_VIEWS = False 60 SUPPORTS_QUERY_EXECUTION_TRACKING = True 61 SCHEMA_DIFFER_KWARGS = { 62 "parameterized_type_defaults": { 63 # default decimal precision varies across backends 64 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)], 65 exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], 66 exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(3,)], 67 }, 68 } 69 # some catalogs support microsecond (precision 6) but it has to be specifically enabled (Hive) or just isnt available (Delta / TIMESTAMP WITH TIME ZONE) 70 # and even if you have a TIMESTAMP(6) the date formatting functions still only support millisecond precision 71 MAX_TIMESTAMP_PRECISION = 3 72 73 @property 74 def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]: 75 return self._extra_config.get("schema_location_mapping") 76 77 @property 78 def timestamp_mapping(self) -> t.Optional[t.Dict[exp.DataType, exp.DataType]]: 79 return self._extra_config.get("timestamp_mapping") 80 81 def _apply_timestamp_mapping( 82 self, columns_to_types: t.Dict[str, exp.DataType] 83 ) -> t.Tuple[t.Dict[str, exp.DataType], t.Set[str]]: 84 """Apply custom timestamp mapping to column types. 85 86 Returns: 87 A tuple of (mapped_columns_to_types, mapped_column_names) where mapped_column_names 88 contains the names of columns that were found in the mapping. 89 """ 90 if not self.timestamp_mapping: 91 return columns_to_types, set() 92 93 result = {} 94 mapped_columns: t.Set[str] = set() 95 for column, column_type in columns_to_types.items(): 96 if column_type in self.timestamp_mapping: 97 result[column] = self.timestamp_mapping[column_type] 98 mapped_columns.add(column) 99 else: 100 result[column] = column_type 101 return result, mapped_columns 102 103 @property 104 def catalog_support(self) -> CatalogSupport: 105 return CatalogSupport.FULL_SUPPORT 106 107 def set_current_catalog(self, catalog: str) -> None: 108 """Sets the catalog name of the current connection.""" 109 self.execute(exp.Use(this=schema_(db="information_schema", catalog=catalog))) 110 111 @lru_cache() 112 def get_catalog_type(self, catalog: t.Optional[str]) -> str: 113 row: t.Tuple = tuple() 114 if catalog: 115 if catalog_type_override := self._catalog_type_overrides.get(catalog): 116 return catalog_type_override 117 row = ( 118 self.fetchone( 119 f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'" 120 ) 121 or () 122 ) 123 return seq_get(row, 0) or self.DEFAULT_CATALOG_TYPE 124 125 @contextlib.contextmanager 126 def session(self, properties: SessionProperties) -> t.Iterator[None]: 127 authorization = properties.get("authorization") 128 if not authorization: 129 yield 130 return 131 132 if not isinstance(authorization, exp.Expr): 133 authorization = exp.Literal.string(authorization) 134 135 if not authorization.is_string: 136 raise SQLMeshError( 137 "Invalid value for `session_properties.authorization`. Must be a string literal." 138 ) 139 140 authorization_sql = authorization.sql(dialect=self.dialect) 141 142 self.execute(f"SET SESSION AUTHORIZATION {authorization_sql}") 143 try: 144 yield 145 finally: 146 self.execute("RESET SESSION AUTHORIZATION") 147 148 def replace_query( 149 self, 150 table_name: TableName, 151 query_or_df: QueryOrDF, 152 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 153 table_description: t.Optional[str] = None, 154 column_descriptions: t.Optional[t.Dict[str, str]] = None, 155 source_columns: t.Optional[t.List[str]] = None, 156 supports_replace_table_override: t.Optional[bool] = None, 157 **kwargs: t.Any, 158 ) -> None: 159 catalog_type = self.get_catalog_type_from_table(table_name) 160 # User may have a custom catalog type name so we are assuming they keep the catalog type still in the name 161 # Ex: `acme_iceberg` would be identified as an iceberg catalog and therefore supports replace table 162 supports_replace_table_override = None 163 for replace_table_catalog_type in CATALOG_TYPES_SUPPORTING_REPLACE_TABLE: 164 if replace_table_catalog_type in catalog_type: 165 supports_replace_table_override = True 166 break 167 168 super().replace_query( 169 table_name=table_name, 170 query_or_df=query_or_df, 171 target_columns_to_types=target_columns_to_types, 172 table_description=table_description, 173 column_descriptions=column_descriptions, 174 source_columns=source_columns, 175 supports_replace_table_override=supports_replace_table_override, 176 **kwargs, 177 ) 178 179 def _insert_overwrite_by_condition( 180 self, 181 table_name: TableName, 182 source_queries: t.List[SourceQuery], 183 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 184 where: t.Optional[exp.Condition] = None, 185 insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, 186 **kwargs: t.Any, 187 ) -> None: 188 catalog = exp.to_table(table_name).catalog or self.get_current_catalog() 189 190 if where and self.get_catalog_type(catalog) == "hive": 191 # These session properties are only valid for the Trino Hive connector 192 # Attempting to set them on an Iceberg catalog will throw an error: 193 # "Session property 'catalog.insert_existing_partitions_behavior' does not exist" 194 self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='OVERWRITE'") 195 super()._insert_overwrite_by_condition( 196 table_name, source_queries, target_columns_to_types, where 197 ) 198 self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='APPEND'") 199 else: 200 super()._insert_overwrite_by_condition( 201 table_name, 202 source_queries, 203 target_columns_to_types, 204 where, 205 insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, 206 ) 207 208 def _truncate_table(self, table_name: TableName) -> None: 209 table = exp.to_table(table_name) 210 # Some trino connectors don't support truncate so we use delete. 211 self.execute(f"DELETE FROM {table.sql(dialect=self.dialect, identify=True)}") 212 213 def _get_data_objects( 214 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 215 ) -> t.List[DataObject]: 216 """ 217 Returns all the data objects that exist in the given schema and optionally catalog. 218 """ 219 schema_name = to_schema(schema_name) 220 schema = schema_name.db 221 catalog = schema_name.catalog or self.get_current_catalog() 222 query = ( 223 exp.select( 224 exp.column("table_catalog", table="t").as_("catalog"), 225 exp.column("table_schema", table="t").as_("schema"), 226 exp.column("table_name", table="t").as_("name"), 227 exp.case() 228 .when( 229 exp.column("name", table="mv").is_(exp.null()).not_(), 230 exp.Literal.string("materialized_view"), 231 ) 232 .when( 233 exp.column("table_type", table="t").eq("BASE TABLE"), 234 exp.Literal.string("table"), 235 ) 236 .else_(exp.column("table_type", table="t")) 237 .as_("type"), 238 ) 239 .from_(exp.to_table(f"{catalog}.information_schema.tables", alias="t")) 240 .join( 241 exp.to_table("system.metadata.materialized_views", alias="mv"), 242 on=exp.and_( 243 exp.column("catalog_name", table="mv").eq( 244 exp.column("table_catalog", table="t") 245 ), 246 exp.column("schema_name", table="mv").eq(exp.column("table_schema", table="t")), 247 exp.column("name", table="mv").eq(exp.column("table_name", table="t")), 248 ), 249 join_type="left", 250 ) 251 .where( 252 exp.and_( 253 exp.column("table_schema", table="t").eq(schema), 254 exp.or_( 255 exp.column("catalog_name", table="mv").is_(exp.null()), 256 exp.column("catalog_name", table="mv").eq(catalog), 257 ), 258 exp.or_( 259 exp.column("schema_name", table="mv").is_(exp.null()), 260 exp.column("schema_name", table="mv").eq(schema), 261 ), 262 ) 263 ) 264 ) 265 if object_names: 266 query = query.where(exp.column("table_name", table="t").isin(*object_names)) 267 df = self.fetchdf(query) 268 return [ 269 DataObject( 270 catalog=row.catalog, # type: ignore 271 schema=row.schema, # type: ignore 272 name=row.name, # type: ignore 273 type=DataObjectType.from_str(row.type), # type: ignore 274 ) 275 for row in df.itertuples() 276 ] 277 278 def _df_to_source_queries( 279 self, 280 df: DF, 281 target_columns_to_types: t.Dict[str, exp.DataType], 282 batch_size: int, 283 target_table: TableName, 284 source_columns: t.Optional[t.List[str]] = None, 285 ) -> t.List[SourceQuery]: 286 import pandas as pd 287 from pandas.api.types import is_datetime64_any_dtype # type: ignore 288 289 assert isinstance(df, pd.DataFrame) 290 source_columns_to_types = get_source_columns_to_types( 291 target_columns_to_types, source_columns 292 ) 293 294 # Trino does not accept timestamps in ISOFORMAT that include the "T". `execution_time` is stored in 295 # Pandas with that format, so we convert the column to a string with the proper format and CAST to 296 # timestamp in Trino. 297 for column, kind in source_columns_to_types.items(): 298 dtype = df.dtypes[column] 299 if is_datetime64_any_dtype(dtype) and getattr(dtype, "tz", None) is not None: 300 df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" ")) 301 302 return super()._df_to_source_queries( 303 df, target_columns_to_types, batch_size, target_table, source_columns=source_columns 304 ) 305 306 def _build_schema_exp( 307 self, 308 table: exp.Table, 309 target_columns_to_types: t.Dict[str, exp.DataType], 310 column_descriptions: t.Optional[t.Dict[str, str]] = None, 311 expressions: t.Optional[t.List[exp.PrimaryKey]] = None, 312 is_view: bool = False, 313 materialized: bool = False, 314 ) -> exp.Schema: 315 target_columns_to_types, mapped_columns = self._apply_timestamp_mapping( 316 target_columns_to_types 317 ) 318 if "delta_lake" in self.get_catalog_type_from_table(table): 319 target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns) 320 321 return super()._build_schema_exp( 322 table, target_columns_to_types, column_descriptions, expressions, is_view 323 ) 324 325 def _scd_type_2( 326 self, 327 target_table: TableName, 328 source_table: QueryOrDF, 329 unique_key: t.Sequence[exp.Expr], 330 valid_from_col: exp.Column, 331 valid_to_col: exp.Column, 332 execution_time: t.Union[TimeLike, exp.Column], 333 invalidate_hard_deletes: bool = True, 334 updated_at_col: t.Optional[exp.Column] = None, 335 check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expr]]] = None, 336 updated_at_as_valid_from: bool = False, 337 execution_time_as_valid_from: bool = False, 338 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 339 table_description: t.Optional[str] = None, 340 column_descriptions: t.Optional[t.Dict[str, str]] = None, 341 truncate: bool = False, 342 source_columns: t.Optional[t.List[str]] = None, 343 **kwargs: t.Any, 344 ) -> None: 345 mapped_columns: t.Set[str] = set() 346 if target_columns_to_types: 347 target_columns_to_types, mapped_columns = self._apply_timestamp_mapping( 348 target_columns_to_types 349 ) 350 if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table( 351 target_table 352 ): 353 target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns) 354 355 return super()._scd_type_2( 356 target_table, 357 source_table, 358 unique_key, 359 valid_from_col, 360 valid_to_col, 361 execution_time, 362 invalidate_hard_deletes, 363 updated_at_col, 364 check_columns, 365 updated_at_as_valid_from, 366 execution_time_as_valid_from, 367 target_columns_to_types, 368 table_description, 369 column_descriptions, 370 truncate, 371 source_columns, 372 **kwargs, 373 ) 374 375 # delta_lake only supports two timestamp data types. This method converts other 376 # timestamp types to those two for use in DDL statements. Trino/delta automatically 377 # converts the data values to the correct type on write, so we only need to handle 378 # the column types in DDL. 379 # - `timestamp(6)` for non-timezone-aware 380 # - `timestamp(3) with time zone` for timezone-aware 381 # https://trino.io/docs/current/connector/delta-lake.html#delta-lake-to-trino-type-mapping 382 def _to_delta_ts( 383 self, 384 columns_to_types: t.Dict[str, exp.DataType], 385 skip_columns: t.Optional[t.Set[str]] = None, 386 ) -> t.Dict[str, exp.DataType]: 387 ts6 = exp.DataType.build("timestamp(6)") 388 ts3_tz = exp.DataType.build("timestamp(3) with time zone") 389 skip = skip_columns or set() 390 391 delta_columns_to_types = { 392 k: ts6 if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMP) else v 393 for k, v in columns_to_types.items() 394 } 395 396 delta_columns_to_types = { 397 k: ts3_tz if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v 398 for k, v in delta_columns_to_types.items() 399 } 400 401 return delta_columns_to_types 402 403 @retry(wait=wait_fixed(1), stop=stop_after_attempt(10), retry=retry_if_result(lambda v: not v)) 404 def _block_until_table_exists(self, table_name: TableName) -> bool: 405 return self.table_exists(table_name) 406 407 def _create_schema( 408 self, 409 schema_name: SchemaName, 410 ignore_if_exists: bool, 411 warn_on_error: bool, 412 properties: t.List[exp.Expr], 413 kind: str, 414 ) -> None: 415 if mapped_location := self._schema_location(schema_name): 416 properties.append(exp.LocationProperty(this=exp.Literal.string(mapped_location))) 417 418 return super()._create_schema( 419 schema_name=schema_name, 420 ignore_if_exists=ignore_if_exists, 421 warn_on_error=warn_on_error, 422 properties=properties, 423 kind=kind, 424 ) 425 426 def _create_table( 427 self, 428 table_name_or_schema: t.Union[exp.Schema, TableName], 429 expression: t.Optional[exp.Expr], 430 exists: bool = True, 431 replace: bool = False, 432 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 433 table_description: t.Optional[str] = None, 434 column_descriptions: t.Optional[t.Dict[str, str]] = None, 435 table_kind: t.Optional[str] = None, 436 track_rows_processed: bool = True, 437 **kwargs: t.Any, 438 ) -> None: 439 super()._create_table( 440 table_name_or_schema=table_name_or_schema, 441 expression=expression, 442 exists=exists, 443 replace=replace, 444 target_columns_to_types=target_columns_to_types, 445 table_description=table_description, 446 column_descriptions=column_descriptions, 447 table_kind=table_kind, 448 track_rows_processed=track_rows_processed, 449 **kwargs, 450 ) 451 452 # extract the table name 453 if isinstance(table_name_or_schema, exp.Schema): 454 table_name = table_name_or_schema.this 455 assert isinstance(table_name, exp.Table) 456 else: 457 table_name = table_name_or_schema 458 459 if "hive" in self.get_catalog_type_from_table(table_name): 460 # the Trino Hive connector can take a few seconds for metadata changes to propagate to all internal threads 461 # (even if metadata TTL is set to 0s) 462 # Blocking until the table shows up means that subsequent code expecting it to exist immediately will not fail 463 self._block_until_table_exists(table_name) 464 465 def _schema_location(self, schema_name: SchemaName) -> t.Optional[str]: 466 if mapping := self.schema_location_mapping: 467 schema = to_schema(schema_name) 468 match_key = schema.db 469 470 # only consider the catalog if it is present 471 if schema.catalog: 472 match_key = f"{schema.catalog}.{match_key}" 473 474 for k, v in mapping.items(): 475 if re.match(k, match_key): 476 return v.replace("@{schema_name}", schema.db).replace( 477 "@{catalog_name}", schema.catalog 478 ) 479 return None
CATALOG_TYPES_SUPPORTING_REPLACE_TABLE =
{'iceberg', 'delta_lake'}
@set_catalog()
class
TrinoEngineAdapter41@set_catalog() 42class TrinoEngineAdapter( 43 PandasNativeFetchDFSupportMixin, 44 HiveMetastoreTablePropertiesMixin, 45 GetCurrentCatalogFromFunctionMixin, 46 RowDiffMixin, 47): 48 DIALECT = "trino" 49 INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INTO_IS_OVERWRITE 50 # Trino does technically support transactions but it doesn't work correctly with partition overwrite so we 51 # disable transactions. If we need to get them enabled again then we would need to disable auto commit on the 52 # connector and then figure out how to get insert/overwrite to work correctly without it. 53 SUPPORTS_TRANSACTIONS = False 54 CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") 55 COMMENT_CREATION_TABLE = CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS 56 COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY 57 SUPPORTS_REPLACE_TABLE = False 58 SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] 59 DEFAULT_CATALOG_TYPE = "hive" 60 QUOTE_IDENTIFIERS_IN_VIEWS = False 61 SUPPORTS_QUERY_EXECUTION_TRACKING = True 62 SCHEMA_DIFFER_KWARGS = { 63 "parameterized_type_defaults": { 64 # default decimal precision varies across backends 65 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)], 66 exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], 67 exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(3,)], 68 }, 69 } 70 # some catalogs support microsecond (precision 6) but it has to be specifically enabled (Hive) or just isnt available (Delta / TIMESTAMP WITH TIME ZONE) 71 # and even if you have a TIMESTAMP(6) the date formatting functions still only support millisecond precision 72 MAX_TIMESTAMP_PRECISION = 3 73 74 @property 75 def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]: 76 return self._extra_config.get("schema_location_mapping") 77 78 @property 79 def timestamp_mapping(self) -> t.Optional[t.Dict[exp.DataType, exp.DataType]]: 80 return self._extra_config.get("timestamp_mapping") 81 82 def _apply_timestamp_mapping( 83 self, columns_to_types: t.Dict[str, exp.DataType] 84 ) -> t.Tuple[t.Dict[str, exp.DataType], t.Set[str]]: 85 """Apply custom timestamp mapping to column types. 86 87 Returns: 88 A tuple of (mapped_columns_to_types, mapped_column_names) where mapped_column_names 89 contains the names of columns that were found in the mapping. 90 """ 91 if not self.timestamp_mapping: 92 return columns_to_types, set() 93 94 result = {} 95 mapped_columns: t.Set[str] = set() 96 for column, column_type in columns_to_types.items(): 97 if column_type in self.timestamp_mapping: 98 result[column] = self.timestamp_mapping[column_type] 99 mapped_columns.add(column) 100 else: 101 result[column] = column_type 102 return result, mapped_columns 103 104 @property 105 def catalog_support(self) -> CatalogSupport: 106 return CatalogSupport.FULL_SUPPORT 107 108 def set_current_catalog(self, catalog: str) -> None: 109 """Sets the catalog name of the current connection.""" 110 self.execute(exp.Use(this=schema_(db="information_schema", catalog=catalog))) 111 112 @lru_cache() 113 def get_catalog_type(self, catalog: t.Optional[str]) -> str: 114 row: t.Tuple = tuple() 115 if catalog: 116 if catalog_type_override := self._catalog_type_overrides.get(catalog): 117 return catalog_type_override 118 row = ( 119 self.fetchone( 120 f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'" 121 ) 122 or () 123 ) 124 return seq_get(row, 0) or self.DEFAULT_CATALOG_TYPE 125 126 @contextlib.contextmanager 127 def session(self, properties: SessionProperties) -> t.Iterator[None]: 128 authorization = properties.get("authorization") 129 if not authorization: 130 yield 131 return 132 133 if not isinstance(authorization, exp.Expr): 134 authorization = exp.Literal.string(authorization) 135 136 if not authorization.is_string: 137 raise SQLMeshError( 138 "Invalid value for `session_properties.authorization`. Must be a string literal." 139 ) 140 141 authorization_sql = authorization.sql(dialect=self.dialect) 142 143 self.execute(f"SET SESSION AUTHORIZATION {authorization_sql}") 144 try: 145 yield 146 finally: 147 self.execute("RESET SESSION AUTHORIZATION") 148 149 def replace_query( 150 self, 151 table_name: TableName, 152 query_or_df: QueryOrDF, 153 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 154 table_description: t.Optional[str] = None, 155 column_descriptions: t.Optional[t.Dict[str, str]] = None, 156 source_columns: t.Optional[t.List[str]] = None, 157 supports_replace_table_override: t.Optional[bool] = None, 158 **kwargs: t.Any, 159 ) -> None: 160 catalog_type = self.get_catalog_type_from_table(table_name) 161 # User may have a custom catalog type name so we are assuming they keep the catalog type still in the name 162 # Ex: `acme_iceberg` would be identified as an iceberg catalog and therefore supports replace table 163 supports_replace_table_override = None 164 for replace_table_catalog_type in CATALOG_TYPES_SUPPORTING_REPLACE_TABLE: 165 if replace_table_catalog_type in catalog_type: 166 supports_replace_table_override = True 167 break 168 169 super().replace_query( 170 table_name=table_name, 171 query_or_df=query_or_df, 172 target_columns_to_types=target_columns_to_types, 173 table_description=table_description, 174 column_descriptions=column_descriptions, 175 source_columns=source_columns, 176 supports_replace_table_override=supports_replace_table_override, 177 **kwargs, 178 ) 179 180 def _insert_overwrite_by_condition( 181 self, 182 table_name: TableName, 183 source_queries: t.List[SourceQuery], 184 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 185 where: t.Optional[exp.Condition] = None, 186 insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, 187 **kwargs: t.Any, 188 ) -> None: 189 catalog = exp.to_table(table_name).catalog or self.get_current_catalog() 190 191 if where and self.get_catalog_type(catalog) == "hive": 192 # These session properties are only valid for the Trino Hive connector 193 # Attempting to set them on an Iceberg catalog will throw an error: 194 # "Session property 'catalog.insert_existing_partitions_behavior' does not exist" 195 self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='OVERWRITE'") 196 super()._insert_overwrite_by_condition( 197 table_name, source_queries, target_columns_to_types, where 198 ) 199 self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='APPEND'") 200 else: 201 super()._insert_overwrite_by_condition( 202 table_name, 203 source_queries, 204 target_columns_to_types, 205 where, 206 insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, 207 ) 208 209 def _truncate_table(self, table_name: TableName) -> None: 210 table = exp.to_table(table_name) 211 # Some trino connectors don't support truncate so we use delete. 212 self.execute(f"DELETE FROM {table.sql(dialect=self.dialect, identify=True)}") 213 214 def _get_data_objects( 215 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 216 ) -> t.List[DataObject]: 217 """ 218 Returns all the data objects that exist in the given schema and optionally catalog. 219 """ 220 schema_name = to_schema(schema_name) 221 schema = schema_name.db 222 catalog = schema_name.catalog or self.get_current_catalog() 223 query = ( 224 exp.select( 225 exp.column("table_catalog", table="t").as_("catalog"), 226 exp.column("table_schema", table="t").as_("schema"), 227 exp.column("table_name", table="t").as_("name"), 228 exp.case() 229 .when( 230 exp.column("name", table="mv").is_(exp.null()).not_(), 231 exp.Literal.string("materialized_view"), 232 ) 233 .when( 234 exp.column("table_type", table="t").eq("BASE TABLE"), 235 exp.Literal.string("table"), 236 ) 237 .else_(exp.column("table_type", table="t")) 238 .as_("type"), 239 ) 240 .from_(exp.to_table(f"{catalog}.information_schema.tables", alias="t")) 241 .join( 242 exp.to_table("system.metadata.materialized_views", alias="mv"), 243 on=exp.and_( 244 exp.column("catalog_name", table="mv").eq( 245 exp.column("table_catalog", table="t") 246 ), 247 exp.column("schema_name", table="mv").eq(exp.column("table_schema", table="t")), 248 exp.column("name", table="mv").eq(exp.column("table_name", table="t")), 249 ), 250 join_type="left", 251 ) 252 .where( 253 exp.and_( 254 exp.column("table_schema", table="t").eq(schema), 255 exp.or_( 256 exp.column("catalog_name", table="mv").is_(exp.null()), 257 exp.column("catalog_name", table="mv").eq(catalog), 258 ), 259 exp.or_( 260 exp.column("schema_name", table="mv").is_(exp.null()), 261 exp.column("schema_name", table="mv").eq(schema), 262 ), 263 ) 264 ) 265 ) 266 if object_names: 267 query = query.where(exp.column("table_name", table="t").isin(*object_names)) 268 df = self.fetchdf(query) 269 return [ 270 DataObject( 271 catalog=row.catalog, # type: ignore 272 schema=row.schema, # type: ignore 273 name=row.name, # type: ignore 274 type=DataObjectType.from_str(row.type), # type: ignore 275 ) 276 for row in df.itertuples() 277 ] 278 279 def _df_to_source_queries( 280 self, 281 df: DF, 282 target_columns_to_types: t.Dict[str, exp.DataType], 283 batch_size: int, 284 target_table: TableName, 285 source_columns: t.Optional[t.List[str]] = None, 286 ) -> t.List[SourceQuery]: 287 import pandas as pd 288 from pandas.api.types import is_datetime64_any_dtype # type: ignore 289 290 assert isinstance(df, pd.DataFrame) 291 source_columns_to_types = get_source_columns_to_types( 292 target_columns_to_types, source_columns 293 ) 294 295 # Trino does not accept timestamps in ISOFORMAT that include the "T". `execution_time` is stored in 296 # Pandas with that format, so we convert the column to a string with the proper format and CAST to 297 # timestamp in Trino. 298 for column, kind in source_columns_to_types.items(): 299 dtype = df.dtypes[column] 300 if is_datetime64_any_dtype(dtype) and getattr(dtype, "tz", None) is not None: 301 df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" ")) 302 303 return super()._df_to_source_queries( 304 df, target_columns_to_types, batch_size, target_table, source_columns=source_columns 305 ) 306 307 def _build_schema_exp( 308 self, 309 table: exp.Table, 310 target_columns_to_types: t.Dict[str, exp.DataType], 311 column_descriptions: t.Optional[t.Dict[str, str]] = None, 312 expressions: t.Optional[t.List[exp.PrimaryKey]] = None, 313 is_view: bool = False, 314 materialized: bool = False, 315 ) -> exp.Schema: 316 target_columns_to_types, mapped_columns = self._apply_timestamp_mapping( 317 target_columns_to_types 318 ) 319 if "delta_lake" in self.get_catalog_type_from_table(table): 320 target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns) 321 322 return super()._build_schema_exp( 323 table, target_columns_to_types, column_descriptions, expressions, is_view 324 ) 325 326 def _scd_type_2( 327 self, 328 target_table: TableName, 329 source_table: QueryOrDF, 330 unique_key: t.Sequence[exp.Expr], 331 valid_from_col: exp.Column, 332 valid_to_col: exp.Column, 333 execution_time: t.Union[TimeLike, exp.Column], 334 invalidate_hard_deletes: bool = True, 335 updated_at_col: t.Optional[exp.Column] = None, 336 check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expr]]] = None, 337 updated_at_as_valid_from: bool = False, 338 execution_time_as_valid_from: bool = False, 339 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 340 table_description: t.Optional[str] = None, 341 column_descriptions: t.Optional[t.Dict[str, str]] = None, 342 truncate: bool = False, 343 source_columns: t.Optional[t.List[str]] = None, 344 **kwargs: t.Any, 345 ) -> None: 346 mapped_columns: t.Set[str] = set() 347 if target_columns_to_types: 348 target_columns_to_types, mapped_columns = self._apply_timestamp_mapping( 349 target_columns_to_types 350 ) 351 if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table( 352 target_table 353 ): 354 target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns) 355 356 return super()._scd_type_2( 357 target_table, 358 source_table, 359 unique_key, 360 valid_from_col, 361 valid_to_col, 362 execution_time, 363 invalidate_hard_deletes, 364 updated_at_col, 365 check_columns, 366 updated_at_as_valid_from, 367 execution_time_as_valid_from, 368 target_columns_to_types, 369 table_description, 370 column_descriptions, 371 truncate, 372 source_columns, 373 **kwargs, 374 ) 375 376 # delta_lake only supports two timestamp data types. This method converts other 377 # timestamp types to those two for use in DDL statements. Trino/delta automatically 378 # converts the data values to the correct type on write, so we only need to handle 379 # the column types in DDL. 380 # - `timestamp(6)` for non-timezone-aware 381 # - `timestamp(3) with time zone` for timezone-aware 382 # https://trino.io/docs/current/connector/delta-lake.html#delta-lake-to-trino-type-mapping 383 def _to_delta_ts( 384 self, 385 columns_to_types: t.Dict[str, exp.DataType], 386 skip_columns: t.Optional[t.Set[str]] = None, 387 ) -> t.Dict[str, exp.DataType]: 388 ts6 = exp.DataType.build("timestamp(6)") 389 ts3_tz = exp.DataType.build("timestamp(3) with time zone") 390 skip = skip_columns or set() 391 392 delta_columns_to_types = { 393 k: ts6 if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMP) else v 394 for k, v in columns_to_types.items() 395 } 396 397 delta_columns_to_types = { 398 k: ts3_tz if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v 399 for k, v in delta_columns_to_types.items() 400 } 401 402 return delta_columns_to_types 403 404 @retry(wait=wait_fixed(1), stop=stop_after_attempt(10), retry=retry_if_result(lambda v: not v)) 405 def _block_until_table_exists(self, table_name: TableName) -> bool: 406 return self.table_exists(table_name) 407 408 def _create_schema( 409 self, 410 schema_name: SchemaName, 411 ignore_if_exists: bool, 412 warn_on_error: bool, 413 properties: t.List[exp.Expr], 414 kind: str, 415 ) -> None: 416 if mapped_location := self._schema_location(schema_name): 417 properties.append(exp.LocationProperty(this=exp.Literal.string(mapped_location))) 418 419 return super()._create_schema( 420 schema_name=schema_name, 421 ignore_if_exists=ignore_if_exists, 422 warn_on_error=warn_on_error, 423 properties=properties, 424 kind=kind, 425 ) 426 427 def _create_table( 428 self, 429 table_name_or_schema: t.Union[exp.Schema, TableName], 430 expression: t.Optional[exp.Expr], 431 exists: bool = True, 432 replace: bool = False, 433 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 434 table_description: t.Optional[str] = None, 435 column_descriptions: t.Optional[t.Dict[str, str]] = None, 436 table_kind: t.Optional[str] = None, 437 track_rows_processed: bool = True, 438 **kwargs: t.Any, 439 ) -> None: 440 super()._create_table( 441 table_name_or_schema=table_name_or_schema, 442 expression=expression, 443 exists=exists, 444 replace=replace, 445 target_columns_to_types=target_columns_to_types, 446 table_description=table_description, 447 column_descriptions=column_descriptions, 448 table_kind=table_kind, 449 track_rows_processed=track_rows_processed, 450 **kwargs, 451 ) 452 453 # extract the table name 454 if isinstance(table_name_or_schema, exp.Schema): 455 table_name = table_name_or_schema.this 456 assert isinstance(table_name, exp.Table) 457 else: 458 table_name = table_name_or_schema 459 460 if "hive" in self.get_catalog_type_from_table(table_name): 461 # the Trino Hive connector can take a few seconds for metadata changes to propagate to all internal threads 462 # (even if metadata TTL is set to 0s) 463 # Blocking until the table shows up means that subsequent code expecting it to exist immediately will not fail 464 self._block_until_table_exists(table_name) 465 466 def _schema_location(self, schema_name: SchemaName) -> t.Optional[str]: 467 if mapping := self.schema_location_mapping: 468 schema = to_schema(schema_name) 469 match_key = schema.db 470 471 # only consider the catalog if it is present 472 if schema.catalog: 473 match_key = f"{schema.catalog}.{match_key}" 474 475 for k, v in mapping.items(): 476 if re.match(k, match_key): 477 return v.replace("@{schema_name}", schema.db).replace( 478 "@{catalog_name}", schema.catalog 479 ) 480 return None
Base class wrapping a Database API compliant connection.
The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.
Arguments:
- connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
- dialect: The dialect with which this adapter is associated.
- multithreaded: Indicates whether this adapter will be used by more than one thread.
SCHEMA_DIFFER_KWARGS =
{'parameterized_type_defaults': {<DType.DECIMAL: 'DECIMAL'>: [(), (0,)], <DType.CHAR: 'CHAR'>: [(1,)], <DType.TIMESTAMP: 'TIMESTAMP'>: [(3,)]}}
timestamp_mapping: Optional[Dict[sqlglot.expressions.datatypes.DataType, sqlglot.expressions.datatypes.DataType]]
catalog_support: sqlmesh.core.engine_adapter.shared.CatalogSupport
def
set_current_catalog(self, catalog: str) -> None:
108 def set_current_catalog(self, catalog: str) -> None: 109 """Sets the catalog name of the current connection.""" 110 self.execute(exp.Use(this=schema_(db="information_schema", catalog=catalog)))
Sets the catalog name of the current connection.
@lru_cache()
def
get_catalog_type(self, catalog: Optional[str]) -> str:
112 @lru_cache() 113 def get_catalog_type(self, catalog: t.Optional[str]) -> str: 114 row: t.Tuple = tuple() 115 if catalog: 116 if catalog_type_override := self._catalog_type_overrides.get(catalog): 117 return catalog_type_override 118 row = ( 119 self.fetchone( 120 f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'" 121 ) 122 or () 123 ) 124 return seq_get(row, 0) or self.DEFAULT_CATALOG_TYPE
Intended to be overridden for data virtualization systems like Trino that, depending on the target catalog, require slightly different properties to be set when creating / updating tables
@contextlib.contextmanager
def
session( self, properties: Dict[str, sqlglot.expressions.core.Expr | str | int | float | bool]) -> Iterator[NoneType]:
126 @contextlib.contextmanager 127 def session(self, properties: SessionProperties) -> t.Iterator[None]: 128 authorization = properties.get("authorization") 129 if not authorization: 130 yield 131 return 132 133 if not isinstance(authorization, exp.Expr): 134 authorization = exp.Literal.string(authorization) 135 136 if not authorization.is_string: 137 raise SQLMeshError( 138 "Invalid value for `session_properties.authorization`. Must be a string literal." 139 ) 140 141 authorization_sql = authorization.sql(dialect=self.dialect) 142 143 self.execute(f"SET SESSION AUTHORIZATION {authorization_sql}") 144 try: 145 yield 146 finally: 147 self.execute("RESET SESSION AUTHORIZATION")
A session context manager.
def
replace_query( self, table_name: Union[str, sqlglot.expressions.query.Table], query_or_df: <MagicMock id='132726885001120'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]] = None, table_description: Optional[str] = None, column_descriptions: Optional[Dict[str, str]] = None, source_columns: Optional[List[str]] = None, supports_replace_table_override: Optional[bool] = None, **kwargs: Any) -> None:
149 def replace_query( 150 self, 151 table_name: TableName, 152 query_or_df: QueryOrDF, 153 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 154 table_description: t.Optional[str] = None, 155 column_descriptions: t.Optional[t.Dict[str, str]] = None, 156 source_columns: t.Optional[t.List[str]] = None, 157 supports_replace_table_override: t.Optional[bool] = None, 158 **kwargs: t.Any, 159 ) -> None: 160 catalog_type = self.get_catalog_type_from_table(table_name) 161 # User may have a custom catalog type name so we are assuming they keep the catalog type still in the name 162 # Ex: `acme_iceberg` would be identified as an iceberg catalog and therefore supports replace table 163 supports_replace_table_override = None 164 for replace_table_catalog_type in CATALOG_TYPES_SUPPORTING_REPLACE_TABLE: 165 if replace_table_catalog_type in catalog_type: 166 supports_replace_table_override = True 167 break 168 169 super().replace_query( 170 table_name=table_name, 171 query_or_df=query_or_df, 172 target_columns_to_types=target_columns_to_types, 173 table_description=table_description, 174 column_descriptions=column_descriptions, 175 source_columns=source_columns, 176 supports_replace_table_override=supports_replace_table_override, 177 **kwargs, 178 )
Replaces an existing table with a query.
For partition based engines (hive, spark), insert override is used. For other systems, create or replace is used.
Arguments:
- table_name: The name of the table (eg. prod.table)
- query_or_df: The SQL query to run or a dataframe.
- target_columns_to_types: Only used if a dataframe is provided. A mapping between the column name and its data type. Expected to be ordered to match the order of values in the dataframe.
- kwargs: Optional create table properties.
Inherited Members
- sqlmesh.core.engine_adapter.base.EngineAdapter
- EngineAdapter
- DEFAULT_BATCH_SIZE
- DATA_OBJECT_FILTER_BATCH_SIZE
- SUPPORTS_INDEXES
- SUPPORTS_MATERIALIZED_VIEWS
- SUPPORTS_MATERIALIZED_VIEW_SCHEMA
- SUPPORTS_VIEW_SCHEMA
- SUPPORTS_CLONING
- SUPPORTS_MANAGED_MODELS
- SUPPORTS_CREATE_DROP_CATALOG
- SUPPORTS_TUPLE_IN
- HAS_VIEW_BINDING
- SUPPORTS_GRANTS
- MAX_IDENTIFIER_LENGTH
- ATTACH_CORRELATION_ID
- SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
- dialect
- correlation_id
- with_settings
- cursor
- connection
- spark
- snowpark
- bigframe
- comments_enabled
- schema_differ
- default_catalog
- engine_run_mode
- recycle
- close
- get_catalog_type_from_table
- current_catalog_type
- create_index
- create_table
- create_managed_table
- ctas
- create_state_table
- create_table_like
- clone_table
- drop_data_object
- drop_table
- drop_managed_table
- get_alter_operations
- alter_table
- create_view
- create_schema
- drop_schema
- drop_view
- create_catalog
- drop_catalog
- columns
- table_exists
- delete_from
- insert_append
- insert_overwrite_by_partition
- insert_overwrite_by_time_partition
- update_table
- scd_type_2_by_time
- scd_type_2_by_column
- merge
- rename_table
- get_data_object
- get_data_objects
- fetchone
- fetchall
- fetchdf
- fetch_pyspark_df
- wap_enabled
- wap_supported
- wap_table_name
- wap_prepare
- wap_publish
- sync_grants_config
- transaction
- execute
- temp_table
- drop_data_object_on_type_mismatch
- ensure_nulls_for_unmatched_after_join
- use_server_nulls_for_unmatched_after_join
- ping
- get_table_last_modified_ts