sqlmesh.core.engine_adapter.spark
1from __future__ import annotations 2 3import logging 4import typing as t 5from functools import partial 6 7from sqlglot import exp 8 9from sqlmesh.core.dialect import to_schema 10from sqlmesh.core.engine_adapter.mixins import ( 11 GetCurrentCatalogFromFunctionMixin, 12 HiveMetastoreTablePropertiesMixin, 13 RowDiffMixin, 14) 15from sqlmesh.core.engine_adapter.shared import ( 16 CatalogSupport, 17 CommentCreationTable, 18 CommentCreationView, 19 DataObject, 20 DataObjectType, 21 InsertOverwriteStrategy, 22 SourceQuery, 23 set_catalog, 24) 25from sqlmesh.utils import classproperty, get_source_columns_to_types 26from sqlmesh.utils.errors import SQLMeshError 27 28if t.TYPE_CHECKING: 29 import pandas as pd 30 from pyspark.sql import types as spark_types 31 32 from sqlmesh.core._typing import SchemaName, TableName 33 from sqlmesh.core.engine_adapter._typing import ( 34 DF, 35 PySparkDataFrame, 36 PySparkSession, 37 Query, 38 ) 39 from sqlmesh.core.engine_adapter.base import QueryOrDF 40 from sqlmesh.engines.spark.db_api.spark_session import SparkSessionConnection 41 42 43logger = logging.getLogger(__name__) 44 45 46@set_catalog() 47class SparkEngineAdapter( 48 GetCurrentCatalogFromFunctionMixin, HiveMetastoreTablePropertiesMixin, RowDiffMixin 49): 50 DIALECT = "spark" 51 SUPPORTS_TRANSACTIONS = False 52 INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INSERT_OVERWRITE 53 COMMENT_CREATION_TABLE = CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS 54 COMMENT_CREATION_VIEW = CommentCreationView.IN_SCHEMA_DEF_NO_COMMANDS 55 # Note: Some formats (like Delta and Iceberg) support REPLACE TABLE but since we don't 56 # currently check for storage formats we say we don't support REPLACE TABLE 57 SUPPORTS_REPLACE_TABLE = False 58 QUOTE_IDENTIFIERS_IN_VIEWS = False 59 SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"] 60 61 WAP_PREFIX = "wap_" 62 BRANCH_PREFIX = "branch_" 63 SCHEMA_DIFFER_KWARGS = { 64 "parameterized_type_defaults": { 65 # default decimal precision varies across backends 66 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)], 67 }, 68 } 69 70 @property 71 def connection(self) -> SparkSessionConnection: 72 return self._connection_pool.get() 73 74 @property 75 def spark(self) -> PySparkSession: 76 return self.connection.spark 77 78 @property 79 def _use_spark_session(self) -> bool: 80 return True 81 82 @property 83 def use_serverless(self) -> bool: 84 return False 85 86 @property 87 def catalog_support(self) -> CatalogSupport: 88 return CatalogSupport.FULL_SUPPORT 89 90 @classproperty 91 def _sqlglot_to_spark_primitive_mapping(self) -> t.Dict[t.Any, t.Any]: 92 from pyspark.sql import types as spark_types 93 94 return { 95 exp.DataType.Type.TINYINT: spark_types.ByteType, 96 exp.DataType.Type.SMALLINT: spark_types.ShortType, 97 exp.DataType.Type.INT: spark_types.IntegerType, 98 exp.DataType.Type.BIGINT: spark_types.LongType, 99 exp.DataType.Type.FLOAT: spark_types.FloatType, 100 exp.DataType.Type.DOUBLE: spark_types.DoubleType, 101 exp.DataType.Type.DECIMAL: spark_types.DecimalType, 102 exp.DataType.Type.VARCHAR: spark_types.StringType, 103 exp.DataType.Type.CHAR: spark_types.StringType, 104 exp.DataType.Type.TEXT: spark_types.StringType, 105 exp.DataType.Type.BINARY: spark_types.BinaryType, 106 exp.DataType.Type.BOOLEAN: spark_types.BooleanType, 107 exp.DataType.Type.DATE: spark_types.DateType, 108 exp.DataType.Type.TIMESTAMPNTZ: spark_types.TimestampNTZType, 109 exp.DataType.Type.DATETIME: spark_types.TimestampNTZType, 110 exp.DataType.Type.TIMESTAMPLTZ: spark_types.TimestampType, 111 exp.DataType.Type.TIMESTAMP: spark_types.TimestampType, 112 exp.DataType.Type.TIMESTAMPTZ: spark_types.TimestampType, 113 } 114 115 @classproperty 116 def _sqlglot_to_spark_complex_mapping(self) -> t.Dict[t.Any, t.Any]: 117 from pyspark.sql import types as spark_types 118 119 return { 120 exp.DataType.Type.ARRAY: spark_types.ArrayType, 121 exp.DataType.Type.MAP: spark_types.MapType, 122 exp.DataType.Type.STRUCT: spark_types.StructType, 123 } 124 125 @classproperty 126 def _spark_to_sqlglot_primitive_mapping(self) -> t.Dict[t.Any, t.Any]: 127 return {v: k for k, v in self._sqlglot_to_spark_primitive_mapping.items()} 128 129 @classproperty 130 def _spark_to_sqlglot_complex_mapping(self) -> t.Dict[t.Any, t.Any]: 131 return {v: k for k, v in self._sqlglot_to_spark_complex_mapping.items()} 132 133 @classmethod 134 def spark_to_sqlglot_types(cls, input: spark_types.StructType) -> t.Dict[str, exp.DataType]: 135 from pyspark.sql import types as spark_types 136 137 def spark_complex_to_sqlglot_complex( 138 complex_type: t.Union[ 139 spark_types.StructType, spark_types.ArrayType, spark_types.MapType 140 ], 141 ) -> exp.DataType: 142 def get_fields( 143 complex_type: t.Union[ 144 spark_types.StructType, spark_types.ArrayType, spark_types.MapType 145 ], 146 ) -> t.Sequence[spark_types.DataType]: 147 if isinstance(complex_type, spark_types.StructType): 148 return complex_type.fields 149 if isinstance(complex_type, spark_types.ArrayType): 150 return [complex_type.elementType] 151 if isinstance(complex_type, spark_types.MapType): 152 return [complex_type.keyType, complex_type.valueType] 153 raise SQLMeshError(f"Unsupported complex type: {complex_type}") 154 155 expressions: t.List[t.Union[exp.ColumnDef, exp.DataType]] = [] 156 fields = get_fields(complex_type) 157 for field in fields: 158 if isinstance(field, (spark_types.StructType, spark_types.MapType)): 159 expressions.append(spark_complex_to_sqlglot_complex(field)) 160 elif isinstance(field, spark_types.StructField): 161 sqlglot_data_type = cls._spark_to_sqlglot_primitive_mapping.get( 162 type(field.dataType) 163 ) or spark_complex_to_sqlglot_complex( 164 field.dataType # type: ignore 165 ) 166 kind = ( 167 sqlglot_data_type 168 if isinstance(sqlglot_data_type, exp.DataType) 169 else exp.DataType(this=sqlglot_data_type) 170 ) 171 expressions.append(exp.ColumnDef(this=exp.to_identifier(field.name), kind=kind)) 172 else: 173 kind = exp.DataType(this=cls._spark_to_sqlglot_primitive_mapping[type(field)]) 174 expressions.append(kind) 175 dtype = cls._spark_to_sqlglot_complex_mapping[type(complex_type)] 176 return exp.DataType( 177 this=dtype, 178 expressions=expressions, 179 nested=True, 180 ) 181 182 resp = spark_complex_to_sqlglot_complex(input) 183 return {column_def.this.name: column_def.args["kind"] for column_def in resp.expressions} 184 185 @classmethod 186 def sqlglot_to_spark_types(cls, input: t.Dict[str, exp.DataType]) -> spark_types.StructType: 187 from pyspark.sql import types as spark_types 188 189 def sqlglot_complex_to_spark_complex(complex_type: exp.DataType) -> spark_types.DataType: 190 is_struct = complex_type.is_type(exp.DataType.Type.STRUCT) 191 expressions = [] 192 for column_def in complex_type.expressions: 193 col_name = column_def.this.name if is_struct else None 194 data_type = column_def.args["kind"] if is_struct else column_def 195 primitive_func = cls._sqlglot_to_spark_primitive_mapping.get(data_type.this) 196 type_func = ( 197 primitive_func 198 if primitive_func 199 else partial(sqlglot_complex_to_spark_complex, data_type) 200 ) 201 if is_struct: 202 expressions.append(spark_types.StructField(col_name, type_func())) # type: ignore 203 else: 204 expressions.append(type_func()) # type: ignore 205 klass = cls._sqlglot_to_spark_complex_mapping[complex_type.this] 206 if is_struct: 207 return klass(expressions) 208 return klass(*expressions) 209 210 return t.cast( 211 spark_types.StructType, 212 sqlglot_complex_to_spark_complex( 213 exp.DataType( 214 this=exp.DataType.Type.STRUCT, 215 expressions=[ 216 exp.ColumnDef(this=exp.to_identifier(column), kind=data_type) 217 for column, data_type in input.items() 218 ], 219 ) 220 ), 221 ) 222 223 @classmethod 224 def is_pyspark_df(cls, value: t.Any) -> bool: 225 return hasattr(value, "sparkSession") 226 227 @classmethod 228 def try_get_pyspark_df(cls, value: t.Any) -> t.Optional[PySparkDataFrame]: 229 if cls.is_pyspark_df(value): 230 return value 231 return None 232 233 @classmethod 234 def try_get_pandas_df(cls, value: t.Any) -> t.Optional[pd.DataFrame]: 235 import pandas as pd 236 237 if isinstance(value, pd.DataFrame): 238 return value 239 return None 240 241 @t.overload 242 def _columns_to_types( 243 self, 244 query_or_df: DF, 245 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 246 source_columns: t.Optional[t.List[str]] = None, 247 ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... 248 249 @t.overload 250 def _columns_to_types( 251 self, 252 query_or_df: Query, 253 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 254 source_columns: t.Optional[t.List[str]] = None, 255 ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... 256 257 def _columns_to_types( 258 self, 259 query_or_df: QueryOrDF, 260 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 261 source_columns: t.Optional[t.List[str]] = None, 262 ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: 263 if target_columns_to_types: 264 return target_columns_to_types, list(source_columns or target_columns_to_types) 265 if self.is_pyspark_df(query_or_df): 266 from pyspark.sql import DataFrame 267 268 target_columns_to_types = self.spark_to_sqlglot_types( 269 t.cast(DataFrame, query_or_df).schema 270 ) 271 return target_columns_to_types, list(source_columns or target_columns_to_types) 272 return super()._columns_to_types( 273 query_or_df, target_columns_to_types, source_columns=source_columns 274 ) 275 276 def _df_to_source_queries( 277 self, 278 df: DF, 279 target_columns_to_types: t.Dict[str, exp.DataType], 280 batch_size: int, 281 target_table: TableName, 282 source_columns: t.Optional[t.List[str]] = None, 283 ) -> t.List[SourceQuery]: 284 df = self._ensure_pyspark_df(df, target_columns_to_types, source_columns=source_columns) 285 286 def query_factory() -> Query: 287 temp_table = self._get_temp_table(target_table or "spark", table_only=True) 288 df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore 289 temp_table.set("db", "global_temp") 290 return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table) 291 292 return [SourceQuery(query_factory=query_factory)] 293 294 def _ensure_pyspark_df( 295 self, 296 generic_df: DF, 297 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 298 source_columns: t.Optional[t.List[str]] = None, 299 ) -> PySparkDataFrame: 300 pyspark_df = self.try_get_pyspark_df(generic_df) 301 if not pyspark_df: 302 df = self.try_get_pandas_df(generic_df) 303 if df is None: 304 raise SQLMeshError( 305 "Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame" 306 ) 307 308 if target_columns_to_types: 309 source_columns_to_types = get_source_columns_to_types( 310 target_columns_to_types, source_columns 311 ) 312 # ensure Pandas dataframe column order matches columns_to_types 313 df = df[list(source_columns_to_types)] 314 else: 315 source_columns_to_types = None 316 kwargs = ( 317 dict(schema=self.sqlglot_to_spark_types(source_columns_to_types)) 318 if source_columns_to_types 319 else {} 320 ) 321 pyspark_df = self.spark.createDataFrame(df, **kwargs) # type: ignore 322 if target_columns_to_types: 323 select_columns = self._casted_columns( 324 target_columns_to_types, source_columns=source_columns 325 ) 326 pyspark_df = pyspark_df.selectExpr(*[x.sql(self.dialect) for x in select_columns]) # type: ignore 327 return pyspark_df 328 329 def _get_temp_table( 330 self, table: TableName, table_only: bool = False, quoted: bool = True 331 ) -> exp.Table: 332 """ 333 Returns the name of the temp table that should be used for the given table name. 334 """ 335 table = super()._get_temp_table(table, table_only=table_only) 336 table_name_id = table.args["this"] 337 # Spark with local filesystem has an issue with temp tables that start with __temp so 338 # we update here to remove the leading double underscore 339 table_name_id.set("this", table_name_id.this.replace("__temp_", "temp_")) 340 return table 341 342 def fetchdf( 343 self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False 344 ) -> pd.DataFrame: 345 return self.fetch_pyspark_df(query, quote_identifiers=quote_identifiers).toPandas() 346 347 def fetch_pyspark_df( 348 self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False 349 ) -> PySparkDataFrame: 350 return self._ensure_pyspark_df( 351 self._fetch_native_df(query, quote_identifiers=quote_identifiers) 352 ) 353 354 def _get_data_objects( 355 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 356 ) -> t.List[DataObject]: 357 schema_name = to_schema(schema_name).sql(dialect=self.dialect) 358 pattern = "*" if object_names is None else "|".join(object_names) 359 sql = f"SHOW TABLE EXTENDED IN {schema_name} LIKE '{pattern}'" 360 try: 361 results = ( 362 self.fetch_pyspark_df(sql).collect() 363 if self._use_spark_session 364 else self.fetchdf(sql).to_dict("records") 365 ) 366 # Improvement: Figure out all the different exceptions we could get from executing a query either with or 367 # without a Spark Session. In addition Databricks would need to be updated to handle it's own exceptions. 368 # Therefore just doing except Exception for now. 369 except Exception: 370 return [] 371 data_objects = [] 372 catalog = self.get_current_catalog() 373 for row in results: # type: ignore 374 row_dict = row.asDict() if not isinstance(row, dict) else row 375 if row_dict.get("isTemporary"): 376 continue 377 schema = row_dict.get("namespace") or row_dict.get("database") 378 data_objects.append( 379 DataObject( 380 catalog=catalog, 381 schema=schema, 382 name=row_dict["tableName"], 383 type=( 384 DataObjectType.VIEW 385 if "Type: VIEW" in row_dict["information"] 386 else DataObjectType.TABLE 387 ), 388 ) 389 ) 390 return data_objects 391 392 def get_current_catalog(self) -> t.Optional[str]: 393 if self._use_spark_session: 394 return self.connection.get_current_catalog() 395 return super().get_current_catalog() 396 397 def set_current_catalog(self, catalog_name: str) -> None: 398 self.connection.set_current_catalog(catalog_name) 399 400 def _get_current_schema(self) -> str: 401 if self._use_spark_session: 402 return self.spark.catalog.currentDatabase() 403 return self.fetchone(exp.select(exp.func("current_database")))[0] # type: ignore 404 405 def get_data_object( 406 self, target_name: TableName, safe_to_cache: bool = False 407 ) -> t.Optional[DataObject]: 408 target_table = exp.to_table(target_name) 409 if isinstance(target_table.this, exp.Dot) and target_table.this.expression.name.startswith( 410 f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}" 411 ): 412 # Exclude the branch name 413 target_table.set("this", target_table.this.this) 414 return super().get_data_object(target_table, safe_to_cache=safe_to_cache) 415 416 def create_state_table( 417 self, 418 table_name: str, 419 target_columns_to_types: t.Dict[str, exp.DataType], 420 primary_key: t.Optional[t.Tuple[str, ...]] = None, 421 ) -> None: 422 self.create_table( 423 table_name, 424 target_columns_to_types, 425 partitioned_by=[exp.column(x) for x in primary_key] if primary_key else None, 426 ) 427 428 def _native_df_to_pandas_df( 429 self, 430 query_or_df: QueryOrDF, 431 ) -> t.Union[Query, pd.DataFrame]: 432 if pyspark_df := self.try_get_pyspark_df(query_or_df): 433 return pyspark_df.toPandas() 434 435 return super()._native_df_to_pandas_df(query_or_df) 436 437 def _create_table( 438 self, 439 table_name_or_schema: t.Union[exp.Schema, TableName], 440 expression: t.Optional[exp.Expr], 441 exists: bool = True, 442 replace: bool = False, 443 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 444 table_description: t.Optional[str] = None, 445 column_descriptions: t.Optional[t.Dict[str, str]] = None, 446 table_kind: t.Optional[str] = None, 447 track_rows_processed: bool = True, 448 **kwargs: t.Any, 449 ) -> None: 450 table_name = ( 451 table_name_or_schema.this 452 if isinstance(table_name_or_schema, exp.Schema) 453 else exp.to_table(table_name_or_schema) 454 ) 455 # Spark doesn't support creating a wap table DDL. Therefore we check if this is a wap table and if it is, 456 # this is not a replace, and the table already exists then we can safely just return. Otherwise we let it error. 457 if not expression and isinstance(table_name.this, exp.Dot): 458 wap_id = table_name.this.parts[-1].name 459 if wap_id.startswith(f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"): 460 table_name.set("this", table_name.this.this) 461 462 do_dummy_insert = False 463 if self.wap_enabled: 464 wap_supported = ( 465 kwargs.get("storage_format") or "" 466 ).lower() == "iceberg" or self.wap_supported(table_name) 467 do_dummy_insert = ( 468 False if not wap_supported or not exists else not self.table_exists(table_name) 469 ) 470 super()._create_table( 471 table_name_or_schema, 472 expression, 473 exists=exists, 474 replace=replace, 475 target_columns_to_types=target_columns_to_types, 476 table_description=table_description, 477 column_descriptions=column_descriptions, 478 track_rows_processed=track_rows_processed, 479 **kwargs, 480 ) 481 table_name = ( 482 table_name_or_schema.this 483 if isinstance(table_name_or_schema, exp.Schema) 484 else exp.to_table(table_name_or_schema) 485 ) 486 if do_dummy_insert: 487 # Performing a dummy insert to create a dummy snapshot for Iceberg tables 488 # to workaround https://github.com/apache/iceberg/issues/8849. 489 dummy_insert = exp.insert(exp.select("*").from_(table_name), table_name) 490 self.execute(dummy_insert) 491 492 def wap_supported(self, table_name: TableName) -> bool: 493 fqn = self._ensure_fqn(table_name) 494 return ( 495 self.spark.conf.get(f"spark.sql.catalog.{fqn.catalog}") 496 == "org.apache.iceberg.spark.SparkCatalog" 497 ) 498 499 def wap_table_name(self, table_name: TableName, wap_id: str) -> str: 500 branch_name = self._wap_branch_name(wap_id) 501 fqn = self._ensure_fqn(table_name) 502 return exp.Dot.build([fqn, exp.to_identifier(f"{self.BRANCH_PREFIX}{branch_name}")]).sql( 503 dialect=self.dialect 504 ) 505 506 def wap_prepare(self, table_name: TableName, wap_id: str) -> str: 507 branch_name = self._wap_branch_name(wap_id) 508 fqn = self._ensure_fqn(table_name) 509 self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} CREATE BRANCH {branch_name}") 510 return self.wap_table_name(table_name, wap_id) 511 512 def wap_publish(self, table_name: TableName, wap_id: str) -> None: 513 branch_name = self._wap_branch_name(wap_id) 514 fqn = self._ensure_fqn(table_name) 515 516 get_snapshot_id_query = ( 517 exp.select("snapshot_id") 518 .from_(exp.Dot.build([fqn, exp.to_identifier("refs")])) 519 .where(exp.column("name").eq(branch_name)) 520 ) 521 iceberg_snapshot_ids = self.fetchall(get_snapshot_id_query) 522 if not iceberg_snapshot_ids: 523 raise SQLMeshError(f"Could not find Iceberg branch '{branch_name}'.") 524 iceberg_snapshot_id = iceberg_snapshot_ids[0][0] 525 526 logger.info( 527 "Cherry-picking Iceberg snapshot %s into table '%s'...", iceberg_snapshot_id, fqn 528 ) 529 530 self.execute( 531 f"CALL {fqn.catalog}.system.cherrypick_snapshot('{fqn.db}.{fqn.name}', {iceberg_snapshot_id})" 532 ) 533 self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} DROP BRANCH {branch_name}") 534 535 def _ensure_fqn(self, table_name: TableName) -> exp.Table: 536 if isinstance(table_name, exp.Table): 537 table_name = table_name.copy() 538 table = exp.to_table(table_name, dialect=self.dialect) 539 if not table.catalog: 540 table.set("catalog", self.get_current_catalog()) 541 if not table.db: 542 table.set("db", self._get_current_schema()) 543 return table 544 545 def _build_create_comment_column_exp( 546 self, table: exp.Table, column_name: str, column_comment: str, table_kind: str = "TABLE" 547 ) -> exp.Comment | str: 548 table_sql = table.sql(dialect=self.dialect, identify=True) 549 column_sql = exp.column(column_name).sql(dialect=self.dialect, identify=True) 550 551 truncated_comment = self._truncate_column_comment(column_comment) 552 comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect) 553 554 return f"ALTER TABLE {table_sql} ALTER COLUMN {column_sql} COMMENT {comment_sql}" 555 556 @classmethod 557 def _wap_branch_name(cls, wap_id: str) -> str: 558 return f"{cls.WAP_PREFIX}{wap_id}"
logger =
<Logger sqlmesh.core.engine_adapter.spark (WARNING)>
@set_catalog()
class
SparkEngineAdapter47@set_catalog() 48class SparkEngineAdapter( 49 GetCurrentCatalogFromFunctionMixin, HiveMetastoreTablePropertiesMixin, RowDiffMixin 50): 51 DIALECT = "spark" 52 SUPPORTS_TRANSACTIONS = False 53 INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INSERT_OVERWRITE 54 COMMENT_CREATION_TABLE = CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS 55 COMMENT_CREATION_VIEW = CommentCreationView.IN_SCHEMA_DEF_NO_COMMANDS 56 # Note: Some formats (like Delta and Iceberg) support REPLACE TABLE but since we don't 57 # currently check for storage formats we say we don't support REPLACE TABLE 58 SUPPORTS_REPLACE_TABLE = False 59 QUOTE_IDENTIFIERS_IN_VIEWS = False 60 SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"] 61 62 WAP_PREFIX = "wap_" 63 BRANCH_PREFIX = "branch_" 64 SCHEMA_DIFFER_KWARGS = { 65 "parameterized_type_defaults": { 66 # default decimal precision varies across backends 67 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)], 68 }, 69 } 70 71 @property 72 def connection(self) -> SparkSessionConnection: 73 return self._connection_pool.get() 74 75 @property 76 def spark(self) -> PySparkSession: 77 return self.connection.spark 78 79 @property 80 def _use_spark_session(self) -> bool: 81 return True 82 83 @property 84 def use_serverless(self) -> bool: 85 return False 86 87 @property 88 def catalog_support(self) -> CatalogSupport: 89 return CatalogSupport.FULL_SUPPORT 90 91 @classproperty 92 def _sqlglot_to_spark_primitive_mapping(self) -> t.Dict[t.Any, t.Any]: 93 from pyspark.sql import types as spark_types 94 95 return { 96 exp.DataType.Type.TINYINT: spark_types.ByteType, 97 exp.DataType.Type.SMALLINT: spark_types.ShortType, 98 exp.DataType.Type.INT: spark_types.IntegerType, 99 exp.DataType.Type.BIGINT: spark_types.LongType, 100 exp.DataType.Type.FLOAT: spark_types.FloatType, 101 exp.DataType.Type.DOUBLE: spark_types.DoubleType, 102 exp.DataType.Type.DECIMAL: spark_types.DecimalType, 103 exp.DataType.Type.VARCHAR: spark_types.StringType, 104 exp.DataType.Type.CHAR: spark_types.StringType, 105 exp.DataType.Type.TEXT: spark_types.StringType, 106 exp.DataType.Type.BINARY: spark_types.BinaryType, 107 exp.DataType.Type.BOOLEAN: spark_types.BooleanType, 108 exp.DataType.Type.DATE: spark_types.DateType, 109 exp.DataType.Type.TIMESTAMPNTZ: spark_types.TimestampNTZType, 110 exp.DataType.Type.DATETIME: spark_types.TimestampNTZType, 111 exp.DataType.Type.TIMESTAMPLTZ: spark_types.TimestampType, 112 exp.DataType.Type.TIMESTAMP: spark_types.TimestampType, 113 exp.DataType.Type.TIMESTAMPTZ: spark_types.TimestampType, 114 } 115 116 @classproperty 117 def _sqlglot_to_spark_complex_mapping(self) -> t.Dict[t.Any, t.Any]: 118 from pyspark.sql import types as spark_types 119 120 return { 121 exp.DataType.Type.ARRAY: spark_types.ArrayType, 122 exp.DataType.Type.MAP: spark_types.MapType, 123 exp.DataType.Type.STRUCT: spark_types.StructType, 124 } 125 126 @classproperty 127 def _spark_to_sqlglot_primitive_mapping(self) -> t.Dict[t.Any, t.Any]: 128 return {v: k for k, v in self._sqlglot_to_spark_primitive_mapping.items()} 129 130 @classproperty 131 def _spark_to_sqlglot_complex_mapping(self) -> t.Dict[t.Any, t.Any]: 132 return {v: k for k, v in self._sqlglot_to_spark_complex_mapping.items()} 133 134 @classmethod 135 def spark_to_sqlglot_types(cls, input: spark_types.StructType) -> t.Dict[str, exp.DataType]: 136 from pyspark.sql import types as spark_types 137 138 def spark_complex_to_sqlglot_complex( 139 complex_type: t.Union[ 140 spark_types.StructType, spark_types.ArrayType, spark_types.MapType 141 ], 142 ) -> exp.DataType: 143 def get_fields( 144 complex_type: t.Union[ 145 spark_types.StructType, spark_types.ArrayType, spark_types.MapType 146 ], 147 ) -> t.Sequence[spark_types.DataType]: 148 if isinstance(complex_type, spark_types.StructType): 149 return complex_type.fields 150 if isinstance(complex_type, spark_types.ArrayType): 151 return [complex_type.elementType] 152 if isinstance(complex_type, spark_types.MapType): 153 return [complex_type.keyType, complex_type.valueType] 154 raise SQLMeshError(f"Unsupported complex type: {complex_type}") 155 156 expressions: t.List[t.Union[exp.ColumnDef, exp.DataType]] = [] 157 fields = get_fields(complex_type) 158 for field in fields: 159 if isinstance(field, (spark_types.StructType, spark_types.MapType)): 160 expressions.append(spark_complex_to_sqlglot_complex(field)) 161 elif isinstance(field, spark_types.StructField): 162 sqlglot_data_type = cls._spark_to_sqlglot_primitive_mapping.get( 163 type(field.dataType) 164 ) or spark_complex_to_sqlglot_complex( 165 field.dataType # type: ignore 166 ) 167 kind = ( 168 sqlglot_data_type 169 if isinstance(sqlglot_data_type, exp.DataType) 170 else exp.DataType(this=sqlglot_data_type) 171 ) 172 expressions.append(exp.ColumnDef(this=exp.to_identifier(field.name), kind=kind)) 173 else: 174 kind = exp.DataType(this=cls._spark_to_sqlglot_primitive_mapping[type(field)]) 175 expressions.append(kind) 176 dtype = cls._spark_to_sqlglot_complex_mapping[type(complex_type)] 177 return exp.DataType( 178 this=dtype, 179 expressions=expressions, 180 nested=True, 181 ) 182 183 resp = spark_complex_to_sqlglot_complex(input) 184 return {column_def.this.name: column_def.args["kind"] for column_def in resp.expressions} 185 186 @classmethod 187 def sqlglot_to_spark_types(cls, input: t.Dict[str, exp.DataType]) -> spark_types.StructType: 188 from pyspark.sql import types as spark_types 189 190 def sqlglot_complex_to_spark_complex(complex_type: exp.DataType) -> spark_types.DataType: 191 is_struct = complex_type.is_type(exp.DataType.Type.STRUCT) 192 expressions = [] 193 for column_def in complex_type.expressions: 194 col_name = column_def.this.name if is_struct else None 195 data_type = column_def.args["kind"] if is_struct else column_def 196 primitive_func = cls._sqlglot_to_spark_primitive_mapping.get(data_type.this) 197 type_func = ( 198 primitive_func 199 if primitive_func 200 else partial(sqlglot_complex_to_spark_complex, data_type) 201 ) 202 if is_struct: 203 expressions.append(spark_types.StructField(col_name, type_func())) # type: ignore 204 else: 205 expressions.append(type_func()) # type: ignore 206 klass = cls._sqlglot_to_spark_complex_mapping[complex_type.this] 207 if is_struct: 208 return klass(expressions) 209 return klass(*expressions) 210 211 return t.cast( 212 spark_types.StructType, 213 sqlglot_complex_to_spark_complex( 214 exp.DataType( 215 this=exp.DataType.Type.STRUCT, 216 expressions=[ 217 exp.ColumnDef(this=exp.to_identifier(column), kind=data_type) 218 for column, data_type in input.items() 219 ], 220 ) 221 ), 222 ) 223 224 @classmethod 225 def is_pyspark_df(cls, value: t.Any) -> bool: 226 return hasattr(value, "sparkSession") 227 228 @classmethod 229 def try_get_pyspark_df(cls, value: t.Any) -> t.Optional[PySparkDataFrame]: 230 if cls.is_pyspark_df(value): 231 return value 232 return None 233 234 @classmethod 235 def try_get_pandas_df(cls, value: t.Any) -> t.Optional[pd.DataFrame]: 236 import pandas as pd 237 238 if isinstance(value, pd.DataFrame): 239 return value 240 return None 241 242 @t.overload 243 def _columns_to_types( 244 self, 245 query_or_df: DF, 246 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 247 source_columns: t.Optional[t.List[str]] = None, 248 ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... 249 250 @t.overload 251 def _columns_to_types( 252 self, 253 query_or_df: Query, 254 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 255 source_columns: t.Optional[t.List[str]] = None, 256 ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... 257 258 def _columns_to_types( 259 self, 260 query_or_df: QueryOrDF, 261 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 262 source_columns: t.Optional[t.List[str]] = None, 263 ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: 264 if target_columns_to_types: 265 return target_columns_to_types, list(source_columns or target_columns_to_types) 266 if self.is_pyspark_df(query_or_df): 267 from pyspark.sql import DataFrame 268 269 target_columns_to_types = self.spark_to_sqlglot_types( 270 t.cast(DataFrame, query_or_df).schema 271 ) 272 return target_columns_to_types, list(source_columns or target_columns_to_types) 273 return super()._columns_to_types( 274 query_or_df, target_columns_to_types, source_columns=source_columns 275 ) 276 277 def _df_to_source_queries( 278 self, 279 df: DF, 280 target_columns_to_types: t.Dict[str, exp.DataType], 281 batch_size: int, 282 target_table: TableName, 283 source_columns: t.Optional[t.List[str]] = None, 284 ) -> t.List[SourceQuery]: 285 df = self._ensure_pyspark_df(df, target_columns_to_types, source_columns=source_columns) 286 287 def query_factory() -> Query: 288 temp_table = self._get_temp_table(target_table or "spark", table_only=True) 289 df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore 290 temp_table.set("db", "global_temp") 291 return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table) 292 293 return [SourceQuery(query_factory=query_factory)] 294 295 def _ensure_pyspark_df( 296 self, 297 generic_df: DF, 298 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 299 source_columns: t.Optional[t.List[str]] = None, 300 ) -> PySparkDataFrame: 301 pyspark_df = self.try_get_pyspark_df(generic_df) 302 if not pyspark_df: 303 df = self.try_get_pandas_df(generic_df) 304 if df is None: 305 raise SQLMeshError( 306 "Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame" 307 ) 308 309 if target_columns_to_types: 310 source_columns_to_types = get_source_columns_to_types( 311 target_columns_to_types, source_columns 312 ) 313 # ensure Pandas dataframe column order matches columns_to_types 314 df = df[list(source_columns_to_types)] 315 else: 316 source_columns_to_types = None 317 kwargs = ( 318 dict(schema=self.sqlglot_to_spark_types(source_columns_to_types)) 319 if source_columns_to_types 320 else {} 321 ) 322 pyspark_df = self.spark.createDataFrame(df, **kwargs) # type: ignore 323 if target_columns_to_types: 324 select_columns = self._casted_columns( 325 target_columns_to_types, source_columns=source_columns 326 ) 327 pyspark_df = pyspark_df.selectExpr(*[x.sql(self.dialect) for x in select_columns]) # type: ignore 328 return pyspark_df 329 330 def _get_temp_table( 331 self, table: TableName, table_only: bool = False, quoted: bool = True 332 ) -> exp.Table: 333 """ 334 Returns the name of the temp table that should be used for the given table name. 335 """ 336 table = super()._get_temp_table(table, table_only=table_only) 337 table_name_id = table.args["this"] 338 # Spark with local filesystem has an issue with temp tables that start with __temp so 339 # we update here to remove the leading double underscore 340 table_name_id.set("this", table_name_id.this.replace("__temp_", "temp_")) 341 return table 342 343 def fetchdf( 344 self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False 345 ) -> pd.DataFrame: 346 return self.fetch_pyspark_df(query, quote_identifiers=quote_identifiers).toPandas() 347 348 def fetch_pyspark_df( 349 self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False 350 ) -> PySparkDataFrame: 351 return self._ensure_pyspark_df( 352 self._fetch_native_df(query, quote_identifiers=quote_identifiers) 353 ) 354 355 def _get_data_objects( 356 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 357 ) -> t.List[DataObject]: 358 schema_name = to_schema(schema_name).sql(dialect=self.dialect) 359 pattern = "*" if object_names is None else "|".join(object_names) 360 sql = f"SHOW TABLE EXTENDED IN {schema_name} LIKE '{pattern}'" 361 try: 362 results = ( 363 self.fetch_pyspark_df(sql).collect() 364 if self._use_spark_session 365 else self.fetchdf(sql).to_dict("records") 366 ) 367 # Improvement: Figure out all the different exceptions we could get from executing a query either with or 368 # without a Spark Session. In addition Databricks would need to be updated to handle it's own exceptions. 369 # Therefore just doing except Exception for now. 370 except Exception: 371 return [] 372 data_objects = [] 373 catalog = self.get_current_catalog() 374 for row in results: # type: ignore 375 row_dict = row.asDict() if not isinstance(row, dict) else row 376 if row_dict.get("isTemporary"): 377 continue 378 schema = row_dict.get("namespace") or row_dict.get("database") 379 data_objects.append( 380 DataObject( 381 catalog=catalog, 382 schema=schema, 383 name=row_dict["tableName"], 384 type=( 385 DataObjectType.VIEW 386 if "Type: VIEW" in row_dict["information"] 387 else DataObjectType.TABLE 388 ), 389 ) 390 ) 391 return data_objects 392 393 def get_current_catalog(self) -> t.Optional[str]: 394 if self._use_spark_session: 395 return self.connection.get_current_catalog() 396 return super().get_current_catalog() 397 398 def set_current_catalog(self, catalog_name: str) -> None: 399 self.connection.set_current_catalog(catalog_name) 400 401 def _get_current_schema(self) -> str: 402 if self._use_spark_session: 403 return self.spark.catalog.currentDatabase() 404 return self.fetchone(exp.select(exp.func("current_database")))[0] # type: ignore 405 406 def get_data_object( 407 self, target_name: TableName, safe_to_cache: bool = False 408 ) -> t.Optional[DataObject]: 409 target_table = exp.to_table(target_name) 410 if isinstance(target_table.this, exp.Dot) and target_table.this.expression.name.startswith( 411 f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}" 412 ): 413 # Exclude the branch name 414 target_table.set("this", target_table.this.this) 415 return super().get_data_object(target_table, safe_to_cache=safe_to_cache) 416 417 def create_state_table( 418 self, 419 table_name: str, 420 target_columns_to_types: t.Dict[str, exp.DataType], 421 primary_key: t.Optional[t.Tuple[str, ...]] = None, 422 ) -> None: 423 self.create_table( 424 table_name, 425 target_columns_to_types, 426 partitioned_by=[exp.column(x) for x in primary_key] if primary_key else None, 427 ) 428 429 def _native_df_to_pandas_df( 430 self, 431 query_or_df: QueryOrDF, 432 ) -> t.Union[Query, pd.DataFrame]: 433 if pyspark_df := self.try_get_pyspark_df(query_or_df): 434 return pyspark_df.toPandas() 435 436 return super()._native_df_to_pandas_df(query_or_df) 437 438 def _create_table( 439 self, 440 table_name_or_schema: t.Union[exp.Schema, TableName], 441 expression: t.Optional[exp.Expr], 442 exists: bool = True, 443 replace: bool = False, 444 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 445 table_description: t.Optional[str] = None, 446 column_descriptions: t.Optional[t.Dict[str, str]] = None, 447 table_kind: t.Optional[str] = None, 448 track_rows_processed: bool = True, 449 **kwargs: t.Any, 450 ) -> None: 451 table_name = ( 452 table_name_or_schema.this 453 if isinstance(table_name_or_schema, exp.Schema) 454 else exp.to_table(table_name_or_schema) 455 ) 456 # Spark doesn't support creating a wap table DDL. Therefore we check if this is a wap table and if it is, 457 # this is not a replace, and the table already exists then we can safely just return. Otherwise we let it error. 458 if not expression and isinstance(table_name.this, exp.Dot): 459 wap_id = table_name.this.parts[-1].name 460 if wap_id.startswith(f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"): 461 table_name.set("this", table_name.this.this) 462 463 do_dummy_insert = False 464 if self.wap_enabled: 465 wap_supported = ( 466 kwargs.get("storage_format") or "" 467 ).lower() == "iceberg" or self.wap_supported(table_name) 468 do_dummy_insert = ( 469 False if not wap_supported or not exists else not self.table_exists(table_name) 470 ) 471 super()._create_table( 472 table_name_or_schema, 473 expression, 474 exists=exists, 475 replace=replace, 476 target_columns_to_types=target_columns_to_types, 477 table_description=table_description, 478 column_descriptions=column_descriptions, 479 track_rows_processed=track_rows_processed, 480 **kwargs, 481 ) 482 table_name = ( 483 table_name_or_schema.this 484 if isinstance(table_name_or_schema, exp.Schema) 485 else exp.to_table(table_name_or_schema) 486 ) 487 if do_dummy_insert: 488 # Performing a dummy insert to create a dummy snapshot for Iceberg tables 489 # to workaround https://github.com/apache/iceberg/issues/8849. 490 dummy_insert = exp.insert(exp.select("*").from_(table_name), table_name) 491 self.execute(dummy_insert) 492 493 def wap_supported(self, table_name: TableName) -> bool: 494 fqn = self._ensure_fqn(table_name) 495 return ( 496 self.spark.conf.get(f"spark.sql.catalog.{fqn.catalog}") 497 == "org.apache.iceberg.spark.SparkCatalog" 498 ) 499 500 def wap_table_name(self, table_name: TableName, wap_id: str) -> str: 501 branch_name = self._wap_branch_name(wap_id) 502 fqn = self._ensure_fqn(table_name) 503 return exp.Dot.build([fqn, exp.to_identifier(f"{self.BRANCH_PREFIX}{branch_name}")]).sql( 504 dialect=self.dialect 505 ) 506 507 def wap_prepare(self, table_name: TableName, wap_id: str) -> str: 508 branch_name = self._wap_branch_name(wap_id) 509 fqn = self._ensure_fqn(table_name) 510 self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} CREATE BRANCH {branch_name}") 511 return self.wap_table_name(table_name, wap_id) 512 513 def wap_publish(self, table_name: TableName, wap_id: str) -> None: 514 branch_name = self._wap_branch_name(wap_id) 515 fqn = self._ensure_fqn(table_name) 516 517 get_snapshot_id_query = ( 518 exp.select("snapshot_id") 519 .from_(exp.Dot.build([fqn, exp.to_identifier("refs")])) 520 .where(exp.column("name").eq(branch_name)) 521 ) 522 iceberg_snapshot_ids = self.fetchall(get_snapshot_id_query) 523 if not iceberg_snapshot_ids: 524 raise SQLMeshError(f"Could not find Iceberg branch '{branch_name}'.") 525 iceberg_snapshot_id = iceberg_snapshot_ids[0][0] 526 527 logger.info( 528 "Cherry-picking Iceberg snapshot %s into table '%s'...", iceberg_snapshot_id, fqn 529 ) 530 531 self.execute( 532 f"CALL {fqn.catalog}.system.cherrypick_snapshot('{fqn.db}.{fqn.name}', {iceberg_snapshot_id})" 533 ) 534 self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} DROP BRANCH {branch_name}") 535 536 def _ensure_fqn(self, table_name: TableName) -> exp.Table: 537 if isinstance(table_name, exp.Table): 538 table_name = table_name.copy() 539 table = exp.to_table(table_name, dialect=self.dialect) 540 if not table.catalog: 541 table.set("catalog", self.get_current_catalog()) 542 if not table.db: 543 table.set("db", self._get_current_schema()) 544 return table 545 546 def _build_create_comment_column_exp( 547 self, table: exp.Table, column_name: str, column_comment: str, table_kind: str = "TABLE" 548 ) -> exp.Comment | str: 549 table_sql = table.sql(dialect=self.dialect, identify=True) 550 column_sql = exp.column(column_name).sql(dialect=self.dialect, identify=True) 551 552 truncated_comment = self._truncate_column_comment(column_comment) 553 comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect) 554 555 return f"ALTER TABLE {table_sql} ALTER COLUMN {column_sql} COMMENT {comment_sql}" 556 557 @classmethod 558 def _wap_branch_name(cls, wap_id: str) -> str: 559 return f"{cls.WAP_PREFIX}{wap_id}"
Base class wrapping a Database API compliant connection.
The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.
Arguments:
- connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
- dialect: The dialect with which this adapter is associated.
- multithreaded: Indicates whether this adapter will be used by more than one thread.
catalog_support: sqlmesh.core.engine_adapter.shared.CatalogSupport
@classmethod
def
spark_to_sqlglot_types( cls, input: <MagicMock name='mock.StructType' id='132726884405776'>) -> Dict[str, sqlglot.expressions.datatypes.DataType]:
134 @classmethod 135 def spark_to_sqlglot_types(cls, input: spark_types.StructType) -> t.Dict[str, exp.DataType]: 136 from pyspark.sql import types as spark_types 137 138 def spark_complex_to_sqlglot_complex( 139 complex_type: t.Union[ 140 spark_types.StructType, spark_types.ArrayType, spark_types.MapType 141 ], 142 ) -> exp.DataType: 143 def get_fields( 144 complex_type: t.Union[ 145 spark_types.StructType, spark_types.ArrayType, spark_types.MapType 146 ], 147 ) -> t.Sequence[spark_types.DataType]: 148 if isinstance(complex_type, spark_types.StructType): 149 return complex_type.fields 150 if isinstance(complex_type, spark_types.ArrayType): 151 return [complex_type.elementType] 152 if isinstance(complex_type, spark_types.MapType): 153 return [complex_type.keyType, complex_type.valueType] 154 raise SQLMeshError(f"Unsupported complex type: {complex_type}") 155 156 expressions: t.List[t.Union[exp.ColumnDef, exp.DataType]] = [] 157 fields = get_fields(complex_type) 158 for field in fields: 159 if isinstance(field, (spark_types.StructType, spark_types.MapType)): 160 expressions.append(spark_complex_to_sqlglot_complex(field)) 161 elif isinstance(field, spark_types.StructField): 162 sqlglot_data_type = cls._spark_to_sqlglot_primitive_mapping.get( 163 type(field.dataType) 164 ) or spark_complex_to_sqlglot_complex( 165 field.dataType # type: ignore 166 ) 167 kind = ( 168 sqlglot_data_type 169 if isinstance(sqlglot_data_type, exp.DataType) 170 else exp.DataType(this=sqlglot_data_type) 171 ) 172 expressions.append(exp.ColumnDef(this=exp.to_identifier(field.name), kind=kind)) 173 else: 174 kind = exp.DataType(this=cls._spark_to_sqlglot_primitive_mapping[type(field)]) 175 expressions.append(kind) 176 dtype = cls._spark_to_sqlglot_complex_mapping[type(complex_type)] 177 return exp.DataType( 178 this=dtype, 179 expressions=expressions, 180 nested=True, 181 ) 182 183 resp = spark_complex_to_sqlglot_complex(input) 184 return {column_def.this.name: column_def.args["kind"] for column_def in resp.expressions}
@classmethod
def
sqlglot_to_spark_types( cls, input: Dict[str, sqlglot.expressions.datatypes.DataType]) -> <MagicMock name='mock.StructType' id='132726884405776'>:
186 @classmethod 187 def sqlglot_to_spark_types(cls, input: t.Dict[str, exp.DataType]) -> spark_types.StructType: 188 from pyspark.sql import types as spark_types 189 190 def sqlglot_complex_to_spark_complex(complex_type: exp.DataType) -> spark_types.DataType: 191 is_struct = complex_type.is_type(exp.DataType.Type.STRUCT) 192 expressions = [] 193 for column_def in complex_type.expressions: 194 col_name = column_def.this.name if is_struct else None 195 data_type = column_def.args["kind"] if is_struct else column_def 196 primitive_func = cls._sqlglot_to_spark_primitive_mapping.get(data_type.this) 197 type_func = ( 198 primitive_func 199 if primitive_func 200 else partial(sqlglot_complex_to_spark_complex, data_type) 201 ) 202 if is_struct: 203 expressions.append(spark_types.StructField(col_name, type_func())) # type: ignore 204 else: 205 expressions.append(type_func()) # type: ignore 206 klass = cls._sqlglot_to_spark_complex_mapping[complex_type.this] 207 if is_struct: 208 return klass(expressions) 209 return klass(*expressions) 210 211 return t.cast( 212 spark_types.StructType, 213 sqlglot_complex_to_spark_complex( 214 exp.DataType( 215 this=exp.DataType.Type.STRUCT, 216 expressions=[ 217 exp.ColumnDef(this=exp.to_identifier(column), kind=data_type) 218 for column, data_type in input.items() 219 ], 220 ) 221 ), 222 )
def
fetchdf( self, query: Union[sqlglot.expressions.core.Expr, str], quote_identifiers: bool = False) -> pandas.core.frame.DataFrame:
343 def fetchdf( 344 self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False 345 ) -> pd.DataFrame: 346 return self.fetch_pyspark_df(query, quote_identifiers=quote_identifiers).toPandas()
Fetches a Pandas DataFrame from the cursor
def
fetch_pyspark_df( self, query: Union[sqlglot.expressions.core.Expr, str], quote_identifiers: bool = False) -> <MagicMock id='132726884245632'>:
348 def fetch_pyspark_df( 349 self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False 350 ) -> PySparkDataFrame: 351 return self._ensure_pyspark_df( 352 self._fetch_native_df(query, quote_identifiers=quote_identifiers) 353 )
Fetches a PySpark DataFrame from the cursor
def
get_current_catalog(self) -> Optional[str]:
393 def get_current_catalog(self) -> t.Optional[str]: 394 if self._use_spark_session: 395 return self.connection.get_current_catalog() 396 return super().get_current_catalog()
Returns the catalog name of the current connection.
def
set_current_catalog(self, catalog_name: str) -> None:
398 def set_current_catalog(self, catalog_name: str) -> None: 399 self.connection.set_current_catalog(catalog_name)
Sets the catalog name of the current connection.
def
get_data_object( self, target_name: <MagicMock id='132726884230256'>, safe_to_cache: bool = False) -> Optional[sqlmesh.core.engine_adapter.shared.DataObject]:
406 def get_data_object( 407 self, target_name: TableName, safe_to_cache: bool = False 408 ) -> t.Optional[DataObject]: 409 target_table = exp.to_table(target_name) 410 if isinstance(target_table.this, exp.Dot) and target_table.this.expression.name.startswith( 411 f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}" 412 ): 413 # Exclude the branch name 414 target_table.set("this", target_table.this.this) 415 return super().get_data_object(target_table, safe_to_cache=safe_to_cache)
def
create_state_table( self, table_name: str, target_columns_to_types: Dict[str, sqlglot.expressions.datatypes.DataType], primary_key: Optional[Tuple[str, ...]] = None) -> None:
417 def create_state_table( 418 self, 419 table_name: str, 420 target_columns_to_types: t.Dict[str, exp.DataType], 421 primary_key: t.Optional[t.Tuple[str, ...]] = None, 422 ) -> None: 423 self.create_table( 424 table_name, 425 target_columns_to_types, 426 partitioned_by=[exp.column(x) for x in primary_key] if primary_key else None, 427 )
Create a table to store SQLMesh internal state.
Arguments:
- table_name: The name of the table to create. Can be fully qualified or just table name.
- target_columns_to_types: A mapping between the column name and its data type.
- primary_key: Determines the table primary key.
def
wap_supported(self, table_name: <MagicMock id='132726884230256'>) -> bool:
493 def wap_supported(self, table_name: TableName) -> bool: 494 fqn = self._ensure_fqn(table_name) 495 return ( 496 self.spark.conf.get(f"spark.sql.catalog.{fqn.catalog}") 497 == "org.apache.iceberg.spark.SparkCatalog" 498 )
Returns whether WAP for the target table is supported.
def
wap_table_name(self, table_name: <MagicMock id='132726884230256'>, wap_id: str) -> str:
500 def wap_table_name(self, table_name: TableName, wap_id: str) -> str: 501 branch_name = self._wap_branch_name(wap_id) 502 fqn = self._ensure_fqn(table_name) 503 return exp.Dot.build([fqn, exp.to_identifier(f"{self.BRANCH_PREFIX}{branch_name}")]).sql( 504 dialect=self.dialect 505 )
Returns the updated table name for the given WAP ID.
Arguments:
- table_name: The name of the target table.
- wap_id: The WAP ID to prepare.
Returns:
The updated table name that should be used for writing.
def
wap_prepare(self, table_name: <MagicMock id='132726884230256'>, wap_id: str) -> str:
507 def wap_prepare(self, table_name: TableName, wap_id: str) -> str: 508 branch_name = self._wap_branch_name(wap_id) 509 fqn = self._ensure_fqn(table_name) 510 self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} CREATE BRANCH {branch_name}") 511 return self.wap_table_name(table_name, wap_id)
Prepares the target table for WAP and returns the updated table name.
Arguments:
- table_name: The name of the target table.
- wap_id: The WAP ID to prepare.
Returns:
The updated table name that should be used for writing.
def
wap_publish(self, table_name: <MagicMock id='132726884230256'>, wap_id: str) -> None:
513 def wap_publish(self, table_name: TableName, wap_id: str) -> None: 514 branch_name = self._wap_branch_name(wap_id) 515 fqn = self._ensure_fqn(table_name) 516 517 get_snapshot_id_query = ( 518 exp.select("snapshot_id") 519 .from_(exp.Dot.build([fqn, exp.to_identifier("refs")])) 520 .where(exp.column("name").eq(branch_name)) 521 ) 522 iceberg_snapshot_ids = self.fetchall(get_snapshot_id_query) 523 if not iceberg_snapshot_ids: 524 raise SQLMeshError(f"Could not find Iceberg branch '{branch_name}'.") 525 iceberg_snapshot_id = iceberg_snapshot_ids[0][0] 526 527 logger.info( 528 "Cherry-picking Iceberg snapshot %s into table '%s'...", iceberg_snapshot_id, fqn 529 ) 530 531 self.execute( 532 f"CALL {fqn.catalog}.system.cherrypick_snapshot('{fqn.db}.{fqn.name}', {iceberg_snapshot_id})" 533 ) 534 self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} DROP BRANCH {branch_name}")
Publishes changes with the given WAP ID to the target table.
Arguments:
- table_name: The name of the target table.
- wap_id: The WAP ID to publish.
Inherited Members
- sqlmesh.core.engine_adapter.base.EngineAdapter
- EngineAdapter
- DEFAULT_BATCH_SIZE
- DATA_OBJECT_FILTER_BATCH_SIZE
- SUPPORTS_INDEXES
- SUPPORTS_MATERIALIZED_VIEWS
- SUPPORTS_MATERIALIZED_VIEW_SCHEMA
- SUPPORTS_VIEW_SCHEMA
- SUPPORTS_CLONING
- SUPPORTS_MANAGED_MODELS
- SUPPORTS_CREATE_DROP_CATALOG
- SUPPORTS_TUPLE_IN
- HAS_VIEW_BINDING
- SUPPORTS_GRANTS
- DEFAULT_CATALOG_TYPE
- MAX_IDENTIFIER_LENGTH
- ATTACH_CORRELATION_ID
- SUPPORTS_QUERY_EXECUTION_TRACKING
- SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
- dialect
- correlation_id
- with_settings
- cursor
- snowpark
- bigframe
- comments_enabled
- schema_differ
- default_catalog
- engine_run_mode
- recycle
- close
- get_catalog_type
- get_catalog_type_from_table
- current_catalog_type
- replace_query
- create_index
- create_table
- create_managed_table
- ctas
- create_table_like
- clone_table
- drop_data_object
- drop_table
- drop_managed_table
- get_alter_operations
- alter_table
- create_view
- create_schema
- drop_schema
- drop_view
- create_catalog
- drop_catalog
- columns
- table_exists
- delete_from
- insert_append
- insert_overwrite_by_partition
- insert_overwrite_by_time_partition
- update_table
- scd_type_2_by_time
- scd_type_2_by_column
- merge
- rename_table
- get_data_objects
- fetchone
- fetchall
- wap_enabled
- sync_grants_config
- transaction
- session
- execute
- temp_table
- drop_data_object_on_type_mismatch
- ensure_nulls_for_unmatched_after_join
- use_server_nulls_for_unmatched_after_join
- ping
- get_table_last_modified_ts