Edit on GitHub

sqlmesh.core.engine_adapter.spark

  1from __future__ import annotations
  2
  3import logging
  4import typing as t
  5from functools import partial
  6
  7from sqlglot import exp
  8
  9from sqlmesh.core.dialect import to_schema
 10from sqlmesh.core.engine_adapter.mixins import (
 11    GetCurrentCatalogFromFunctionMixin,
 12    HiveMetastoreTablePropertiesMixin,
 13    RowDiffMixin,
 14)
 15from sqlmesh.core.engine_adapter.shared import (
 16    CatalogSupport,
 17    CommentCreationTable,
 18    CommentCreationView,
 19    DataObject,
 20    DataObjectType,
 21    InsertOverwriteStrategy,
 22    SourceQuery,
 23    set_catalog,
 24)
 25from sqlmesh.utils import classproperty, get_source_columns_to_types
 26from sqlmesh.utils.errors import SQLMeshError
 27
 28if t.TYPE_CHECKING:
 29    import pandas as pd
 30    from pyspark.sql import types as spark_types
 31
 32    from sqlmesh.core._typing import SchemaName, TableName
 33    from sqlmesh.core.engine_adapter._typing import (
 34        DF,
 35        PySparkDataFrame,
 36        PySparkSession,
 37        Query,
 38    )
 39    from sqlmesh.core.engine_adapter.base import QueryOrDF
 40    from sqlmesh.engines.spark.db_api.spark_session import SparkSessionConnection
 41
 42
 43logger = logging.getLogger(__name__)
 44
 45
 46@set_catalog()
 47class SparkEngineAdapter(
 48    GetCurrentCatalogFromFunctionMixin, HiveMetastoreTablePropertiesMixin, RowDiffMixin
 49):
 50    DIALECT = "spark"
 51    SUPPORTS_TRANSACTIONS = False
 52    INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INSERT_OVERWRITE
 53    COMMENT_CREATION_TABLE = CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS
 54    COMMENT_CREATION_VIEW = CommentCreationView.IN_SCHEMA_DEF_NO_COMMANDS
 55    # Note: Some formats (like Delta and Iceberg) support REPLACE TABLE but since we don't
 56    # currently check for storage formats we say we don't support REPLACE TABLE
 57    SUPPORTS_REPLACE_TABLE = False
 58    QUOTE_IDENTIFIERS_IN_VIEWS = False
 59    SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"]
 60
 61    WAP_PREFIX = "wap_"
 62    BRANCH_PREFIX = "branch_"
 63    SCHEMA_DIFFER_KWARGS = {
 64        "parameterized_type_defaults": {
 65            # default decimal precision varies across backends
 66            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)],
 67        },
 68    }
 69
 70    @property
 71    def connection(self) -> SparkSessionConnection:
 72        return self._connection_pool.get()
 73
 74    @property
 75    def spark(self) -> PySparkSession:
 76        return self.connection.spark
 77
 78    @property
 79    def _use_spark_session(self) -> bool:
 80        return True
 81
 82    @property
 83    def use_serverless(self) -> bool:
 84        return False
 85
 86    @property
 87    def catalog_support(self) -> CatalogSupport:
 88        return CatalogSupport.FULL_SUPPORT
 89
 90    @classproperty
 91    def _sqlglot_to_spark_primitive_mapping(self) -> t.Dict[t.Any, t.Any]:
 92        from pyspark.sql import types as spark_types
 93
 94        return {
 95            exp.DataType.Type.TINYINT: spark_types.ByteType,
 96            exp.DataType.Type.SMALLINT: spark_types.ShortType,
 97            exp.DataType.Type.INT: spark_types.IntegerType,
 98            exp.DataType.Type.BIGINT: spark_types.LongType,
 99            exp.DataType.Type.FLOAT: spark_types.FloatType,
100            exp.DataType.Type.DOUBLE: spark_types.DoubleType,
101            exp.DataType.Type.DECIMAL: spark_types.DecimalType,
102            exp.DataType.Type.VARCHAR: spark_types.StringType,
103            exp.DataType.Type.CHAR: spark_types.StringType,
104            exp.DataType.Type.TEXT: spark_types.StringType,
105            exp.DataType.Type.BINARY: spark_types.BinaryType,
106            exp.DataType.Type.BOOLEAN: spark_types.BooleanType,
107            exp.DataType.Type.DATE: spark_types.DateType,
108            exp.DataType.Type.TIMESTAMPNTZ: spark_types.TimestampNTZType,
109            exp.DataType.Type.DATETIME: spark_types.TimestampNTZType,
110            exp.DataType.Type.TIMESTAMPLTZ: spark_types.TimestampType,
111            exp.DataType.Type.TIMESTAMP: spark_types.TimestampType,
112            exp.DataType.Type.TIMESTAMPTZ: spark_types.TimestampType,
113        }
114
115    @classproperty
116    def _sqlglot_to_spark_complex_mapping(self) -> t.Dict[t.Any, t.Any]:
117        from pyspark.sql import types as spark_types
118
119        return {
120            exp.DataType.Type.ARRAY: spark_types.ArrayType,
121            exp.DataType.Type.MAP: spark_types.MapType,
122            exp.DataType.Type.STRUCT: spark_types.StructType,
123        }
124
125    @classproperty
126    def _spark_to_sqlglot_primitive_mapping(self) -> t.Dict[t.Any, t.Any]:
127        return {v: k for k, v in self._sqlglot_to_spark_primitive_mapping.items()}
128
129    @classproperty
130    def _spark_to_sqlglot_complex_mapping(self) -> t.Dict[t.Any, t.Any]:
131        return {v: k for k, v in self._sqlglot_to_spark_complex_mapping.items()}
132
133    @classmethod
134    def spark_to_sqlglot_types(cls, input: spark_types.StructType) -> t.Dict[str, exp.DataType]:
135        from pyspark.sql import types as spark_types
136
137        def spark_complex_to_sqlglot_complex(
138            complex_type: t.Union[
139                spark_types.StructType, spark_types.ArrayType, spark_types.MapType
140            ],
141        ) -> exp.DataType:
142            def get_fields(
143                complex_type: t.Union[
144                    spark_types.StructType, spark_types.ArrayType, spark_types.MapType
145                ],
146            ) -> t.Sequence[spark_types.DataType]:
147                if isinstance(complex_type, spark_types.StructType):
148                    return complex_type.fields
149                if isinstance(complex_type, spark_types.ArrayType):
150                    return [complex_type.elementType]
151                if isinstance(complex_type, spark_types.MapType):
152                    return [complex_type.keyType, complex_type.valueType]
153                raise SQLMeshError(f"Unsupported complex type: {complex_type}")
154
155            expressions: t.List[t.Union[exp.ColumnDef, exp.DataType]] = []
156            fields = get_fields(complex_type)
157            for field in fields:
158                if isinstance(field, (spark_types.StructType, spark_types.MapType)):
159                    expressions.append(spark_complex_to_sqlglot_complex(field))
160                elif isinstance(field, spark_types.StructField):
161                    sqlglot_data_type = cls._spark_to_sqlglot_primitive_mapping.get(
162                        type(field.dataType)
163                    ) or spark_complex_to_sqlglot_complex(
164                        field.dataType  # type: ignore
165                    )
166                    kind = (
167                        sqlglot_data_type
168                        if isinstance(sqlglot_data_type, exp.DataType)
169                        else exp.DataType(this=sqlglot_data_type)
170                    )
171                    expressions.append(exp.ColumnDef(this=exp.to_identifier(field.name), kind=kind))
172                else:
173                    kind = exp.DataType(this=cls._spark_to_sqlglot_primitive_mapping[type(field)])
174                    expressions.append(kind)
175            dtype = cls._spark_to_sqlglot_complex_mapping[type(complex_type)]
176            return exp.DataType(
177                this=dtype,
178                expressions=expressions,
179                nested=True,
180            )
181
182        resp = spark_complex_to_sqlglot_complex(input)
183        return {column_def.this.name: column_def.args["kind"] for column_def in resp.expressions}
184
185    @classmethod
186    def sqlglot_to_spark_types(cls, input: t.Dict[str, exp.DataType]) -> spark_types.StructType:
187        from pyspark.sql import types as spark_types
188
189        def sqlglot_complex_to_spark_complex(complex_type: exp.DataType) -> spark_types.DataType:
190            is_struct = complex_type.is_type(exp.DataType.Type.STRUCT)
191            expressions = []
192            for column_def in complex_type.expressions:
193                col_name = column_def.this.name if is_struct else None
194                data_type = column_def.args["kind"] if is_struct else column_def
195                primitive_func = cls._sqlglot_to_spark_primitive_mapping.get(data_type.this)
196                type_func = (
197                    primitive_func
198                    if primitive_func
199                    else partial(sqlglot_complex_to_spark_complex, data_type)
200                )
201                if is_struct:
202                    expressions.append(spark_types.StructField(col_name, type_func()))  # type: ignore
203                else:
204                    expressions.append(type_func())  # type: ignore
205            klass = cls._sqlglot_to_spark_complex_mapping[complex_type.this]
206            if is_struct:
207                return klass(expressions)
208            return klass(*expressions)
209
210        return t.cast(
211            spark_types.StructType,
212            sqlglot_complex_to_spark_complex(
213                exp.DataType(
214                    this=exp.DataType.Type.STRUCT,
215                    expressions=[
216                        exp.ColumnDef(this=exp.to_identifier(column), kind=data_type)
217                        for column, data_type in input.items()
218                    ],
219                )
220            ),
221        )
222
223    @classmethod
224    def is_pyspark_df(cls, value: t.Any) -> bool:
225        return hasattr(value, "sparkSession")
226
227    @classmethod
228    def try_get_pyspark_df(cls, value: t.Any) -> t.Optional[PySparkDataFrame]:
229        if cls.is_pyspark_df(value):
230            return value
231        return None
232
233    @classmethod
234    def try_get_pandas_df(cls, value: t.Any) -> t.Optional[pd.DataFrame]:
235        import pandas as pd
236
237        if isinstance(value, pd.DataFrame):
238            return value
239        return None
240
241    @t.overload
242    def _columns_to_types(
243        self,
244        query_or_df: DF,
245        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
246        source_columns: t.Optional[t.List[str]] = None,
247    ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ...
248
249    @t.overload
250    def _columns_to_types(
251        self,
252        query_or_df: Query,
253        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
254        source_columns: t.Optional[t.List[str]] = None,
255    ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ...
256
257    def _columns_to_types(
258        self,
259        query_or_df: QueryOrDF,
260        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
261        source_columns: t.Optional[t.List[str]] = None,
262    ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]:
263        if target_columns_to_types:
264            return target_columns_to_types, list(source_columns or target_columns_to_types)
265        if self.is_pyspark_df(query_or_df):
266            from pyspark.sql import DataFrame
267
268            target_columns_to_types = self.spark_to_sqlglot_types(
269                t.cast(DataFrame, query_or_df).schema
270            )
271            return target_columns_to_types, list(source_columns or target_columns_to_types)
272        return super()._columns_to_types(
273            query_or_df, target_columns_to_types, source_columns=source_columns
274        )
275
276    def _df_to_source_queries(
277        self,
278        df: DF,
279        target_columns_to_types: t.Dict[str, exp.DataType],
280        batch_size: int,
281        target_table: TableName,
282        source_columns: t.Optional[t.List[str]] = None,
283    ) -> t.List[SourceQuery]:
284        df = self._ensure_pyspark_df(df, target_columns_to_types, source_columns=source_columns)
285
286        def query_factory() -> Query:
287            temp_table = self._get_temp_table(target_table or "spark", table_only=True)
288            df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect))  # type: ignore
289            temp_table.set("db", "global_temp")
290            return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table)
291
292        return [SourceQuery(query_factory=query_factory)]
293
294    def _ensure_pyspark_df(
295        self,
296        generic_df: DF,
297        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
298        source_columns: t.Optional[t.List[str]] = None,
299    ) -> PySparkDataFrame:
300        pyspark_df = self.try_get_pyspark_df(generic_df)
301        if not pyspark_df:
302            df = self.try_get_pandas_df(generic_df)
303            if df is None:
304                raise SQLMeshError(
305                    "Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame"
306                )
307
308            if target_columns_to_types:
309                source_columns_to_types = get_source_columns_to_types(
310                    target_columns_to_types, source_columns
311                )
312                # ensure Pandas dataframe column order matches columns_to_types
313                df = df[list(source_columns_to_types)]
314            else:
315                source_columns_to_types = None
316            kwargs = (
317                dict(schema=self.sqlglot_to_spark_types(source_columns_to_types))
318                if source_columns_to_types
319                else {}
320            )
321            pyspark_df = self.spark.createDataFrame(df, **kwargs)  # type: ignore
322        if target_columns_to_types:
323            select_columns = self._casted_columns(
324                target_columns_to_types, source_columns=source_columns
325            )
326            pyspark_df = pyspark_df.selectExpr(*[x.sql(self.dialect) for x in select_columns])  # type: ignore
327        return pyspark_df
328
329    def _get_temp_table(
330        self, table: TableName, table_only: bool = False, quoted: bool = True
331    ) -> exp.Table:
332        """
333        Returns the name of the temp table that should be used for the given table name.
334        """
335        table = super()._get_temp_table(table, table_only=table_only)
336        table_name_id = table.args["this"]
337        # Spark with local filesystem has an issue with temp tables that start with __temp so
338        # we update here to remove the leading double underscore
339        table_name_id.set("this", table_name_id.this.replace("__temp_", "temp_"))
340        return table
341
342    def fetchdf(
343        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
344    ) -> pd.DataFrame:
345        return self.fetch_pyspark_df(query, quote_identifiers=quote_identifiers).toPandas()
346
347    def fetch_pyspark_df(
348        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
349    ) -> PySparkDataFrame:
350        return self._ensure_pyspark_df(
351            self._fetch_native_df(query, quote_identifiers=quote_identifiers)
352        )
353
354    def _get_data_objects(
355        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
356    ) -> t.List[DataObject]:
357        schema_name = to_schema(schema_name).sql(dialect=self.dialect)
358        pattern = "*" if object_names is None else "|".join(object_names)
359        sql = f"SHOW TABLE EXTENDED IN {schema_name} LIKE '{pattern}'"
360        try:
361            results = (
362                self.fetch_pyspark_df(sql).collect()
363                if self._use_spark_session
364                else self.fetchdf(sql).to_dict("records")
365            )
366        # Improvement: Figure out all the different exceptions we could get from executing a query either with or
367        # without a Spark Session. In addition Databricks would need to be updated to handle it's own exceptions.
368        # Therefore just doing except Exception for now.
369        except Exception:
370            return []
371        data_objects = []
372        catalog = self.get_current_catalog()
373        for row in results:  # type: ignore
374            row_dict = row.asDict() if not isinstance(row, dict) else row
375            if row_dict.get("isTemporary"):
376                continue
377            schema = row_dict.get("namespace") or row_dict.get("database")
378            data_objects.append(
379                DataObject(
380                    catalog=catalog,
381                    schema=schema,
382                    name=row_dict["tableName"],
383                    type=(
384                        DataObjectType.VIEW
385                        if "Type: VIEW" in row_dict["information"]
386                        else DataObjectType.TABLE
387                    ),
388                )
389            )
390        return data_objects
391
392    def get_current_catalog(self) -> t.Optional[str]:
393        if self._use_spark_session:
394            return self.connection.get_current_catalog()
395        return super().get_current_catalog()
396
397    def set_current_catalog(self, catalog_name: str) -> None:
398        self.connection.set_current_catalog(catalog_name)
399
400    def _get_current_schema(self) -> str:
401        if self._use_spark_session:
402            return self.spark.catalog.currentDatabase()
403        return self.fetchone(exp.select(exp.func("current_database")))[0]  # type: ignore
404
405    def get_data_object(
406        self, target_name: TableName, safe_to_cache: bool = False
407    ) -> t.Optional[DataObject]:
408        target_table = exp.to_table(target_name)
409        if isinstance(target_table.this, exp.Dot) and target_table.this.expression.name.startswith(
410            f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"
411        ):
412            # Exclude the branch name
413            target_table.set("this", target_table.this.this)
414        return super().get_data_object(target_table, safe_to_cache=safe_to_cache)
415
416    def create_state_table(
417        self,
418        table_name: str,
419        target_columns_to_types: t.Dict[str, exp.DataType],
420        primary_key: t.Optional[t.Tuple[str, ...]] = None,
421    ) -> None:
422        self.create_table(
423            table_name,
424            target_columns_to_types,
425            partitioned_by=[exp.column(x) for x in primary_key] if primary_key else None,
426        )
427
428    def _native_df_to_pandas_df(
429        self,
430        query_or_df: QueryOrDF,
431    ) -> t.Union[Query, pd.DataFrame]:
432        if pyspark_df := self.try_get_pyspark_df(query_or_df):
433            return pyspark_df.toPandas()
434
435        return super()._native_df_to_pandas_df(query_or_df)
436
437    def _create_table(
438        self,
439        table_name_or_schema: t.Union[exp.Schema, TableName],
440        expression: t.Optional[exp.Expr],
441        exists: bool = True,
442        replace: bool = False,
443        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
444        table_description: t.Optional[str] = None,
445        column_descriptions: t.Optional[t.Dict[str, str]] = None,
446        table_kind: t.Optional[str] = None,
447        track_rows_processed: bool = True,
448        **kwargs: t.Any,
449    ) -> None:
450        table_name = (
451            table_name_or_schema.this
452            if isinstance(table_name_or_schema, exp.Schema)
453            else exp.to_table(table_name_or_schema)
454        )
455        # Spark doesn't support creating a wap table DDL. Therefore we check if this is a wap table and if it is,
456        # this is not a replace, and the table already exists then we can safely just return. Otherwise we let it error.
457        if not expression and isinstance(table_name.this, exp.Dot):
458            wap_id = table_name.this.parts[-1].name
459            if wap_id.startswith(f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"):
460                table_name.set("this", table_name.this.this)
461
462        do_dummy_insert = False
463        if self.wap_enabled:
464            wap_supported = (
465                kwargs.get("storage_format") or ""
466            ).lower() == "iceberg" or self.wap_supported(table_name)
467            do_dummy_insert = (
468                False if not wap_supported or not exists else not self.table_exists(table_name)
469            )
470        super()._create_table(
471            table_name_or_schema,
472            expression,
473            exists=exists,
474            replace=replace,
475            target_columns_to_types=target_columns_to_types,
476            table_description=table_description,
477            column_descriptions=column_descriptions,
478            track_rows_processed=track_rows_processed,
479            **kwargs,
480        )
481        table_name = (
482            table_name_or_schema.this
483            if isinstance(table_name_or_schema, exp.Schema)
484            else exp.to_table(table_name_or_schema)
485        )
486        if do_dummy_insert:
487            # Performing a dummy insert to create a dummy snapshot for Iceberg tables
488            # to workaround https://github.com/apache/iceberg/issues/8849.
489            dummy_insert = exp.insert(exp.select("*").from_(table_name), table_name)
490            self.execute(dummy_insert)
491
492    def wap_supported(self, table_name: TableName) -> bool:
493        fqn = self._ensure_fqn(table_name)
494        return (
495            self.spark.conf.get(f"spark.sql.catalog.{fqn.catalog}")
496            == "org.apache.iceberg.spark.SparkCatalog"
497        )
498
499    def wap_table_name(self, table_name: TableName, wap_id: str) -> str:
500        branch_name = self._wap_branch_name(wap_id)
501        fqn = self._ensure_fqn(table_name)
502        return exp.Dot.build([fqn, exp.to_identifier(f"{self.BRANCH_PREFIX}{branch_name}")]).sql(
503            dialect=self.dialect
504        )
505
506    def wap_prepare(self, table_name: TableName, wap_id: str) -> str:
507        branch_name = self._wap_branch_name(wap_id)
508        fqn = self._ensure_fqn(table_name)
509        self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} CREATE BRANCH {branch_name}")
510        return self.wap_table_name(table_name, wap_id)
511
512    def wap_publish(self, table_name: TableName, wap_id: str) -> None:
513        branch_name = self._wap_branch_name(wap_id)
514        fqn = self._ensure_fqn(table_name)
515
516        get_snapshot_id_query = (
517            exp.select("snapshot_id")
518            .from_(exp.Dot.build([fqn, exp.to_identifier("refs")]))
519            .where(exp.column("name").eq(branch_name))
520        )
521        iceberg_snapshot_ids = self.fetchall(get_snapshot_id_query)
522        if not iceberg_snapshot_ids:
523            raise SQLMeshError(f"Could not find Iceberg branch '{branch_name}'.")
524        iceberg_snapshot_id = iceberg_snapshot_ids[0][0]
525
526        logger.info(
527            "Cherry-picking Iceberg snapshot %s into table '%s'...", iceberg_snapshot_id, fqn
528        )
529
530        self.execute(
531            f"CALL {fqn.catalog}.system.cherrypick_snapshot('{fqn.db}.{fqn.name}', {iceberg_snapshot_id})"
532        )
533        self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} DROP BRANCH {branch_name}")
534
535    def _ensure_fqn(self, table_name: TableName) -> exp.Table:
536        if isinstance(table_name, exp.Table):
537            table_name = table_name.copy()
538        table = exp.to_table(table_name, dialect=self.dialect)
539        if not table.catalog:
540            table.set("catalog", self.get_current_catalog())
541        if not table.db:
542            table.set("db", self._get_current_schema())
543        return table
544
545    def _build_create_comment_column_exp(
546        self, table: exp.Table, column_name: str, column_comment: str, table_kind: str = "TABLE"
547    ) -> exp.Comment | str:
548        table_sql = table.sql(dialect=self.dialect, identify=True)
549        column_sql = exp.column(column_name).sql(dialect=self.dialect, identify=True)
550
551        truncated_comment = self._truncate_column_comment(column_comment)
552        comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect)
553
554        return f"ALTER TABLE {table_sql} ALTER COLUMN {column_sql} COMMENT {comment_sql}"
555
556    @classmethod
557    def _wap_branch_name(cls, wap_id: str) -> str:
558        return f"{cls.WAP_PREFIX}{wap_id}"
logger = <Logger sqlmesh.core.engine_adapter.spark (WARNING)>
 47@set_catalog()
 48class SparkEngineAdapter(
 49    GetCurrentCatalogFromFunctionMixin, HiveMetastoreTablePropertiesMixin, RowDiffMixin
 50):
 51    DIALECT = "spark"
 52    SUPPORTS_TRANSACTIONS = False
 53    INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INSERT_OVERWRITE
 54    COMMENT_CREATION_TABLE = CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS
 55    COMMENT_CREATION_VIEW = CommentCreationView.IN_SCHEMA_DEF_NO_COMMANDS
 56    # Note: Some formats (like Delta and Iceberg) support REPLACE TABLE but since we don't
 57    # currently check for storage formats we say we don't support REPLACE TABLE
 58    SUPPORTS_REPLACE_TABLE = False
 59    QUOTE_IDENTIFIERS_IN_VIEWS = False
 60    SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"]
 61
 62    WAP_PREFIX = "wap_"
 63    BRANCH_PREFIX = "branch_"
 64    SCHEMA_DIFFER_KWARGS = {
 65        "parameterized_type_defaults": {
 66            # default decimal precision varies across backends
 67            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)],
 68        },
 69    }
 70
 71    @property
 72    def connection(self) -> SparkSessionConnection:
 73        return self._connection_pool.get()
 74
 75    @property
 76    def spark(self) -> PySparkSession:
 77        return self.connection.spark
 78
 79    @property
 80    def _use_spark_session(self) -> bool:
 81        return True
 82
 83    @property
 84    def use_serverless(self) -> bool:
 85        return False
 86
 87    @property
 88    def catalog_support(self) -> CatalogSupport:
 89        return CatalogSupport.FULL_SUPPORT
 90
 91    @classproperty
 92    def _sqlglot_to_spark_primitive_mapping(self) -> t.Dict[t.Any, t.Any]:
 93        from pyspark.sql import types as spark_types
 94
 95        return {
 96            exp.DataType.Type.TINYINT: spark_types.ByteType,
 97            exp.DataType.Type.SMALLINT: spark_types.ShortType,
 98            exp.DataType.Type.INT: spark_types.IntegerType,
 99            exp.DataType.Type.BIGINT: spark_types.LongType,
100            exp.DataType.Type.FLOAT: spark_types.FloatType,
101            exp.DataType.Type.DOUBLE: spark_types.DoubleType,
102            exp.DataType.Type.DECIMAL: spark_types.DecimalType,
103            exp.DataType.Type.VARCHAR: spark_types.StringType,
104            exp.DataType.Type.CHAR: spark_types.StringType,
105            exp.DataType.Type.TEXT: spark_types.StringType,
106            exp.DataType.Type.BINARY: spark_types.BinaryType,
107            exp.DataType.Type.BOOLEAN: spark_types.BooleanType,
108            exp.DataType.Type.DATE: spark_types.DateType,
109            exp.DataType.Type.TIMESTAMPNTZ: spark_types.TimestampNTZType,
110            exp.DataType.Type.DATETIME: spark_types.TimestampNTZType,
111            exp.DataType.Type.TIMESTAMPLTZ: spark_types.TimestampType,
112            exp.DataType.Type.TIMESTAMP: spark_types.TimestampType,
113            exp.DataType.Type.TIMESTAMPTZ: spark_types.TimestampType,
114        }
115
116    @classproperty
117    def _sqlglot_to_spark_complex_mapping(self) -> t.Dict[t.Any, t.Any]:
118        from pyspark.sql import types as spark_types
119
120        return {
121            exp.DataType.Type.ARRAY: spark_types.ArrayType,
122            exp.DataType.Type.MAP: spark_types.MapType,
123            exp.DataType.Type.STRUCT: spark_types.StructType,
124        }
125
126    @classproperty
127    def _spark_to_sqlglot_primitive_mapping(self) -> t.Dict[t.Any, t.Any]:
128        return {v: k for k, v in self._sqlglot_to_spark_primitive_mapping.items()}
129
130    @classproperty
131    def _spark_to_sqlglot_complex_mapping(self) -> t.Dict[t.Any, t.Any]:
132        return {v: k for k, v in self._sqlglot_to_spark_complex_mapping.items()}
133
134    @classmethod
135    def spark_to_sqlglot_types(cls, input: spark_types.StructType) -> t.Dict[str, exp.DataType]:
136        from pyspark.sql import types as spark_types
137
138        def spark_complex_to_sqlglot_complex(
139            complex_type: t.Union[
140                spark_types.StructType, spark_types.ArrayType, spark_types.MapType
141            ],
142        ) -> exp.DataType:
143            def get_fields(
144                complex_type: t.Union[
145                    spark_types.StructType, spark_types.ArrayType, spark_types.MapType
146                ],
147            ) -> t.Sequence[spark_types.DataType]:
148                if isinstance(complex_type, spark_types.StructType):
149                    return complex_type.fields
150                if isinstance(complex_type, spark_types.ArrayType):
151                    return [complex_type.elementType]
152                if isinstance(complex_type, spark_types.MapType):
153                    return [complex_type.keyType, complex_type.valueType]
154                raise SQLMeshError(f"Unsupported complex type: {complex_type}")
155
156            expressions: t.List[t.Union[exp.ColumnDef, exp.DataType]] = []
157            fields = get_fields(complex_type)
158            for field in fields:
159                if isinstance(field, (spark_types.StructType, spark_types.MapType)):
160                    expressions.append(spark_complex_to_sqlglot_complex(field))
161                elif isinstance(field, spark_types.StructField):
162                    sqlglot_data_type = cls._spark_to_sqlglot_primitive_mapping.get(
163                        type(field.dataType)
164                    ) or spark_complex_to_sqlglot_complex(
165                        field.dataType  # type: ignore
166                    )
167                    kind = (
168                        sqlglot_data_type
169                        if isinstance(sqlglot_data_type, exp.DataType)
170                        else exp.DataType(this=sqlglot_data_type)
171                    )
172                    expressions.append(exp.ColumnDef(this=exp.to_identifier(field.name), kind=kind))
173                else:
174                    kind = exp.DataType(this=cls._spark_to_sqlglot_primitive_mapping[type(field)])
175                    expressions.append(kind)
176            dtype = cls._spark_to_sqlglot_complex_mapping[type(complex_type)]
177            return exp.DataType(
178                this=dtype,
179                expressions=expressions,
180                nested=True,
181            )
182
183        resp = spark_complex_to_sqlglot_complex(input)
184        return {column_def.this.name: column_def.args["kind"] for column_def in resp.expressions}
185
186    @classmethod
187    def sqlglot_to_spark_types(cls, input: t.Dict[str, exp.DataType]) -> spark_types.StructType:
188        from pyspark.sql import types as spark_types
189
190        def sqlglot_complex_to_spark_complex(complex_type: exp.DataType) -> spark_types.DataType:
191            is_struct = complex_type.is_type(exp.DataType.Type.STRUCT)
192            expressions = []
193            for column_def in complex_type.expressions:
194                col_name = column_def.this.name if is_struct else None
195                data_type = column_def.args["kind"] if is_struct else column_def
196                primitive_func = cls._sqlglot_to_spark_primitive_mapping.get(data_type.this)
197                type_func = (
198                    primitive_func
199                    if primitive_func
200                    else partial(sqlglot_complex_to_spark_complex, data_type)
201                )
202                if is_struct:
203                    expressions.append(spark_types.StructField(col_name, type_func()))  # type: ignore
204                else:
205                    expressions.append(type_func())  # type: ignore
206            klass = cls._sqlglot_to_spark_complex_mapping[complex_type.this]
207            if is_struct:
208                return klass(expressions)
209            return klass(*expressions)
210
211        return t.cast(
212            spark_types.StructType,
213            sqlglot_complex_to_spark_complex(
214                exp.DataType(
215                    this=exp.DataType.Type.STRUCT,
216                    expressions=[
217                        exp.ColumnDef(this=exp.to_identifier(column), kind=data_type)
218                        for column, data_type in input.items()
219                    ],
220                )
221            ),
222        )
223
224    @classmethod
225    def is_pyspark_df(cls, value: t.Any) -> bool:
226        return hasattr(value, "sparkSession")
227
228    @classmethod
229    def try_get_pyspark_df(cls, value: t.Any) -> t.Optional[PySparkDataFrame]:
230        if cls.is_pyspark_df(value):
231            return value
232        return None
233
234    @classmethod
235    def try_get_pandas_df(cls, value: t.Any) -> t.Optional[pd.DataFrame]:
236        import pandas as pd
237
238        if isinstance(value, pd.DataFrame):
239            return value
240        return None
241
242    @t.overload
243    def _columns_to_types(
244        self,
245        query_or_df: DF,
246        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
247        source_columns: t.Optional[t.List[str]] = None,
248    ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ...
249
250    @t.overload
251    def _columns_to_types(
252        self,
253        query_or_df: Query,
254        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
255        source_columns: t.Optional[t.List[str]] = None,
256    ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ...
257
258    def _columns_to_types(
259        self,
260        query_or_df: QueryOrDF,
261        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
262        source_columns: t.Optional[t.List[str]] = None,
263    ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]:
264        if target_columns_to_types:
265            return target_columns_to_types, list(source_columns or target_columns_to_types)
266        if self.is_pyspark_df(query_or_df):
267            from pyspark.sql import DataFrame
268
269            target_columns_to_types = self.spark_to_sqlglot_types(
270                t.cast(DataFrame, query_or_df).schema
271            )
272            return target_columns_to_types, list(source_columns or target_columns_to_types)
273        return super()._columns_to_types(
274            query_or_df, target_columns_to_types, source_columns=source_columns
275        )
276
277    def _df_to_source_queries(
278        self,
279        df: DF,
280        target_columns_to_types: t.Dict[str, exp.DataType],
281        batch_size: int,
282        target_table: TableName,
283        source_columns: t.Optional[t.List[str]] = None,
284    ) -> t.List[SourceQuery]:
285        df = self._ensure_pyspark_df(df, target_columns_to_types, source_columns=source_columns)
286
287        def query_factory() -> Query:
288            temp_table = self._get_temp_table(target_table or "spark", table_only=True)
289            df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect))  # type: ignore
290            temp_table.set("db", "global_temp")
291            return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table)
292
293        return [SourceQuery(query_factory=query_factory)]
294
295    def _ensure_pyspark_df(
296        self,
297        generic_df: DF,
298        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
299        source_columns: t.Optional[t.List[str]] = None,
300    ) -> PySparkDataFrame:
301        pyspark_df = self.try_get_pyspark_df(generic_df)
302        if not pyspark_df:
303            df = self.try_get_pandas_df(generic_df)
304            if df is None:
305                raise SQLMeshError(
306                    "Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame"
307                )
308
309            if target_columns_to_types:
310                source_columns_to_types = get_source_columns_to_types(
311                    target_columns_to_types, source_columns
312                )
313                # ensure Pandas dataframe column order matches columns_to_types
314                df = df[list(source_columns_to_types)]
315            else:
316                source_columns_to_types = None
317            kwargs = (
318                dict(schema=self.sqlglot_to_spark_types(source_columns_to_types))
319                if source_columns_to_types
320                else {}
321            )
322            pyspark_df = self.spark.createDataFrame(df, **kwargs)  # type: ignore
323        if target_columns_to_types:
324            select_columns = self._casted_columns(
325                target_columns_to_types, source_columns=source_columns
326            )
327            pyspark_df = pyspark_df.selectExpr(*[x.sql(self.dialect) for x in select_columns])  # type: ignore
328        return pyspark_df
329
330    def _get_temp_table(
331        self, table: TableName, table_only: bool = False, quoted: bool = True
332    ) -> exp.Table:
333        """
334        Returns the name of the temp table that should be used for the given table name.
335        """
336        table = super()._get_temp_table(table, table_only=table_only)
337        table_name_id = table.args["this"]
338        # Spark with local filesystem has an issue with temp tables that start with __temp so
339        # we update here to remove the leading double underscore
340        table_name_id.set("this", table_name_id.this.replace("__temp_", "temp_"))
341        return table
342
343    def fetchdf(
344        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
345    ) -> pd.DataFrame:
346        return self.fetch_pyspark_df(query, quote_identifiers=quote_identifiers).toPandas()
347
348    def fetch_pyspark_df(
349        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
350    ) -> PySparkDataFrame:
351        return self._ensure_pyspark_df(
352            self._fetch_native_df(query, quote_identifiers=quote_identifiers)
353        )
354
355    def _get_data_objects(
356        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
357    ) -> t.List[DataObject]:
358        schema_name = to_schema(schema_name).sql(dialect=self.dialect)
359        pattern = "*" if object_names is None else "|".join(object_names)
360        sql = f"SHOW TABLE EXTENDED IN {schema_name} LIKE '{pattern}'"
361        try:
362            results = (
363                self.fetch_pyspark_df(sql).collect()
364                if self._use_spark_session
365                else self.fetchdf(sql).to_dict("records")
366            )
367        # Improvement: Figure out all the different exceptions we could get from executing a query either with or
368        # without a Spark Session. In addition Databricks would need to be updated to handle it's own exceptions.
369        # Therefore just doing except Exception for now.
370        except Exception:
371            return []
372        data_objects = []
373        catalog = self.get_current_catalog()
374        for row in results:  # type: ignore
375            row_dict = row.asDict() if not isinstance(row, dict) else row
376            if row_dict.get("isTemporary"):
377                continue
378            schema = row_dict.get("namespace") or row_dict.get("database")
379            data_objects.append(
380                DataObject(
381                    catalog=catalog,
382                    schema=schema,
383                    name=row_dict["tableName"],
384                    type=(
385                        DataObjectType.VIEW
386                        if "Type: VIEW" in row_dict["information"]
387                        else DataObjectType.TABLE
388                    ),
389                )
390            )
391        return data_objects
392
393    def get_current_catalog(self) -> t.Optional[str]:
394        if self._use_spark_session:
395            return self.connection.get_current_catalog()
396        return super().get_current_catalog()
397
398    def set_current_catalog(self, catalog_name: str) -> None:
399        self.connection.set_current_catalog(catalog_name)
400
401    def _get_current_schema(self) -> str:
402        if self._use_spark_session:
403            return self.spark.catalog.currentDatabase()
404        return self.fetchone(exp.select(exp.func("current_database")))[0]  # type: ignore
405
406    def get_data_object(
407        self, target_name: TableName, safe_to_cache: bool = False
408    ) -> t.Optional[DataObject]:
409        target_table = exp.to_table(target_name)
410        if isinstance(target_table.this, exp.Dot) and target_table.this.expression.name.startswith(
411            f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"
412        ):
413            # Exclude the branch name
414            target_table.set("this", target_table.this.this)
415        return super().get_data_object(target_table, safe_to_cache=safe_to_cache)
416
417    def create_state_table(
418        self,
419        table_name: str,
420        target_columns_to_types: t.Dict[str, exp.DataType],
421        primary_key: t.Optional[t.Tuple[str, ...]] = None,
422    ) -> None:
423        self.create_table(
424            table_name,
425            target_columns_to_types,
426            partitioned_by=[exp.column(x) for x in primary_key] if primary_key else None,
427        )
428
429    def _native_df_to_pandas_df(
430        self,
431        query_or_df: QueryOrDF,
432    ) -> t.Union[Query, pd.DataFrame]:
433        if pyspark_df := self.try_get_pyspark_df(query_or_df):
434            return pyspark_df.toPandas()
435
436        return super()._native_df_to_pandas_df(query_or_df)
437
438    def _create_table(
439        self,
440        table_name_or_schema: t.Union[exp.Schema, TableName],
441        expression: t.Optional[exp.Expr],
442        exists: bool = True,
443        replace: bool = False,
444        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
445        table_description: t.Optional[str] = None,
446        column_descriptions: t.Optional[t.Dict[str, str]] = None,
447        table_kind: t.Optional[str] = None,
448        track_rows_processed: bool = True,
449        **kwargs: t.Any,
450    ) -> None:
451        table_name = (
452            table_name_or_schema.this
453            if isinstance(table_name_or_schema, exp.Schema)
454            else exp.to_table(table_name_or_schema)
455        )
456        # Spark doesn't support creating a wap table DDL. Therefore we check if this is a wap table and if it is,
457        # this is not a replace, and the table already exists then we can safely just return. Otherwise we let it error.
458        if not expression and isinstance(table_name.this, exp.Dot):
459            wap_id = table_name.this.parts[-1].name
460            if wap_id.startswith(f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"):
461                table_name.set("this", table_name.this.this)
462
463        do_dummy_insert = False
464        if self.wap_enabled:
465            wap_supported = (
466                kwargs.get("storage_format") or ""
467            ).lower() == "iceberg" or self.wap_supported(table_name)
468            do_dummy_insert = (
469                False if not wap_supported or not exists else not self.table_exists(table_name)
470            )
471        super()._create_table(
472            table_name_or_schema,
473            expression,
474            exists=exists,
475            replace=replace,
476            target_columns_to_types=target_columns_to_types,
477            table_description=table_description,
478            column_descriptions=column_descriptions,
479            track_rows_processed=track_rows_processed,
480            **kwargs,
481        )
482        table_name = (
483            table_name_or_schema.this
484            if isinstance(table_name_or_schema, exp.Schema)
485            else exp.to_table(table_name_or_schema)
486        )
487        if do_dummy_insert:
488            # Performing a dummy insert to create a dummy snapshot for Iceberg tables
489            # to workaround https://github.com/apache/iceberg/issues/8849.
490            dummy_insert = exp.insert(exp.select("*").from_(table_name), table_name)
491            self.execute(dummy_insert)
492
493    def wap_supported(self, table_name: TableName) -> bool:
494        fqn = self._ensure_fqn(table_name)
495        return (
496            self.spark.conf.get(f"spark.sql.catalog.{fqn.catalog}")
497            == "org.apache.iceberg.spark.SparkCatalog"
498        )
499
500    def wap_table_name(self, table_name: TableName, wap_id: str) -> str:
501        branch_name = self._wap_branch_name(wap_id)
502        fqn = self._ensure_fqn(table_name)
503        return exp.Dot.build([fqn, exp.to_identifier(f"{self.BRANCH_PREFIX}{branch_name}")]).sql(
504            dialect=self.dialect
505        )
506
507    def wap_prepare(self, table_name: TableName, wap_id: str) -> str:
508        branch_name = self._wap_branch_name(wap_id)
509        fqn = self._ensure_fqn(table_name)
510        self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} CREATE BRANCH {branch_name}")
511        return self.wap_table_name(table_name, wap_id)
512
513    def wap_publish(self, table_name: TableName, wap_id: str) -> None:
514        branch_name = self._wap_branch_name(wap_id)
515        fqn = self._ensure_fqn(table_name)
516
517        get_snapshot_id_query = (
518            exp.select("snapshot_id")
519            .from_(exp.Dot.build([fqn, exp.to_identifier("refs")]))
520            .where(exp.column("name").eq(branch_name))
521        )
522        iceberg_snapshot_ids = self.fetchall(get_snapshot_id_query)
523        if not iceberg_snapshot_ids:
524            raise SQLMeshError(f"Could not find Iceberg branch '{branch_name}'.")
525        iceberg_snapshot_id = iceberg_snapshot_ids[0][0]
526
527        logger.info(
528            "Cherry-picking Iceberg snapshot %s into table '%s'...", iceberg_snapshot_id, fqn
529        )
530
531        self.execute(
532            f"CALL {fqn.catalog}.system.cherrypick_snapshot('{fqn.db}.{fqn.name}', {iceberg_snapshot_id})"
533        )
534        self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} DROP BRANCH {branch_name}")
535
536    def _ensure_fqn(self, table_name: TableName) -> exp.Table:
537        if isinstance(table_name, exp.Table):
538            table_name = table_name.copy()
539        table = exp.to_table(table_name, dialect=self.dialect)
540        if not table.catalog:
541            table.set("catalog", self.get_current_catalog())
542        if not table.db:
543            table.set("db", self._get_current_schema())
544        return table
545
546    def _build_create_comment_column_exp(
547        self, table: exp.Table, column_name: str, column_comment: str, table_kind: str = "TABLE"
548    ) -> exp.Comment | str:
549        table_sql = table.sql(dialect=self.dialect, identify=True)
550        column_sql = exp.column(column_name).sql(dialect=self.dialect, identify=True)
551
552        truncated_comment = self._truncate_column_comment(column_comment)
553        comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect)
554
555        return f"ALTER TABLE {table_sql} ALTER COLUMN {column_sql} COMMENT {comment_sql}"
556
557    @classmethod
558    def _wap_branch_name(cls, wap_id: str) -> str:
559        return f"{cls.WAP_PREFIX}{wap_id}"

Base class wrapping a Database API compliant connection.

The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.

Arguments:
  • connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
  • dialect: The dialect with which this adapter is associated.
  • multithreaded: Indicates whether this adapter will be used by more than one thread.
DIALECT = 'spark'
SUPPORTS_TRANSACTIONS = False
INSERT_OVERWRITE_STRATEGY = <InsertOverwriteStrategy.INSERT_OVERWRITE: 2>
COMMENT_CREATION_TABLE = <CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS: 3>
COMMENT_CREATION_VIEW = <CommentCreationView.IN_SCHEMA_DEF_NO_COMMANDS: 3>
SUPPORTS_REPLACE_TABLE = False
QUOTE_IDENTIFIERS_IN_VIEWS = False
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ['DATABASE', 'SCHEMA']
WAP_PREFIX = 'wap_'
BRANCH_PREFIX = 'branch_'
SCHEMA_DIFFER_KWARGS = {'parameterized_type_defaults': {<DType.DECIMAL: 'DECIMAL'>: [(), (0,)]}}
connection: <MagicMock id='132726895914688'>
71    @property
72    def connection(self) -> SparkSessionConnection:
73        return self._connection_pool.get()
spark: <MagicMock id='132726885925360'>
75    @property
76    def spark(self) -> PySparkSession:
77        return self.connection.spark
use_serverless: bool
83    @property
84    def use_serverless(self) -> bool:
85        return False
87    @property
88    def catalog_support(self) -> CatalogSupport:
89        return CatalogSupport.FULL_SUPPORT
@classmethod
def spark_to_sqlglot_types( cls, input: <MagicMock name='mock.StructType' id='132726884405776'>) -> Dict[str, sqlglot.expressions.datatypes.DataType]:
134    @classmethod
135    def spark_to_sqlglot_types(cls, input: spark_types.StructType) -> t.Dict[str, exp.DataType]:
136        from pyspark.sql import types as spark_types
137
138        def spark_complex_to_sqlglot_complex(
139            complex_type: t.Union[
140                spark_types.StructType, spark_types.ArrayType, spark_types.MapType
141            ],
142        ) -> exp.DataType:
143            def get_fields(
144                complex_type: t.Union[
145                    spark_types.StructType, spark_types.ArrayType, spark_types.MapType
146                ],
147            ) -> t.Sequence[spark_types.DataType]:
148                if isinstance(complex_type, spark_types.StructType):
149                    return complex_type.fields
150                if isinstance(complex_type, spark_types.ArrayType):
151                    return [complex_type.elementType]
152                if isinstance(complex_type, spark_types.MapType):
153                    return [complex_type.keyType, complex_type.valueType]
154                raise SQLMeshError(f"Unsupported complex type: {complex_type}")
155
156            expressions: t.List[t.Union[exp.ColumnDef, exp.DataType]] = []
157            fields = get_fields(complex_type)
158            for field in fields:
159                if isinstance(field, (spark_types.StructType, spark_types.MapType)):
160                    expressions.append(spark_complex_to_sqlglot_complex(field))
161                elif isinstance(field, spark_types.StructField):
162                    sqlglot_data_type = cls._spark_to_sqlglot_primitive_mapping.get(
163                        type(field.dataType)
164                    ) or spark_complex_to_sqlglot_complex(
165                        field.dataType  # type: ignore
166                    )
167                    kind = (
168                        sqlglot_data_type
169                        if isinstance(sqlglot_data_type, exp.DataType)
170                        else exp.DataType(this=sqlglot_data_type)
171                    )
172                    expressions.append(exp.ColumnDef(this=exp.to_identifier(field.name), kind=kind))
173                else:
174                    kind = exp.DataType(this=cls._spark_to_sqlglot_primitive_mapping[type(field)])
175                    expressions.append(kind)
176            dtype = cls._spark_to_sqlglot_complex_mapping[type(complex_type)]
177            return exp.DataType(
178                this=dtype,
179                expressions=expressions,
180                nested=True,
181            )
182
183        resp = spark_complex_to_sqlglot_complex(input)
184        return {column_def.this.name: column_def.args["kind"] for column_def in resp.expressions}
@classmethod
def sqlglot_to_spark_types( cls, input: Dict[str, sqlglot.expressions.datatypes.DataType]) -> <MagicMock name='mock.StructType' id='132726884405776'>:
186    @classmethod
187    def sqlglot_to_spark_types(cls, input: t.Dict[str, exp.DataType]) -> spark_types.StructType:
188        from pyspark.sql import types as spark_types
189
190        def sqlglot_complex_to_spark_complex(complex_type: exp.DataType) -> spark_types.DataType:
191            is_struct = complex_type.is_type(exp.DataType.Type.STRUCT)
192            expressions = []
193            for column_def in complex_type.expressions:
194                col_name = column_def.this.name if is_struct else None
195                data_type = column_def.args["kind"] if is_struct else column_def
196                primitive_func = cls._sqlglot_to_spark_primitive_mapping.get(data_type.this)
197                type_func = (
198                    primitive_func
199                    if primitive_func
200                    else partial(sqlglot_complex_to_spark_complex, data_type)
201                )
202                if is_struct:
203                    expressions.append(spark_types.StructField(col_name, type_func()))  # type: ignore
204                else:
205                    expressions.append(type_func())  # type: ignore
206            klass = cls._sqlglot_to_spark_complex_mapping[complex_type.this]
207            if is_struct:
208                return klass(expressions)
209            return klass(*expressions)
210
211        return t.cast(
212            spark_types.StructType,
213            sqlglot_complex_to_spark_complex(
214                exp.DataType(
215                    this=exp.DataType.Type.STRUCT,
216                    expressions=[
217                        exp.ColumnDef(this=exp.to_identifier(column), kind=data_type)
218                        for column, data_type in input.items()
219                    ],
220                )
221            ),
222        )
@classmethod
def is_pyspark_df(cls, value: Any) -> bool:
224    @classmethod
225    def is_pyspark_df(cls, value: t.Any) -> bool:
226        return hasattr(value, "sparkSession")
@classmethod
def try_get_pyspark_df(cls, value: Any) -> Optional[<MagicMock id='132726884245632'>]:
228    @classmethod
229    def try_get_pyspark_df(cls, value: t.Any) -> t.Optional[PySparkDataFrame]:
230        if cls.is_pyspark_df(value):
231            return value
232        return None
@classmethod
def try_get_pandas_df(cls, value: Any) -> Optional[pandas.core.frame.DataFrame]:
234    @classmethod
235    def try_get_pandas_df(cls, value: t.Any) -> t.Optional[pd.DataFrame]:
236        import pandas as pd
237
238        if isinstance(value, pd.DataFrame):
239            return value
240        return None
def fetchdf( self, query: Union[sqlglot.expressions.core.Expr, str], quote_identifiers: bool = False) -> pandas.core.frame.DataFrame:
343    def fetchdf(
344        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
345    ) -> pd.DataFrame:
346        return self.fetch_pyspark_df(query, quote_identifiers=quote_identifiers).toPandas()

Fetches a Pandas DataFrame from the cursor

def fetch_pyspark_df( self, query: Union[sqlglot.expressions.core.Expr, str], quote_identifiers: bool = False) -> <MagicMock id='132726884245632'>:
348    def fetch_pyspark_df(
349        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
350    ) -> PySparkDataFrame:
351        return self._ensure_pyspark_df(
352            self._fetch_native_df(query, quote_identifiers=quote_identifiers)
353        )

Fetches a PySpark DataFrame from the cursor

def get_current_catalog(self) -> Optional[str]:
393    def get_current_catalog(self) -> t.Optional[str]:
394        if self._use_spark_session:
395            return self.connection.get_current_catalog()
396        return super().get_current_catalog()

Returns the catalog name of the current connection.

def set_current_catalog(self, catalog_name: str) -> None:
398    def set_current_catalog(self, catalog_name: str) -> None:
399        self.connection.set_current_catalog(catalog_name)

Sets the catalog name of the current connection.

def get_data_object( self, target_name: <MagicMock id='132726884230256'>, safe_to_cache: bool = False) -> Optional[sqlmesh.core.engine_adapter.shared.DataObject]:
406    def get_data_object(
407        self, target_name: TableName, safe_to_cache: bool = False
408    ) -> t.Optional[DataObject]:
409        target_table = exp.to_table(target_name)
410        if isinstance(target_table.this, exp.Dot) and target_table.this.expression.name.startswith(
411            f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"
412        ):
413            # Exclude the branch name
414            target_table.set("this", target_table.this.this)
415        return super().get_data_object(target_table, safe_to_cache=safe_to_cache)
def create_state_table( self, table_name: str, target_columns_to_types: Dict[str, sqlglot.expressions.datatypes.DataType], primary_key: Optional[Tuple[str, ...]] = None) -> None:
417    def create_state_table(
418        self,
419        table_name: str,
420        target_columns_to_types: t.Dict[str, exp.DataType],
421        primary_key: t.Optional[t.Tuple[str, ...]] = None,
422    ) -> None:
423        self.create_table(
424            table_name,
425            target_columns_to_types,
426            partitioned_by=[exp.column(x) for x in primary_key] if primary_key else None,
427        )

Create a table to store SQLMesh internal state.

Arguments:
  • table_name: The name of the table to create. Can be fully qualified or just table name.
  • target_columns_to_types: A mapping between the column name and its data type.
  • primary_key: Determines the table primary key.
def wap_supported(self, table_name: <MagicMock id='132726884230256'>) -> bool:
493    def wap_supported(self, table_name: TableName) -> bool:
494        fqn = self._ensure_fqn(table_name)
495        return (
496            self.spark.conf.get(f"spark.sql.catalog.{fqn.catalog}")
497            == "org.apache.iceberg.spark.SparkCatalog"
498        )

Returns whether WAP for the target table is supported.

def wap_table_name(self, table_name: <MagicMock id='132726884230256'>, wap_id: str) -> str:
500    def wap_table_name(self, table_name: TableName, wap_id: str) -> str:
501        branch_name = self._wap_branch_name(wap_id)
502        fqn = self._ensure_fqn(table_name)
503        return exp.Dot.build([fqn, exp.to_identifier(f"{self.BRANCH_PREFIX}{branch_name}")]).sql(
504            dialect=self.dialect
505        )

Returns the updated table name for the given WAP ID.

Arguments:
  • table_name: The name of the target table.
  • wap_id: The WAP ID to prepare.
Returns:

The updated table name that should be used for writing.

def wap_prepare(self, table_name: <MagicMock id='132726884230256'>, wap_id: str) -> str:
507    def wap_prepare(self, table_name: TableName, wap_id: str) -> str:
508        branch_name = self._wap_branch_name(wap_id)
509        fqn = self._ensure_fqn(table_name)
510        self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} CREATE BRANCH {branch_name}")
511        return self.wap_table_name(table_name, wap_id)

Prepares the target table for WAP and returns the updated table name.

Arguments:
  • table_name: The name of the target table.
  • wap_id: The WAP ID to prepare.
Returns:

The updated table name that should be used for writing.

def wap_publish(self, table_name: <MagicMock id='132726884230256'>, wap_id: str) -> None:
513    def wap_publish(self, table_name: TableName, wap_id: str) -> None:
514        branch_name = self._wap_branch_name(wap_id)
515        fqn = self._ensure_fqn(table_name)
516
517        get_snapshot_id_query = (
518            exp.select("snapshot_id")
519            .from_(exp.Dot.build([fqn, exp.to_identifier("refs")]))
520            .where(exp.column("name").eq(branch_name))
521        )
522        iceberg_snapshot_ids = self.fetchall(get_snapshot_id_query)
523        if not iceberg_snapshot_ids:
524            raise SQLMeshError(f"Could not find Iceberg branch '{branch_name}'.")
525        iceberg_snapshot_id = iceberg_snapshot_ids[0][0]
526
527        logger.info(
528            "Cherry-picking Iceberg snapshot %s into table '%s'...", iceberg_snapshot_id, fqn
529        )
530
531        self.execute(
532            f"CALL {fqn.catalog}.system.cherrypick_snapshot('{fqn.db}.{fqn.name}', {iceberg_snapshot_id})"
533        )
534        self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} DROP BRANCH {branch_name}")

Publishes changes with the given WAP ID to the target table.

Arguments:
  • table_name: The name of the target table.
  • wap_id: The WAP ID to publish.
Inherited Members
sqlmesh.core.engine_adapter.base.EngineAdapter
EngineAdapter
DEFAULT_BATCH_SIZE
DATA_OBJECT_FILTER_BATCH_SIZE
SUPPORTS_INDEXES
SUPPORTS_MATERIALIZED_VIEWS
SUPPORTS_MATERIALIZED_VIEW_SCHEMA
SUPPORTS_VIEW_SCHEMA
SUPPORTS_CLONING
SUPPORTS_MANAGED_MODELS
SUPPORTS_CREATE_DROP_CATALOG
SUPPORTS_TUPLE_IN
HAS_VIEW_BINDING
SUPPORTS_GRANTS
DEFAULT_CATALOG_TYPE
MAX_IDENTIFIER_LENGTH
ATTACH_CORRELATION_ID
SUPPORTS_QUERY_EXECUTION_TRACKING
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
dialect
correlation_id
with_settings
cursor
snowpark
bigframe
comments_enabled
schema_differ
default_catalog
engine_run_mode
recycle
close
get_catalog_type
get_catalog_type_from_table
current_catalog_type
replace_query
create_index
create_table
create_managed_table
ctas
create_table_like
clone_table
drop_data_object
drop_table
drop_managed_table
get_alter_operations
alter_table
create_view
create_schema
drop_schema
drop_view
create_catalog
drop_catalog
columns
table_exists
delete_from
insert_append
insert_overwrite_by_partition
insert_overwrite_by_time_partition
update_table
scd_type_2_by_time
scd_type_2_by_column
merge
rename_table
get_data_objects
fetchone
fetchall
wap_enabled
sync_grants_config
transaction
session
execute
temp_table
drop_data_object_on_type_mismatch
ensure_nulls_for_unmatched_after_join
use_server_nulls_for_unmatched_after_join
ping
get_table_last_modified_ts
sqlmesh.core.engine_adapter.mixins.GetCurrentCatalogFromFunctionMixin
CURRENT_CATALOG_EXPRESSION
sqlmesh.core.engine_adapter.mixins.HiveMetastoreTablePropertiesMixin
MAX_TABLE_COMMENT_LENGTH
MAX_COLUMN_COMMENT_LENGTH
sqlmesh.core.engine_adapter.mixins.RowDiffMixin
MAX_TIMESTAMP_PRECISION
concat_columns
normalize_value