sqlmesh.core.engine_adapter.athena
1from __future__ import annotations 2from functools import lru_cache 3import typing as t 4import logging 5from sqlglot import exp 6from sqlmesh.core.dialect import to_schema 7from sqlmesh.utils.aws import validate_s3_uri, parse_s3_uri 8from sqlmesh.core.engine_adapter.mixins import PandasNativeFetchDFSupportMixin, RowDiffMixin 9from sqlmesh.core.engine_adapter.trino import TrinoEngineAdapter 10from sqlmesh.core.node import IntervalUnit 11import posixpath 12from sqlmesh.utils.errors import SQLMeshError 13from sqlmesh.core.engine_adapter.shared import ( 14 CatalogSupport, 15 DataObject, 16 DataObjectType, 17 CommentCreationTable, 18 CommentCreationView, 19 SourceQuery, 20 InsertOverwriteStrategy, 21) 22 23if t.TYPE_CHECKING: 24 from sqlmesh.core._typing import SchemaName, TableName 25 from sqlmesh.core.engine_adapter._typing import QueryOrDF 26 27 TableType = t.Union[t.Literal["hive"], t.Literal["iceberg"]] 28 29logger = logging.getLogger(__name__) 30 31 32class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin, RowDiffMixin): 33 DIALECT = "athena" 34 SUPPORTS_TRANSACTIONS = False 35 SUPPORTS_REPLACE_TABLE = False 36 # Athena's support for table and column comments is too patchy to consider "supported" 37 # Hive tables: Table + Column comments are supported 38 # Iceberg tables: Column comments only 39 # CTAS, Views: No comment support at all 40 COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED 41 COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED 42 SCHEMA_DIFFER_KWARGS = TrinoEngineAdapter.SCHEMA_DIFFER_KWARGS 43 MAX_TIMESTAMP_PRECISION = 3 # copied from Trino 44 # Athena does not deal with comments well, e.g: 45 # >>> self._execute('/* test */ DESCRIBE foo') 46 # pyathena.error.OperationalError: FAILED: ParseException line 1:0 cannot recognize input near '/' '*' 'test' 47 ATTACH_CORRELATION_ID = False 48 SUPPORTS_QUERY_EXECUTION_TRACKING = True 49 SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"] 50 51 def __init__( 52 self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any 53 ): 54 # Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config 55 # which means that EngineAdapter.with_settings() keeps this property when it makes a clone 56 super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs) 57 self.s3_warehouse_location = s3_warehouse_location 58 59 self._default_catalog = self._default_catalog or "awsdatacatalog" 60 61 @property 62 def s3_warehouse_location(self) -> t.Optional[str]: 63 return self._s3_warehouse_location 64 65 @s3_warehouse_location.setter 66 def s3_warehouse_location(self, value: t.Optional[str]) -> None: 67 if value: 68 value = validate_s3_uri(value, base=True) 69 self._s3_warehouse_location = value 70 71 @property 72 def s3_warehouse_location_or_raise(self) -> str: 73 # this makes tests easier to write without extra null checks to keep mypy happy 74 if location := self.s3_warehouse_location: 75 return location 76 77 raise SQLMeshError("s3_warehouse_location was expected to be populated; it isnt") 78 79 @property 80 def catalog_support(self) -> CatalogSupport: 81 # Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that 82 # It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog" 83 # are pointers to the "awsdatacatalog" of other AWS accounts 84 return CatalogSupport.SINGLE_CATALOG_ONLY 85 86 def create_state_table( 87 self, 88 table_name: str, 89 target_columns_to_types: t.Dict[str, exp.DataType], 90 primary_key: t.Optional[t.Tuple[str, ...]] = None, 91 ) -> None: 92 self.create_table( 93 table_name, 94 target_columns_to_types, 95 primary_key=primary_key, 96 # it's painfully slow, but it works 97 table_format="iceberg", 98 ) 99 100 def _get_data_objects( 101 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 102 ) -> t.List[DataObject]: 103 """ 104 Returns all the data objects that exist in the given schema and optionally catalog. 105 """ 106 schema_name = to_schema(schema_name) 107 schema = schema_name.db 108 query = ( 109 exp.select( 110 exp.column("table_catalog").as_("catalog"), 111 exp.column("table_schema", table="t").as_("schema"), 112 exp.column("table_name", table="t").as_("name"), 113 exp.case() 114 .when( 115 exp.column("table_type", table="t").eq("BASE TABLE"), 116 exp.Literal.string("table"), 117 ) 118 .else_(exp.column("table_type", table="t")) 119 .as_("type"), 120 ) 121 .from_(exp.to_table("information_schema.tables", alias="t")) 122 .where(exp.column("table_schema", table="t").eq(schema)) 123 ) 124 if object_names: 125 query = query.where(exp.column("table_name", table="t").isin(*object_names)) 126 127 df = self.fetchdf(query) 128 129 return [ 130 DataObject( 131 catalog=row.catalog, # type: ignore 132 schema=row.schema, # type: ignore 133 name=row.name, # type: ignore 134 type=DataObjectType.from_str(row.type), # type: ignore 135 ) 136 for row in df.itertuples() 137 ] 138 139 def columns( 140 self, table_name: TableName, include_pseudo_columns: bool = False 141 ) -> t.Dict[str, exp.DataType]: 142 table = exp.to_table(table_name) 143 # note: the data_type column contains the full parameterized type, eg 'varchar(10)' 144 query = ( 145 exp.select("column_name", "data_type") 146 .from_("information_schema.columns") 147 .where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name)) 148 .order_by("ordinal_position") 149 ) 150 result = self.fetchdf(query, quote_identifiers=True) 151 return { 152 str(r.column_name): exp.DataType.build(str(r.data_type)) 153 for r in result.itertuples(index=False) 154 } 155 156 def _create_schema( 157 self, 158 schema_name: SchemaName, 159 ignore_if_exists: bool, 160 warn_on_error: bool, 161 properties: t.List[exp.Expr], 162 kind: str, 163 ) -> None: 164 if location := self._table_location(table_properties=None, table=exp.to_table(schema_name)): 165 # don't add extra LocationProperty's if one already exists 166 if not any(p for p in properties if isinstance(p, exp.LocationProperty)): 167 properties.append(location) 168 169 return super()._create_schema( 170 schema_name=schema_name, 171 ignore_if_exists=ignore_if_exists, 172 warn_on_error=warn_on_error, 173 properties=properties, 174 kind=kind, 175 ) 176 177 def _build_create_table_exp( 178 self, 179 table_name_or_schema: t.Union[exp.Schema, TableName], 180 expression: t.Optional[exp.Expr], 181 exists: bool = True, 182 replace: bool = False, 183 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 184 table_description: t.Optional[str] = None, 185 table_kind: t.Optional[str] = None, 186 partitioned_by: t.Optional[t.List[exp.Expr]] = None, 187 table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, 188 **kwargs: t.Any, 189 ) -> exp.Create: 190 exists = False if replace else exists 191 192 table: exp.Table 193 if isinstance(table_name_or_schema, str): 194 table = exp.to_table(table_name_or_schema) 195 elif isinstance(table_name_or_schema, exp.Schema): 196 table = table_name_or_schema.this 197 else: 198 table = table_name_or_schema 199 200 properties = self._build_table_properties_exp( 201 table=table, 202 expression=expression, 203 target_columns_to_types=target_columns_to_types, 204 partitioned_by=partitioned_by, 205 table_properties=table_properties, 206 table_description=table_description, 207 table_kind=table_kind, 208 **kwargs, 209 ) 210 211 is_hive = self._table_type(kwargs.get("table_format", None)) == "hive" 212 213 # Filter any PARTITIONED BY properties from the main column list since they cant be specified in both places 214 # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html 215 if is_hive and partitioned_by and isinstance(table_name_or_schema, exp.Schema): 216 partitioned_by_column_names = {e.name for e in partitioned_by} 217 filtered_expressions = [ 218 e 219 for e in table_name_or_schema.expressions 220 if isinstance(e, exp.ColumnDef) and e.this.name not in partitioned_by_column_names 221 ] 222 table_name_or_schema.args["expressions"] = filtered_expressions 223 224 return exp.Create( 225 this=table_name_or_schema, 226 kind=table_kind or "TABLE", 227 replace=replace, 228 exists=exists, 229 expression=expression, 230 properties=properties, 231 ) 232 233 def _build_table_properties_exp( 234 self, 235 catalog_name: t.Optional[str] = None, 236 table_format: t.Optional[str] = None, 237 storage_format: t.Optional[str] = None, 238 partitioned_by: t.Optional[t.List[exp.Expr]] = None, 239 partition_interval_unit: t.Optional[IntervalUnit] = None, 240 clustered_by: t.Optional[t.List[exp.Expr]] = None, 241 table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, 242 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 243 table_description: t.Optional[str] = None, 244 table_kind: t.Optional[str] = None, 245 table: t.Optional[exp.Table] = None, 246 expression: t.Optional[exp.Expr] = None, 247 **kwargs: t.Any, 248 ) -> t.Optional[exp.Properties]: 249 properties: t.List[exp.Expr] = [] 250 table_properties = table_properties or {} 251 252 is_hive = self._table_type(table_format) == "hive" 253 is_iceberg = not is_hive 254 255 if is_hive and not expression: 256 # Hive tables are CREATE EXTERNAL TABLE, Iceberg tables are CREATE TABLE 257 # Unless it's a CTAS, those are always CREATE TABLE 258 properties.append(exp.ExternalProperty()) 259 260 if table_format: 261 properties.append( 262 exp.Property(this=exp.var("table_type"), value=exp.Literal.string(table_format)) 263 ) 264 265 if table_description: 266 properties.append(exp.SchemaCommentProperty(this=exp.Literal.string(table_description))) 267 268 if partitioned_by: 269 schema_expressions: t.List[exp.Expr] = [] 270 if is_hive and target_columns_to_types: 271 # For Hive-style tables, you cannot include the partitioned by columns in the main set of columns 272 # In the PARTITIONED BY expression, you also cant just include the column names, you need to include the data type as well 273 # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html 274 for match_name, match_dtype in self._find_matching_columns( 275 partitioned_by, target_columns_to_types 276 ): 277 column_def = exp.ColumnDef(this=exp.to_identifier(match_name), kind=match_dtype) 278 schema_expressions.append(column_def) 279 else: 280 schema_expressions = partitioned_by 281 282 properties.append( 283 exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions)) 284 ) 285 286 if clustered_by: 287 # Athena itself supports CLUSTERED BY, via the syntax CLUSTERED BY (col) INTO <n> BUCKETS 288 # However, SQLMesh is more closely aligned with BigQuery's notion of clustering and 289 # defines `clustered_by` as a List[str] with no way of indicating the number of buckets 290 # 291 # Athena's concept of CLUSTER BY is more like Iceberg's `bucket(<num_buckets>, col)` partition transform 292 logging.warning("clustered_by is not supported in the Athena adapter at this time") 293 294 if storage_format: 295 if is_iceberg: 296 # TBLPROPERTIES('format'='parquet') 297 table_properties["format"] = exp.Literal.string(storage_format) 298 else: 299 # STORED AS PARQUET 300 properties.append(exp.FileFormatProperty(this=storage_format)) 301 302 if table and (location := self._table_location_or_raise(table_properties, table)): 303 properties.append(location) 304 305 if is_iceberg and expression: 306 # To make a CTAS expression persist as iceberg, alongside setting `table_type=iceberg`, you also need to set is_external=false 307 # Note that SQLGlot does the right thing with LocationProperty and writes it as `location` (Iceberg) instead of `external_location` (Hive) 308 # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties 309 properties.append(exp.Property(this=exp.var("is_external"), value="false")) 310 311 for name, value in table_properties.items(): 312 properties.append(exp.Property(this=exp.var(name), value=value)) 313 314 if properties: 315 return exp.Properties(expressions=properties) 316 317 return None 318 319 def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None: 320 table = exp.to_table(table_name) 321 322 if self._query_table_type(table) == "hive": 323 self._truncate_table(table) 324 325 return super().drop_table(table_name=table, exists=exists, **kwargs) 326 327 def _truncate_table(self, table_name: TableName) -> None: 328 table = exp.to_table(table_name) 329 330 # Truncating an Iceberg table is just DELETE FROM <table> 331 if self._query_table_type(table) == "iceberg": 332 return self.delete_from(table, exp.true()) 333 334 # Truncating a partitioned Hive table is dropping all partitions and deleting the data from S3 335 if self._is_hive_partitioned_table(table): 336 self._clear_partition_data(table, exp.true()) 337 elif s3_location := self._query_table_s3_location(table): 338 # Truncating a non-partitioned Hive table is clearing out all data in its Location 339 self._clear_s3_location(s3_location) 340 341 def _table_type(self, table_format: t.Optional[str] = None) -> TableType: 342 """ 343 Interpret the "table_format" property to check if this is a Hive or an Iceberg table 344 """ 345 if table_format and table_format.lower() == "iceberg": 346 return "iceberg" 347 348 # if we cant detect any indication of Iceberg, this is a Hive table 349 return "hive" 350 351 def _query_table_type(self, table: exp.Table) -> t.Optional[TableType]: 352 if self.table_exists(table): 353 return self._query_table_type_or_raise(table) 354 return None 355 356 @lru_cache() 357 def _query_table_type_or_raise(self, table: exp.Table) -> TableType: 358 """ 359 Hit the DB to check if this is a Hive or an Iceberg table. 360 361 Note that in order to @lru_cache() this method, we have the following assumptions: 362 - The table must exist (otherwise we would cache None if this method was called before table creation and always return None after creation) 363 - The table type will not change within the same SQLMesh session 364 """ 365 # Note: SHOW TBLPROPERTIES gets parsed by SQLGlot as an exp.Command anyway so we just use a string here 366 # This also means we need to use dialect="hive" instead of dialect="athena" so that the identifiers get the correct quoting (backticks) 367 for row in self.fetchall(f"SHOW TBLPROPERTIES {table.sql(dialect='hive', identify=True)}"): 368 # This query returns a single column with values like 'EXTERNAL\tTRUE' 369 row_lower = row[0].lower() 370 if "external" in row_lower and "true" in row_lower: 371 return "hive" 372 return "iceberg" 373 374 def _is_hive_partitioned_table(self, table: exp.Table) -> bool: 375 try: 376 self._list_partitions(table=table, where=None, limit=1) 377 return True 378 except Exception as e: 379 if "TABLE_NOT_FOUND" in str(e): 380 return False 381 raise e 382 383 def _table_location_or_raise( 384 self, table_properties: t.Optional[t.Dict[str, exp.Expr]], table: exp.Table 385 ) -> exp.LocationProperty: 386 location = self._table_location(table_properties, table) 387 if not location: 388 raise SQLMeshError( 389 f"Cannot figure out location for table {table}. Please either set `s3_base_location` in `physical_properties` or set `s3_warehouse_location` in the Athena connection config" 390 ) 391 return location 392 393 def _table_location( 394 self, 395 table_properties: t.Optional[t.Dict[str, exp.Expr]], 396 table: exp.Table, 397 ) -> t.Optional[exp.LocationProperty]: 398 base_uri: str 399 400 # If the user has manually specified a `s3_base_location`, use it 401 if table_properties and "s3_base_location" in table_properties: 402 s3_base_location_property = table_properties.pop( 403 "s3_base_location" 404 ) # pop because it's handled differently and we dont want it to end up in the TBLPROPERTIES clause 405 if isinstance(s3_base_location_property, exp.Expr): 406 base_uri = s3_base_location_property.name 407 else: 408 base_uri = s3_base_location_property 409 410 elif self.s3_warehouse_location: 411 # If the user has set `s3_warehouse_location` in the connection config, the base URI is <s3_warehouse_location>/<catalog>/<schema>/ 412 base_uri = posixpath.join( 413 self.s3_warehouse_location, table.catalog or "", table.db or "" 414 ) 415 else: 416 return None 417 418 full_uri = validate_s3_uri(posixpath.join(base_uri, table.text("this") or ""), base=True) 419 return exp.LocationProperty(this=exp.Literal.string(full_uri)) 420 421 def _find_matching_columns( 422 self, partitioned_by: t.List[exp.Expr], columns_to_types: t.Dict[str, exp.DataType] 423 ) -> t.List[t.Tuple[str, exp.DataType]]: 424 matches = [] 425 for col in partitioned_by: 426 # TODO: do we care about normalization? 427 key = col.name 428 if isinstance(col, exp.Column) and (match_dtype := columns_to_types.get(key)): 429 matches.append((key, match_dtype)) 430 return matches 431 432 def replace_query( 433 self, 434 table_name: TableName, 435 query_or_df: QueryOrDF, 436 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 437 table_description: t.Optional[str] = None, 438 column_descriptions: t.Optional[t.Dict[str, str]] = None, 439 source_columns: t.Optional[t.List[str]] = None, 440 supports_replace_table_override: t.Optional[bool] = None, 441 **kwargs: t.Any, 442 ) -> None: 443 table = exp.to_table(table_name) 444 445 if self._query_table_type(table=table) == "hive": 446 self.drop_table(table) 447 448 return super().replace_query( 449 table_name=table, 450 query_or_df=query_or_df, 451 target_columns_to_types=target_columns_to_types, 452 table_description=table_description, 453 column_descriptions=column_descriptions, 454 source_columns=source_columns, 455 **kwargs, 456 ) 457 458 def _insert_overwrite_by_time_partition( 459 self, 460 table_name: TableName, 461 source_queries: t.List[SourceQuery], 462 target_columns_to_types: t.Dict[str, exp.DataType], 463 where: exp.Condition, 464 **kwargs: t.Any, 465 ) -> None: 466 table = exp.to_table(table_name) 467 468 table_type = self._query_table_type(table) 469 470 if table_type == "iceberg": 471 # Iceberg tables work as expected, we can use the default behaviour 472 return super()._insert_overwrite_by_time_partition( 473 table, source_queries, target_columns_to_types, where, **kwargs 474 ) 475 476 # For Hive tables, we need to drop all the partitions covered by the query and delete the data from S3 477 self._clear_partition_data(table, where) 478 479 # Now the data is physically gone, we can continue with inserting a new partition 480 return super()._insert_overwrite_by_time_partition( 481 table, 482 source_queries, 483 target_columns_to_types, 484 where, 485 insert_overwrite_strategy_override=InsertOverwriteStrategy.INTO_IS_OVERWRITE, # since we already cleared the data 486 **kwargs, 487 ) 488 489 def _clear_partition_data(self, table: exp.Table, where: t.Optional[exp.Condition]) -> None: 490 if partitions_to_drop := self._list_partitions(table, where): 491 for _, s3_location in partitions_to_drop: 492 logger.debug( 493 f"Clearing S3 location for '{table.sql(dialect=self.dialect)}': {s3_location}" 494 ) 495 self._clear_s3_location(s3_location) 496 497 partition_values = [k for k, _ in partitions_to_drop] 498 logger.debug( 499 f"Dropping partitions for '{table.sql(dialect=self.dialect)}' from metastore: {partition_values}" 500 ) 501 self._drop_partitions_from_metastore(table, partition_values) 502 503 def _list_partitions( 504 self, 505 table: exp.Table, 506 where: t.Optional[exp.Condition] = None, 507 limit: t.Optional[int] = None, 508 ) -> t.List[t.Tuple[t.List[str], str]]: 509 # Use Athena's magic "$partitions" metadata table to identify the partitions to drop 510 # Doing it this way allows us to use SQL to filter the partition list 511 partition_table_name = table.copy() 512 partition_table_name.this.replace( 513 exp.to_identifier(f"{table.name}$partitions", quoted=True) 514 ) 515 516 query = exp.select("*").from_(partition_table_name).where(where) 517 if limit: 518 query = query.limit(limit) 519 520 partition_values = [list(r) for r in self.fetchall(query, quote_identifiers=True)] 521 522 if partition_values: 523 response = self._glue_client.batch_get_partition( 524 DatabaseName=table.db, 525 TableName=table.name, 526 PartitionsToGet=[{"Values": [str(v) for v in lst]} for lst in partition_values], 527 ) 528 return sorted( 529 [(p["Values"], p["StorageDescriptor"]["Location"]) for p in response["Partitions"]] 530 ) 531 532 return [] 533 534 def _query_table_s3_location(self, table: exp.Table) -> str: 535 response = self._glue_client.get_table(DatabaseName=table.db, Name=table.name) 536 537 # Athena wont let you create a table without a location, so *theoretically* this should never be empty 538 if location := response.get("Table", {}).get("StorageDescriptor", {}).get("Location", None): 539 return location 540 541 raise SQLMeshError(f"Table {table} has no location set in the metastore!") 542 543 def _drop_partitions_from_metastore( 544 self, table: exp.Table, partition_values: t.List[t.List[str]] 545 ) -> None: 546 # todo: switch to itertools.batched when our minimum supported Python is 3.12 547 # 25 = maximum number of partitions that batch_delete_partition can process at once 548 # ref: https://docs.aws.amazon.com/glue/latest/webapi/API_BatchDeletePartition.html#API_BatchDeletePartition_RequestParameters 549 def _chunks() -> t.Iterable[t.List[t.List[str]]]: 550 for i in range(0, len(partition_values), 25): 551 yield partition_values[i : i + 25] 552 553 for batch in _chunks(): 554 self._glue_client.batch_delete_partition( 555 DatabaseName=table.db, 556 TableName=table.name, 557 PartitionsToDelete=[{"Values": v} for v in batch], 558 ) 559 560 def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None: 561 table = exp.to_table(table_name) 562 563 table_type = self._query_table_type(table) 564 565 # If Iceberg, DELETE operations work as expected 566 if table_type == "iceberg": 567 return super().delete_from(table, where) 568 569 # If Hive, DELETE is an error 570 if table_type == "hive": 571 # However, if there are no actual records to delete, we can make DELETE a no-op 572 # This simplifies a bunch of calling code that just assumes DELETE works (which to be fair is a reasonable assumption since it does for every other engine) 573 empty_check = ( 574 exp.select("*").from_(table).where(where).limit(1) 575 ) # deliberately not count(*) because we want the engine to stop as soon as it finds a record 576 if len(self.fetchall(empty_check)) > 0: 577 raise SQLMeshError("Cannot delete individual records from a Hive table") 578 579 return None 580 581 def _clear_s3_location(self, s3_uri: str) -> None: 582 s3 = self._s3_client 583 584 bucket, key = parse_s3_uri(s3_uri) 585 if not key.endswith("/"): 586 key = f"{key}/" 587 588 keys_to_delete = [] 589 590 # note: uses Delimiter=/ to prevent stepping into folders 591 # the assumption is that all the files in a partition live directly at the partition `Location` 592 for page in s3.get_paginator("list_objects_v2").paginate( 593 Bucket=bucket, Prefix=key, Delimiter="/" 594 ): 595 # list_objects_v2() returns 1000 keys per page so that lines up nicely with delete_objects() being able to delete 1000 keys at a time 596 keys = [item["Key"] for item in page.get("Contents", [])] 597 if keys: 598 keys_to_delete.append(keys) 599 600 for chunk in keys_to_delete: 601 s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}) 602 603 @property 604 def _glue_client(self) -> t.Any: 605 return self._boto3_client("glue") 606 607 @property 608 def _s3_client(self) -> t.Any: 609 return self._boto3_client("s3") 610 611 def _boto3_client(self, name: str) -> t.Any: 612 # use the client factory from PyAthena which is already configured with the correct AWS details 613 conn = self.connection 614 return conn.session.client( 615 name, 616 region_name=conn.region_name, 617 config=conn.config, 618 **conn._client_kwargs, 619 ) # type: ignore 620 621 def get_current_catalog(self) -> t.Optional[str]: 622 return self.connection.catalog_name
logger =
<Logger sqlmesh.core.engine_adapter.athena (WARNING)>
class
AthenaEngineAdapter(sqlmesh.core.engine_adapter.mixins.PandasNativeFetchDFSupportMixin, sqlmesh.core.engine_adapter.mixins.RowDiffMixin):
33class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin, RowDiffMixin): 34 DIALECT = "athena" 35 SUPPORTS_TRANSACTIONS = False 36 SUPPORTS_REPLACE_TABLE = False 37 # Athena's support for table and column comments is too patchy to consider "supported" 38 # Hive tables: Table + Column comments are supported 39 # Iceberg tables: Column comments only 40 # CTAS, Views: No comment support at all 41 COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED 42 COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED 43 SCHEMA_DIFFER_KWARGS = TrinoEngineAdapter.SCHEMA_DIFFER_KWARGS 44 MAX_TIMESTAMP_PRECISION = 3 # copied from Trino 45 # Athena does not deal with comments well, e.g: 46 # >>> self._execute('/* test */ DESCRIBE foo') 47 # pyathena.error.OperationalError: FAILED: ParseException line 1:0 cannot recognize input near '/' '*' 'test' 48 ATTACH_CORRELATION_ID = False 49 SUPPORTS_QUERY_EXECUTION_TRACKING = True 50 SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"] 51 52 def __init__( 53 self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any 54 ): 55 # Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config 56 # which means that EngineAdapter.with_settings() keeps this property when it makes a clone 57 super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs) 58 self.s3_warehouse_location = s3_warehouse_location 59 60 self._default_catalog = self._default_catalog or "awsdatacatalog" 61 62 @property 63 def s3_warehouse_location(self) -> t.Optional[str]: 64 return self._s3_warehouse_location 65 66 @s3_warehouse_location.setter 67 def s3_warehouse_location(self, value: t.Optional[str]) -> None: 68 if value: 69 value = validate_s3_uri(value, base=True) 70 self._s3_warehouse_location = value 71 72 @property 73 def s3_warehouse_location_or_raise(self) -> str: 74 # this makes tests easier to write without extra null checks to keep mypy happy 75 if location := self.s3_warehouse_location: 76 return location 77 78 raise SQLMeshError("s3_warehouse_location was expected to be populated; it isnt") 79 80 @property 81 def catalog_support(self) -> CatalogSupport: 82 # Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that 83 # It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog" 84 # are pointers to the "awsdatacatalog" of other AWS accounts 85 return CatalogSupport.SINGLE_CATALOG_ONLY 86 87 def create_state_table( 88 self, 89 table_name: str, 90 target_columns_to_types: t.Dict[str, exp.DataType], 91 primary_key: t.Optional[t.Tuple[str, ...]] = None, 92 ) -> None: 93 self.create_table( 94 table_name, 95 target_columns_to_types, 96 primary_key=primary_key, 97 # it's painfully slow, but it works 98 table_format="iceberg", 99 ) 100 101 def _get_data_objects( 102 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 103 ) -> t.List[DataObject]: 104 """ 105 Returns all the data objects that exist in the given schema and optionally catalog. 106 """ 107 schema_name = to_schema(schema_name) 108 schema = schema_name.db 109 query = ( 110 exp.select( 111 exp.column("table_catalog").as_("catalog"), 112 exp.column("table_schema", table="t").as_("schema"), 113 exp.column("table_name", table="t").as_("name"), 114 exp.case() 115 .when( 116 exp.column("table_type", table="t").eq("BASE TABLE"), 117 exp.Literal.string("table"), 118 ) 119 .else_(exp.column("table_type", table="t")) 120 .as_("type"), 121 ) 122 .from_(exp.to_table("information_schema.tables", alias="t")) 123 .where(exp.column("table_schema", table="t").eq(schema)) 124 ) 125 if object_names: 126 query = query.where(exp.column("table_name", table="t").isin(*object_names)) 127 128 df = self.fetchdf(query) 129 130 return [ 131 DataObject( 132 catalog=row.catalog, # type: ignore 133 schema=row.schema, # type: ignore 134 name=row.name, # type: ignore 135 type=DataObjectType.from_str(row.type), # type: ignore 136 ) 137 for row in df.itertuples() 138 ] 139 140 def columns( 141 self, table_name: TableName, include_pseudo_columns: bool = False 142 ) -> t.Dict[str, exp.DataType]: 143 table = exp.to_table(table_name) 144 # note: the data_type column contains the full parameterized type, eg 'varchar(10)' 145 query = ( 146 exp.select("column_name", "data_type") 147 .from_("information_schema.columns") 148 .where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name)) 149 .order_by("ordinal_position") 150 ) 151 result = self.fetchdf(query, quote_identifiers=True) 152 return { 153 str(r.column_name): exp.DataType.build(str(r.data_type)) 154 for r in result.itertuples(index=False) 155 } 156 157 def _create_schema( 158 self, 159 schema_name: SchemaName, 160 ignore_if_exists: bool, 161 warn_on_error: bool, 162 properties: t.List[exp.Expr], 163 kind: str, 164 ) -> None: 165 if location := self._table_location(table_properties=None, table=exp.to_table(schema_name)): 166 # don't add extra LocationProperty's if one already exists 167 if not any(p for p in properties if isinstance(p, exp.LocationProperty)): 168 properties.append(location) 169 170 return super()._create_schema( 171 schema_name=schema_name, 172 ignore_if_exists=ignore_if_exists, 173 warn_on_error=warn_on_error, 174 properties=properties, 175 kind=kind, 176 ) 177 178 def _build_create_table_exp( 179 self, 180 table_name_or_schema: t.Union[exp.Schema, TableName], 181 expression: t.Optional[exp.Expr], 182 exists: bool = True, 183 replace: bool = False, 184 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 185 table_description: t.Optional[str] = None, 186 table_kind: t.Optional[str] = None, 187 partitioned_by: t.Optional[t.List[exp.Expr]] = None, 188 table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, 189 **kwargs: t.Any, 190 ) -> exp.Create: 191 exists = False if replace else exists 192 193 table: exp.Table 194 if isinstance(table_name_or_schema, str): 195 table = exp.to_table(table_name_or_schema) 196 elif isinstance(table_name_or_schema, exp.Schema): 197 table = table_name_or_schema.this 198 else: 199 table = table_name_or_schema 200 201 properties = self._build_table_properties_exp( 202 table=table, 203 expression=expression, 204 target_columns_to_types=target_columns_to_types, 205 partitioned_by=partitioned_by, 206 table_properties=table_properties, 207 table_description=table_description, 208 table_kind=table_kind, 209 **kwargs, 210 ) 211 212 is_hive = self._table_type(kwargs.get("table_format", None)) == "hive" 213 214 # Filter any PARTITIONED BY properties from the main column list since they cant be specified in both places 215 # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html 216 if is_hive and partitioned_by and isinstance(table_name_or_schema, exp.Schema): 217 partitioned_by_column_names = {e.name for e in partitioned_by} 218 filtered_expressions = [ 219 e 220 for e in table_name_or_schema.expressions 221 if isinstance(e, exp.ColumnDef) and e.this.name not in partitioned_by_column_names 222 ] 223 table_name_or_schema.args["expressions"] = filtered_expressions 224 225 return exp.Create( 226 this=table_name_or_schema, 227 kind=table_kind or "TABLE", 228 replace=replace, 229 exists=exists, 230 expression=expression, 231 properties=properties, 232 ) 233 234 def _build_table_properties_exp( 235 self, 236 catalog_name: t.Optional[str] = None, 237 table_format: t.Optional[str] = None, 238 storage_format: t.Optional[str] = None, 239 partitioned_by: t.Optional[t.List[exp.Expr]] = None, 240 partition_interval_unit: t.Optional[IntervalUnit] = None, 241 clustered_by: t.Optional[t.List[exp.Expr]] = None, 242 table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, 243 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 244 table_description: t.Optional[str] = None, 245 table_kind: t.Optional[str] = None, 246 table: t.Optional[exp.Table] = None, 247 expression: t.Optional[exp.Expr] = None, 248 **kwargs: t.Any, 249 ) -> t.Optional[exp.Properties]: 250 properties: t.List[exp.Expr] = [] 251 table_properties = table_properties or {} 252 253 is_hive = self._table_type(table_format) == "hive" 254 is_iceberg = not is_hive 255 256 if is_hive and not expression: 257 # Hive tables are CREATE EXTERNAL TABLE, Iceberg tables are CREATE TABLE 258 # Unless it's a CTAS, those are always CREATE TABLE 259 properties.append(exp.ExternalProperty()) 260 261 if table_format: 262 properties.append( 263 exp.Property(this=exp.var("table_type"), value=exp.Literal.string(table_format)) 264 ) 265 266 if table_description: 267 properties.append(exp.SchemaCommentProperty(this=exp.Literal.string(table_description))) 268 269 if partitioned_by: 270 schema_expressions: t.List[exp.Expr] = [] 271 if is_hive and target_columns_to_types: 272 # For Hive-style tables, you cannot include the partitioned by columns in the main set of columns 273 # In the PARTITIONED BY expression, you also cant just include the column names, you need to include the data type as well 274 # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html 275 for match_name, match_dtype in self._find_matching_columns( 276 partitioned_by, target_columns_to_types 277 ): 278 column_def = exp.ColumnDef(this=exp.to_identifier(match_name), kind=match_dtype) 279 schema_expressions.append(column_def) 280 else: 281 schema_expressions = partitioned_by 282 283 properties.append( 284 exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions)) 285 ) 286 287 if clustered_by: 288 # Athena itself supports CLUSTERED BY, via the syntax CLUSTERED BY (col) INTO <n> BUCKETS 289 # However, SQLMesh is more closely aligned with BigQuery's notion of clustering and 290 # defines `clustered_by` as a List[str] with no way of indicating the number of buckets 291 # 292 # Athena's concept of CLUSTER BY is more like Iceberg's `bucket(<num_buckets>, col)` partition transform 293 logging.warning("clustered_by is not supported in the Athena adapter at this time") 294 295 if storage_format: 296 if is_iceberg: 297 # TBLPROPERTIES('format'='parquet') 298 table_properties["format"] = exp.Literal.string(storage_format) 299 else: 300 # STORED AS PARQUET 301 properties.append(exp.FileFormatProperty(this=storage_format)) 302 303 if table and (location := self._table_location_or_raise(table_properties, table)): 304 properties.append(location) 305 306 if is_iceberg and expression: 307 # To make a CTAS expression persist as iceberg, alongside setting `table_type=iceberg`, you also need to set is_external=false 308 # Note that SQLGlot does the right thing with LocationProperty and writes it as `location` (Iceberg) instead of `external_location` (Hive) 309 # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties 310 properties.append(exp.Property(this=exp.var("is_external"), value="false")) 311 312 for name, value in table_properties.items(): 313 properties.append(exp.Property(this=exp.var(name), value=value)) 314 315 if properties: 316 return exp.Properties(expressions=properties) 317 318 return None 319 320 def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None: 321 table = exp.to_table(table_name) 322 323 if self._query_table_type(table) == "hive": 324 self._truncate_table(table) 325 326 return super().drop_table(table_name=table, exists=exists, **kwargs) 327 328 def _truncate_table(self, table_name: TableName) -> None: 329 table = exp.to_table(table_name) 330 331 # Truncating an Iceberg table is just DELETE FROM <table> 332 if self._query_table_type(table) == "iceberg": 333 return self.delete_from(table, exp.true()) 334 335 # Truncating a partitioned Hive table is dropping all partitions and deleting the data from S3 336 if self._is_hive_partitioned_table(table): 337 self._clear_partition_data(table, exp.true()) 338 elif s3_location := self._query_table_s3_location(table): 339 # Truncating a non-partitioned Hive table is clearing out all data in its Location 340 self._clear_s3_location(s3_location) 341 342 def _table_type(self, table_format: t.Optional[str] = None) -> TableType: 343 """ 344 Interpret the "table_format" property to check if this is a Hive or an Iceberg table 345 """ 346 if table_format and table_format.lower() == "iceberg": 347 return "iceberg" 348 349 # if we cant detect any indication of Iceberg, this is a Hive table 350 return "hive" 351 352 def _query_table_type(self, table: exp.Table) -> t.Optional[TableType]: 353 if self.table_exists(table): 354 return self._query_table_type_or_raise(table) 355 return None 356 357 @lru_cache() 358 def _query_table_type_or_raise(self, table: exp.Table) -> TableType: 359 """ 360 Hit the DB to check if this is a Hive or an Iceberg table. 361 362 Note that in order to @lru_cache() this method, we have the following assumptions: 363 - The table must exist (otherwise we would cache None if this method was called before table creation and always return None after creation) 364 - The table type will not change within the same SQLMesh session 365 """ 366 # Note: SHOW TBLPROPERTIES gets parsed by SQLGlot as an exp.Command anyway so we just use a string here 367 # This also means we need to use dialect="hive" instead of dialect="athena" so that the identifiers get the correct quoting (backticks) 368 for row in self.fetchall(f"SHOW TBLPROPERTIES {table.sql(dialect='hive', identify=True)}"): 369 # This query returns a single column with values like 'EXTERNAL\tTRUE' 370 row_lower = row[0].lower() 371 if "external" in row_lower and "true" in row_lower: 372 return "hive" 373 return "iceberg" 374 375 def _is_hive_partitioned_table(self, table: exp.Table) -> bool: 376 try: 377 self._list_partitions(table=table, where=None, limit=1) 378 return True 379 except Exception as e: 380 if "TABLE_NOT_FOUND" in str(e): 381 return False 382 raise e 383 384 def _table_location_or_raise( 385 self, table_properties: t.Optional[t.Dict[str, exp.Expr]], table: exp.Table 386 ) -> exp.LocationProperty: 387 location = self._table_location(table_properties, table) 388 if not location: 389 raise SQLMeshError( 390 f"Cannot figure out location for table {table}. Please either set `s3_base_location` in `physical_properties` or set `s3_warehouse_location` in the Athena connection config" 391 ) 392 return location 393 394 def _table_location( 395 self, 396 table_properties: t.Optional[t.Dict[str, exp.Expr]], 397 table: exp.Table, 398 ) -> t.Optional[exp.LocationProperty]: 399 base_uri: str 400 401 # If the user has manually specified a `s3_base_location`, use it 402 if table_properties and "s3_base_location" in table_properties: 403 s3_base_location_property = table_properties.pop( 404 "s3_base_location" 405 ) # pop because it's handled differently and we dont want it to end up in the TBLPROPERTIES clause 406 if isinstance(s3_base_location_property, exp.Expr): 407 base_uri = s3_base_location_property.name 408 else: 409 base_uri = s3_base_location_property 410 411 elif self.s3_warehouse_location: 412 # If the user has set `s3_warehouse_location` in the connection config, the base URI is <s3_warehouse_location>/<catalog>/<schema>/ 413 base_uri = posixpath.join( 414 self.s3_warehouse_location, table.catalog or "", table.db or "" 415 ) 416 else: 417 return None 418 419 full_uri = validate_s3_uri(posixpath.join(base_uri, table.text("this") or ""), base=True) 420 return exp.LocationProperty(this=exp.Literal.string(full_uri)) 421 422 def _find_matching_columns( 423 self, partitioned_by: t.List[exp.Expr], columns_to_types: t.Dict[str, exp.DataType] 424 ) -> t.List[t.Tuple[str, exp.DataType]]: 425 matches = [] 426 for col in partitioned_by: 427 # TODO: do we care about normalization? 428 key = col.name 429 if isinstance(col, exp.Column) and (match_dtype := columns_to_types.get(key)): 430 matches.append((key, match_dtype)) 431 return matches 432 433 def replace_query( 434 self, 435 table_name: TableName, 436 query_or_df: QueryOrDF, 437 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 438 table_description: t.Optional[str] = None, 439 column_descriptions: t.Optional[t.Dict[str, str]] = None, 440 source_columns: t.Optional[t.List[str]] = None, 441 supports_replace_table_override: t.Optional[bool] = None, 442 **kwargs: t.Any, 443 ) -> None: 444 table = exp.to_table(table_name) 445 446 if self._query_table_type(table=table) == "hive": 447 self.drop_table(table) 448 449 return super().replace_query( 450 table_name=table, 451 query_or_df=query_or_df, 452 target_columns_to_types=target_columns_to_types, 453 table_description=table_description, 454 column_descriptions=column_descriptions, 455 source_columns=source_columns, 456 **kwargs, 457 ) 458 459 def _insert_overwrite_by_time_partition( 460 self, 461 table_name: TableName, 462 source_queries: t.List[SourceQuery], 463 target_columns_to_types: t.Dict[str, exp.DataType], 464 where: exp.Condition, 465 **kwargs: t.Any, 466 ) -> None: 467 table = exp.to_table(table_name) 468 469 table_type = self._query_table_type(table) 470 471 if table_type == "iceberg": 472 # Iceberg tables work as expected, we can use the default behaviour 473 return super()._insert_overwrite_by_time_partition( 474 table, source_queries, target_columns_to_types, where, **kwargs 475 ) 476 477 # For Hive tables, we need to drop all the partitions covered by the query and delete the data from S3 478 self._clear_partition_data(table, where) 479 480 # Now the data is physically gone, we can continue with inserting a new partition 481 return super()._insert_overwrite_by_time_partition( 482 table, 483 source_queries, 484 target_columns_to_types, 485 where, 486 insert_overwrite_strategy_override=InsertOverwriteStrategy.INTO_IS_OVERWRITE, # since we already cleared the data 487 **kwargs, 488 ) 489 490 def _clear_partition_data(self, table: exp.Table, where: t.Optional[exp.Condition]) -> None: 491 if partitions_to_drop := self._list_partitions(table, where): 492 for _, s3_location in partitions_to_drop: 493 logger.debug( 494 f"Clearing S3 location for '{table.sql(dialect=self.dialect)}': {s3_location}" 495 ) 496 self._clear_s3_location(s3_location) 497 498 partition_values = [k for k, _ in partitions_to_drop] 499 logger.debug( 500 f"Dropping partitions for '{table.sql(dialect=self.dialect)}' from metastore: {partition_values}" 501 ) 502 self._drop_partitions_from_metastore(table, partition_values) 503 504 def _list_partitions( 505 self, 506 table: exp.Table, 507 where: t.Optional[exp.Condition] = None, 508 limit: t.Optional[int] = None, 509 ) -> t.List[t.Tuple[t.List[str], str]]: 510 # Use Athena's magic "$partitions" metadata table to identify the partitions to drop 511 # Doing it this way allows us to use SQL to filter the partition list 512 partition_table_name = table.copy() 513 partition_table_name.this.replace( 514 exp.to_identifier(f"{table.name}$partitions", quoted=True) 515 ) 516 517 query = exp.select("*").from_(partition_table_name).where(where) 518 if limit: 519 query = query.limit(limit) 520 521 partition_values = [list(r) for r in self.fetchall(query, quote_identifiers=True)] 522 523 if partition_values: 524 response = self._glue_client.batch_get_partition( 525 DatabaseName=table.db, 526 TableName=table.name, 527 PartitionsToGet=[{"Values": [str(v) for v in lst]} for lst in partition_values], 528 ) 529 return sorted( 530 [(p["Values"], p["StorageDescriptor"]["Location"]) for p in response["Partitions"]] 531 ) 532 533 return [] 534 535 def _query_table_s3_location(self, table: exp.Table) -> str: 536 response = self._glue_client.get_table(DatabaseName=table.db, Name=table.name) 537 538 # Athena wont let you create a table without a location, so *theoretically* this should never be empty 539 if location := response.get("Table", {}).get("StorageDescriptor", {}).get("Location", None): 540 return location 541 542 raise SQLMeshError(f"Table {table} has no location set in the metastore!") 543 544 def _drop_partitions_from_metastore( 545 self, table: exp.Table, partition_values: t.List[t.List[str]] 546 ) -> None: 547 # todo: switch to itertools.batched when our minimum supported Python is 3.12 548 # 25 = maximum number of partitions that batch_delete_partition can process at once 549 # ref: https://docs.aws.amazon.com/glue/latest/webapi/API_BatchDeletePartition.html#API_BatchDeletePartition_RequestParameters 550 def _chunks() -> t.Iterable[t.List[t.List[str]]]: 551 for i in range(0, len(partition_values), 25): 552 yield partition_values[i : i + 25] 553 554 for batch in _chunks(): 555 self._glue_client.batch_delete_partition( 556 DatabaseName=table.db, 557 TableName=table.name, 558 PartitionsToDelete=[{"Values": v} for v in batch], 559 ) 560 561 def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None: 562 table = exp.to_table(table_name) 563 564 table_type = self._query_table_type(table) 565 566 # If Iceberg, DELETE operations work as expected 567 if table_type == "iceberg": 568 return super().delete_from(table, where) 569 570 # If Hive, DELETE is an error 571 if table_type == "hive": 572 # However, if there are no actual records to delete, we can make DELETE a no-op 573 # This simplifies a bunch of calling code that just assumes DELETE works (which to be fair is a reasonable assumption since it does for every other engine) 574 empty_check = ( 575 exp.select("*").from_(table).where(where).limit(1) 576 ) # deliberately not count(*) because we want the engine to stop as soon as it finds a record 577 if len(self.fetchall(empty_check)) > 0: 578 raise SQLMeshError("Cannot delete individual records from a Hive table") 579 580 return None 581 582 def _clear_s3_location(self, s3_uri: str) -> None: 583 s3 = self._s3_client 584 585 bucket, key = parse_s3_uri(s3_uri) 586 if not key.endswith("/"): 587 key = f"{key}/" 588 589 keys_to_delete = [] 590 591 # note: uses Delimiter=/ to prevent stepping into folders 592 # the assumption is that all the files in a partition live directly at the partition `Location` 593 for page in s3.get_paginator("list_objects_v2").paginate( 594 Bucket=bucket, Prefix=key, Delimiter="/" 595 ): 596 # list_objects_v2() returns 1000 keys per page so that lines up nicely with delete_objects() being able to delete 1000 keys at a time 597 keys = [item["Key"] for item in page.get("Contents", [])] 598 if keys: 599 keys_to_delete.append(keys) 600 601 for chunk in keys_to_delete: 602 s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}) 603 604 @property 605 def _glue_client(self) -> t.Any: 606 return self._boto3_client("glue") 607 608 @property 609 def _s3_client(self) -> t.Any: 610 return self._boto3_client("s3") 611 612 def _boto3_client(self, name: str) -> t.Any: 613 # use the client factory from PyAthena which is already configured with the correct AWS details 614 conn = self.connection 615 return conn.session.client( 616 name, 617 region_name=conn.region_name, 618 config=conn.config, 619 **conn._client_kwargs, 620 ) # type: ignore 621 622 def get_current_catalog(self) -> t.Optional[str]: 623 return self.connection.catalog_name
Base class wrapping a Database API compliant connection.
The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.
Arguments:
- connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
- dialect: The dialect with which this adapter is associated.
- multithreaded: Indicates whether this adapter will be used by more than one thread.
AthenaEngineAdapter( *args: Any, s3_warehouse_location: Optional[str] = None, **kwargs: Any)
52 def __init__( 53 self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any 54 ): 55 # Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config 56 # which means that EngineAdapter.with_settings() keeps this property when it makes a clone 57 super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs) 58 self.s3_warehouse_location = s3_warehouse_location 59 60 self._default_catalog = self._default_catalog or "awsdatacatalog"
SCHEMA_DIFFER_KWARGS =
{'parameterized_type_defaults': {<DType.DECIMAL: 'DECIMAL'>: [(), (0,)], <DType.CHAR: 'CHAR'>: [(1,)], <DType.TIMESTAMP: 'TIMESTAMP'>: [(3,)]}}
s3_warehouse_location_or_raise: str
72 @property 73 def s3_warehouse_location_or_raise(self) -> str: 74 # this makes tests easier to write without extra null checks to keep mypy happy 75 if location := self.s3_warehouse_location: 76 return location 77 78 raise SQLMeshError("s3_warehouse_location was expected to be populated; it isnt")
catalog_support: sqlmesh.core.engine_adapter.shared.CatalogSupport
80 @property 81 def catalog_support(self) -> CatalogSupport: 82 # Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that 83 # It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog" 84 # are pointers to the "awsdatacatalog" of other AWS accounts 85 return CatalogSupport.SINGLE_CATALOG_ONLY
def
create_state_table( self, table_name: str, target_columns_to_types: Dict[str, sqlglot.expressions.datatypes.DataType], primary_key: Optional[Tuple[str, ...]] = None) -> None:
87 def create_state_table( 88 self, 89 table_name: str, 90 target_columns_to_types: t.Dict[str, exp.DataType], 91 primary_key: t.Optional[t.Tuple[str, ...]] = None, 92 ) -> None: 93 self.create_table( 94 table_name, 95 target_columns_to_types, 96 primary_key=primary_key, 97 # it's painfully slow, but it works 98 table_format="iceberg", 99 )
Create a table to store SQLMesh internal state.
Arguments:
- table_name: The name of the table to create. Can be fully qualified or just table name.
- target_columns_to_types: A mapping between the column name and its data type.
- primary_key: Determines the table primary key.
def
columns( self, table_name: Union[str, sqlglot.expressions.query.Table], include_pseudo_columns: bool = False) -> Dict[str, sqlglot.expressions.datatypes.DataType]:
140 def columns( 141 self, table_name: TableName, include_pseudo_columns: bool = False 142 ) -> t.Dict[str, exp.DataType]: 143 table = exp.to_table(table_name) 144 # note: the data_type column contains the full parameterized type, eg 'varchar(10)' 145 query = ( 146 exp.select("column_name", "data_type") 147 .from_("information_schema.columns") 148 .where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name)) 149 .order_by("ordinal_position") 150 ) 151 result = self.fetchdf(query, quote_identifiers=True) 152 return { 153 str(r.column_name): exp.DataType.build(str(r.data_type)) 154 for r in result.itertuples(index=False) 155 }
Fetches column names and types for the target table.
def
drop_table( self, table_name: Union[str, sqlglot.expressions.query.Table], exists: bool = True, **kwargs: Any) -> None:
320 def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None: 321 table = exp.to_table(table_name) 322 323 if self._query_table_type(table) == "hive": 324 self._truncate_table(table) 325 326 return super().drop_table(table_name=table, exists=exists, **kwargs)
Drops a table.
Arguments:
- table_name: The name of the table to drop.
- exists: If exists, defaults to True.
def
replace_query( self, table_name: Union[str, sqlglot.expressions.query.Table], query_or_df: <MagicMock id='132726908808128'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]] = None, table_description: Optional[str] = None, column_descriptions: Optional[Dict[str, str]] = None, source_columns: Optional[List[str]] = None, supports_replace_table_override: Optional[bool] = None, **kwargs: Any) -> None:
433 def replace_query( 434 self, 435 table_name: TableName, 436 query_or_df: QueryOrDF, 437 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 438 table_description: t.Optional[str] = None, 439 column_descriptions: t.Optional[t.Dict[str, str]] = None, 440 source_columns: t.Optional[t.List[str]] = None, 441 supports_replace_table_override: t.Optional[bool] = None, 442 **kwargs: t.Any, 443 ) -> None: 444 table = exp.to_table(table_name) 445 446 if self._query_table_type(table=table) == "hive": 447 self.drop_table(table) 448 449 return super().replace_query( 450 table_name=table, 451 query_or_df=query_or_df, 452 target_columns_to_types=target_columns_to_types, 453 table_description=table_description, 454 column_descriptions=column_descriptions, 455 source_columns=source_columns, 456 **kwargs, 457 )
Replaces an existing table with a query.
For partition based engines (hive, spark), insert override is used. For other systems, create or replace is used.
Arguments:
- table_name: The name of the table (eg. prod.table)
- query_or_df: The SQL query to run or a dataframe.
- target_columns_to_types: Only used if a dataframe is provided. A mapping between the column name and its data type. Expected to be ordered to match the order of values in the dataframe.
- kwargs: Optional create table properties.
def
delete_from( self, table_name: Union[str, sqlglot.expressions.query.Table], where: Union[str, sqlglot.expressions.core.Expr]) -> None:
561 def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None: 562 table = exp.to_table(table_name) 563 564 table_type = self._query_table_type(table) 565 566 # If Iceberg, DELETE operations work as expected 567 if table_type == "iceberg": 568 return super().delete_from(table, where) 569 570 # If Hive, DELETE is an error 571 if table_type == "hive": 572 # However, if there are no actual records to delete, we can make DELETE a no-op 573 # This simplifies a bunch of calling code that just assumes DELETE works (which to be fair is a reasonable assumption since it does for every other engine) 574 empty_check = ( 575 exp.select("*").from_(table).where(where).limit(1) 576 ) # deliberately not count(*) because we want the engine to stop as soon as it finds a record 577 if len(self.fetchall(empty_check)) > 0: 578 raise SQLMeshError("Cannot delete individual records from a Hive table") 579 580 return None
Inherited Members
- sqlmesh.core.engine_adapter.base.EngineAdapter
- DEFAULT_BATCH_SIZE
- DATA_OBJECT_FILTER_BATCH_SIZE
- SUPPORTS_INDEXES
- MAX_TABLE_COMMENT_LENGTH
- MAX_COLUMN_COMMENT_LENGTH
- INSERT_OVERWRITE_STRATEGY
- SUPPORTS_MATERIALIZED_VIEWS
- SUPPORTS_MATERIALIZED_VIEW_SCHEMA
- SUPPORTS_VIEW_SCHEMA
- SUPPORTS_CLONING
- SUPPORTS_MANAGED_MODELS
- SUPPORTS_CREATE_DROP_CATALOG
- SUPPORTS_TUPLE_IN
- HAS_VIEW_BINDING
- SUPPORTS_GRANTS
- DEFAULT_CATALOG_TYPE
- QUOTE_IDENTIFIERS_IN_VIEWS
- MAX_IDENTIFIER_LENGTH
- SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
- dialect
- correlation_id
- with_settings
- cursor
- connection
- spark
- snowpark
- bigframe
- comments_enabled
- schema_differ
- default_catalog
- engine_run_mode
- recycle
- close
- set_current_catalog
- get_catalog_type
- get_catalog_type_from_table
- current_catalog_type
- create_index
- create_table
- create_managed_table
- ctas
- create_table_like
- clone_table
- drop_data_object
- drop_managed_table
- get_alter_operations
- alter_table
- create_view
- create_schema
- drop_schema
- drop_view
- create_catalog
- drop_catalog
- table_exists
- insert_append
- insert_overwrite_by_partition
- insert_overwrite_by_time_partition
- update_table
- scd_type_2_by_time
- scd_type_2_by_column
- merge
- rename_table
- get_data_object
- get_data_objects
- fetchone
- fetchall
- fetchdf
- fetch_pyspark_df
- wap_enabled
- wap_supported
- wap_table_name
- wap_prepare
- wap_publish
- sync_grants_config
- transaction
- session
- execute
- temp_table
- drop_data_object_on_type_mismatch
- ensure_nulls_for_unmatched_after_join
- use_server_nulls_for_unmatched_after_join
- ping
- get_table_last_modified_ts