Edit on GitHub

sqlmesh.core.engine_adapter.databricks

  1from __future__ import annotations
  2
  3import logging
  4import typing as t
  5from functools import partial
  6
  7from sqlglot import exp
  8
  9from sqlmesh.core.dialect import to_schema
 10from sqlmesh.core.engine_adapter.mixins import GrantsFromInfoSchemaMixin
 11from sqlmesh.core.engine_adapter.shared import (
 12    CatalogSupport,
 13    DataObject,
 14    DataObjectType,
 15    InsertOverwriteStrategy,
 16    SourceQuery,
 17)
 18from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
 19from sqlmesh.core.node import IntervalUnit
 20from sqlmesh.core.schema_diff import NestedSupport
 21from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection
 22from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError
 23
 24if t.TYPE_CHECKING:
 25    import pandas as pd
 26
 27    from sqlmesh.core._typing import SchemaName, TableName, SessionProperties
 28    from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query
 29
 30logger = logging.getLogger(__name__)
 31
 32
 33class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin):
 34    DIALECT = "databricks"
 35    INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE
 36    SUPPORTS_CLONING = True
 37    SUPPORTS_MATERIALIZED_VIEWS = True
 38    SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
 39    SUPPORTS_GRANTS = True
 40    USE_CATALOG_IN_GRANTS = True
 41    # Spark has this set to false for compatibility when mixing with Trino but that isn't a concern with Databricks
 42    QUOTE_IDENTIFIERS_IN_VIEWS = True
 43    SCHEMA_DIFFER_KWARGS = {
 44        "support_positional_add": True,
 45        "nested_support": NestedSupport.ALL,
 46        "array_element_selector": "element",
 47        "parameterized_type_defaults": {
 48            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)],
 49        },
 50    }
 51
 52    def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
 53        super().__init__(*args, **kwargs)
 54        self._set_spark_engine_adapter_if_needed()
 55
 56    @classmethod
 57    def can_access_spark_session(cls, disable_spark_session: bool) -> bool:
 58        from sqlmesh import RuntimeEnv
 59
 60        if disable_spark_session:
 61            return False
 62
 63        return RuntimeEnv.get().is_databricks
 64
 65    @classmethod
 66    def can_access_databricks_connect(cls, disable_databricks_connect: bool) -> bool:
 67        if disable_databricks_connect:
 68            return False
 69
 70        try:
 71            from databricks.connect import DatabricksSession  # noqa
 72
 73            return True
 74        except ImportError:
 75            return False
 76
 77    @property
 78    def _use_spark_session(self) -> bool:
 79        if self.can_access_spark_session(bool(self._extra_config.get("disable_spark_session"))):
 80            return True
 81
 82        if self.can_access_databricks_connect(
 83            bool(self._extra_config.get("disable_databricks_connect"))
 84        ):
 85            if self._extra_config.get("databricks_connect_use_serverless"):
 86                return True
 87
 88            if {
 89                "databricks_connect_cluster_id",
 90                "databricks_connect_server_hostname",
 91                "databricks_connect_access_token",
 92            }.issubset(self._extra_config):
 93                return True
 94
 95        return False
 96
 97    @property
 98    def is_spark_session_connection(self) -> bool:
 99        return isinstance(self.connection, SparkSessionConnection)
100
101    def _set_spark_engine_adapter_if_needed(self) -> None:
102        self._spark_engine_adapter = None
103
104        if not self._use_spark_session or self.is_spark_session_connection:
105            return
106
107        from databricks.connect import DatabricksSession
108
109        connect_kwargs = dict(
110            host=self._extra_config["databricks_connect_server_hostname"],
111            token=self._extra_config.get("databricks_connect_access_token"),
112        )
113        if "databricks_connect_use_serverless" in self._extra_config:
114            connect_kwargs["serverless"] = True
115        else:
116            connect_kwargs["cluster_id"] = self._extra_config["databricks_connect_cluster_id"]
117
118        catalog = self._extra_config.get("catalog")
119        spark = (
120            DatabricksSession.builder.remote(**connect_kwargs).userAgent("sqlmesh").getOrCreate()
121        )
122        self._spark_engine_adapter = SparkEngineAdapter(
123            partial(connection, spark=spark, catalog=catalog),
124            default_catalog=catalog,
125            execute_log_level=self._execute_log_level,
126            multithreaded=self._multithreaded,
127            sql_gen_kwargs=self._sql_gen_kwargs,
128            register_comments=self._register_comments,
129            pre_ping=self._pre_ping,
130            pretty_sql=self._pretty_sql,
131        )
132
133    @property
134    def cursor(self) -> t.Any:
135        if (
136            self._connection_pool.get_attribute("use_spark_engine_adapter")
137            and not self.is_spark_session_connection
138        ):
139            return self._spark_engine_adapter.cursor  # type: ignore
140        return super().cursor
141
142    @property
143    def spark(self) -> PySparkSession:
144        if not self._use_spark_session:
145            raise SQLMeshError(
146                "SparkSession is not available. "
147                "Either run from a Databricks Notebook or "
148                "install `databricks-connect` and configure it to connect to your Databricks cluster."
149            )
150        if self.is_spark_session_connection:
151            return self.connection.spark
152        return self._spark_engine_adapter.spark  # type: ignore
153
154    @property
155    def catalog_support(self) -> CatalogSupport:
156        return CatalogSupport.FULL_SUPPORT
157
158    @staticmethod
159    def _grant_object_kind(table_type: DataObjectType) -> str:
160        if table_type == DataObjectType.VIEW:
161            return "VIEW"
162        if table_type == DataObjectType.MATERIALIZED_VIEW:
163            return "MATERIALIZED VIEW"
164        return "TABLE"
165
166    def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
167        # We only care about explicitly granted privileges and not inherited ones
168        # if this is removed you would see grants inherited from the catalog get returned
169        expression = super()._get_grant_expression(table)
170        expression.args["where"].set(
171            "this",
172            exp.and_(
173                expression.args["where"].this,
174                exp.column("inherited_from").eq(exp.Literal.string("NONE")),
175                wrap=False,
176            ),
177        )
178        return expression
179
180    def _begin_session(self, properties: SessionProperties) -> t.Any:
181        """Begin a new session."""
182        # Align the different possible connectors to a single catalog
183        self.set_current_catalog(self.default_catalog)  # type: ignore
184
185    def _end_session(self) -> None:
186        self._connection_pool.set_attribute("use_spark_engine_adapter", False)
187
188    def _df_to_source_queries(
189        self,
190        df: DF,
191        target_columns_to_types: t.Dict[str, exp.DataType],
192        batch_size: int,
193        target_table: TableName,
194        source_columns: t.Optional[t.List[str]] = None,
195    ) -> t.List[SourceQuery]:
196        if not self._use_spark_session:
197            return super(SparkEngineAdapter, self)._df_to_source_queries(
198                df, target_columns_to_types, batch_size, target_table, source_columns=source_columns
199            )
200        pyspark_df = self._ensure_pyspark_df(
201            df, target_columns_to_types, source_columns=source_columns
202        )
203
204        def query_factory() -> Query:
205            temp_table = self._get_temp_table(target_table or "spark", table_only=True)
206            pyspark_df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect))
207            self._connection_pool.set_attribute("use_spark_engine_adapter", True)
208            return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table)
209
210        return [SourceQuery(query_factory=query_factory)]
211
212    def _fetch_native_df(
213        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
214    ) -> DF:
215        """Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
216        if self.is_spark_session_connection:
217            return super()._fetch_native_df(query, quote_identifiers=quote_identifiers)
218        if self._spark_engine_adapter:
219            return self._spark_engine_adapter._fetch_native_df(  # type: ignore
220                query, quote_identifiers=quote_identifiers
221            )
222        self.execute(query)
223        return self.cursor.fetchall_arrow().to_pandas()
224
225    def fetchdf(
226        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
227    ) -> pd.DataFrame:
228        """
229        Returns a Pandas DataFrame from a query or expression.
230        """
231        import pandas as pd
232
233        df = self._fetch_native_df(query, quote_identifiers=quote_identifiers)
234        if not isinstance(df, pd.DataFrame):
235            return df.toPandas()
236        return df
237
238    def get_current_catalog(self) -> t.Optional[str]:
239        pyspark_catalog = None
240        sql_connector_catalog = None
241        if self._spark_engine_adapter:
242            from py4j.protocol import Py4JError
243            from pyspark.errors.exceptions.connect import SparkConnectGrpcException
244
245            try:
246                # Note: Spark 3.4+ Only API
247                pyspark_catalog = self._spark_engine_adapter.get_current_catalog()
248            except (Py4JError, SparkConnectGrpcException):
249                pass
250        elif self.is_spark_session_connection:
251            pyspark_catalog = self.connection.spark.catalog.currentCatalog()
252        if not self.is_spark_session_connection:
253            result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION))
254            sql_connector_catalog = result[0] if result else None
255        if self._spark_engine_adapter and pyspark_catalog != sql_connector_catalog:
256            logger.warning(
257                f"Current catalog mismatch between Databricks SQL Connector and Databricks-Connect: `{sql_connector_catalog}` != `{pyspark_catalog}`. Set `catalog` connection property to make them the same."
258            )
259        return pyspark_catalog or sql_connector_catalog
260
261    def set_current_catalog(self, catalog_name: str) -> None:
262        def _set_spark_session_current_catalog(spark: PySparkSession) -> None:
263            from py4j.protocol import Py4JError
264            from pyspark.errors.exceptions.connect import SparkConnectGrpcException
265
266            try:
267                # Note: Spark 3.4+ Only API
268                spark.catalog.setCurrentCatalog(catalog_name)
269            except (Py4JError, SparkConnectGrpcException):
270                pass
271
272        # Since Databricks splits commands across the Dataframe API and the SQL Connector
273        # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
274        # are set to the same catalog since they maintain their default catalog separately
275        self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG"))
276        if self.is_spark_session_connection:
277            _set_spark_session_current_catalog(self.connection.spark)
278
279        if self._spark_engine_adapter:
280            _set_spark_session_current_catalog(self._spark_engine_adapter.spark)
281
282    def _get_data_objects(
283        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
284    ) -> t.List[DataObject]:
285        """
286        Returns all the data objects that exist in the given schema and catalog.
287        """
288        schema = to_schema(schema_name)
289        catalog_name = schema.catalog or self.get_current_catalog()
290        query = (
291            exp.select(
292                exp.column("table_name").as_("name"),
293                exp.column("table_schema").as_("schema"),
294                exp.column("table_catalog").as_("catalog"),
295                exp.case(exp.column("table_type"))
296                .when(exp.Literal.string("VIEW"), exp.Literal.string("view"))
297                .when(
298                    exp.Literal.string("MATERIALIZED_VIEW"), exp.Literal.string("materialized_view")
299                )
300                .else_(exp.Literal.string("table"))
301                .as_("type"),
302            )
303            .from_(
304                # always query `system` information_schema
305                exp.table_("tables", "information_schema", "system")
306            )
307            .where(exp.column("table_catalog").eq(catalog_name))
308            .where(exp.column("table_schema").eq(schema.db))
309        )
310
311        if object_names:
312            query = query.where(exp.column("table_name").isin(*object_names))
313
314        df = self.fetchdf(query)
315        return [
316            DataObject(
317                catalog=row.catalog,  # type: ignore
318                schema=row.schema,  # type: ignore
319                name=row.name,  # type: ignore
320                type=DataObjectType.from_str(row.type),  # type: ignore
321            )
322            for row in df.itertuples()
323        ]
324
325    def clone_table(
326        self,
327        target_table_name: TableName,
328        source_table_name: TableName,
329        replace: bool = False,
330        exists: bool = True,
331        clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
332        **kwargs: t.Any,
333    ) -> None:
334        clone_kwargs = clone_kwargs or {}
335        clone_kwargs["shallow"] = True
336        super().clone_table(
337            target_table_name,
338            source_table_name,
339            replace=replace,
340            clone_kwargs=clone_kwargs,
341            **kwargs,
342        )
343
344    def wap_supported(self, table_name: TableName) -> bool:
345        return False
346
347    def close(self) -> t.Any:
348        """Closes all open connections and releases all allocated resources."""
349        super().close()
350        if self._spark_engine_adapter:
351            self._spark_engine_adapter.close()
352
353    @property
354    def default_catalog(self) -> t.Optional[str]:
355        try:
356            return super().default_catalog
357        except MissingDefaultCatalogError as e:
358            raise MissingDefaultCatalogError(
359                "Could not determine default catalog. Define the connection property `catalog` since it can't be inferred from your connection. See SQLMesh Databricks documentation for details"
360            ) from e
361
362    def _build_table_properties_exp(
363        self,
364        catalog_name: t.Optional[str] = None,
365        table_format: t.Optional[str] = None,
366        storage_format: t.Optional[str] = None,
367        partitioned_by: t.Optional[t.List[exp.Expr]] = None,
368        partition_interval_unit: t.Optional[IntervalUnit] = None,
369        clustered_by: t.Optional[t.List[exp.Expr]] = None,
370        table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
371        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
372        table_description: t.Optional[str] = None,
373        table_kind: t.Optional[str] = None,
374        **kwargs: t.Any,
375    ) -> t.Optional[exp.Properties]:
376        properties = super()._build_table_properties_exp(
377            catalog_name=catalog_name,
378            table_format=table_format,
379            storage_format=storage_format,
380            partitioned_by=partitioned_by,
381            partition_interval_unit=partition_interval_unit,
382            clustered_by=clustered_by,
383            table_properties=table_properties,
384            target_columns_to_types=target_columns_to_types,
385            table_description=table_description,
386            table_kind=table_kind,
387        )
388        if clustered_by:
389            # Databricks expects wrapped CLUSTER BY expressions
390            clustered_by_exp = exp.Cluster(
391                expressions=[exp.Tuple(expressions=[c.copy() for c in clustered_by])]
392            )
393            expressions = properties.expressions if properties else []
394            expressions.append(clustered_by_exp)
395            properties = exp.Properties(expressions=expressions)
396        return properties
397
398    def _build_column_defs(
399        self,
400        target_columns_to_types: t.Dict[str, exp.DataType],
401        column_descriptions: t.Optional[t.Dict[str, str]] = None,
402        is_view: bool = False,
403        materialized: bool = False,
404    ) -> t.List[exp.ColumnDef]:
405        # Databricks requires column types to be specified when adding column comments
406        # in CREATE MATERIALIZED VIEW statements. Override is_view to False to force
407        # column types to be included when comments are present.
408        if is_view and materialized and column_descriptions:
409            is_view = False
410
411        return super()._build_column_defs(
412            target_columns_to_types, column_descriptions, is_view, materialized
413        )
logger = <Logger sqlmesh.core.engine_adapter.databricks (WARNING)>
 34class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin):
 35    DIALECT = "databricks"
 36    INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE
 37    SUPPORTS_CLONING = True
 38    SUPPORTS_MATERIALIZED_VIEWS = True
 39    SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
 40    SUPPORTS_GRANTS = True
 41    USE_CATALOG_IN_GRANTS = True
 42    # Spark has this set to false for compatibility when mixing with Trino but that isn't a concern with Databricks
 43    QUOTE_IDENTIFIERS_IN_VIEWS = True
 44    SCHEMA_DIFFER_KWARGS = {
 45        "support_positional_add": True,
 46        "nested_support": NestedSupport.ALL,
 47        "array_element_selector": "element",
 48        "parameterized_type_defaults": {
 49            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)],
 50        },
 51    }
 52
 53    def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
 54        super().__init__(*args, **kwargs)
 55        self._set_spark_engine_adapter_if_needed()
 56
 57    @classmethod
 58    def can_access_spark_session(cls, disable_spark_session: bool) -> bool:
 59        from sqlmesh import RuntimeEnv
 60
 61        if disable_spark_session:
 62            return False
 63
 64        return RuntimeEnv.get().is_databricks
 65
 66    @classmethod
 67    def can_access_databricks_connect(cls, disable_databricks_connect: bool) -> bool:
 68        if disable_databricks_connect:
 69            return False
 70
 71        try:
 72            from databricks.connect import DatabricksSession  # noqa
 73
 74            return True
 75        except ImportError:
 76            return False
 77
 78    @property
 79    def _use_spark_session(self) -> bool:
 80        if self.can_access_spark_session(bool(self._extra_config.get("disable_spark_session"))):
 81            return True
 82
 83        if self.can_access_databricks_connect(
 84            bool(self._extra_config.get("disable_databricks_connect"))
 85        ):
 86            if self._extra_config.get("databricks_connect_use_serverless"):
 87                return True
 88
 89            if {
 90                "databricks_connect_cluster_id",
 91                "databricks_connect_server_hostname",
 92                "databricks_connect_access_token",
 93            }.issubset(self._extra_config):
 94                return True
 95
 96        return False
 97
 98    @property
 99    def is_spark_session_connection(self) -> bool:
100        return isinstance(self.connection, SparkSessionConnection)
101
102    def _set_spark_engine_adapter_if_needed(self) -> None:
103        self._spark_engine_adapter = None
104
105        if not self._use_spark_session or self.is_spark_session_connection:
106            return
107
108        from databricks.connect import DatabricksSession
109
110        connect_kwargs = dict(
111            host=self._extra_config["databricks_connect_server_hostname"],
112            token=self._extra_config.get("databricks_connect_access_token"),
113        )
114        if "databricks_connect_use_serverless" in self._extra_config:
115            connect_kwargs["serverless"] = True
116        else:
117            connect_kwargs["cluster_id"] = self._extra_config["databricks_connect_cluster_id"]
118
119        catalog = self._extra_config.get("catalog")
120        spark = (
121            DatabricksSession.builder.remote(**connect_kwargs).userAgent("sqlmesh").getOrCreate()
122        )
123        self._spark_engine_adapter = SparkEngineAdapter(
124            partial(connection, spark=spark, catalog=catalog),
125            default_catalog=catalog,
126            execute_log_level=self._execute_log_level,
127            multithreaded=self._multithreaded,
128            sql_gen_kwargs=self._sql_gen_kwargs,
129            register_comments=self._register_comments,
130            pre_ping=self._pre_ping,
131            pretty_sql=self._pretty_sql,
132        )
133
134    @property
135    def cursor(self) -> t.Any:
136        if (
137            self._connection_pool.get_attribute("use_spark_engine_adapter")
138            and not self.is_spark_session_connection
139        ):
140            return self._spark_engine_adapter.cursor  # type: ignore
141        return super().cursor
142
143    @property
144    def spark(self) -> PySparkSession:
145        if not self._use_spark_session:
146            raise SQLMeshError(
147                "SparkSession is not available. "
148                "Either run from a Databricks Notebook or "
149                "install `databricks-connect` and configure it to connect to your Databricks cluster."
150            )
151        if self.is_spark_session_connection:
152            return self.connection.spark
153        return self._spark_engine_adapter.spark  # type: ignore
154
155    @property
156    def catalog_support(self) -> CatalogSupport:
157        return CatalogSupport.FULL_SUPPORT
158
159    @staticmethod
160    def _grant_object_kind(table_type: DataObjectType) -> str:
161        if table_type == DataObjectType.VIEW:
162            return "VIEW"
163        if table_type == DataObjectType.MATERIALIZED_VIEW:
164            return "MATERIALIZED VIEW"
165        return "TABLE"
166
167    def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
168        # We only care about explicitly granted privileges and not inherited ones
169        # if this is removed you would see grants inherited from the catalog get returned
170        expression = super()._get_grant_expression(table)
171        expression.args["where"].set(
172            "this",
173            exp.and_(
174                expression.args["where"].this,
175                exp.column("inherited_from").eq(exp.Literal.string("NONE")),
176                wrap=False,
177            ),
178        )
179        return expression
180
181    def _begin_session(self, properties: SessionProperties) -> t.Any:
182        """Begin a new session."""
183        # Align the different possible connectors to a single catalog
184        self.set_current_catalog(self.default_catalog)  # type: ignore
185
186    def _end_session(self) -> None:
187        self._connection_pool.set_attribute("use_spark_engine_adapter", False)
188
189    def _df_to_source_queries(
190        self,
191        df: DF,
192        target_columns_to_types: t.Dict[str, exp.DataType],
193        batch_size: int,
194        target_table: TableName,
195        source_columns: t.Optional[t.List[str]] = None,
196    ) -> t.List[SourceQuery]:
197        if not self._use_spark_session:
198            return super(SparkEngineAdapter, self)._df_to_source_queries(
199                df, target_columns_to_types, batch_size, target_table, source_columns=source_columns
200            )
201        pyspark_df = self._ensure_pyspark_df(
202            df, target_columns_to_types, source_columns=source_columns
203        )
204
205        def query_factory() -> Query:
206            temp_table = self._get_temp_table(target_table or "spark", table_only=True)
207            pyspark_df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect))
208            self._connection_pool.set_attribute("use_spark_engine_adapter", True)
209            return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table)
210
211        return [SourceQuery(query_factory=query_factory)]
212
213    def _fetch_native_df(
214        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
215    ) -> DF:
216        """Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
217        if self.is_spark_session_connection:
218            return super()._fetch_native_df(query, quote_identifiers=quote_identifiers)
219        if self._spark_engine_adapter:
220            return self._spark_engine_adapter._fetch_native_df(  # type: ignore
221                query, quote_identifiers=quote_identifiers
222            )
223        self.execute(query)
224        return self.cursor.fetchall_arrow().to_pandas()
225
226    def fetchdf(
227        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
228    ) -> pd.DataFrame:
229        """
230        Returns a Pandas DataFrame from a query or expression.
231        """
232        import pandas as pd
233
234        df = self._fetch_native_df(query, quote_identifiers=quote_identifiers)
235        if not isinstance(df, pd.DataFrame):
236            return df.toPandas()
237        return df
238
239    def get_current_catalog(self) -> t.Optional[str]:
240        pyspark_catalog = None
241        sql_connector_catalog = None
242        if self._spark_engine_adapter:
243            from py4j.protocol import Py4JError
244            from pyspark.errors.exceptions.connect import SparkConnectGrpcException
245
246            try:
247                # Note: Spark 3.4+ Only API
248                pyspark_catalog = self._spark_engine_adapter.get_current_catalog()
249            except (Py4JError, SparkConnectGrpcException):
250                pass
251        elif self.is_spark_session_connection:
252            pyspark_catalog = self.connection.spark.catalog.currentCatalog()
253        if not self.is_spark_session_connection:
254            result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION))
255            sql_connector_catalog = result[0] if result else None
256        if self._spark_engine_adapter and pyspark_catalog != sql_connector_catalog:
257            logger.warning(
258                f"Current catalog mismatch between Databricks SQL Connector and Databricks-Connect: `{sql_connector_catalog}` != `{pyspark_catalog}`. Set `catalog` connection property to make them the same."
259            )
260        return pyspark_catalog or sql_connector_catalog
261
262    def set_current_catalog(self, catalog_name: str) -> None:
263        def _set_spark_session_current_catalog(spark: PySparkSession) -> None:
264            from py4j.protocol import Py4JError
265            from pyspark.errors.exceptions.connect import SparkConnectGrpcException
266
267            try:
268                # Note: Spark 3.4+ Only API
269                spark.catalog.setCurrentCatalog(catalog_name)
270            except (Py4JError, SparkConnectGrpcException):
271                pass
272
273        # Since Databricks splits commands across the Dataframe API and the SQL Connector
274        # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
275        # are set to the same catalog since they maintain their default catalog separately
276        self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG"))
277        if self.is_spark_session_connection:
278            _set_spark_session_current_catalog(self.connection.spark)
279
280        if self._spark_engine_adapter:
281            _set_spark_session_current_catalog(self._spark_engine_adapter.spark)
282
283    def _get_data_objects(
284        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
285    ) -> t.List[DataObject]:
286        """
287        Returns all the data objects that exist in the given schema and catalog.
288        """
289        schema = to_schema(schema_name)
290        catalog_name = schema.catalog or self.get_current_catalog()
291        query = (
292            exp.select(
293                exp.column("table_name").as_("name"),
294                exp.column("table_schema").as_("schema"),
295                exp.column("table_catalog").as_("catalog"),
296                exp.case(exp.column("table_type"))
297                .when(exp.Literal.string("VIEW"), exp.Literal.string("view"))
298                .when(
299                    exp.Literal.string("MATERIALIZED_VIEW"), exp.Literal.string("materialized_view")
300                )
301                .else_(exp.Literal.string("table"))
302                .as_("type"),
303            )
304            .from_(
305                # always query `system` information_schema
306                exp.table_("tables", "information_schema", "system")
307            )
308            .where(exp.column("table_catalog").eq(catalog_name))
309            .where(exp.column("table_schema").eq(schema.db))
310        )
311
312        if object_names:
313            query = query.where(exp.column("table_name").isin(*object_names))
314
315        df = self.fetchdf(query)
316        return [
317            DataObject(
318                catalog=row.catalog,  # type: ignore
319                schema=row.schema,  # type: ignore
320                name=row.name,  # type: ignore
321                type=DataObjectType.from_str(row.type),  # type: ignore
322            )
323            for row in df.itertuples()
324        ]
325
326    def clone_table(
327        self,
328        target_table_name: TableName,
329        source_table_name: TableName,
330        replace: bool = False,
331        exists: bool = True,
332        clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
333        **kwargs: t.Any,
334    ) -> None:
335        clone_kwargs = clone_kwargs or {}
336        clone_kwargs["shallow"] = True
337        super().clone_table(
338            target_table_name,
339            source_table_name,
340            replace=replace,
341            clone_kwargs=clone_kwargs,
342            **kwargs,
343        )
344
345    def wap_supported(self, table_name: TableName) -> bool:
346        return False
347
348    def close(self) -> t.Any:
349        """Closes all open connections and releases all allocated resources."""
350        super().close()
351        if self._spark_engine_adapter:
352            self._spark_engine_adapter.close()
353
354    @property
355    def default_catalog(self) -> t.Optional[str]:
356        try:
357            return super().default_catalog
358        except MissingDefaultCatalogError as e:
359            raise MissingDefaultCatalogError(
360                "Could not determine default catalog. Define the connection property `catalog` since it can't be inferred from your connection. See SQLMesh Databricks documentation for details"
361            ) from e
362
363    def _build_table_properties_exp(
364        self,
365        catalog_name: t.Optional[str] = None,
366        table_format: t.Optional[str] = None,
367        storage_format: t.Optional[str] = None,
368        partitioned_by: t.Optional[t.List[exp.Expr]] = None,
369        partition_interval_unit: t.Optional[IntervalUnit] = None,
370        clustered_by: t.Optional[t.List[exp.Expr]] = None,
371        table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
372        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
373        table_description: t.Optional[str] = None,
374        table_kind: t.Optional[str] = None,
375        **kwargs: t.Any,
376    ) -> t.Optional[exp.Properties]:
377        properties = super()._build_table_properties_exp(
378            catalog_name=catalog_name,
379            table_format=table_format,
380            storage_format=storage_format,
381            partitioned_by=partitioned_by,
382            partition_interval_unit=partition_interval_unit,
383            clustered_by=clustered_by,
384            table_properties=table_properties,
385            target_columns_to_types=target_columns_to_types,
386            table_description=table_description,
387            table_kind=table_kind,
388        )
389        if clustered_by:
390            # Databricks expects wrapped CLUSTER BY expressions
391            clustered_by_exp = exp.Cluster(
392                expressions=[exp.Tuple(expressions=[c.copy() for c in clustered_by])]
393            )
394            expressions = properties.expressions if properties else []
395            expressions.append(clustered_by_exp)
396            properties = exp.Properties(expressions=expressions)
397        return properties
398
399    def _build_column_defs(
400        self,
401        target_columns_to_types: t.Dict[str, exp.DataType],
402        column_descriptions: t.Optional[t.Dict[str, str]] = None,
403        is_view: bool = False,
404        materialized: bool = False,
405    ) -> t.List[exp.ColumnDef]:
406        # Databricks requires column types to be specified when adding column comments
407        # in CREATE MATERIALIZED VIEW statements. Override is_view to False to force
408        # column types to be included when comments are present.
409        if is_view and materialized and column_descriptions:
410            is_view = False
411
412        return super()._build_column_defs(
413            target_columns_to_types, column_descriptions, is_view, materialized
414        )

Base class wrapping a Database API compliant connection.

The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.

Arguments:
  • connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
  • dialect: The dialect with which this adapter is associated.
  • multithreaded: Indicates whether this adapter will be used by more than one thread.
DatabricksEngineAdapter(*args: Any, **kwargs: Any)
53    def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
54        super().__init__(*args, **kwargs)
55        self._set_spark_engine_adapter_if_needed()
DIALECT = 'databricks'
INSERT_OVERWRITE_STRATEGY = <InsertOverwriteStrategy.REPLACE_WHERE: 3>
SUPPORTS_CLONING = True
SUPPORTS_MATERIALIZED_VIEWS = True
SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
SUPPORTS_GRANTS = True
USE_CATALOG_IN_GRANTS = True
QUOTE_IDENTIFIERS_IN_VIEWS = True
SCHEMA_DIFFER_KWARGS = {'support_positional_add': True, 'nested_support': <NestedSupport.ALL: 'ALL'>, 'array_element_selector': 'element', 'parameterized_type_defaults': {<DType.DECIMAL: 'DECIMAL'>: [(10, 0), (0,)]}}
@classmethod
def can_access_spark_session(cls, disable_spark_session: bool) -> bool:
57    @classmethod
58    def can_access_spark_session(cls, disable_spark_session: bool) -> bool:
59        from sqlmesh import RuntimeEnv
60
61        if disable_spark_session:
62            return False
63
64        return RuntimeEnv.get().is_databricks
@classmethod
def can_access_databricks_connect(cls, disable_databricks_connect: bool) -> bool:
66    @classmethod
67    def can_access_databricks_connect(cls, disable_databricks_connect: bool) -> bool:
68        if disable_databricks_connect:
69            return False
70
71        try:
72            from databricks.connect import DatabricksSession  # noqa
73
74            return True
75        except ImportError:
76            return False
is_spark_session_connection: bool
 98    @property
 99    def is_spark_session_connection(self) -> bool:
100        return isinstance(self.connection, SparkSessionConnection)
cursor: Any
134    @property
135    def cursor(self) -> t.Any:
136        if (
137            self._connection_pool.get_attribute("use_spark_engine_adapter")
138            and not self.is_spark_session_connection
139        ):
140            return self._spark_engine_adapter.cursor  # type: ignore
141        return super().cursor
spark: <MagicMock id='132726896222384'>
143    @property
144    def spark(self) -> PySparkSession:
145        if not self._use_spark_session:
146            raise SQLMeshError(
147                "SparkSession is not available. "
148                "Either run from a Databricks Notebook or "
149                "install `databricks-connect` and configure it to connect to your Databricks cluster."
150            )
151        if self.is_spark_session_connection:
152            return self.connection.spark
153        return self._spark_engine_adapter.spark  # type: ignore
155    @property
156    def catalog_support(self) -> CatalogSupport:
157        return CatalogSupport.FULL_SUPPORT
def fetchdf( self, query: Union[sqlglot.expressions.core.Expr, str], quote_identifiers: bool = False) -> pandas.core.frame.DataFrame:
226    def fetchdf(
227        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
228    ) -> pd.DataFrame:
229        """
230        Returns a Pandas DataFrame from a query or expression.
231        """
232        import pandas as pd
233
234        df = self._fetch_native_df(query, quote_identifiers=quote_identifiers)
235        if not isinstance(df, pd.DataFrame):
236            return df.toPandas()
237        return df

Returns a Pandas DataFrame from a query or expression.

def get_current_catalog(self) -> Optional[str]:
239    def get_current_catalog(self) -> t.Optional[str]:
240        pyspark_catalog = None
241        sql_connector_catalog = None
242        if self._spark_engine_adapter:
243            from py4j.protocol import Py4JError
244            from pyspark.errors.exceptions.connect import SparkConnectGrpcException
245
246            try:
247                # Note: Spark 3.4+ Only API
248                pyspark_catalog = self._spark_engine_adapter.get_current_catalog()
249            except (Py4JError, SparkConnectGrpcException):
250                pass
251        elif self.is_spark_session_connection:
252            pyspark_catalog = self.connection.spark.catalog.currentCatalog()
253        if not self.is_spark_session_connection:
254            result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION))
255            sql_connector_catalog = result[0] if result else None
256        if self._spark_engine_adapter and pyspark_catalog != sql_connector_catalog:
257            logger.warning(
258                f"Current catalog mismatch between Databricks SQL Connector and Databricks-Connect: `{sql_connector_catalog}` != `{pyspark_catalog}`. Set `catalog` connection property to make them the same."
259            )
260        return pyspark_catalog or sql_connector_catalog

Returns the catalog name of the current connection.

def set_current_catalog(self, catalog_name: str) -> None:
262    def set_current_catalog(self, catalog_name: str) -> None:
263        def _set_spark_session_current_catalog(spark: PySparkSession) -> None:
264            from py4j.protocol import Py4JError
265            from pyspark.errors.exceptions.connect import SparkConnectGrpcException
266
267            try:
268                # Note: Spark 3.4+ Only API
269                spark.catalog.setCurrentCatalog(catalog_name)
270            except (Py4JError, SparkConnectGrpcException):
271                pass
272
273        # Since Databricks splits commands across the Dataframe API and the SQL Connector
274        # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
275        # are set to the same catalog since they maintain their default catalog separately
276        self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG"))
277        if self.is_spark_session_connection:
278            _set_spark_session_current_catalog(self.connection.spark)
279
280        if self._spark_engine_adapter:
281            _set_spark_session_current_catalog(self._spark_engine_adapter.spark)

Sets the catalog name of the current connection.

def clone_table( self, target_table_name: Union[str, sqlglot.expressions.query.Table], source_table_name: Union[str, sqlglot.expressions.query.Table], replace: bool = False, exists: bool = True, clone_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
326    def clone_table(
327        self,
328        target_table_name: TableName,
329        source_table_name: TableName,
330        replace: bool = False,
331        exists: bool = True,
332        clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
333        **kwargs: t.Any,
334    ) -> None:
335        clone_kwargs = clone_kwargs or {}
336        clone_kwargs["shallow"] = True
337        super().clone_table(
338            target_table_name,
339            source_table_name,
340            replace=replace,
341            clone_kwargs=clone_kwargs,
342            **kwargs,
343        )

Creates a table with the target name by cloning the source table.

Arguments:
  • target_table_name: The name of the table that should be created.
  • source_table_name: The name of the source table that should be cloned.
  • replace: Whether or not to replace an existing table.
  • exists: Indicates whether to include the IF NOT EXISTS check.
def wap_supported(self, table_name: Union[str, sqlglot.expressions.query.Table]) -> bool:
345    def wap_supported(self, table_name: TableName) -> bool:
346        return False

Returns whether WAP for the target table is supported.

def close(self) -> Any:
348    def close(self) -> t.Any:
349        """Closes all open connections and releases all allocated resources."""
350        super().close()
351        if self._spark_engine_adapter:
352            self._spark_engine_adapter.close()

Closes all open connections and releases all allocated resources.

default_catalog: Optional[str]
354    @property
355    def default_catalog(self) -> t.Optional[str]:
356        try:
357            return super().default_catalog
358        except MissingDefaultCatalogError as e:
359            raise MissingDefaultCatalogError(
360                "Could not determine default catalog. Define the connection property `catalog` since it can't be inferred from your connection. See SQLMesh Databricks documentation for details"
361            ) from e
Inherited Members
sqlmesh.core.engine_adapter.spark.SparkEngineAdapter
SUPPORTS_TRANSACTIONS
COMMENT_CREATION_TABLE
COMMENT_CREATION_VIEW
SUPPORTS_REPLACE_TABLE
SUPPORTED_DROP_CASCADE_OBJECT_KINDS
WAP_PREFIX
BRANCH_PREFIX
connection
use_serverless
spark_to_sqlglot_types
sqlglot_to_spark_types
is_pyspark_df
try_get_pyspark_df
try_get_pandas_df
fetch_pyspark_df
get_data_object
create_state_table
wap_table_name
wap_prepare
wap_publish
sqlmesh.core.engine_adapter.mixins.GetCurrentCatalogFromFunctionMixin
CURRENT_CATALOG_EXPRESSION
sqlmesh.core.engine_adapter.mixins.HiveMetastoreTablePropertiesMixin
MAX_TABLE_COMMENT_LENGTH
MAX_COLUMN_COMMENT_LENGTH
sqlmesh.core.engine_adapter.mixins.RowDiffMixin
MAX_TIMESTAMP_PRECISION
concat_columns
normalize_value
sqlmesh.core.engine_adapter.mixins.GrantsFromInfoSchemaMixin
CURRENT_USER_OR_ROLE_EXPRESSION
SUPPORTS_MULTIPLE_GRANT_PRINCIPALS
GRANT_INFORMATION_SCHEMA_TABLE_NAME
sqlmesh.core.engine_adapter.base.EngineAdapter
DEFAULT_BATCH_SIZE
DATA_OBJECT_FILTER_BATCH_SIZE
SUPPORTS_INDEXES
SUPPORTS_VIEW_SCHEMA
SUPPORTS_MANAGED_MODELS
SUPPORTS_CREATE_DROP_CATALOG
SUPPORTS_TUPLE_IN
HAS_VIEW_BINDING
DEFAULT_CATALOG_TYPE
MAX_IDENTIFIER_LENGTH
ATTACH_CORRELATION_ID
SUPPORTS_QUERY_EXECUTION_TRACKING
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
dialect
correlation_id
with_settings
snowpark
bigframe
comments_enabled
schema_differ
engine_run_mode
recycle
get_catalog_type
get_catalog_type_from_table
current_catalog_type
replace_query
create_index
create_table
create_managed_table
ctas
create_table_like
drop_data_object
drop_table
drop_managed_table
get_alter_operations
alter_table
create_view
create_schema
drop_schema
drop_view
create_catalog
drop_catalog
columns
table_exists
delete_from
insert_append
insert_overwrite_by_partition
insert_overwrite_by_time_partition
update_table
scd_type_2_by_time
scd_type_2_by_column
merge
rename_table
get_data_objects
fetchone
fetchall
wap_enabled
sync_grants_config
transaction
session
execute
temp_table
drop_data_object_on_type_mismatch
ensure_nulls_for_unmatched_after_join
use_server_nulls_for_unmatched_after_join
ping
get_table_last_modified_ts