Edit on GitHub

sqlmesh.core.engine_adapter.athena

  1from __future__ import annotations
  2from functools import lru_cache
  3import typing as t
  4import logging
  5from sqlglot import exp
  6from sqlmesh.core.dialect import to_schema
  7from sqlmesh.utils.aws import validate_s3_uri, parse_s3_uri
  8from sqlmesh.core.engine_adapter.mixins import PandasNativeFetchDFSupportMixin, RowDiffMixin
  9from sqlmesh.core.engine_adapter.trino import TrinoEngineAdapter
 10from sqlmesh.core.node import IntervalUnit
 11import posixpath
 12from sqlmesh.utils.errors import SQLMeshError
 13from sqlmesh.core.engine_adapter.shared import (
 14    CatalogSupport,
 15    DataObject,
 16    DataObjectType,
 17    CommentCreationTable,
 18    CommentCreationView,
 19    SourceQuery,
 20    InsertOverwriteStrategy,
 21)
 22
 23if t.TYPE_CHECKING:
 24    from sqlmesh.core._typing import SchemaName, TableName
 25    from sqlmesh.core.engine_adapter._typing import QueryOrDF
 26
 27    TableType = t.Union[t.Literal["hive"], t.Literal["iceberg"]]
 28
 29logger = logging.getLogger(__name__)
 30
 31
 32class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin, RowDiffMixin):
 33    DIALECT = "athena"
 34    SUPPORTS_TRANSACTIONS = False
 35    SUPPORTS_REPLACE_TABLE = False
 36    # Athena's support for table and column comments is too patchy to consider "supported"
 37    # Hive tables: Table + Column comments are supported
 38    # Iceberg tables: Column comments only
 39    # CTAS, Views: No comment support at all
 40    COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED
 41    COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
 42    SCHEMA_DIFFER_KWARGS = TrinoEngineAdapter.SCHEMA_DIFFER_KWARGS
 43    MAX_TIMESTAMP_PRECISION = 3  # copied from Trino
 44    # Athena does not deal with comments well, e.g:
 45    # >>> self._execute('/* test */ DESCRIBE foo')
 46    #     pyathena.error.OperationalError: FAILED: ParseException line 1:0 cannot recognize input near '/' '*' 'test'
 47    ATTACH_CORRELATION_ID = False
 48    SUPPORTS_QUERY_EXECUTION_TRACKING = True
 49    SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"]
 50
 51    def __init__(
 52        self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any
 53    ):
 54        # Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config
 55        # which means that EngineAdapter.with_settings() keeps this property when it makes a clone
 56        super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs)
 57        self.s3_warehouse_location = s3_warehouse_location
 58
 59        self._default_catalog = self._default_catalog or "awsdatacatalog"
 60
 61    @property
 62    def s3_warehouse_location(self) -> t.Optional[str]:
 63        return self._s3_warehouse_location
 64
 65    @s3_warehouse_location.setter
 66    def s3_warehouse_location(self, value: t.Optional[str]) -> None:
 67        if value:
 68            value = validate_s3_uri(value, base=True)
 69        self._s3_warehouse_location = value
 70
 71    @property
 72    def s3_warehouse_location_or_raise(self) -> str:
 73        # this makes tests easier to write without extra null checks to keep mypy happy
 74        if location := self.s3_warehouse_location:
 75            return location
 76
 77        raise SQLMeshError("s3_warehouse_location was expected to be populated; it isnt")
 78
 79    @property
 80    def catalog_support(self) -> CatalogSupport:
 81        # Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that
 82        # It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog"
 83        # are pointers to the "awsdatacatalog" of other AWS accounts
 84        return CatalogSupport.SINGLE_CATALOG_ONLY
 85
 86    def create_state_table(
 87        self,
 88        table_name: str,
 89        target_columns_to_types: t.Dict[str, exp.DataType],
 90        primary_key: t.Optional[t.Tuple[str, ...]] = None,
 91    ) -> None:
 92        self.create_table(
 93            table_name,
 94            target_columns_to_types,
 95            primary_key=primary_key,
 96            # it's painfully slow, but it works
 97            table_format="iceberg",
 98        )
 99
100    def _get_data_objects(
101        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
102    ) -> t.List[DataObject]:
103        """
104        Returns all the data objects that exist in the given schema and optionally catalog.
105        """
106        schema_name = to_schema(schema_name)
107        schema = schema_name.db
108        query = (
109            exp.select(
110                exp.column("table_catalog").as_("catalog"),
111                exp.column("table_schema", table="t").as_("schema"),
112                exp.column("table_name", table="t").as_("name"),
113                exp.case()
114                .when(
115                    exp.column("table_type", table="t").eq("BASE TABLE"),
116                    exp.Literal.string("table"),
117                )
118                .else_(exp.column("table_type", table="t"))
119                .as_("type"),
120            )
121            .from_(exp.to_table("information_schema.tables", alias="t"))
122            .where(exp.column("table_schema", table="t").eq(schema))
123        )
124        if object_names:
125            query = query.where(exp.column("table_name", table="t").isin(*object_names))
126
127        df = self.fetchdf(query)
128
129        return [
130            DataObject(
131                catalog=row.catalog,  # type: ignore
132                schema=row.schema,  # type: ignore
133                name=row.name,  # type: ignore
134                type=DataObjectType.from_str(row.type),  # type: ignore
135            )
136            for row in df.itertuples()
137        ]
138
139    def columns(
140        self, table_name: TableName, include_pseudo_columns: bool = False
141    ) -> t.Dict[str, exp.DataType]:
142        table = exp.to_table(table_name)
143        # note: the data_type column contains the full parameterized type, eg 'varchar(10)'
144        query = (
145            exp.select("column_name", "data_type")
146            .from_("information_schema.columns")
147            .where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name))
148            .order_by("ordinal_position")
149        )
150        result = self.fetchdf(query, quote_identifiers=True)
151        return {
152            str(r.column_name): exp.DataType.build(str(r.data_type))
153            for r in result.itertuples(index=False)
154        }
155
156    def _create_schema(
157        self,
158        schema_name: SchemaName,
159        ignore_if_exists: bool,
160        warn_on_error: bool,
161        properties: t.List[exp.Expr],
162        kind: str,
163    ) -> None:
164        if location := self._table_location(table_properties=None, table=exp.to_table(schema_name)):
165            # don't add extra LocationProperty's if one already exists
166            if not any(p for p in properties if isinstance(p, exp.LocationProperty)):
167                properties.append(location)
168
169        return super()._create_schema(
170            schema_name=schema_name,
171            ignore_if_exists=ignore_if_exists,
172            warn_on_error=warn_on_error,
173            properties=properties,
174            kind=kind,
175        )
176
177    def _build_create_table_exp(
178        self,
179        table_name_or_schema: t.Union[exp.Schema, TableName],
180        expression: t.Optional[exp.Expr],
181        exists: bool = True,
182        replace: bool = False,
183        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
184        table_description: t.Optional[str] = None,
185        table_kind: t.Optional[str] = None,
186        partitioned_by: t.Optional[t.List[exp.Expr]] = None,
187        table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
188        **kwargs: t.Any,
189    ) -> exp.Create:
190        exists = False if replace else exists
191
192        table: exp.Table
193        if isinstance(table_name_or_schema, str):
194            table = exp.to_table(table_name_or_schema)
195        elif isinstance(table_name_or_schema, exp.Schema):
196            table = table_name_or_schema.this
197        else:
198            table = table_name_or_schema
199
200        properties = self._build_table_properties_exp(
201            table=table,
202            expression=expression,
203            target_columns_to_types=target_columns_to_types,
204            partitioned_by=partitioned_by,
205            table_properties=table_properties,
206            table_description=table_description,
207            table_kind=table_kind,
208            **kwargs,
209        )
210
211        is_hive = self._table_type(kwargs.get("table_format", None)) == "hive"
212
213        # Filter any PARTITIONED BY properties from the main column list since they cant be specified in both places
214        # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html
215        if is_hive and partitioned_by and isinstance(table_name_or_schema, exp.Schema):
216            partitioned_by_column_names = {e.name for e in partitioned_by}
217            filtered_expressions = [
218                e
219                for e in table_name_or_schema.expressions
220                if isinstance(e, exp.ColumnDef) and e.this.name not in partitioned_by_column_names
221            ]
222            table_name_or_schema.args["expressions"] = filtered_expressions
223
224        return exp.Create(
225            this=table_name_or_schema,
226            kind=table_kind or "TABLE",
227            replace=replace,
228            exists=exists,
229            expression=expression,
230            properties=properties,
231        )
232
233    def _build_table_properties_exp(
234        self,
235        catalog_name: t.Optional[str] = None,
236        table_format: t.Optional[str] = None,
237        storage_format: t.Optional[str] = None,
238        partitioned_by: t.Optional[t.List[exp.Expr]] = None,
239        partition_interval_unit: t.Optional[IntervalUnit] = None,
240        clustered_by: t.Optional[t.List[exp.Expr]] = None,
241        table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
242        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
243        table_description: t.Optional[str] = None,
244        table_kind: t.Optional[str] = None,
245        table: t.Optional[exp.Table] = None,
246        expression: t.Optional[exp.Expr] = None,
247        **kwargs: t.Any,
248    ) -> t.Optional[exp.Properties]:
249        properties: t.List[exp.Expr] = []
250        table_properties = table_properties or {}
251
252        is_hive = self._table_type(table_format) == "hive"
253        is_iceberg = not is_hive
254
255        if is_hive and not expression:
256            # Hive tables are CREATE EXTERNAL TABLE, Iceberg tables are CREATE TABLE
257            # Unless it's a CTAS, those are always CREATE TABLE
258            properties.append(exp.ExternalProperty())
259
260        if table_format:
261            properties.append(
262                exp.Property(this=exp.var("table_type"), value=exp.Literal.string(table_format))
263            )
264
265        if table_description:
266            properties.append(exp.SchemaCommentProperty(this=exp.Literal.string(table_description)))
267
268        if partitioned_by:
269            schema_expressions: t.List[exp.Expr] = []
270            if is_hive and target_columns_to_types:
271                # For Hive-style tables, you cannot include the partitioned by columns in the main set of columns
272                # In the PARTITIONED BY expression, you also cant just include the column names, you need to include the data type as well
273                # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html
274                for match_name, match_dtype in self._find_matching_columns(
275                    partitioned_by, target_columns_to_types
276                ):
277                    column_def = exp.ColumnDef(this=exp.to_identifier(match_name), kind=match_dtype)
278                    schema_expressions.append(column_def)
279            else:
280                schema_expressions = partitioned_by
281
282            properties.append(
283                exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions))
284            )
285
286        if clustered_by:
287            # Athena itself supports CLUSTERED BY, via the syntax CLUSTERED BY (col) INTO <n> BUCKETS
288            # However, SQLMesh is more closely aligned with BigQuery's notion of clustering and
289            # defines `clustered_by` as a List[str] with no way of indicating the number of buckets
290            #
291            # Athena's concept of CLUSTER BY is more like Iceberg's `bucket(<num_buckets>, col)` partition transform
292            logging.warning("clustered_by is not supported in the Athena adapter at this time")
293
294        if storage_format:
295            if is_iceberg:
296                # TBLPROPERTIES('format'='parquet')
297                table_properties["format"] = exp.Literal.string(storage_format)
298            else:
299                # STORED AS PARQUET
300                properties.append(exp.FileFormatProperty(this=storage_format))
301
302        if table and (location := self._table_location_or_raise(table_properties, table)):
303            properties.append(location)
304
305            if is_iceberg and expression:
306                # To make a CTAS expression persist as iceberg, alongside setting `table_type=iceberg`, you also need to set is_external=false
307                # Note that SQLGlot does the right thing with LocationProperty and writes it as `location` (Iceberg) instead of `external_location` (Hive)
308                # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties
309                properties.append(exp.Property(this=exp.var("is_external"), value="false"))
310
311        for name, value in table_properties.items():
312            properties.append(exp.Property(this=exp.var(name), value=value))
313
314        if properties:
315            return exp.Properties(expressions=properties)
316
317        return None
318
319    def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None:
320        table = exp.to_table(table_name)
321
322        if self._query_table_type(table) == "hive":
323            self._truncate_table(table)
324
325        return super().drop_table(table_name=table, exists=exists, **kwargs)
326
327    def _truncate_table(self, table_name: TableName) -> None:
328        table = exp.to_table(table_name)
329
330        # Truncating an Iceberg table is just DELETE FROM <table>
331        if self._query_table_type(table) == "iceberg":
332            return self.delete_from(table, exp.true())
333
334        # Truncating a partitioned Hive table is dropping all partitions and deleting the data from S3
335        if self._is_hive_partitioned_table(table):
336            self._clear_partition_data(table, exp.true())
337        elif s3_location := self._query_table_s3_location(table):
338            # Truncating a non-partitioned Hive table is clearing out all data in its Location
339            self._clear_s3_location(s3_location)
340
341    def _table_type(self, table_format: t.Optional[str] = None) -> TableType:
342        """
343        Interpret the "table_format" property to check if this is a Hive or an Iceberg table
344        """
345        if table_format and table_format.lower() == "iceberg":
346            return "iceberg"
347
348        # if we cant detect any indication of Iceberg, this is a Hive table
349        return "hive"
350
351    def _query_table_type(self, table: exp.Table) -> t.Optional[TableType]:
352        if self.table_exists(table):
353            return self._query_table_type_or_raise(table)
354        return None
355
356    @lru_cache()
357    def _query_table_type_or_raise(self, table: exp.Table) -> TableType:
358        """
359        Hit the DB to check if this is a Hive or an Iceberg table.
360
361        Note that in order to @lru_cache() this method, we have the following assumptions:
362         - The table must exist (otherwise we would cache None if this method was called before table creation and always return None after creation)
363         - The table type will not change within the same SQLMesh session
364        """
365        # Note: SHOW TBLPROPERTIES gets parsed by SQLGlot as an exp.Command anyway so we just use a string here
366        # This also means we need to use dialect="hive" instead of dialect="athena" so that the identifiers get the correct quoting (backticks)
367        for row in self.fetchall(f"SHOW TBLPROPERTIES {table.sql(dialect='hive', identify=True)}"):
368            # This query returns a single column with values like 'EXTERNAL\tTRUE'
369            row_lower = row[0].lower()
370            if "external" in row_lower and "true" in row_lower:
371                return "hive"
372        return "iceberg"
373
374    def _is_hive_partitioned_table(self, table: exp.Table) -> bool:
375        try:
376            self._list_partitions(table=table, where=None, limit=1)
377            return True
378        except Exception as e:
379            if "TABLE_NOT_FOUND" in str(e):
380                return False
381            raise e
382
383    def _table_location_or_raise(
384        self, table_properties: t.Optional[t.Dict[str, exp.Expr]], table: exp.Table
385    ) -> exp.LocationProperty:
386        location = self._table_location(table_properties, table)
387        if not location:
388            raise SQLMeshError(
389                f"Cannot figure out location for table {table}. Please either set `s3_base_location` in `physical_properties` or set `s3_warehouse_location` in the Athena connection config"
390            )
391        return location
392
393    def _table_location(
394        self,
395        table_properties: t.Optional[t.Dict[str, exp.Expr]],
396        table: exp.Table,
397    ) -> t.Optional[exp.LocationProperty]:
398        base_uri: str
399
400        # If the user has manually specified a `s3_base_location`, use it
401        if table_properties and "s3_base_location" in table_properties:
402            s3_base_location_property = table_properties.pop(
403                "s3_base_location"
404            )  # pop because it's handled differently and we dont want it to end up in the TBLPROPERTIES clause
405            if isinstance(s3_base_location_property, exp.Expr):
406                base_uri = s3_base_location_property.name
407            else:
408                base_uri = s3_base_location_property
409
410        elif self.s3_warehouse_location:
411            # If the user has set `s3_warehouse_location` in the connection config, the base URI is <s3_warehouse_location>/<catalog>/<schema>/
412            base_uri = posixpath.join(
413                self.s3_warehouse_location, table.catalog or "", table.db or ""
414            )
415        else:
416            return None
417
418        full_uri = validate_s3_uri(posixpath.join(base_uri, table.text("this") or ""), base=True)
419        return exp.LocationProperty(this=exp.Literal.string(full_uri))
420
421    def _find_matching_columns(
422        self, partitioned_by: t.List[exp.Expr], columns_to_types: t.Dict[str, exp.DataType]
423    ) -> t.List[t.Tuple[str, exp.DataType]]:
424        matches = []
425        for col in partitioned_by:
426            # TODO: do we care about normalization?
427            key = col.name
428            if isinstance(col, exp.Column) and (match_dtype := columns_to_types.get(key)):
429                matches.append((key, match_dtype))
430        return matches
431
432    def replace_query(
433        self,
434        table_name: TableName,
435        query_or_df: QueryOrDF,
436        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
437        table_description: t.Optional[str] = None,
438        column_descriptions: t.Optional[t.Dict[str, str]] = None,
439        source_columns: t.Optional[t.List[str]] = None,
440        supports_replace_table_override: t.Optional[bool] = None,
441        **kwargs: t.Any,
442    ) -> None:
443        table = exp.to_table(table_name)
444
445        if self._query_table_type(table=table) == "hive":
446            self.drop_table(table)
447
448        return super().replace_query(
449            table_name=table,
450            query_or_df=query_or_df,
451            target_columns_to_types=target_columns_to_types,
452            table_description=table_description,
453            column_descriptions=column_descriptions,
454            source_columns=source_columns,
455            **kwargs,
456        )
457
458    def _insert_overwrite_by_time_partition(
459        self,
460        table_name: TableName,
461        source_queries: t.List[SourceQuery],
462        target_columns_to_types: t.Dict[str, exp.DataType],
463        where: exp.Condition,
464        **kwargs: t.Any,
465    ) -> None:
466        table = exp.to_table(table_name)
467
468        table_type = self._query_table_type(table)
469
470        if table_type == "iceberg":
471            # Iceberg tables work as expected, we can use the default behaviour
472            return super()._insert_overwrite_by_time_partition(
473                table, source_queries, target_columns_to_types, where, **kwargs
474            )
475
476        # For Hive tables, we need to drop all the partitions covered by the query and delete the data from S3
477        self._clear_partition_data(table, where)
478
479        # Now the data is physically gone, we can continue with inserting a new partition
480        return super()._insert_overwrite_by_time_partition(
481            table,
482            source_queries,
483            target_columns_to_types,
484            where,
485            insert_overwrite_strategy_override=InsertOverwriteStrategy.INTO_IS_OVERWRITE,  # since we already cleared the data
486            **kwargs,
487        )
488
489    def _clear_partition_data(self, table: exp.Table, where: t.Optional[exp.Condition]) -> None:
490        if partitions_to_drop := self._list_partitions(table, where):
491            for _, s3_location in partitions_to_drop:
492                logger.debug(
493                    f"Clearing S3 location for '{table.sql(dialect=self.dialect)}': {s3_location}"
494                )
495                self._clear_s3_location(s3_location)
496
497            partition_values = [k for k, _ in partitions_to_drop]
498            logger.debug(
499                f"Dropping partitions for '{table.sql(dialect=self.dialect)}' from metastore: {partition_values}"
500            )
501            self._drop_partitions_from_metastore(table, partition_values)
502
503    def _list_partitions(
504        self,
505        table: exp.Table,
506        where: t.Optional[exp.Condition] = None,
507        limit: t.Optional[int] = None,
508    ) -> t.List[t.Tuple[t.List[str], str]]:
509        # Use Athena's magic "$partitions" metadata table to identify the partitions to drop
510        # Doing it this way allows us to use SQL to filter the partition list
511        partition_table_name = table.copy()
512        partition_table_name.this.replace(
513            exp.to_identifier(f"{table.name}$partitions", quoted=True)
514        )
515
516        query = exp.select("*").from_(partition_table_name).where(where)
517        if limit:
518            query = query.limit(limit)
519
520        partition_values = [list(r) for r in self.fetchall(query, quote_identifiers=True)]
521
522        if partition_values:
523            response = self._glue_client.batch_get_partition(
524                DatabaseName=table.db,
525                TableName=table.name,
526                PartitionsToGet=[{"Values": [str(v) for v in lst]} for lst in partition_values],
527            )
528            return sorted(
529                [(p["Values"], p["StorageDescriptor"]["Location"]) for p in response["Partitions"]]
530            )
531
532        return []
533
534    def _query_table_s3_location(self, table: exp.Table) -> str:
535        response = self._glue_client.get_table(DatabaseName=table.db, Name=table.name)
536
537        # Athena wont let you create a table without a location, so *theoretically* this should never be empty
538        if location := response.get("Table", {}).get("StorageDescriptor", {}).get("Location", None):
539            return location
540
541        raise SQLMeshError(f"Table {table} has no location set in the metastore!")
542
543    def _drop_partitions_from_metastore(
544        self, table: exp.Table, partition_values: t.List[t.List[str]]
545    ) -> None:
546        # todo: switch to itertools.batched when our minimum supported Python is 3.12
547        # 25 = maximum number of partitions that batch_delete_partition can process at once
548        # ref: https://docs.aws.amazon.com/glue/latest/webapi/API_BatchDeletePartition.html#API_BatchDeletePartition_RequestParameters
549        def _chunks() -> t.Iterable[t.List[t.List[str]]]:
550            for i in range(0, len(partition_values), 25):
551                yield partition_values[i : i + 25]
552
553        for batch in _chunks():
554            self._glue_client.batch_delete_partition(
555                DatabaseName=table.db,
556                TableName=table.name,
557                PartitionsToDelete=[{"Values": v} for v in batch],
558            )
559
560    def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None:
561        table = exp.to_table(table_name)
562
563        table_type = self._query_table_type(table)
564
565        # If Iceberg, DELETE operations work as expected
566        if table_type == "iceberg":
567            return super().delete_from(table, where)
568
569        # If Hive, DELETE is an error
570        if table_type == "hive":
571            # However, if there are no actual records to delete, we can make DELETE a no-op
572            # This simplifies a bunch of calling code that just assumes DELETE works (which to be fair is a reasonable assumption since it does for every other engine)
573            empty_check = (
574                exp.select("*").from_(table).where(where).limit(1)
575            )  # deliberately not count(*) because we want the engine to stop as soon as it finds a record
576            if len(self.fetchall(empty_check)) > 0:
577                raise SQLMeshError("Cannot delete individual records from a Hive table")
578
579        return None
580
581    def _clear_s3_location(self, s3_uri: str) -> None:
582        s3 = self._s3_client
583
584        bucket, key = parse_s3_uri(s3_uri)
585        if not key.endswith("/"):
586            key = f"{key}/"
587
588        keys_to_delete = []
589
590        # note: uses Delimiter=/ to prevent stepping into folders
591        # the assumption is that all the files in a partition live directly at the partition `Location`
592        for page in s3.get_paginator("list_objects_v2").paginate(
593            Bucket=bucket, Prefix=key, Delimiter="/"
594        ):
595            # list_objects_v2() returns 1000 keys per page so that lines up nicely with delete_objects() being able to delete 1000 keys at a time
596            keys = [item["Key"] for item in page.get("Contents", [])]
597            if keys:
598                keys_to_delete.append(keys)
599
600        for chunk in keys_to_delete:
601            s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]})
602
603    @property
604    def _glue_client(self) -> t.Any:
605        return self._boto3_client("glue")
606
607    @property
608    def _s3_client(self) -> t.Any:
609        return self._boto3_client("s3")
610
611    def _boto3_client(self, name: str) -> t.Any:
612        # use the client factory from PyAthena which is already configured with the correct AWS details
613        conn = self.connection
614        return conn.session.client(
615            name,
616            region_name=conn.region_name,
617            config=conn.config,
618            **conn._client_kwargs,
619        )  # type: ignore
620
621    def get_current_catalog(self) -> t.Optional[str]:
622        return self.connection.catalog_name
logger = <Logger sqlmesh.core.engine_adapter.athena (WARNING)>
 33class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin, RowDiffMixin):
 34    DIALECT = "athena"
 35    SUPPORTS_TRANSACTIONS = False
 36    SUPPORTS_REPLACE_TABLE = False
 37    # Athena's support for table and column comments is too patchy to consider "supported"
 38    # Hive tables: Table + Column comments are supported
 39    # Iceberg tables: Column comments only
 40    # CTAS, Views: No comment support at all
 41    COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED
 42    COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
 43    SCHEMA_DIFFER_KWARGS = TrinoEngineAdapter.SCHEMA_DIFFER_KWARGS
 44    MAX_TIMESTAMP_PRECISION = 3  # copied from Trino
 45    # Athena does not deal with comments well, e.g:
 46    # >>> self._execute('/* test */ DESCRIBE foo')
 47    #     pyathena.error.OperationalError: FAILED: ParseException line 1:0 cannot recognize input near '/' '*' 'test'
 48    ATTACH_CORRELATION_ID = False
 49    SUPPORTS_QUERY_EXECUTION_TRACKING = True
 50    SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"]
 51
 52    def __init__(
 53        self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any
 54    ):
 55        # Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config
 56        # which means that EngineAdapter.with_settings() keeps this property when it makes a clone
 57        super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs)
 58        self.s3_warehouse_location = s3_warehouse_location
 59
 60        self._default_catalog = self._default_catalog or "awsdatacatalog"
 61
 62    @property
 63    def s3_warehouse_location(self) -> t.Optional[str]:
 64        return self._s3_warehouse_location
 65
 66    @s3_warehouse_location.setter
 67    def s3_warehouse_location(self, value: t.Optional[str]) -> None:
 68        if value:
 69            value = validate_s3_uri(value, base=True)
 70        self._s3_warehouse_location = value
 71
 72    @property
 73    def s3_warehouse_location_or_raise(self) -> str:
 74        # this makes tests easier to write without extra null checks to keep mypy happy
 75        if location := self.s3_warehouse_location:
 76            return location
 77
 78        raise SQLMeshError("s3_warehouse_location was expected to be populated; it isnt")
 79
 80    @property
 81    def catalog_support(self) -> CatalogSupport:
 82        # Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that
 83        # It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog"
 84        # are pointers to the "awsdatacatalog" of other AWS accounts
 85        return CatalogSupport.SINGLE_CATALOG_ONLY
 86
 87    def create_state_table(
 88        self,
 89        table_name: str,
 90        target_columns_to_types: t.Dict[str, exp.DataType],
 91        primary_key: t.Optional[t.Tuple[str, ...]] = None,
 92    ) -> None:
 93        self.create_table(
 94            table_name,
 95            target_columns_to_types,
 96            primary_key=primary_key,
 97            # it's painfully slow, but it works
 98            table_format="iceberg",
 99        )
100
101    def _get_data_objects(
102        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
103    ) -> t.List[DataObject]:
104        """
105        Returns all the data objects that exist in the given schema and optionally catalog.
106        """
107        schema_name = to_schema(schema_name)
108        schema = schema_name.db
109        query = (
110            exp.select(
111                exp.column("table_catalog").as_("catalog"),
112                exp.column("table_schema", table="t").as_("schema"),
113                exp.column("table_name", table="t").as_("name"),
114                exp.case()
115                .when(
116                    exp.column("table_type", table="t").eq("BASE TABLE"),
117                    exp.Literal.string("table"),
118                )
119                .else_(exp.column("table_type", table="t"))
120                .as_("type"),
121            )
122            .from_(exp.to_table("information_schema.tables", alias="t"))
123            .where(exp.column("table_schema", table="t").eq(schema))
124        )
125        if object_names:
126            query = query.where(exp.column("table_name", table="t").isin(*object_names))
127
128        df = self.fetchdf(query)
129
130        return [
131            DataObject(
132                catalog=row.catalog,  # type: ignore
133                schema=row.schema,  # type: ignore
134                name=row.name,  # type: ignore
135                type=DataObjectType.from_str(row.type),  # type: ignore
136            )
137            for row in df.itertuples()
138        ]
139
140    def columns(
141        self, table_name: TableName, include_pseudo_columns: bool = False
142    ) -> t.Dict[str, exp.DataType]:
143        table = exp.to_table(table_name)
144        # note: the data_type column contains the full parameterized type, eg 'varchar(10)'
145        query = (
146            exp.select("column_name", "data_type")
147            .from_("information_schema.columns")
148            .where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name))
149            .order_by("ordinal_position")
150        )
151        result = self.fetchdf(query, quote_identifiers=True)
152        return {
153            str(r.column_name): exp.DataType.build(str(r.data_type))
154            for r in result.itertuples(index=False)
155        }
156
157    def _create_schema(
158        self,
159        schema_name: SchemaName,
160        ignore_if_exists: bool,
161        warn_on_error: bool,
162        properties: t.List[exp.Expr],
163        kind: str,
164    ) -> None:
165        if location := self._table_location(table_properties=None, table=exp.to_table(schema_name)):
166            # don't add extra LocationProperty's if one already exists
167            if not any(p for p in properties if isinstance(p, exp.LocationProperty)):
168                properties.append(location)
169
170        return super()._create_schema(
171            schema_name=schema_name,
172            ignore_if_exists=ignore_if_exists,
173            warn_on_error=warn_on_error,
174            properties=properties,
175            kind=kind,
176        )
177
178    def _build_create_table_exp(
179        self,
180        table_name_or_schema: t.Union[exp.Schema, TableName],
181        expression: t.Optional[exp.Expr],
182        exists: bool = True,
183        replace: bool = False,
184        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
185        table_description: t.Optional[str] = None,
186        table_kind: t.Optional[str] = None,
187        partitioned_by: t.Optional[t.List[exp.Expr]] = None,
188        table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
189        **kwargs: t.Any,
190    ) -> exp.Create:
191        exists = False if replace else exists
192
193        table: exp.Table
194        if isinstance(table_name_or_schema, str):
195            table = exp.to_table(table_name_or_schema)
196        elif isinstance(table_name_or_schema, exp.Schema):
197            table = table_name_or_schema.this
198        else:
199            table = table_name_or_schema
200
201        properties = self._build_table_properties_exp(
202            table=table,
203            expression=expression,
204            target_columns_to_types=target_columns_to_types,
205            partitioned_by=partitioned_by,
206            table_properties=table_properties,
207            table_description=table_description,
208            table_kind=table_kind,
209            **kwargs,
210        )
211
212        is_hive = self._table_type(kwargs.get("table_format", None)) == "hive"
213
214        # Filter any PARTITIONED BY properties from the main column list since they cant be specified in both places
215        # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html
216        if is_hive and partitioned_by and isinstance(table_name_or_schema, exp.Schema):
217            partitioned_by_column_names = {e.name for e in partitioned_by}
218            filtered_expressions = [
219                e
220                for e in table_name_or_schema.expressions
221                if isinstance(e, exp.ColumnDef) and e.this.name not in partitioned_by_column_names
222            ]
223            table_name_or_schema.args["expressions"] = filtered_expressions
224
225        return exp.Create(
226            this=table_name_or_schema,
227            kind=table_kind or "TABLE",
228            replace=replace,
229            exists=exists,
230            expression=expression,
231            properties=properties,
232        )
233
234    def _build_table_properties_exp(
235        self,
236        catalog_name: t.Optional[str] = None,
237        table_format: t.Optional[str] = None,
238        storage_format: t.Optional[str] = None,
239        partitioned_by: t.Optional[t.List[exp.Expr]] = None,
240        partition_interval_unit: t.Optional[IntervalUnit] = None,
241        clustered_by: t.Optional[t.List[exp.Expr]] = None,
242        table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
243        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
244        table_description: t.Optional[str] = None,
245        table_kind: t.Optional[str] = None,
246        table: t.Optional[exp.Table] = None,
247        expression: t.Optional[exp.Expr] = None,
248        **kwargs: t.Any,
249    ) -> t.Optional[exp.Properties]:
250        properties: t.List[exp.Expr] = []
251        table_properties = table_properties or {}
252
253        is_hive = self._table_type(table_format) == "hive"
254        is_iceberg = not is_hive
255
256        if is_hive and not expression:
257            # Hive tables are CREATE EXTERNAL TABLE, Iceberg tables are CREATE TABLE
258            # Unless it's a CTAS, those are always CREATE TABLE
259            properties.append(exp.ExternalProperty())
260
261        if table_format:
262            properties.append(
263                exp.Property(this=exp.var("table_type"), value=exp.Literal.string(table_format))
264            )
265
266        if table_description:
267            properties.append(exp.SchemaCommentProperty(this=exp.Literal.string(table_description)))
268
269        if partitioned_by:
270            schema_expressions: t.List[exp.Expr] = []
271            if is_hive and target_columns_to_types:
272                # For Hive-style tables, you cannot include the partitioned by columns in the main set of columns
273                # In the PARTITIONED BY expression, you also cant just include the column names, you need to include the data type as well
274                # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html
275                for match_name, match_dtype in self._find_matching_columns(
276                    partitioned_by, target_columns_to_types
277                ):
278                    column_def = exp.ColumnDef(this=exp.to_identifier(match_name), kind=match_dtype)
279                    schema_expressions.append(column_def)
280            else:
281                schema_expressions = partitioned_by
282
283            properties.append(
284                exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions))
285            )
286
287        if clustered_by:
288            # Athena itself supports CLUSTERED BY, via the syntax CLUSTERED BY (col) INTO <n> BUCKETS
289            # However, SQLMesh is more closely aligned with BigQuery's notion of clustering and
290            # defines `clustered_by` as a List[str] with no way of indicating the number of buckets
291            #
292            # Athena's concept of CLUSTER BY is more like Iceberg's `bucket(<num_buckets>, col)` partition transform
293            logging.warning("clustered_by is not supported in the Athena adapter at this time")
294
295        if storage_format:
296            if is_iceberg:
297                # TBLPROPERTIES('format'='parquet')
298                table_properties["format"] = exp.Literal.string(storage_format)
299            else:
300                # STORED AS PARQUET
301                properties.append(exp.FileFormatProperty(this=storage_format))
302
303        if table and (location := self._table_location_or_raise(table_properties, table)):
304            properties.append(location)
305
306            if is_iceberg and expression:
307                # To make a CTAS expression persist as iceberg, alongside setting `table_type=iceberg`, you also need to set is_external=false
308                # Note that SQLGlot does the right thing with LocationProperty and writes it as `location` (Iceberg) instead of `external_location` (Hive)
309                # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties
310                properties.append(exp.Property(this=exp.var("is_external"), value="false"))
311
312        for name, value in table_properties.items():
313            properties.append(exp.Property(this=exp.var(name), value=value))
314
315        if properties:
316            return exp.Properties(expressions=properties)
317
318        return None
319
320    def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None:
321        table = exp.to_table(table_name)
322
323        if self._query_table_type(table) == "hive":
324            self._truncate_table(table)
325
326        return super().drop_table(table_name=table, exists=exists, **kwargs)
327
328    def _truncate_table(self, table_name: TableName) -> None:
329        table = exp.to_table(table_name)
330
331        # Truncating an Iceberg table is just DELETE FROM <table>
332        if self._query_table_type(table) == "iceberg":
333            return self.delete_from(table, exp.true())
334
335        # Truncating a partitioned Hive table is dropping all partitions and deleting the data from S3
336        if self._is_hive_partitioned_table(table):
337            self._clear_partition_data(table, exp.true())
338        elif s3_location := self._query_table_s3_location(table):
339            # Truncating a non-partitioned Hive table is clearing out all data in its Location
340            self._clear_s3_location(s3_location)
341
342    def _table_type(self, table_format: t.Optional[str] = None) -> TableType:
343        """
344        Interpret the "table_format" property to check if this is a Hive or an Iceberg table
345        """
346        if table_format and table_format.lower() == "iceberg":
347            return "iceberg"
348
349        # if we cant detect any indication of Iceberg, this is a Hive table
350        return "hive"
351
352    def _query_table_type(self, table: exp.Table) -> t.Optional[TableType]:
353        if self.table_exists(table):
354            return self._query_table_type_or_raise(table)
355        return None
356
357    @lru_cache()
358    def _query_table_type_or_raise(self, table: exp.Table) -> TableType:
359        """
360        Hit the DB to check if this is a Hive or an Iceberg table.
361
362        Note that in order to @lru_cache() this method, we have the following assumptions:
363         - The table must exist (otherwise we would cache None if this method was called before table creation and always return None after creation)
364         - The table type will not change within the same SQLMesh session
365        """
366        # Note: SHOW TBLPROPERTIES gets parsed by SQLGlot as an exp.Command anyway so we just use a string here
367        # This also means we need to use dialect="hive" instead of dialect="athena" so that the identifiers get the correct quoting (backticks)
368        for row in self.fetchall(f"SHOW TBLPROPERTIES {table.sql(dialect='hive', identify=True)}"):
369            # This query returns a single column with values like 'EXTERNAL\tTRUE'
370            row_lower = row[0].lower()
371            if "external" in row_lower and "true" in row_lower:
372                return "hive"
373        return "iceberg"
374
375    def _is_hive_partitioned_table(self, table: exp.Table) -> bool:
376        try:
377            self._list_partitions(table=table, where=None, limit=1)
378            return True
379        except Exception as e:
380            if "TABLE_NOT_FOUND" in str(e):
381                return False
382            raise e
383
384    def _table_location_or_raise(
385        self, table_properties: t.Optional[t.Dict[str, exp.Expr]], table: exp.Table
386    ) -> exp.LocationProperty:
387        location = self._table_location(table_properties, table)
388        if not location:
389            raise SQLMeshError(
390                f"Cannot figure out location for table {table}. Please either set `s3_base_location` in `physical_properties` or set `s3_warehouse_location` in the Athena connection config"
391            )
392        return location
393
394    def _table_location(
395        self,
396        table_properties: t.Optional[t.Dict[str, exp.Expr]],
397        table: exp.Table,
398    ) -> t.Optional[exp.LocationProperty]:
399        base_uri: str
400
401        # If the user has manually specified a `s3_base_location`, use it
402        if table_properties and "s3_base_location" in table_properties:
403            s3_base_location_property = table_properties.pop(
404                "s3_base_location"
405            )  # pop because it's handled differently and we dont want it to end up in the TBLPROPERTIES clause
406            if isinstance(s3_base_location_property, exp.Expr):
407                base_uri = s3_base_location_property.name
408            else:
409                base_uri = s3_base_location_property
410
411        elif self.s3_warehouse_location:
412            # If the user has set `s3_warehouse_location` in the connection config, the base URI is <s3_warehouse_location>/<catalog>/<schema>/
413            base_uri = posixpath.join(
414                self.s3_warehouse_location, table.catalog or "", table.db or ""
415            )
416        else:
417            return None
418
419        full_uri = validate_s3_uri(posixpath.join(base_uri, table.text("this") or ""), base=True)
420        return exp.LocationProperty(this=exp.Literal.string(full_uri))
421
422    def _find_matching_columns(
423        self, partitioned_by: t.List[exp.Expr], columns_to_types: t.Dict[str, exp.DataType]
424    ) -> t.List[t.Tuple[str, exp.DataType]]:
425        matches = []
426        for col in partitioned_by:
427            # TODO: do we care about normalization?
428            key = col.name
429            if isinstance(col, exp.Column) and (match_dtype := columns_to_types.get(key)):
430                matches.append((key, match_dtype))
431        return matches
432
433    def replace_query(
434        self,
435        table_name: TableName,
436        query_or_df: QueryOrDF,
437        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
438        table_description: t.Optional[str] = None,
439        column_descriptions: t.Optional[t.Dict[str, str]] = None,
440        source_columns: t.Optional[t.List[str]] = None,
441        supports_replace_table_override: t.Optional[bool] = None,
442        **kwargs: t.Any,
443    ) -> None:
444        table = exp.to_table(table_name)
445
446        if self._query_table_type(table=table) == "hive":
447            self.drop_table(table)
448
449        return super().replace_query(
450            table_name=table,
451            query_or_df=query_or_df,
452            target_columns_to_types=target_columns_to_types,
453            table_description=table_description,
454            column_descriptions=column_descriptions,
455            source_columns=source_columns,
456            **kwargs,
457        )
458
459    def _insert_overwrite_by_time_partition(
460        self,
461        table_name: TableName,
462        source_queries: t.List[SourceQuery],
463        target_columns_to_types: t.Dict[str, exp.DataType],
464        where: exp.Condition,
465        **kwargs: t.Any,
466    ) -> None:
467        table = exp.to_table(table_name)
468
469        table_type = self._query_table_type(table)
470
471        if table_type == "iceberg":
472            # Iceberg tables work as expected, we can use the default behaviour
473            return super()._insert_overwrite_by_time_partition(
474                table, source_queries, target_columns_to_types, where, **kwargs
475            )
476
477        # For Hive tables, we need to drop all the partitions covered by the query and delete the data from S3
478        self._clear_partition_data(table, where)
479
480        # Now the data is physically gone, we can continue with inserting a new partition
481        return super()._insert_overwrite_by_time_partition(
482            table,
483            source_queries,
484            target_columns_to_types,
485            where,
486            insert_overwrite_strategy_override=InsertOverwriteStrategy.INTO_IS_OVERWRITE,  # since we already cleared the data
487            **kwargs,
488        )
489
490    def _clear_partition_data(self, table: exp.Table, where: t.Optional[exp.Condition]) -> None:
491        if partitions_to_drop := self._list_partitions(table, where):
492            for _, s3_location in partitions_to_drop:
493                logger.debug(
494                    f"Clearing S3 location for '{table.sql(dialect=self.dialect)}': {s3_location}"
495                )
496                self._clear_s3_location(s3_location)
497
498            partition_values = [k for k, _ in partitions_to_drop]
499            logger.debug(
500                f"Dropping partitions for '{table.sql(dialect=self.dialect)}' from metastore: {partition_values}"
501            )
502            self._drop_partitions_from_metastore(table, partition_values)
503
504    def _list_partitions(
505        self,
506        table: exp.Table,
507        where: t.Optional[exp.Condition] = None,
508        limit: t.Optional[int] = None,
509    ) -> t.List[t.Tuple[t.List[str], str]]:
510        # Use Athena's magic "$partitions" metadata table to identify the partitions to drop
511        # Doing it this way allows us to use SQL to filter the partition list
512        partition_table_name = table.copy()
513        partition_table_name.this.replace(
514            exp.to_identifier(f"{table.name}$partitions", quoted=True)
515        )
516
517        query = exp.select("*").from_(partition_table_name).where(where)
518        if limit:
519            query = query.limit(limit)
520
521        partition_values = [list(r) for r in self.fetchall(query, quote_identifiers=True)]
522
523        if partition_values:
524            response = self._glue_client.batch_get_partition(
525                DatabaseName=table.db,
526                TableName=table.name,
527                PartitionsToGet=[{"Values": [str(v) for v in lst]} for lst in partition_values],
528            )
529            return sorted(
530                [(p["Values"], p["StorageDescriptor"]["Location"]) for p in response["Partitions"]]
531            )
532
533        return []
534
535    def _query_table_s3_location(self, table: exp.Table) -> str:
536        response = self._glue_client.get_table(DatabaseName=table.db, Name=table.name)
537
538        # Athena wont let you create a table without a location, so *theoretically* this should never be empty
539        if location := response.get("Table", {}).get("StorageDescriptor", {}).get("Location", None):
540            return location
541
542        raise SQLMeshError(f"Table {table} has no location set in the metastore!")
543
544    def _drop_partitions_from_metastore(
545        self, table: exp.Table, partition_values: t.List[t.List[str]]
546    ) -> None:
547        # todo: switch to itertools.batched when our minimum supported Python is 3.12
548        # 25 = maximum number of partitions that batch_delete_partition can process at once
549        # ref: https://docs.aws.amazon.com/glue/latest/webapi/API_BatchDeletePartition.html#API_BatchDeletePartition_RequestParameters
550        def _chunks() -> t.Iterable[t.List[t.List[str]]]:
551            for i in range(0, len(partition_values), 25):
552                yield partition_values[i : i + 25]
553
554        for batch in _chunks():
555            self._glue_client.batch_delete_partition(
556                DatabaseName=table.db,
557                TableName=table.name,
558                PartitionsToDelete=[{"Values": v} for v in batch],
559            )
560
561    def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None:
562        table = exp.to_table(table_name)
563
564        table_type = self._query_table_type(table)
565
566        # If Iceberg, DELETE operations work as expected
567        if table_type == "iceberg":
568            return super().delete_from(table, where)
569
570        # If Hive, DELETE is an error
571        if table_type == "hive":
572            # However, if there are no actual records to delete, we can make DELETE a no-op
573            # This simplifies a bunch of calling code that just assumes DELETE works (which to be fair is a reasonable assumption since it does for every other engine)
574            empty_check = (
575                exp.select("*").from_(table).where(where).limit(1)
576            )  # deliberately not count(*) because we want the engine to stop as soon as it finds a record
577            if len(self.fetchall(empty_check)) > 0:
578                raise SQLMeshError("Cannot delete individual records from a Hive table")
579
580        return None
581
582    def _clear_s3_location(self, s3_uri: str) -> None:
583        s3 = self._s3_client
584
585        bucket, key = parse_s3_uri(s3_uri)
586        if not key.endswith("/"):
587            key = f"{key}/"
588
589        keys_to_delete = []
590
591        # note: uses Delimiter=/ to prevent stepping into folders
592        # the assumption is that all the files in a partition live directly at the partition `Location`
593        for page in s3.get_paginator("list_objects_v2").paginate(
594            Bucket=bucket, Prefix=key, Delimiter="/"
595        ):
596            # list_objects_v2() returns 1000 keys per page so that lines up nicely with delete_objects() being able to delete 1000 keys at a time
597            keys = [item["Key"] for item in page.get("Contents", [])]
598            if keys:
599                keys_to_delete.append(keys)
600
601        for chunk in keys_to_delete:
602            s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]})
603
604    @property
605    def _glue_client(self) -> t.Any:
606        return self._boto3_client("glue")
607
608    @property
609    def _s3_client(self) -> t.Any:
610        return self._boto3_client("s3")
611
612    def _boto3_client(self, name: str) -> t.Any:
613        # use the client factory from PyAthena which is already configured with the correct AWS details
614        conn = self.connection
615        return conn.session.client(
616            name,
617            region_name=conn.region_name,
618            config=conn.config,
619            **conn._client_kwargs,
620        )  # type: ignore
621
622    def get_current_catalog(self) -> t.Optional[str]:
623        return self.connection.catalog_name

Base class wrapping a Database API compliant connection.

The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.

Arguments:
  • connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
  • dialect: The dialect with which this adapter is associated.
  • multithreaded: Indicates whether this adapter will be used by more than one thread.
AthenaEngineAdapter( *args: Any, s3_warehouse_location: Optional[str] = None, **kwargs: Any)
52    def __init__(
53        self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any
54    ):
55        # Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config
56        # which means that EngineAdapter.with_settings() keeps this property when it makes a clone
57        super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs)
58        self.s3_warehouse_location = s3_warehouse_location
59
60        self._default_catalog = self._default_catalog or "awsdatacatalog"
DIALECT = 'athena'
SUPPORTS_TRANSACTIONS = False
SUPPORTS_REPLACE_TABLE = False
COMMENT_CREATION_TABLE = <CommentCreationTable.UNSUPPORTED: 1>
COMMENT_CREATION_VIEW = <CommentCreationView.UNSUPPORTED: 1>
SCHEMA_DIFFER_KWARGS = {'parameterized_type_defaults': {<DType.DECIMAL: 'DECIMAL'>: [(), (0,)], <DType.CHAR: 'CHAR'>: [(1,)], <DType.TIMESTAMP: 'TIMESTAMP'>: [(3,)]}}
MAX_TIMESTAMP_PRECISION = 3
ATTACH_CORRELATION_ID = False
SUPPORTS_QUERY_EXECUTION_TRACKING = True
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ['DATABASE', 'SCHEMA']
s3_warehouse_location: Optional[str]
62    @property
63    def s3_warehouse_location(self) -> t.Optional[str]:
64        return self._s3_warehouse_location
s3_warehouse_location_or_raise: str
72    @property
73    def s3_warehouse_location_or_raise(self) -> str:
74        # this makes tests easier to write without extra null checks to keep mypy happy
75        if location := self.s3_warehouse_location:
76            return location
77
78        raise SQLMeshError("s3_warehouse_location was expected to be populated; it isnt")
80    @property
81    def catalog_support(self) -> CatalogSupport:
82        # Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that
83        # It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog"
84        # are pointers to the "awsdatacatalog" of other AWS accounts
85        return CatalogSupport.SINGLE_CATALOG_ONLY
def create_state_table( self, table_name: str, target_columns_to_types: Dict[str, sqlglot.expressions.datatypes.DataType], primary_key: Optional[Tuple[str, ...]] = None) -> None:
87    def create_state_table(
88        self,
89        table_name: str,
90        target_columns_to_types: t.Dict[str, exp.DataType],
91        primary_key: t.Optional[t.Tuple[str, ...]] = None,
92    ) -> None:
93        self.create_table(
94            table_name,
95            target_columns_to_types,
96            primary_key=primary_key,
97            # it's painfully slow, but it works
98            table_format="iceberg",
99        )

Create a table to store SQLMesh internal state.

Arguments:
  • table_name: The name of the table to create. Can be fully qualified or just table name.
  • target_columns_to_types: A mapping between the column name and its data type.
  • primary_key: Determines the table primary key.
def columns( self, table_name: Union[str, sqlglot.expressions.query.Table], include_pseudo_columns: bool = False) -> Dict[str, sqlglot.expressions.datatypes.DataType]:
140    def columns(
141        self, table_name: TableName, include_pseudo_columns: bool = False
142    ) -> t.Dict[str, exp.DataType]:
143        table = exp.to_table(table_name)
144        # note: the data_type column contains the full parameterized type, eg 'varchar(10)'
145        query = (
146            exp.select("column_name", "data_type")
147            .from_("information_schema.columns")
148            .where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name))
149            .order_by("ordinal_position")
150        )
151        result = self.fetchdf(query, quote_identifiers=True)
152        return {
153            str(r.column_name): exp.DataType.build(str(r.data_type))
154            for r in result.itertuples(index=False)
155        }

Fetches column names and types for the target table.

def drop_table( self, table_name: Union[str, sqlglot.expressions.query.Table], exists: bool = True, **kwargs: Any) -> None:
320    def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None:
321        table = exp.to_table(table_name)
322
323        if self._query_table_type(table) == "hive":
324            self._truncate_table(table)
325
326        return super().drop_table(table_name=table, exists=exists, **kwargs)

Drops a table.

Arguments:
  • table_name: The name of the table to drop.
  • exists: If exists, defaults to True.
def replace_query( self, table_name: Union[str, sqlglot.expressions.query.Table], query_or_df: <MagicMock id='132726908808128'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]] = None, table_description: Optional[str] = None, column_descriptions: Optional[Dict[str, str]] = None, source_columns: Optional[List[str]] = None, supports_replace_table_override: Optional[bool] = None, **kwargs: Any) -> None:
433    def replace_query(
434        self,
435        table_name: TableName,
436        query_or_df: QueryOrDF,
437        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
438        table_description: t.Optional[str] = None,
439        column_descriptions: t.Optional[t.Dict[str, str]] = None,
440        source_columns: t.Optional[t.List[str]] = None,
441        supports_replace_table_override: t.Optional[bool] = None,
442        **kwargs: t.Any,
443    ) -> None:
444        table = exp.to_table(table_name)
445
446        if self._query_table_type(table=table) == "hive":
447            self.drop_table(table)
448
449        return super().replace_query(
450            table_name=table,
451            query_or_df=query_or_df,
452            target_columns_to_types=target_columns_to_types,
453            table_description=table_description,
454            column_descriptions=column_descriptions,
455            source_columns=source_columns,
456            **kwargs,
457        )

Replaces an existing table with a query.

For partition based engines (hive, spark), insert override is used. For other systems, create or replace is used.

Arguments:
  • table_name: The name of the table (eg. prod.table)
  • query_or_df: The SQL query to run or a dataframe.
  • target_columns_to_types: Only used if a dataframe is provided. A mapping between the column name and its data type. Expected to be ordered to match the order of values in the dataframe.
  • kwargs: Optional create table properties.
def delete_from( self, table_name: Union[str, sqlglot.expressions.query.Table], where: Union[str, sqlglot.expressions.core.Expr]) -> None:
561    def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None:
562        table = exp.to_table(table_name)
563
564        table_type = self._query_table_type(table)
565
566        # If Iceberg, DELETE operations work as expected
567        if table_type == "iceberg":
568            return super().delete_from(table, where)
569
570        # If Hive, DELETE is an error
571        if table_type == "hive":
572            # However, if there are no actual records to delete, we can make DELETE a no-op
573            # This simplifies a bunch of calling code that just assumes DELETE works (which to be fair is a reasonable assumption since it does for every other engine)
574            empty_check = (
575                exp.select("*").from_(table).where(where).limit(1)
576            )  # deliberately not count(*) because we want the engine to stop as soon as it finds a record
577            if len(self.fetchall(empty_check)) > 0:
578                raise SQLMeshError("Cannot delete individual records from a Hive table")
579
580        return None
def get_current_catalog(self) -> Optional[str]:
622    def get_current_catalog(self) -> t.Optional[str]:
623        return self.connection.catalog_name

Returns the catalog name of the current connection.

Inherited Members
sqlmesh.core.engine_adapter.mixins.RowDiffMixin
concat_columns
normalize_value
sqlmesh.core.engine_adapter.base.EngineAdapter
DEFAULT_BATCH_SIZE
DATA_OBJECT_FILTER_BATCH_SIZE
SUPPORTS_INDEXES
MAX_TABLE_COMMENT_LENGTH
MAX_COLUMN_COMMENT_LENGTH
INSERT_OVERWRITE_STRATEGY
SUPPORTS_MATERIALIZED_VIEWS
SUPPORTS_MATERIALIZED_VIEW_SCHEMA
SUPPORTS_VIEW_SCHEMA
SUPPORTS_CLONING
SUPPORTS_MANAGED_MODELS
SUPPORTS_CREATE_DROP_CATALOG
SUPPORTS_TUPLE_IN
HAS_VIEW_BINDING
SUPPORTS_GRANTS
DEFAULT_CATALOG_TYPE
QUOTE_IDENTIFIERS_IN_VIEWS
MAX_IDENTIFIER_LENGTH
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
dialect
correlation_id
with_settings
cursor
connection
spark
snowpark
bigframe
comments_enabled
schema_differ
default_catalog
engine_run_mode
recycle
close
set_current_catalog
get_catalog_type
get_catalog_type_from_table
current_catalog_type
create_index
create_table
create_managed_table
ctas
create_table_like
clone_table
drop_data_object
drop_managed_table
get_alter_operations
alter_table
create_view
create_schema
drop_schema
drop_view
create_catalog
drop_catalog
table_exists
insert_append
insert_overwrite_by_partition
insert_overwrite_by_time_partition
update_table
scd_type_2_by_time
scd_type_2_by_column
merge
rename_table
get_data_object
get_data_objects
fetchone
fetchall
fetchdf
fetch_pyspark_df
wap_enabled
wap_supported
wap_table_name
wap_prepare
wap_publish
sync_grants_config
transaction
session
execute
temp_table
drop_data_object_on_type_mismatch
ensure_nulls_for_unmatched_after_join
use_server_nulls_for_unmatched_after_join
ping
get_table_last_modified_ts