sqlmesh.engines.spark.db_api.spark_session
1from __future__ import annotations 2 3import logging 4import typing as t 5from threading import get_ident 6 7from sqlmesh.engines.spark.db_api.errors import NotSupportedError, ProgrammingError 8 9if t.TYPE_CHECKING: 10 from pyspark.sql import DataFrame, SparkSession 11 from pyspark.sql.types import Row 12 13logger = logging.getLogger(__name__) 14 15 16class SparkSessionCursor: 17 def __init__(self, spark: SparkSession): 18 self._spark = spark 19 self._last_df: t.Optional[DataFrame] = None 20 self._last_output: t.Optional[t.List[t.Tuple]] = None 21 self._last_output_cursor: int = 0 22 23 def execute(self, query: str, parameters: t.Optional[t.Any] = None) -> None: 24 if parameters: 25 raise NotSupportedError("Parameterized queries are not supported") 26 27 self._last_df = self._spark.sql(query) 28 self._last_output = None 29 self._last_output_cursor = 0 30 31 def fetchone(self) -> t.Optional[t.Tuple]: 32 result = self._fetch(size=1) 33 return result[0] if result else None 34 35 def fetchmany(self, size: int = 1) -> t.List[t.Tuple]: 36 return self._fetch(size=size) 37 38 def fetchall(self) -> t.List[t.Tuple]: 39 return self._fetch() 40 41 def close(self) -> None: 42 pass 43 44 def fetchdf(self) -> t.Optional[DataFrame]: 45 return self._last_df 46 47 def _fetch(self, size: t.Optional[int] = None) -> t.List[t.Tuple]: 48 if size and size < 0: 49 raise ProgrammingError("The size argument can't be negative") 50 51 if self._last_df is None: 52 raise ProgrammingError("No call to .execute() has been issued") 53 54 if self._last_output is None: 55 self._last_output = _normalize_rows(self._last_df.collect()) 56 57 if self._last_output_cursor >= len(self._last_output): 58 return [] 59 60 if size is None: 61 size = len(self._last_output) - self._last_output_cursor 62 63 output = self._last_output[self._last_output_cursor : self._last_output_cursor + size] 64 self._last_output_cursor += size 65 66 return output 67 68 69class SparkSessionConnection: 70 def __init__(self, spark: SparkSession, catalog: t.Optional[str] = None): 71 self.spark = spark 72 self.catalog = catalog 73 74 @property 75 def _spark_major_minor(self) -> t.Tuple[int, int]: 76 return tuple(int(x) for x in self.spark.version.split(".")[:2]) # type: ignore 77 78 def get_current_catalog(self) -> t.Optional[str]: 79 if self._spark_major_minor >= (3, 4): 80 return self.spark.catalog.currentCatalog() 81 return self.catalog or "spark_catalog" 82 83 def set_current_catalog(self, catalog_name: str) -> None: 84 if self._spark_major_minor >= (3, 4): 85 return self.spark.catalog.setCurrentCatalog(catalog_name) 86 current_catalog = self.get_current_catalog() 87 if current_catalog != catalog_name: 88 logger.warning( 89 "Spark <3.4 does not support certain cross catalog queries since the default catalog cannot be set <3.4" 90 ) 91 92 def cursor(self) -> SparkSessionCursor: 93 from pyspark.errors import PySparkAttributeError 94 95 try: 96 self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}") 97 self.spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") 98 self.spark.conf.set("hive.exec.dynamic.partition", "true") 99 self.spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict") 100 except (NotImplementedError, PySparkAttributeError): 101 # Databricks Connect does not support accessing the SparkContext nor does it support 102 # setting dynamic partition overwrite since it uses replace where 103 # Also Serverless jobs don't support access to spark context so we pass for that too 104 pass 105 if self.catalog: 106 from py4j.protocol import Py4JError 107 108 try: 109 self.set_current_catalog(self.catalog) 110 # Databricks does not support `setCurrentCatalog` with Unity catalog 111 # and shared clusters so we use the Databricks Unity only SQL command instead 112 except Py4JError: 113 self.spark.sql(f"USE CATALOG {self.catalog}") 114 return SparkSessionCursor(self.spark) 115 116 def commit(self) -> None: 117 pass 118 119 def rollback(self) -> None: 120 pass 121 122 def close(self) -> None: 123 pass 124 125 126def connection(spark: SparkSession, catalog: t.Optional[str] = None) -> SparkSessionConnection: 127 return SparkSessionConnection(spark, catalog) 128 129 130def _normalize_rows(rows: t.Sequence[Row]) -> t.List[t.Tuple]: 131 return [tuple(r) for r in rows]
logger =
<Logger sqlmesh.engines.spark.db_api.spark_session (WARNING)>
class
SparkSessionCursor:
17class SparkSessionCursor: 18 def __init__(self, spark: SparkSession): 19 self._spark = spark 20 self._last_df: t.Optional[DataFrame] = None 21 self._last_output: t.Optional[t.List[t.Tuple]] = None 22 self._last_output_cursor: int = 0 23 24 def execute(self, query: str, parameters: t.Optional[t.Any] = None) -> None: 25 if parameters: 26 raise NotSupportedError("Parameterized queries are not supported") 27 28 self._last_df = self._spark.sql(query) 29 self._last_output = None 30 self._last_output_cursor = 0 31 32 def fetchone(self) -> t.Optional[t.Tuple]: 33 result = self._fetch(size=1) 34 return result[0] if result else None 35 36 def fetchmany(self, size: int = 1) -> t.List[t.Tuple]: 37 return self._fetch(size=size) 38 39 def fetchall(self) -> t.List[t.Tuple]: 40 return self._fetch() 41 42 def close(self) -> None: 43 pass 44 45 def fetchdf(self) -> t.Optional[DataFrame]: 46 return self._last_df 47 48 def _fetch(self, size: t.Optional[int] = None) -> t.List[t.Tuple]: 49 if size and size < 0: 50 raise ProgrammingError("The size argument can't be negative") 51 52 if self._last_df is None: 53 raise ProgrammingError("No call to .execute() has been issued") 54 55 if self._last_output is None: 56 self._last_output = _normalize_rows(self._last_df.collect()) 57 58 if self._last_output_cursor >= len(self._last_output): 59 return [] 60 61 if size is None: 62 size = len(self._last_output) - self._last_output_cursor 63 64 output = self._last_output[self._last_output_cursor : self._last_output_cursor + size] 65 self._last_output_cursor += size 66 67 return output
class
SparkSessionConnection:
70class SparkSessionConnection: 71 def __init__(self, spark: SparkSession, catalog: t.Optional[str] = None): 72 self.spark = spark 73 self.catalog = catalog 74 75 @property 76 def _spark_major_minor(self) -> t.Tuple[int, int]: 77 return tuple(int(x) for x in self.spark.version.split(".")[:2]) # type: ignore 78 79 def get_current_catalog(self) -> t.Optional[str]: 80 if self._spark_major_minor >= (3, 4): 81 return self.spark.catalog.currentCatalog() 82 return self.catalog or "spark_catalog" 83 84 def set_current_catalog(self, catalog_name: str) -> None: 85 if self._spark_major_minor >= (3, 4): 86 return self.spark.catalog.setCurrentCatalog(catalog_name) 87 current_catalog = self.get_current_catalog() 88 if current_catalog != catalog_name: 89 logger.warning( 90 "Spark <3.4 does not support certain cross catalog queries since the default catalog cannot be set <3.4" 91 ) 92 93 def cursor(self) -> SparkSessionCursor: 94 from pyspark.errors import PySparkAttributeError 95 96 try: 97 self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}") 98 self.spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") 99 self.spark.conf.set("hive.exec.dynamic.partition", "true") 100 self.spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict") 101 except (NotImplementedError, PySparkAttributeError): 102 # Databricks Connect does not support accessing the SparkContext nor does it support 103 # setting dynamic partition overwrite since it uses replace where 104 # Also Serverless jobs don't support access to spark context so we pass for that too 105 pass 106 if self.catalog: 107 from py4j.protocol import Py4JError 108 109 try: 110 self.set_current_catalog(self.catalog) 111 # Databricks does not support `setCurrentCatalog` with Unity catalog 112 # and shared clusters so we use the Databricks Unity only SQL command instead 113 except Py4JError: 114 self.spark.sql(f"USE CATALOG {self.catalog}") 115 return SparkSessionCursor(self.spark) 116 117 def commit(self) -> None: 118 pass 119 120 def rollback(self) -> None: 121 pass 122 123 def close(self) -> None: 124 pass
def
set_current_catalog(self, catalog_name: str) -> None:
84 def set_current_catalog(self, catalog_name: str) -> None: 85 if self._spark_major_minor >= (3, 4): 86 return self.spark.catalog.setCurrentCatalog(catalog_name) 87 current_catalog = self.get_current_catalog() 88 if current_catalog != catalog_name: 89 logger.warning( 90 "Spark <3.4 does not support certain cross catalog queries since the default catalog cannot be set <3.4" 91 )
93 def cursor(self) -> SparkSessionCursor: 94 from pyspark.errors import PySparkAttributeError 95 96 try: 97 self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}") 98 self.spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") 99 self.spark.conf.set("hive.exec.dynamic.partition", "true") 100 self.spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict") 101 except (NotImplementedError, PySparkAttributeError): 102 # Databricks Connect does not support accessing the SparkContext nor does it support 103 # setting dynamic partition overwrite since it uses replace where 104 # Also Serverless jobs don't support access to spark context so we pass for that too 105 pass 106 if self.catalog: 107 from py4j.protocol import Py4JError 108 109 try: 110 self.set_current_catalog(self.catalog) 111 # Databricks does not support `setCurrentCatalog` with Unity catalog 112 # and shared clusters so we use the Databricks Unity only SQL command instead 113 except Py4JError: 114 self.spark.sql(f"USE CATALOG {self.catalog}") 115 return SparkSessionCursor(self.spark)
def
connection( spark: <MagicMock id='126494217767184'>, catalog: Optional[str] = None) -> SparkSessionConnection: