Edit on GitHub

sqlmesh.engines.spark.db_api.spark_session

  1from __future__ import annotations
  2
  3import logging
  4import typing as t
  5from threading import get_ident
  6
  7from sqlmesh.engines.spark.db_api.errors import NotSupportedError, ProgrammingError
  8
  9if t.TYPE_CHECKING:
 10    from pyspark.sql import DataFrame, SparkSession
 11    from pyspark.sql.types import Row
 12
 13logger = logging.getLogger(__name__)
 14
 15
 16class SparkSessionCursor:
 17    def __init__(self, spark: SparkSession):
 18        self._spark = spark
 19        self._last_df: t.Optional[DataFrame] = None
 20        self._last_output: t.Optional[t.List[t.Tuple]] = None
 21        self._last_output_cursor: int = 0
 22
 23    def execute(self, query: str, parameters: t.Optional[t.Any] = None) -> None:
 24        if parameters:
 25            raise NotSupportedError("Parameterized queries are not supported")
 26
 27        self._last_df = self._spark.sql(query)
 28        self._last_output = None
 29        self._last_output_cursor = 0
 30
 31    def fetchone(self) -> t.Optional[t.Tuple]:
 32        result = self._fetch(size=1)
 33        return result[0] if result else None
 34
 35    def fetchmany(self, size: int = 1) -> t.List[t.Tuple]:
 36        return self._fetch(size=size)
 37
 38    def fetchall(self) -> t.List[t.Tuple]:
 39        return self._fetch()
 40
 41    def close(self) -> None:
 42        pass
 43
 44    def fetchdf(self) -> t.Optional[DataFrame]:
 45        return self._last_df
 46
 47    def _fetch(self, size: t.Optional[int] = None) -> t.List[t.Tuple]:
 48        if size and size < 0:
 49            raise ProgrammingError("The size argument can't be negative")
 50
 51        if self._last_df is None:
 52            raise ProgrammingError("No call to .execute() has been issued")
 53
 54        if self._last_output is None:
 55            self._last_output = _normalize_rows(self._last_df.collect())
 56
 57        if self._last_output_cursor >= len(self._last_output):
 58            return []
 59
 60        if size is None:
 61            size = len(self._last_output) - self._last_output_cursor
 62
 63        output = self._last_output[self._last_output_cursor : self._last_output_cursor + size]
 64        self._last_output_cursor += size
 65
 66        return output
 67
 68
 69class SparkSessionConnection:
 70    def __init__(self, spark: SparkSession, catalog: t.Optional[str] = None):
 71        self.spark = spark
 72        self.catalog = catalog
 73
 74    @property
 75    def _spark_major_minor(self) -> t.Tuple[int, int]:
 76        return tuple(int(x) for x in self.spark.version.split(".")[:2])  # type: ignore
 77
 78    def get_current_catalog(self) -> t.Optional[str]:
 79        if self._spark_major_minor >= (3, 4):
 80            return self.spark.catalog.currentCatalog()
 81        return self.catalog or "spark_catalog"
 82
 83    def set_current_catalog(self, catalog_name: str) -> None:
 84        if self._spark_major_minor >= (3, 4):
 85            return self.spark.catalog.setCurrentCatalog(catalog_name)
 86        current_catalog = self.get_current_catalog()
 87        if current_catalog != catalog_name:
 88            logger.warning(
 89                "Spark <3.4 does not support certain cross catalog queries since the default catalog cannot be set <3.4"
 90            )
 91
 92    def cursor(self) -> SparkSessionCursor:
 93        from pyspark.errors import PySparkAttributeError
 94
 95        try:
 96            self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}")
 97            self.spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
 98            self.spark.conf.set("hive.exec.dynamic.partition", "true")
 99            self.spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict")
100        except (NotImplementedError, PySparkAttributeError):
101            # Databricks Connect does not support accessing the SparkContext nor does it support
102            # setting dynamic partition overwrite since it uses replace where
103            # Also Serverless jobs don't support access to spark context so we pass for that too
104            pass
105        if self.catalog:
106            from py4j.protocol import Py4JError
107
108            try:
109                self.set_current_catalog(self.catalog)
110            # Databricks does not support `setCurrentCatalog` with Unity catalog
111            # and shared clusters so we use the Databricks Unity only SQL command instead
112            except Py4JError:
113                self.spark.sql(f"USE CATALOG {self.catalog}")
114        return SparkSessionCursor(self.spark)
115
116    def commit(self) -> None:
117        pass
118
119    def rollback(self) -> None:
120        pass
121
122    def close(self) -> None:
123        pass
124
125
126def connection(spark: SparkSession, catalog: t.Optional[str] = None) -> SparkSessionConnection:
127    return SparkSessionConnection(spark, catalog)
128
129
130def _normalize_rows(rows: t.Sequence[Row]) -> t.List[t.Tuple]:
131    return [tuple(r) for r in rows]
logger = <Logger sqlmesh.engines.spark.db_api.spark_session (WARNING)>
class SparkSessionCursor:
17class SparkSessionCursor:
18    def __init__(self, spark: SparkSession):
19        self._spark = spark
20        self._last_df: t.Optional[DataFrame] = None
21        self._last_output: t.Optional[t.List[t.Tuple]] = None
22        self._last_output_cursor: int = 0
23
24    def execute(self, query: str, parameters: t.Optional[t.Any] = None) -> None:
25        if parameters:
26            raise NotSupportedError("Parameterized queries are not supported")
27
28        self._last_df = self._spark.sql(query)
29        self._last_output = None
30        self._last_output_cursor = 0
31
32    def fetchone(self) -> t.Optional[t.Tuple]:
33        result = self._fetch(size=1)
34        return result[0] if result else None
35
36    def fetchmany(self, size: int = 1) -> t.List[t.Tuple]:
37        return self._fetch(size=size)
38
39    def fetchall(self) -> t.List[t.Tuple]:
40        return self._fetch()
41
42    def close(self) -> None:
43        pass
44
45    def fetchdf(self) -> t.Optional[DataFrame]:
46        return self._last_df
47
48    def _fetch(self, size: t.Optional[int] = None) -> t.List[t.Tuple]:
49        if size and size < 0:
50            raise ProgrammingError("The size argument can't be negative")
51
52        if self._last_df is None:
53            raise ProgrammingError("No call to .execute() has been issued")
54
55        if self._last_output is None:
56            self._last_output = _normalize_rows(self._last_df.collect())
57
58        if self._last_output_cursor >= len(self._last_output):
59            return []
60
61        if size is None:
62            size = len(self._last_output) - self._last_output_cursor
63
64        output = self._last_output[self._last_output_cursor : self._last_output_cursor + size]
65        self._last_output_cursor += size
66
67        return output
SparkSessionCursor(spark: <MagicMock id='126494217767184'>)
18    def __init__(self, spark: SparkSession):
19        self._spark = spark
20        self._last_df: t.Optional[DataFrame] = None
21        self._last_output: t.Optional[t.List[t.Tuple]] = None
22        self._last_output_cursor: int = 0
def execute(self, query: str, parameters: Optional[Any] = None) -> None:
24    def execute(self, query: str, parameters: t.Optional[t.Any] = None) -> None:
25        if parameters:
26            raise NotSupportedError("Parameterized queries are not supported")
27
28        self._last_df = self._spark.sql(query)
29        self._last_output = None
30        self._last_output_cursor = 0
def fetchone(self) -> Optional[Tuple]:
32    def fetchone(self) -> t.Optional[t.Tuple]:
33        result = self._fetch(size=1)
34        return result[0] if result else None
def fetchmany(self, size: int = 1) -> List[Tuple]:
36    def fetchmany(self, size: int = 1) -> t.List[t.Tuple]:
37        return self._fetch(size=size)
def fetchall(self) -> List[Tuple]:
39    def fetchall(self) -> t.List[t.Tuple]:
40        return self._fetch()
def close(self) -> None:
42    def close(self) -> None:
43        pass
def fetchdf(self) -> Optional[<MagicMock id='126494217943984'>]:
45    def fetchdf(self) -> t.Optional[DataFrame]:
46        return self._last_df
class SparkSessionConnection:
 70class SparkSessionConnection:
 71    def __init__(self, spark: SparkSession, catalog: t.Optional[str] = None):
 72        self.spark = spark
 73        self.catalog = catalog
 74
 75    @property
 76    def _spark_major_minor(self) -> t.Tuple[int, int]:
 77        return tuple(int(x) for x in self.spark.version.split(".")[:2])  # type: ignore
 78
 79    def get_current_catalog(self) -> t.Optional[str]:
 80        if self._spark_major_minor >= (3, 4):
 81            return self.spark.catalog.currentCatalog()
 82        return self.catalog or "spark_catalog"
 83
 84    def set_current_catalog(self, catalog_name: str) -> None:
 85        if self._spark_major_minor >= (3, 4):
 86            return self.spark.catalog.setCurrentCatalog(catalog_name)
 87        current_catalog = self.get_current_catalog()
 88        if current_catalog != catalog_name:
 89            logger.warning(
 90                "Spark <3.4 does not support certain cross catalog queries since the default catalog cannot be set <3.4"
 91            )
 92
 93    def cursor(self) -> SparkSessionCursor:
 94        from pyspark.errors import PySparkAttributeError
 95
 96        try:
 97            self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}")
 98            self.spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
 99            self.spark.conf.set("hive.exec.dynamic.partition", "true")
100            self.spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict")
101        except (NotImplementedError, PySparkAttributeError):
102            # Databricks Connect does not support accessing the SparkContext nor does it support
103            # setting dynamic partition overwrite since it uses replace where
104            # Also Serverless jobs don't support access to spark context so we pass for that too
105            pass
106        if self.catalog:
107            from py4j.protocol import Py4JError
108
109            try:
110                self.set_current_catalog(self.catalog)
111            # Databricks does not support `setCurrentCatalog` with Unity catalog
112            # and shared clusters so we use the Databricks Unity only SQL command instead
113            except Py4JError:
114                self.spark.sql(f"USE CATALOG {self.catalog}")
115        return SparkSessionCursor(self.spark)
116
117    def commit(self) -> None:
118        pass
119
120    def rollback(self) -> None:
121        pass
122
123    def close(self) -> None:
124        pass
SparkSessionConnection( spark: <MagicMock id='126494217767184'>, catalog: Optional[str] = None)
71    def __init__(self, spark: SparkSession, catalog: t.Optional[str] = None):
72        self.spark = spark
73        self.catalog = catalog
spark
catalog
def get_current_catalog(self) -> Optional[str]:
79    def get_current_catalog(self) -> t.Optional[str]:
80        if self._spark_major_minor >= (3, 4):
81            return self.spark.catalog.currentCatalog()
82        return self.catalog or "spark_catalog"
def set_current_catalog(self, catalog_name: str) -> None:
84    def set_current_catalog(self, catalog_name: str) -> None:
85        if self._spark_major_minor >= (3, 4):
86            return self.spark.catalog.setCurrentCatalog(catalog_name)
87        current_catalog = self.get_current_catalog()
88        if current_catalog != catalog_name:
89            logger.warning(
90                "Spark <3.4 does not support certain cross catalog queries since the default catalog cannot be set <3.4"
91            )
def cursor(self) -> SparkSessionCursor:
 93    def cursor(self) -> SparkSessionCursor:
 94        from pyspark.errors import PySparkAttributeError
 95
 96        try:
 97            self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}")
 98            self.spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
 99            self.spark.conf.set("hive.exec.dynamic.partition", "true")
100            self.spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict")
101        except (NotImplementedError, PySparkAttributeError):
102            # Databricks Connect does not support accessing the SparkContext nor does it support
103            # setting dynamic partition overwrite since it uses replace where
104            # Also Serverless jobs don't support access to spark context so we pass for that too
105            pass
106        if self.catalog:
107            from py4j.protocol import Py4JError
108
109            try:
110                self.set_current_catalog(self.catalog)
111            # Databricks does not support `setCurrentCatalog` with Unity catalog
112            # and shared clusters so we use the Databricks Unity only SQL command instead
113            except Py4JError:
114                self.spark.sql(f"USE CATALOG {self.catalog}")
115        return SparkSessionCursor(self.spark)
def commit(self) -> None:
117    def commit(self) -> None:
118        pass
def rollback(self) -> None:
120    def rollback(self) -> None:
121        pass
def close(self) -> None:
123    def close(self) -> None:
124        pass
def connection( spark: <MagicMock id='126494217767184'>, catalog: Optional[str] = None) -> SparkSessionConnection:
127def connection(spark: SparkSession, catalog: t.Optional[str] = None) -> SparkSessionConnection:
128    return SparkSessionConnection(spark, catalog)