Edit on GitHub

sqlmesh.core.engine_adapter.duckdb

  1from __future__ import annotations
  2
  3import typing as t
  4from sqlglot import exp
  5from pathlib import Path
  6
  7from sqlmesh.core.engine_adapter.mixins import (
  8    GetCurrentCatalogFromFunctionMixin,
  9    LogicalMergeMixin,
 10    RowDiffMixin,
 11)
 12from sqlmesh.core.engine_adapter.shared import (
 13    CatalogSupport,
 14    CommentCreationTable,
 15    CommentCreationView,
 16    DataObject,
 17    DataObjectType,
 18    SourceQuery,
 19    set_catalog,
 20)
 21
 22if t.TYPE_CHECKING:
 23    from sqlmesh.core._typing import SchemaName, TableName
 24    from sqlmesh.core.engine_adapter._typing import DF
 25
 26
 27@set_catalog(override_mapping={"_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG})
 28class DuckDBEngineAdapter(LogicalMergeMixin, GetCurrentCatalogFromFunctionMixin, RowDiffMixin):
 29    DIALECT = "duckdb"
 30    SUPPORTS_TRANSACTIONS = False
 31    SCHEMA_DIFFER_KWARGS = {
 32        "parameterized_type_defaults": {
 33            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 3), (0,)],
 34        },
 35    }
 36    COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY
 37    COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY
 38    SUPPORTS_CREATE_DROP_CATALOG = True
 39    SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA", "TABLE", "VIEW"]
 40
 41    @property
 42    def catalog_support(self) -> CatalogSupport:
 43        return CatalogSupport.FULL_SUPPORT
 44
 45    def set_current_catalog(self, catalog: str) -> None:
 46        """Sets the catalog name of the current connection."""
 47        self.execute(exp.Use(this=exp.to_identifier(catalog)))
 48
 49    def _create_catalog(self, catalog_name: exp.Identifier) -> None:
 50        if not self._is_motherduck:
 51            db_filename = f"{catalog_name.output_name}.db"
 52            self.execute(
 53                exp.Attach(
 54                    this=exp.alias_(exp.Literal.string(db_filename), catalog_name), exists=True
 55                )
 56            )
 57        else:
 58            self.execute(
 59                exp.Create(this=exp.Table(this=catalog_name), kind="DATABASE", exists=True)
 60            )
 61
 62    def _drop_catalog(self, catalog_name: exp.Identifier) -> None:
 63        if not self._is_motherduck:
 64            db_file_path = Path(f"{catalog_name.output_name}.db")
 65            self.execute(exp.Detach(this=catalog_name, exists=True))
 66            if db_file_path.exists():
 67                db_file_path.unlink()
 68        else:
 69            self.execute(
 70                exp.Drop(
 71                    this=exp.Table(this=catalog_name), kind="DATABASE", cascade=True, exists=True
 72                )
 73            )
 74
 75    def _df_to_source_queries(
 76        self,
 77        df: DF,
 78        target_columns_to_types: t.Dict[str, exp.DataType],
 79        batch_size: int,
 80        target_table: TableName,
 81        source_columns: t.Optional[t.List[str]] = None,
 82    ) -> t.List[SourceQuery]:
 83        temp_table = self._get_temp_table(target_table)
 84        temp_table_sql = (
 85            exp.select(*self._casted_columns(target_columns_to_types, source_columns))
 86            .from_("df")
 87            .sql(dialect=self.dialect)
 88        )
 89        self.cursor.sql(f"CREATE TABLE {temp_table} AS {temp_table_sql}")
 90        return [
 91            SourceQuery(
 92                query_factory=lambda: self._select_columns(target_columns_to_types).from_(
 93                    temp_table
 94                ),  # type: ignore
 95                cleanup_func=lambda: self.drop_table(temp_table),
 96            )
 97        ]
 98
 99    def _get_data_objects(
100        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
101    ) -> t.List[DataObject]:
102        """
103        Returns all the data objects that exist in the given schema and optionally catalog.
104        """
105        catalog = self.get_current_catalog()
106
107        if isinstance(schema_name, exp.Table):
108            # Ensures we don't generate identifier quotes
109            schema_name = ".".join(part.name for part in schema_name.parts)
110
111        query = (
112            exp.select(
113                exp.column("table_name").as_("name"),
114                exp.column("table_schema").as_("schema"),
115                exp.case(exp.column("table_type"))
116                .when(
117                    exp.Literal.string("BASE TABLE"),
118                    exp.Literal.string("table"),
119                )
120                .when(
121                    exp.Literal.string("VIEW"),
122                    exp.Literal.string("view"),
123                )
124                .when(
125                    exp.Literal.string("LOCAL TEMPORARY"),
126                    exp.Literal.string("table"),
127                )
128                .as_("type"),
129            )
130            .from_(exp.to_table("system.information_schema.tables"))
131            .where(
132                exp.column("table_catalog").eq(catalog), exp.column("table_schema").eq(schema_name)
133            )
134        )
135        if object_names:
136            query = query.where(exp.column("table_name").isin(*object_names))
137        df = self.fetchdf(query)
138        return [
139            DataObject(
140                catalog=catalog,  # type: ignore
141                schema=row.schema,  # type: ignore
142                name=row.name,  # type: ignore
143                type=DataObjectType.from_str(row.type),  # type: ignore
144            )
145            for row in df.itertuples()
146        ]
147
148    def _normalize_decimal_value(self, col: exp.Expr, precision: int) -> exp.Expr:
149        """
150        duckdb truncates instead of rounding when casting to decimal.
151
152        other databases: select cast(3.14159 as decimal(38,3)) -> 3.142
153        duckdb: select cast(3.14159 as decimal(38,3)) -> 3.141
154
155        however, we can get the behaviour of other databases by casting to double first:
156        select cast(cast(3.14159 as double) as decimal(38, 3)) -> 3.142
157        """
158        return exp.cast(
159            exp.cast(col, "DOUBLE"),
160            f"DECIMAL(38, {precision})",
161        )
162
163    def _create_table(
164        self,
165        table_name_or_schema: t.Union[exp.Schema, TableName],
166        expression: t.Optional[exp.Expr],
167        exists: bool = True,
168        replace: bool = False,
169        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
170        table_description: t.Optional[str] = None,
171        column_descriptions: t.Optional[t.Dict[str, str]] = None,
172        table_kind: t.Optional[str] = None,
173        track_rows_processed: bool = True,
174        **kwargs: t.Any,
175    ) -> None:
176        catalog = self.get_current_catalog()
177        catalog_type_tuple = self.fetchone(
178            exp.select("type")
179            .from_("duckdb_databases()")
180            .where(exp.column("database_name").eq(catalog))
181        )
182        catalog_type = catalog_type_tuple[0] if catalog_type_tuple else None
183
184        partitioned_by_exps = None
185        if catalog_type == "ducklake":
186            partitioned_by_exps = kwargs.pop("partitioned_by", None)
187
188        super()._create_table(
189            table_name_or_schema,
190            expression,
191            exists,
192            replace,
193            target_columns_to_types,
194            table_description,
195            column_descriptions,
196            table_kind,
197            track_rows_processed=track_rows_processed,
198            **kwargs,
199        )
200
201        if partitioned_by_exps:
202            # Schema object contains column definitions, so we extract Table
203            table_name = (
204                table_name_or_schema.this
205                if isinstance(table_name_or_schema, exp.Schema)
206                else table_name_or_schema
207            )
208            table_name_str = (
209                table_name.sql(dialect=self.dialect)
210                if isinstance(table_name, exp.Table)
211                else table_name
212            )
213            partitioned_by_str = ", ".join(
214                expr.sql(dialect=self.dialect) for expr in partitioned_by_exps
215            )
216            self.execute(f"ALTER TABLE {table_name_str} SET PARTITIONED BY ({partitioned_by_str});")
217
218    @property
219    def _is_motherduck(self) -> bool:
220        return self._extra_config.get("is_motherduck", False)
@set_catalog(override_mapping={'_get_data_objects': CatalogSupport.REQUIRES_SET_CATALOG})
class DuckDBEngineAdapter(sqlmesh.core.engine_adapter.mixins.LogicalMergeMixin, sqlmesh.core.engine_adapter.mixins.GetCurrentCatalogFromFunctionMixin, sqlmesh.core.engine_adapter.mixins.RowDiffMixin):
 28@set_catalog(override_mapping={"_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG})
 29class DuckDBEngineAdapter(LogicalMergeMixin, GetCurrentCatalogFromFunctionMixin, RowDiffMixin):
 30    DIALECT = "duckdb"
 31    SUPPORTS_TRANSACTIONS = False
 32    SCHEMA_DIFFER_KWARGS = {
 33        "parameterized_type_defaults": {
 34            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 3), (0,)],
 35        },
 36    }
 37    COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY
 38    COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY
 39    SUPPORTS_CREATE_DROP_CATALOG = True
 40    SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA", "TABLE", "VIEW"]
 41
 42    @property
 43    def catalog_support(self) -> CatalogSupport:
 44        return CatalogSupport.FULL_SUPPORT
 45
 46    def set_current_catalog(self, catalog: str) -> None:
 47        """Sets the catalog name of the current connection."""
 48        self.execute(exp.Use(this=exp.to_identifier(catalog)))
 49
 50    def _create_catalog(self, catalog_name: exp.Identifier) -> None:
 51        if not self._is_motherduck:
 52            db_filename = f"{catalog_name.output_name}.db"
 53            self.execute(
 54                exp.Attach(
 55                    this=exp.alias_(exp.Literal.string(db_filename), catalog_name), exists=True
 56                )
 57            )
 58        else:
 59            self.execute(
 60                exp.Create(this=exp.Table(this=catalog_name), kind="DATABASE", exists=True)
 61            )
 62
 63    def _drop_catalog(self, catalog_name: exp.Identifier) -> None:
 64        if not self._is_motherduck:
 65            db_file_path = Path(f"{catalog_name.output_name}.db")
 66            self.execute(exp.Detach(this=catalog_name, exists=True))
 67            if db_file_path.exists():
 68                db_file_path.unlink()
 69        else:
 70            self.execute(
 71                exp.Drop(
 72                    this=exp.Table(this=catalog_name), kind="DATABASE", cascade=True, exists=True
 73                )
 74            )
 75
 76    def _df_to_source_queries(
 77        self,
 78        df: DF,
 79        target_columns_to_types: t.Dict[str, exp.DataType],
 80        batch_size: int,
 81        target_table: TableName,
 82        source_columns: t.Optional[t.List[str]] = None,
 83    ) -> t.List[SourceQuery]:
 84        temp_table = self._get_temp_table(target_table)
 85        temp_table_sql = (
 86            exp.select(*self._casted_columns(target_columns_to_types, source_columns))
 87            .from_("df")
 88            .sql(dialect=self.dialect)
 89        )
 90        self.cursor.sql(f"CREATE TABLE {temp_table} AS {temp_table_sql}")
 91        return [
 92            SourceQuery(
 93                query_factory=lambda: self._select_columns(target_columns_to_types).from_(
 94                    temp_table
 95                ),  # type: ignore
 96                cleanup_func=lambda: self.drop_table(temp_table),
 97            )
 98        ]
 99
100    def _get_data_objects(
101        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
102    ) -> t.List[DataObject]:
103        """
104        Returns all the data objects that exist in the given schema and optionally catalog.
105        """
106        catalog = self.get_current_catalog()
107
108        if isinstance(schema_name, exp.Table):
109            # Ensures we don't generate identifier quotes
110            schema_name = ".".join(part.name for part in schema_name.parts)
111
112        query = (
113            exp.select(
114                exp.column("table_name").as_("name"),
115                exp.column("table_schema").as_("schema"),
116                exp.case(exp.column("table_type"))
117                .when(
118                    exp.Literal.string("BASE TABLE"),
119                    exp.Literal.string("table"),
120                )
121                .when(
122                    exp.Literal.string("VIEW"),
123                    exp.Literal.string("view"),
124                )
125                .when(
126                    exp.Literal.string("LOCAL TEMPORARY"),
127                    exp.Literal.string("table"),
128                )
129                .as_("type"),
130            )
131            .from_(exp.to_table("system.information_schema.tables"))
132            .where(
133                exp.column("table_catalog").eq(catalog), exp.column("table_schema").eq(schema_name)
134            )
135        )
136        if object_names:
137            query = query.where(exp.column("table_name").isin(*object_names))
138        df = self.fetchdf(query)
139        return [
140            DataObject(
141                catalog=catalog,  # type: ignore
142                schema=row.schema,  # type: ignore
143                name=row.name,  # type: ignore
144                type=DataObjectType.from_str(row.type),  # type: ignore
145            )
146            for row in df.itertuples()
147        ]
148
149    def _normalize_decimal_value(self, col: exp.Expr, precision: int) -> exp.Expr:
150        """
151        duckdb truncates instead of rounding when casting to decimal.
152
153        other databases: select cast(3.14159 as decimal(38,3)) -> 3.142
154        duckdb: select cast(3.14159 as decimal(38,3)) -> 3.141
155
156        however, we can get the behaviour of other databases by casting to double first:
157        select cast(cast(3.14159 as double) as decimal(38, 3)) -> 3.142
158        """
159        return exp.cast(
160            exp.cast(col, "DOUBLE"),
161            f"DECIMAL(38, {precision})",
162        )
163
164    def _create_table(
165        self,
166        table_name_or_schema: t.Union[exp.Schema, TableName],
167        expression: t.Optional[exp.Expr],
168        exists: bool = True,
169        replace: bool = False,
170        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
171        table_description: t.Optional[str] = None,
172        column_descriptions: t.Optional[t.Dict[str, str]] = None,
173        table_kind: t.Optional[str] = None,
174        track_rows_processed: bool = True,
175        **kwargs: t.Any,
176    ) -> None:
177        catalog = self.get_current_catalog()
178        catalog_type_tuple = self.fetchone(
179            exp.select("type")
180            .from_("duckdb_databases()")
181            .where(exp.column("database_name").eq(catalog))
182        )
183        catalog_type = catalog_type_tuple[0] if catalog_type_tuple else None
184
185        partitioned_by_exps = None
186        if catalog_type == "ducklake":
187            partitioned_by_exps = kwargs.pop("partitioned_by", None)
188
189        super()._create_table(
190            table_name_or_schema,
191            expression,
192            exists,
193            replace,
194            target_columns_to_types,
195            table_description,
196            column_descriptions,
197            table_kind,
198            track_rows_processed=track_rows_processed,
199            **kwargs,
200        )
201
202        if partitioned_by_exps:
203            # Schema object contains column definitions, so we extract Table
204            table_name = (
205                table_name_or_schema.this
206                if isinstance(table_name_or_schema, exp.Schema)
207                else table_name_or_schema
208            )
209            table_name_str = (
210                table_name.sql(dialect=self.dialect)
211                if isinstance(table_name, exp.Table)
212                else table_name
213            )
214            partitioned_by_str = ", ".join(
215                expr.sql(dialect=self.dialect) for expr in partitioned_by_exps
216            )
217            self.execute(f"ALTER TABLE {table_name_str} SET PARTITIONED BY ({partitioned_by_str});")
218
219    @property
220    def _is_motherduck(self) -> bool:
221        return self._extra_config.get("is_motherduck", False)

Base class wrapping a Database API compliant connection.

The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.

Arguments:
  • connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
  • dialect: The dialect with which this adapter is associated.
  • multithreaded: Indicates whether this adapter will be used by more than one thread.
DIALECT = 'duckdb'
SUPPORTS_TRANSACTIONS = False
SCHEMA_DIFFER_KWARGS = {'parameterized_type_defaults': {<DType.DECIMAL: 'DECIMAL'>: [(18, 3), (0,)]}}
COMMENT_CREATION_TABLE = <CommentCreationTable.COMMENT_COMMAND_ONLY: 4>
COMMENT_CREATION_VIEW = <CommentCreationView.COMMENT_COMMAND_ONLY: 4>
SUPPORTS_CREATE_DROP_CATALOG = True
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ['SCHEMA', 'TABLE', 'VIEW']
42    @property
43    def catalog_support(self) -> CatalogSupport:
44        return CatalogSupport.FULL_SUPPORT
def set_current_catalog(self, catalog: str) -> None:
46    def set_current_catalog(self, catalog: str) -> None:
47        """Sets the catalog name of the current connection."""
48        self.execute(exp.Use(this=exp.to_identifier(catalog)))

Sets the catalog name of the current connection.

def merge( self, target_table: Union[str, sqlglot.expressions.query.Table], source_table: <MagicMock id='132726895264336'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]], unique_key: Sequence[sqlglot.expressions.core.Expr], when_matched: Optional[sqlglot.expressions.dml.Whens] = None, merge_filter: Optional[sqlglot.expressions.core.Expr] = None, source_columns: Optional[List[str]] = None, **kwargs: Any) -> None:
37    def merge(
38        self,
39        target_table: TableName,
40        source_table: QueryOrDF,
41        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
42        unique_key: t.Sequence[exp.Expr],
43        when_matched: t.Optional[exp.Whens] = None,
44        merge_filter: t.Optional[exp.Expr] = None,
45        source_columns: t.Optional[t.List[str]] = None,
46        **kwargs: t.Any,
47    ) -> None:
48        logical_merge(
49            self,
50            target_table,
51            source_table,
52            target_columns_to_types,
53            unique_key,
54            when_matched=when_matched,
55            merge_filter=merge_filter,
56            source_columns=source_columns,
57        )
Inherited Members
sqlmesh.core.engine_adapter.base.EngineAdapter
EngineAdapter
DEFAULT_BATCH_SIZE
DATA_OBJECT_FILTER_BATCH_SIZE
SUPPORTS_INDEXES
MAX_TABLE_COMMENT_LENGTH
MAX_COLUMN_COMMENT_LENGTH
INSERT_OVERWRITE_STRATEGY
SUPPORTS_MATERIALIZED_VIEWS
SUPPORTS_MATERIALIZED_VIEW_SCHEMA
SUPPORTS_VIEW_SCHEMA
SUPPORTS_CLONING
SUPPORTS_MANAGED_MODELS
SUPPORTS_TUPLE_IN
HAS_VIEW_BINDING
SUPPORTS_REPLACE_TABLE
SUPPORTS_GRANTS
DEFAULT_CATALOG_TYPE
QUOTE_IDENTIFIERS_IN_VIEWS
MAX_IDENTIFIER_LENGTH
ATTACH_CORRELATION_ID
SUPPORTS_QUERY_EXECUTION_TRACKING
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
dialect
correlation_id
with_settings
cursor
connection
spark
snowpark
bigframe
comments_enabled
schema_differ
default_catalog
engine_run_mode
recycle
close
get_catalog_type
get_catalog_type_from_table
current_catalog_type
replace_query
create_index
create_table
create_managed_table
ctas
create_state_table
create_table_like
clone_table
drop_data_object
drop_table
drop_managed_table
get_alter_operations
alter_table
create_view
create_schema
drop_schema
drop_view
create_catalog
drop_catalog
columns
table_exists
delete_from
insert_append
insert_overwrite_by_partition
insert_overwrite_by_time_partition
update_table
scd_type_2_by_time
scd_type_2_by_column
rename_table
get_data_object
get_data_objects
fetchone
fetchall
fetchdf
fetch_pyspark_df
wap_enabled
wap_supported
wap_table_name
wap_prepare
wap_publish
sync_grants_config
transaction
session
execute
temp_table
drop_data_object_on_type_mismatch
ensure_nulls_for_unmatched_after_join
use_server_nulls_for_unmatched_after_join
ping
get_table_last_modified_ts
sqlmesh.core.engine_adapter.mixins.GetCurrentCatalogFromFunctionMixin
CURRENT_CATALOG_EXPRESSION
get_current_catalog
sqlmesh.core.engine_adapter.mixins.RowDiffMixin
MAX_TIMESTAMP_PRECISION
concat_columns
normalize_value