sqlmesh.core.engine_adapter.databricks
1from __future__ import annotations 2 3import logging 4import typing as t 5from functools import partial 6 7from sqlglot import exp 8 9from sqlmesh.core.dialect import to_schema 10from sqlmesh.core.engine_adapter.mixins import GrantsFromInfoSchemaMixin 11from sqlmesh.core.engine_adapter.shared import ( 12 CatalogSupport, 13 DataObject, 14 DataObjectType, 15 InsertOverwriteStrategy, 16 SourceQuery, 17) 18from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter 19from sqlmesh.core.node import IntervalUnit 20from sqlmesh.core.schema_diff import NestedSupport 21from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection 22from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError 23 24if t.TYPE_CHECKING: 25 import pandas as pd 26 27 from sqlmesh.core._typing import SchemaName, TableName, SessionProperties 28 from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query 29 30logger = logging.getLogger(__name__) 31 32 33class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin): 34 DIALECT = "databricks" 35 INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE 36 SUPPORTS_CLONING = True 37 SUPPORTS_MATERIALIZED_VIEWS = True 38 SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True 39 SUPPORTS_GRANTS = True 40 USE_CATALOG_IN_GRANTS = True 41 # Spark has this set to false for compatibility when mixing with Trino but that isn't a concern with Databricks 42 QUOTE_IDENTIFIERS_IN_VIEWS = True 43 SCHEMA_DIFFER_KWARGS = { 44 "support_positional_add": True, 45 "nested_support": NestedSupport.ALL, 46 "array_element_selector": "element", 47 "parameterized_type_defaults": { 48 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)], 49 }, 50 } 51 52 def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: 53 super().__init__(*args, **kwargs) 54 self._set_spark_engine_adapter_if_needed() 55 56 @classmethod 57 def can_access_spark_session(cls, disable_spark_session: bool) -> bool: 58 from sqlmesh import RuntimeEnv 59 60 if disable_spark_session: 61 return False 62 63 return RuntimeEnv.get().is_databricks 64 65 @classmethod 66 def can_access_databricks_connect(cls, disable_databricks_connect: bool) -> bool: 67 if disable_databricks_connect: 68 return False 69 70 try: 71 from databricks.connect import DatabricksSession # noqa 72 73 return True 74 except ImportError: 75 return False 76 77 @property 78 def _use_spark_session(self) -> bool: 79 if self.can_access_spark_session(bool(self._extra_config.get("disable_spark_session"))): 80 return True 81 82 if self.can_access_databricks_connect( 83 bool(self._extra_config.get("disable_databricks_connect")) 84 ): 85 if self._extra_config.get("databricks_connect_use_serverless"): 86 return True 87 88 if { 89 "databricks_connect_cluster_id", 90 "databricks_connect_server_hostname", 91 "databricks_connect_access_token", 92 }.issubset(self._extra_config): 93 return True 94 95 return False 96 97 @property 98 def is_spark_session_connection(self) -> bool: 99 return isinstance(self.connection, SparkSessionConnection) 100 101 def _set_spark_engine_adapter_if_needed(self) -> None: 102 self._spark_engine_adapter = None 103 104 if not self._use_spark_session or self.is_spark_session_connection: 105 return 106 107 from databricks.connect import DatabricksSession 108 109 connect_kwargs = dict( 110 host=self._extra_config["databricks_connect_server_hostname"], 111 token=self._extra_config.get("databricks_connect_access_token"), 112 ) 113 if "databricks_connect_use_serverless" in self._extra_config: 114 connect_kwargs["serverless"] = True 115 else: 116 connect_kwargs["cluster_id"] = self._extra_config["databricks_connect_cluster_id"] 117 118 catalog = self._extra_config.get("catalog") 119 spark = ( 120 DatabricksSession.builder.remote(**connect_kwargs).userAgent("sqlmesh").getOrCreate() 121 ) 122 self._spark_engine_adapter = SparkEngineAdapter( 123 partial(connection, spark=spark, catalog=catalog), 124 default_catalog=catalog, 125 execute_log_level=self._execute_log_level, 126 multithreaded=self._multithreaded, 127 sql_gen_kwargs=self._sql_gen_kwargs, 128 register_comments=self._register_comments, 129 pre_ping=self._pre_ping, 130 pretty_sql=self._pretty_sql, 131 ) 132 133 @property 134 def cursor(self) -> t.Any: 135 if ( 136 self._connection_pool.get_attribute("use_spark_engine_adapter") 137 and not self.is_spark_session_connection 138 ): 139 return self._spark_engine_adapter.cursor # type: ignore 140 return super().cursor 141 142 @property 143 def spark(self) -> PySparkSession: 144 if not self._use_spark_session: 145 raise SQLMeshError( 146 "SparkSession is not available. " 147 "Either run from a Databricks Notebook or " 148 "install `databricks-connect` and configure it to connect to your Databricks cluster." 149 ) 150 if self.is_spark_session_connection: 151 return self.connection.spark 152 return self._spark_engine_adapter.spark # type: ignore 153 154 @property 155 def catalog_support(self) -> CatalogSupport: 156 return CatalogSupport.FULL_SUPPORT 157 158 @staticmethod 159 def _grant_object_kind(table_type: DataObjectType) -> str: 160 if table_type == DataObjectType.VIEW: 161 return "VIEW" 162 if table_type == DataObjectType.MATERIALIZED_VIEW: 163 return "MATERIALIZED VIEW" 164 return "TABLE" 165 166 def _get_grant_expression(self, table: exp.Table) -> exp.Expr: 167 # We only care about explicitly granted privileges and not inherited ones 168 # if this is removed you would see grants inherited from the catalog get returned 169 expression = super()._get_grant_expression(table) 170 expression.args["where"].set( 171 "this", 172 exp.and_( 173 expression.args["where"].this, 174 exp.column("inherited_from").eq(exp.Literal.string("NONE")), 175 wrap=False, 176 ), 177 ) 178 return expression 179 180 def _begin_session(self, properties: SessionProperties) -> t.Any: 181 """Begin a new session.""" 182 # Align the different possible connectors to a single catalog 183 self.set_current_catalog(self.default_catalog) # type: ignore 184 185 def _end_session(self) -> None: 186 self._connection_pool.set_attribute("use_spark_engine_adapter", False) 187 188 def _df_to_source_queries( 189 self, 190 df: DF, 191 target_columns_to_types: t.Dict[str, exp.DataType], 192 batch_size: int, 193 target_table: TableName, 194 source_columns: t.Optional[t.List[str]] = None, 195 ) -> t.List[SourceQuery]: 196 if not self._use_spark_session: 197 return super(SparkEngineAdapter, self)._df_to_source_queries( 198 df, target_columns_to_types, batch_size, target_table, source_columns=source_columns 199 ) 200 pyspark_df = self._ensure_pyspark_df( 201 df, target_columns_to_types, source_columns=source_columns 202 ) 203 204 def query_factory() -> Query: 205 temp_table = self._get_temp_table(target_table or "spark", table_only=True) 206 pyspark_df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) 207 self._connection_pool.set_attribute("use_spark_engine_adapter", True) 208 return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table) 209 210 return [SourceQuery(query_factory=query_factory)] 211 212 def _fetch_native_df( 213 self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False 214 ) -> DF: 215 """Fetches a DataFrame that can be either Pandas or PySpark from the cursor""" 216 if self.is_spark_session_connection: 217 return super()._fetch_native_df(query, quote_identifiers=quote_identifiers) 218 if self._spark_engine_adapter: 219 return self._spark_engine_adapter._fetch_native_df( # type: ignore 220 query, quote_identifiers=quote_identifiers 221 ) 222 self.execute(query) 223 return self.cursor.fetchall_arrow().to_pandas() 224 225 def fetchdf( 226 self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False 227 ) -> pd.DataFrame: 228 """ 229 Returns a Pandas DataFrame from a query or expression. 230 """ 231 import pandas as pd 232 233 df = self._fetch_native_df(query, quote_identifiers=quote_identifiers) 234 if not isinstance(df, pd.DataFrame): 235 return df.toPandas() 236 return df 237 238 def get_current_catalog(self) -> t.Optional[str]: 239 pyspark_catalog = None 240 sql_connector_catalog = None 241 if self._spark_engine_adapter: 242 from py4j.protocol import Py4JError 243 from pyspark.errors.exceptions.connect import SparkConnectGrpcException 244 245 try: 246 # Note: Spark 3.4+ Only API 247 pyspark_catalog = self._spark_engine_adapter.get_current_catalog() 248 except (Py4JError, SparkConnectGrpcException): 249 pass 250 elif self.is_spark_session_connection: 251 pyspark_catalog = self.connection.spark.catalog.currentCatalog() 252 if not self.is_spark_session_connection: 253 result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION)) 254 sql_connector_catalog = result[0] if result else None 255 if self._spark_engine_adapter and pyspark_catalog != sql_connector_catalog: 256 logger.warning( 257 f"Current catalog mismatch between Databricks SQL Connector and Databricks-Connect: `{sql_connector_catalog}` != `{pyspark_catalog}`. Set `catalog` connection property to make them the same." 258 ) 259 return pyspark_catalog or sql_connector_catalog 260 261 def set_current_catalog(self, catalog_name: str) -> None: 262 def _set_spark_session_current_catalog(spark: PySparkSession) -> None: 263 from py4j.protocol import Py4JError 264 from pyspark.errors.exceptions.connect import SparkConnectGrpcException 265 266 try: 267 # Note: Spark 3.4+ Only API 268 spark.catalog.setCurrentCatalog(catalog_name) 269 except (Py4JError, SparkConnectGrpcException): 270 pass 271 272 # Since Databricks splits commands across the Dataframe API and the SQL Connector 273 # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both 274 # are set to the same catalog since they maintain their default catalog separately 275 self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG")) 276 if self.is_spark_session_connection: 277 _set_spark_session_current_catalog(self.connection.spark) 278 279 if self._spark_engine_adapter: 280 _set_spark_session_current_catalog(self._spark_engine_adapter.spark) 281 282 def _get_data_objects( 283 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 284 ) -> t.List[DataObject]: 285 """ 286 Returns all the data objects that exist in the given schema and catalog. 287 """ 288 schema = to_schema(schema_name) 289 catalog_name = schema.catalog or self.get_current_catalog() 290 query = ( 291 exp.select( 292 exp.column("table_name").as_("name"), 293 exp.column("table_schema").as_("schema"), 294 exp.column("table_catalog").as_("catalog"), 295 exp.case(exp.column("table_type")) 296 .when(exp.Literal.string("VIEW"), exp.Literal.string("view")) 297 .when( 298 exp.Literal.string("MATERIALIZED_VIEW"), exp.Literal.string("materialized_view") 299 ) 300 .else_(exp.Literal.string("table")) 301 .as_("type"), 302 ) 303 .from_( 304 # always query `system` information_schema 305 exp.table_("tables", "information_schema", "system") 306 ) 307 .where(exp.column("table_catalog").eq(catalog_name)) 308 .where(exp.column("table_schema").eq(schema.db)) 309 ) 310 311 if object_names: 312 query = query.where(exp.column("table_name").isin(*object_names)) 313 314 df = self.fetchdf(query) 315 return [ 316 DataObject( 317 catalog=row.catalog, # type: ignore 318 schema=row.schema, # type: ignore 319 name=row.name, # type: ignore 320 type=DataObjectType.from_str(row.type), # type: ignore 321 ) 322 for row in df.itertuples() 323 ] 324 325 def clone_table( 326 self, 327 target_table_name: TableName, 328 source_table_name: TableName, 329 replace: bool = False, 330 exists: bool = True, 331 clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None, 332 **kwargs: t.Any, 333 ) -> None: 334 clone_kwargs = clone_kwargs or {} 335 clone_kwargs["shallow"] = True 336 super().clone_table( 337 target_table_name, 338 source_table_name, 339 replace=replace, 340 clone_kwargs=clone_kwargs, 341 **kwargs, 342 ) 343 344 def wap_supported(self, table_name: TableName) -> bool: 345 return False 346 347 def close(self) -> t.Any: 348 """Closes all open connections and releases all allocated resources.""" 349 super().close() 350 if self._spark_engine_adapter: 351 self._spark_engine_adapter.close() 352 353 @property 354 def default_catalog(self) -> t.Optional[str]: 355 try: 356 return super().default_catalog 357 except MissingDefaultCatalogError as e: 358 raise MissingDefaultCatalogError( 359 "Could not determine default catalog. Define the connection property `catalog` since it can't be inferred from your connection. See SQLMesh Databricks documentation for details" 360 ) from e 361 362 def _build_table_properties_exp( 363 self, 364 catalog_name: t.Optional[str] = None, 365 table_format: t.Optional[str] = None, 366 storage_format: t.Optional[str] = None, 367 partitioned_by: t.Optional[t.List[exp.Expr]] = None, 368 partition_interval_unit: t.Optional[IntervalUnit] = None, 369 clustered_by: t.Optional[t.List[exp.Expr]] = None, 370 table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, 371 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 372 table_description: t.Optional[str] = None, 373 table_kind: t.Optional[str] = None, 374 **kwargs: t.Any, 375 ) -> t.Optional[exp.Properties]: 376 properties = super()._build_table_properties_exp( 377 catalog_name=catalog_name, 378 table_format=table_format, 379 storage_format=storage_format, 380 partitioned_by=partitioned_by, 381 partition_interval_unit=partition_interval_unit, 382 clustered_by=clustered_by, 383 table_properties=table_properties, 384 target_columns_to_types=target_columns_to_types, 385 table_description=table_description, 386 table_kind=table_kind, 387 ) 388 if clustered_by: 389 # Databricks expects wrapped CLUSTER BY expressions 390 clustered_by_exp = exp.Cluster( 391 expressions=[exp.Tuple(expressions=[c.copy() for c in clustered_by])] 392 ) 393 expressions = properties.expressions if properties else [] 394 expressions.append(clustered_by_exp) 395 properties = exp.Properties(expressions=expressions) 396 return properties 397 398 def _build_column_defs( 399 self, 400 target_columns_to_types: t.Dict[str, exp.DataType], 401 column_descriptions: t.Optional[t.Dict[str, str]] = None, 402 is_view: bool = False, 403 materialized: bool = False, 404 ) -> t.List[exp.ColumnDef]: 405 # Databricks requires column types to be specified when adding column comments 406 # in CREATE MATERIALIZED VIEW statements. Override is_view to False to force 407 # column types to be included when comments are present. 408 if is_view and materialized and column_descriptions: 409 is_view = False 410 411 return super()._build_column_defs( 412 target_columns_to_types, column_descriptions, is_view, materialized 413 )
logger =
<Logger sqlmesh.core.engine_adapter.databricks (WARNING)>
class
DatabricksEngineAdapter(sqlmesh.core.engine_adapter.spark.SparkEngineAdapter, sqlmesh.core.engine_adapter.mixins.GrantsFromInfoSchemaMixin):
34class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin): 35 DIALECT = "databricks" 36 INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE 37 SUPPORTS_CLONING = True 38 SUPPORTS_MATERIALIZED_VIEWS = True 39 SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True 40 SUPPORTS_GRANTS = True 41 USE_CATALOG_IN_GRANTS = True 42 # Spark has this set to false for compatibility when mixing with Trino but that isn't a concern with Databricks 43 QUOTE_IDENTIFIERS_IN_VIEWS = True 44 SCHEMA_DIFFER_KWARGS = { 45 "support_positional_add": True, 46 "nested_support": NestedSupport.ALL, 47 "array_element_selector": "element", 48 "parameterized_type_defaults": { 49 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)], 50 }, 51 } 52 53 def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: 54 super().__init__(*args, **kwargs) 55 self._set_spark_engine_adapter_if_needed() 56 57 @classmethod 58 def can_access_spark_session(cls, disable_spark_session: bool) -> bool: 59 from sqlmesh import RuntimeEnv 60 61 if disable_spark_session: 62 return False 63 64 return RuntimeEnv.get().is_databricks 65 66 @classmethod 67 def can_access_databricks_connect(cls, disable_databricks_connect: bool) -> bool: 68 if disable_databricks_connect: 69 return False 70 71 try: 72 from databricks.connect import DatabricksSession # noqa 73 74 return True 75 except ImportError: 76 return False 77 78 @property 79 def _use_spark_session(self) -> bool: 80 if self.can_access_spark_session(bool(self._extra_config.get("disable_spark_session"))): 81 return True 82 83 if self.can_access_databricks_connect( 84 bool(self._extra_config.get("disable_databricks_connect")) 85 ): 86 if self._extra_config.get("databricks_connect_use_serverless"): 87 return True 88 89 if { 90 "databricks_connect_cluster_id", 91 "databricks_connect_server_hostname", 92 "databricks_connect_access_token", 93 }.issubset(self._extra_config): 94 return True 95 96 return False 97 98 @property 99 def is_spark_session_connection(self) -> bool: 100 return isinstance(self.connection, SparkSessionConnection) 101 102 def _set_spark_engine_adapter_if_needed(self) -> None: 103 self._spark_engine_adapter = None 104 105 if not self._use_spark_session or self.is_spark_session_connection: 106 return 107 108 from databricks.connect import DatabricksSession 109 110 connect_kwargs = dict( 111 host=self._extra_config["databricks_connect_server_hostname"], 112 token=self._extra_config.get("databricks_connect_access_token"), 113 ) 114 if "databricks_connect_use_serverless" in self._extra_config: 115 connect_kwargs["serverless"] = True 116 else: 117 connect_kwargs["cluster_id"] = self._extra_config["databricks_connect_cluster_id"] 118 119 catalog = self._extra_config.get("catalog") 120 spark = ( 121 DatabricksSession.builder.remote(**connect_kwargs).userAgent("sqlmesh").getOrCreate() 122 ) 123 self._spark_engine_adapter = SparkEngineAdapter( 124 partial(connection, spark=spark, catalog=catalog), 125 default_catalog=catalog, 126 execute_log_level=self._execute_log_level, 127 multithreaded=self._multithreaded, 128 sql_gen_kwargs=self._sql_gen_kwargs, 129 register_comments=self._register_comments, 130 pre_ping=self._pre_ping, 131 pretty_sql=self._pretty_sql, 132 ) 133 134 @property 135 def cursor(self) -> t.Any: 136 if ( 137 self._connection_pool.get_attribute("use_spark_engine_adapter") 138 and not self.is_spark_session_connection 139 ): 140 return self._spark_engine_adapter.cursor # type: ignore 141 return super().cursor 142 143 @property 144 def spark(self) -> PySparkSession: 145 if not self._use_spark_session: 146 raise SQLMeshError( 147 "SparkSession is not available. " 148 "Either run from a Databricks Notebook or " 149 "install `databricks-connect` and configure it to connect to your Databricks cluster." 150 ) 151 if self.is_spark_session_connection: 152 return self.connection.spark 153 return self._spark_engine_adapter.spark # type: ignore 154 155 @property 156 def catalog_support(self) -> CatalogSupport: 157 return CatalogSupport.FULL_SUPPORT 158 159 @staticmethod 160 def _grant_object_kind(table_type: DataObjectType) -> str: 161 if table_type == DataObjectType.VIEW: 162 return "VIEW" 163 if table_type == DataObjectType.MATERIALIZED_VIEW: 164 return "MATERIALIZED VIEW" 165 return "TABLE" 166 167 def _get_grant_expression(self, table: exp.Table) -> exp.Expr: 168 # We only care about explicitly granted privileges and not inherited ones 169 # if this is removed you would see grants inherited from the catalog get returned 170 expression = super()._get_grant_expression(table) 171 expression.args["where"].set( 172 "this", 173 exp.and_( 174 expression.args["where"].this, 175 exp.column("inherited_from").eq(exp.Literal.string("NONE")), 176 wrap=False, 177 ), 178 ) 179 return expression 180 181 def _begin_session(self, properties: SessionProperties) -> t.Any: 182 """Begin a new session.""" 183 # Align the different possible connectors to a single catalog 184 self.set_current_catalog(self.default_catalog) # type: ignore 185 186 def _end_session(self) -> None: 187 self._connection_pool.set_attribute("use_spark_engine_adapter", False) 188 189 def _df_to_source_queries( 190 self, 191 df: DF, 192 target_columns_to_types: t.Dict[str, exp.DataType], 193 batch_size: int, 194 target_table: TableName, 195 source_columns: t.Optional[t.List[str]] = None, 196 ) -> t.List[SourceQuery]: 197 if not self._use_spark_session: 198 return super(SparkEngineAdapter, self)._df_to_source_queries( 199 df, target_columns_to_types, batch_size, target_table, source_columns=source_columns 200 ) 201 pyspark_df = self._ensure_pyspark_df( 202 df, target_columns_to_types, source_columns=source_columns 203 ) 204 205 def query_factory() -> Query: 206 temp_table = self._get_temp_table(target_table or "spark", table_only=True) 207 pyspark_df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) 208 self._connection_pool.set_attribute("use_spark_engine_adapter", True) 209 return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table) 210 211 return [SourceQuery(query_factory=query_factory)] 212 213 def _fetch_native_df( 214 self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False 215 ) -> DF: 216 """Fetches a DataFrame that can be either Pandas or PySpark from the cursor""" 217 if self.is_spark_session_connection: 218 return super()._fetch_native_df(query, quote_identifiers=quote_identifiers) 219 if self._spark_engine_adapter: 220 return self._spark_engine_adapter._fetch_native_df( # type: ignore 221 query, quote_identifiers=quote_identifiers 222 ) 223 self.execute(query) 224 return self.cursor.fetchall_arrow().to_pandas() 225 226 def fetchdf( 227 self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False 228 ) -> pd.DataFrame: 229 """ 230 Returns a Pandas DataFrame from a query or expression. 231 """ 232 import pandas as pd 233 234 df = self._fetch_native_df(query, quote_identifiers=quote_identifiers) 235 if not isinstance(df, pd.DataFrame): 236 return df.toPandas() 237 return df 238 239 def get_current_catalog(self) -> t.Optional[str]: 240 pyspark_catalog = None 241 sql_connector_catalog = None 242 if self._spark_engine_adapter: 243 from py4j.protocol import Py4JError 244 from pyspark.errors.exceptions.connect import SparkConnectGrpcException 245 246 try: 247 # Note: Spark 3.4+ Only API 248 pyspark_catalog = self._spark_engine_adapter.get_current_catalog() 249 except (Py4JError, SparkConnectGrpcException): 250 pass 251 elif self.is_spark_session_connection: 252 pyspark_catalog = self.connection.spark.catalog.currentCatalog() 253 if not self.is_spark_session_connection: 254 result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION)) 255 sql_connector_catalog = result[0] if result else None 256 if self._spark_engine_adapter and pyspark_catalog != sql_connector_catalog: 257 logger.warning( 258 f"Current catalog mismatch between Databricks SQL Connector and Databricks-Connect: `{sql_connector_catalog}` != `{pyspark_catalog}`. Set `catalog` connection property to make them the same." 259 ) 260 return pyspark_catalog or sql_connector_catalog 261 262 def set_current_catalog(self, catalog_name: str) -> None: 263 def _set_spark_session_current_catalog(spark: PySparkSession) -> None: 264 from py4j.protocol import Py4JError 265 from pyspark.errors.exceptions.connect import SparkConnectGrpcException 266 267 try: 268 # Note: Spark 3.4+ Only API 269 spark.catalog.setCurrentCatalog(catalog_name) 270 except (Py4JError, SparkConnectGrpcException): 271 pass 272 273 # Since Databricks splits commands across the Dataframe API and the SQL Connector 274 # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both 275 # are set to the same catalog since they maintain their default catalog separately 276 self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG")) 277 if self.is_spark_session_connection: 278 _set_spark_session_current_catalog(self.connection.spark) 279 280 if self._spark_engine_adapter: 281 _set_spark_session_current_catalog(self._spark_engine_adapter.spark) 282 283 def _get_data_objects( 284 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 285 ) -> t.List[DataObject]: 286 """ 287 Returns all the data objects that exist in the given schema and catalog. 288 """ 289 schema = to_schema(schema_name) 290 catalog_name = schema.catalog or self.get_current_catalog() 291 query = ( 292 exp.select( 293 exp.column("table_name").as_("name"), 294 exp.column("table_schema").as_("schema"), 295 exp.column("table_catalog").as_("catalog"), 296 exp.case(exp.column("table_type")) 297 .when(exp.Literal.string("VIEW"), exp.Literal.string("view")) 298 .when( 299 exp.Literal.string("MATERIALIZED_VIEW"), exp.Literal.string("materialized_view") 300 ) 301 .else_(exp.Literal.string("table")) 302 .as_("type"), 303 ) 304 .from_( 305 # always query `system` information_schema 306 exp.table_("tables", "information_schema", "system") 307 ) 308 .where(exp.column("table_catalog").eq(catalog_name)) 309 .where(exp.column("table_schema").eq(schema.db)) 310 ) 311 312 if object_names: 313 query = query.where(exp.column("table_name").isin(*object_names)) 314 315 df = self.fetchdf(query) 316 return [ 317 DataObject( 318 catalog=row.catalog, # type: ignore 319 schema=row.schema, # type: ignore 320 name=row.name, # type: ignore 321 type=DataObjectType.from_str(row.type), # type: ignore 322 ) 323 for row in df.itertuples() 324 ] 325 326 def clone_table( 327 self, 328 target_table_name: TableName, 329 source_table_name: TableName, 330 replace: bool = False, 331 exists: bool = True, 332 clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None, 333 **kwargs: t.Any, 334 ) -> None: 335 clone_kwargs = clone_kwargs or {} 336 clone_kwargs["shallow"] = True 337 super().clone_table( 338 target_table_name, 339 source_table_name, 340 replace=replace, 341 clone_kwargs=clone_kwargs, 342 **kwargs, 343 ) 344 345 def wap_supported(self, table_name: TableName) -> bool: 346 return False 347 348 def close(self) -> t.Any: 349 """Closes all open connections and releases all allocated resources.""" 350 super().close() 351 if self._spark_engine_adapter: 352 self._spark_engine_adapter.close() 353 354 @property 355 def default_catalog(self) -> t.Optional[str]: 356 try: 357 return super().default_catalog 358 except MissingDefaultCatalogError as e: 359 raise MissingDefaultCatalogError( 360 "Could not determine default catalog. Define the connection property `catalog` since it can't be inferred from your connection. See SQLMesh Databricks documentation for details" 361 ) from e 362 363 def _build_table_properties_exp( 364 self, 365 catalog_name: t.Optional[str] = None, 366 table_format: t.Optional[str] = None, 367 storage_format: t.Optional[str] = None, 368 partitioned_by: t.Optional[t.List[exp.Expr]] = None, 369 partition_interval_unit: t.Optional[IntervalUnit] = None, 370 clustered_by: t.Optional[t.List[exp.Expr]] = None, 371 table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, 372 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 373 table_description: t.Optional[str] = None, 374 table_kind: t.Optional[str] = None, 375 **kwargs: t.Any, 376 ) -> t.Optional[exp.Properties]: 377 properties = super()._build_table_properties_exp( 378 catalog_name=catalog_name, 379 table_format=table_format, 380 storage_format=storage_format, 381 partitioned_by=partitioned_by, 382 partition_interval_unit=partition_interval_unit, 383 clustered_by=clustered_by, 384 table_properties=table_properties, 385 target_columns_to_types=target_columns_to_types, 386 table_description=table_description, 387 table_kind=table_kind, 388 ) 389 if clustered_by: 390 # Databricks expects wrapped CLUSTER BY expressions 391 clustered_by_exp = exp.Cluster( 392 expressions=[exp.Tuple(expressions=[c.copy() for c in clustered_by])] 393 ) 394 expressions = properties.expressions if properties else [] 395 expressions.append(clustered_by_exp) 396 properties = exp.Properties(expressions=expressions) 397 return properties 398 399 def _build_column_defs( 400 self, 401 target_columns_to_types: t.Dict[str, exp.DataType], 402 column_descriptions: t.Optional[t.Dict[str, str]] = None, 403 is_view: bool = False, 404 materialized: bool = False, 405 ) -> t.List[exp.ColumnDef]: 406 # Databricks requires column types to be specified when adding column comments 407 # in CREATE MATERIALIZED VIEW statements. Override is_view to False to force 408 # column types to be included when comments are present. 409 if is_view and materialized and column_descriptions: 410 is_view = False 411 412 return super()._build_column_defs( 413 target_columns_to_types, column_descriptions, is_view, materialized 414 )
Base class wrapping a Database API compliant connection.
The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.
Arguments:
- connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
- dialect: The dialect with which this adapter is associated.
- multithreaded: Indicates whether this adapter will be used by more than one thread.
SCHEMA_DIFFER_KWARGS =
{'support_positional_add': True, 'nested_support': <NestedSupport.ALL: 'ALL'>, 'array_element_selector': 'element', 'parameterized_type_defaults': {<DType.DECIMAL: 'DECIMAL'>: [(10, 0), (0,)]}}
spark: <MagicMock id='132726896222384'>
143 @property 144 def spark(self) -> PySparkSession: 145 if not self._use_spark_session: 146 raise SQLMeshError( 147 "SparkSession is not available. " 148 "Either run from a Databricks Notebook or " 149 "install `databricks-connect` and configure it to connect to your Databricks cluster." 150 ) 151 if self.is_spark_session_connection: 152 return self.connection.spark 153 return self._spark_engine_adapter.spark # type: ignore
catalog_support: sqlmesh.core.engine_adapter.shared.CatalogSupport
def
fetchdf( self, query: Union[sqlglot.expressions.core.Expr, str], quote_identifiers: bool = False) -> pandas.core.frame.DataFrame:
226 def fetchdf( 227 self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False 228 ) -> pd.DataFrame: 229 """ 230 Returns a Pandas DataFrame from a query or expression. 231 """ 232 import pandas as pd 233 234 df = self._fetch_native_df(query, quote_identifiers=quote_identifiers) 235 if not isinstance(df, pd.DataFrame): 236 return df.toPandas() 237 return df
Returns a Pandas DataFrame from a query or expression.
def
get_current_catalog(self) -> Optional[str]:
239 def get_current_catalog(self) -> t.Optional[str]: 240 pyspark_catalog = None 241 sql_connector_catalog = None 242 if self._spark_engine_adapter: 243 from py4j.protocol import Py4JError 244 from pyspark.errors.exceptions.connect import SparkConnectGrpcException 245 246 try: 247 # Note: Spark 3.4+ Only API 248 pyspark_catalog = self._spark_engine_adapter.get_current_catalog() 249 except (Py4JError, SparkConnectGrpcException): 250 pass 251 elif self.is_spark_session_connection: 252 pyspark_catalog = self.connection.spark.catalog.currentCatalog() 253 if not self.is_spark_session_connection: 254 result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION)) 255 sql_connector_catalog = result[0] if result else None 256 if self._spark_engine_adapter and pyspark_catalog != sql_connector_catalog: 257 logger.warning( 258 f"Current catalog mismatch between Databricks SQL Connector and Databricks-Connect: `{sql_connector_catalog}` != `{pyspark_catalog}`. Set `catalog` connection property to make them the same." 259 ) 260 return pyspark_catalog or sql_connector_catalog
Returns the catalog name of the current connection.
def
set_current_catalog(self, catalog_name: str) -> None:
262 def set_current_catalog(self, catalog_name: str) -> None: 263 def _set_spark_session_current_catalog(spark: PySparkSession) -> None: 264 from py4j.protocol import Py4JError 265 from pyspark.errors.exceptions.connect import SparkConnectGrpcException 266 267 try: 268 # Note: Spark 3.4+ Only API 269 spark.catalog.setCurrentCatalog(catalog_name) 270 except (Py4JError, SparkConnectGrpcException): 271 pass 272 273 # Since Databricks splits commands across the Dataframe API and the SQL Connector 274 # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both 275 # are set to the same catalog since they maintain their default catalog separately 276 self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG")) 277 if self.is_spark_session_connection: 278 _set_spark_session_current_catalog(self.connection.spark) 279 280 if self._spark_engine_adapter: 281 _set_spark_session_current_catalog(self._spark_engine_adapter.spark)
Sets the catalog name of the current connection.
def
clone_table( self, target_table_name: Union[str, sqlglot.expressions.query.Table], source_table_name: Union[str, sqlglot.expressions.query.Table], replace: bool = False, exists: bool = True, clone_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
326 def clone_table( 327 self, 328 target_table_name: TableName, 329 source_table_name: TableName, 330 replace: bool = False, 331 exists: bool = True, 332 clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None, 333 **kwargs: t.Any, 334 ) -> None: 335 clone_kwargs = clone_kwargs or {} 336 clone_kwargs["shallow"] = True 337 super().clone_table( 338 target_table_name, 339 source_table_name, 340 replace=replace, 341 clone_kwargs=clone_kwargs, 342 **kwargs, 343 )
Creates a table with the target name by cloning the source table.
Arguments:
- target_table_name: The name of the table that should be created.
- source_table_name: The name of the source table that should be cloned.
- replace: Whether or not to replace an existing table.
- exists: Indicates whether to include the IF NOT EXISTS check.
def
wap_supported(self, table_name: Union[str, sqlglot.expressions.query.Table]) -> bool:
Returns whether WAP for the target table is supported.
def
close(self) -> Any:
348 def close(self) -> t.Any: 349 """Closes all open connections and releases all allocated resources.""" 350 super().close() 351 if self._spark_engine_adapter: 352 self._spark_engine_adapter.close()
Closes all open connections and releases all allocated resources.
default_catalog: Optional[str]
354 @property 355 def default_catalog(self) -> t.Optional[str]: 356 try: 357 return super().default_catalog 358 except MissingDefaultCatalogError as e: 359 raise MissingDefaultCatalogError( 360 "Could not determine default catalog. Define the connection property `catalog` since it can't be inferred from your connection. See SQLMesh Databricks documentation for details" 361 ) from e
Inherited Members
- sqlmesh.core.engine_adapter.spark.SparkEngineAdapter
- SUPPORTS_TRANSACTIONS
- COMMENT_CREATION_TABLE
- COMMENT_CREATION_VIEW
- SUPPORTS_REPLACE_TABLE
- SUPPORTED_DROP_CASCADE_OBJECT_KINDS
- WAP_PREFIX
- BRANCH_PREFIX
- connection
- use_serverless
- spark_to_sqlglot_types
- sqlglot_to_spark_types
- is_pyspark_df
- try_get_pyspark_df
- try_get_pandas_df
- fetch_pyspark_df
- get_data_object
- create_state_table
- wap_table_name
- wap_prepare
- wap_publish
- sqlmesh.core.engine_adapter.mixins.HiveMetastoreTablePropertiesMixin
- MAX_TABLE_COMMENT_LENGTH
- MAX_COLUMN_COMMENT_LENGTH
- sqlmesh.core.engine_adapter.mixins.RowDiffMixin
- MAX_TIMESTAMP_PRECISION
- concat_columns
- normalize_value
- sqlmesh.core.engine_adapter.mixins.GrantsFromInfoSchemaMixin
- CURRENT_USER_OR_ROLE_EXPRESSION
- SUPPORTS_MULTIPLE_GRANT_PRINCIPALS
- GRANT_INFORMATION_SCHEMA_TABLE_NAME
- sqlmesh.core.engine_adapter.base.EngineAdapter
- DEFAULT_BATCH_SIZE
- DATA_OBJECT_FILTER_BATCH_SIZE
- SUPPORTS_INDEXES
- SUPPORTS_VIEW_SCHEMA
- SUPPORTS_MANAGED_MODELS
- SUPPORTS_CREATE_DROP_CATALOG
- SUPPORTS_TUPLE_IN
- HAS_VIEW_BINDING
- DEFAULT_CATALOG_TYPE
- MAX_IDENTIFIER_LENGTH
- ATTACH_CORRELATION_ID
- SUPPORTS_QUERY_EXECUTION_TRACKING
- SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
- dialect
- correlation_id
- with_settings
- snowpark
- bigframe
- comments_enabled
- schema_differ
- engine_run_mode
- recycle
- get_catalog_type
- get_catalog_type_from_table
- current_catalog_type
- replace_query
- create_index
- create_table
- create_managed_table
- ctas
- create_table_like
- drop_data_object
- drop_table
- drop_managed_table
- get_alter_operations
- alter_table
- create_view
- create_schema
- drop_schema
- drop_view
- create_catalog
- drop_catalog
- columns
- table_exists
- delete_from
- insert_append
- insert_overwrite_by_partition
- insert_overwrite_by_time_partition
- update_table
- scd_type_2_by_time
- scd_type_2_by_column
- merge
- rename_table
- get_data_objects
- fetchone
- fetchall
- wap_enabled
- sync_grants_config
- transaction
- session
- execute
- temp_table
- drop_data_object_on_type_mismatch
- ensure_nulls_for_unmatched_after_join
- use_server_nulls_for_unmatched_after_join
- ping
- get_table_last_modified_ts