sqlmesh.core.engine_adapter.duckdb
1from __future__ import annotations 2 3import typing as t 4from sqlglot import exp 5from pathlib import Path 6 7from sqlmesh.core.engine_adapter.mixins import ( 8 GetCurrentCatalogFromFunctionMixin, 9 LogicalMergeMixin, 10 RowDiffMixin, 11) 12from sqlmesh.core.engine_adapter.shared import ( 13 CatalogSupport, 14 CommentCreationTable, 15 CommentCreationView, 16 DataObject, 17 DataObjectType, 18 SourceQuery, 19 set_catalog, 20) 21 22if t.TYPE_CHECKING: 23 from sqlmesh.core._typing import SchemaName, TableName 24 from sqlmesh.core.engine_adapter._typing import DF 25 26 27@set_catalog(override_mapping={"_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG}) 28class DuckDBEngineAdapter(LogicalMergeMixin, GetCurrentCatalogFromFunctionMixin, RowDiffMixin): 29 DIALECT = "duckdb" 30 SUPPORTS_TRANSACTIONS = False 31 SCHEMA_DIFFER_KWARGS = { 32 "parameterized_type_defaults": { 33 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 3), (0,)], 34 }, 35 } 36 COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY 37 COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY 38 SUPPORTS_CREATE_DROP_CATALOG = True 39 SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA", "TABLE", "VIEW"] 40 41 @property 42 def catalog_support(self) -> CatalogSupport: 43 return CatalogSupport.FULL_SUPPORT 44 45 def set_current_catalog(self, catalog: str) -> None: 46 """Sets the catalog name of the current connection.""" 47 self.execute(exp.Use(this=exp.to_identifier(catalog))) 48 49 def _create_catalog(self, catalog_name: exp.Identifier) -> None: 50 if not self._is_motherduck: 51 db_filename = f"{catalog_name.output_name}.db" 52 self.execute( 53 exp.Attach( 54 this=exp.alias_(exp.Literal.string(db_filename), catalog_name), exists=True 55 ) 56 ) 57 else: 58 self.execute( 59 exp.Create(this=exp.Table(this=catalog_name), kind="DATABASE", exists=True) 60 ) 61 62 def _drop_catalog(self, catalog_name: exp.Identifier) -> None: 63 if not self._is_motherduck: 64 db_file_path = Path(f"{catalog_name.output_name}.db") 65 self.execute(exp.Detach(this=catalog_name, exists=True)) 66 if db_file_path.exists(): 67 db_file_path.unlink() 68 else: 69 self.execute( 70 exp.Drop( 71 this=exp.Table(this=catalog_name), kind="DATABASE", cascade=True, exists=True 72 ) 73 ) 74 75 def _df_to_source_queries( 76 self, 77 df: DF, 78 target_columns_to_types: t.Dict[str, exp.DataType], 79 batch_size: int, 80 target_table: TableName, 81 source_columns: t.Optional[t.List[str]] = None, 82 ) -> t.List[SourceQuery]: 83 temp_table = self._get_temp_table(target_table) 84 temp_table_sql = ( 85 exp.select(*self._casted_columns(target_columns_to_types, source_columns)) 86 .from_("df") 87 .sql(dialect=self.dialect) 88 ) 89 self.cursor.sql(f"CREATE TABLE {temp_table} AS {temp_table_sql}") 90 return [ 91 SourceQuery( 92 query_factory=lambda: self._select_columns(target_columns_to_types).from_( 93 temp_table 94 ), # type: ignore 95 cleanup_func=lambda: self.drop_table(temp_table), 96 ) 97 ] 98 99 def _get_data_objects( 100 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 101 ) -> t.List[DataObject]: 102 """ 103 Returns all the data objects that exist in the given schema and optionally catalog. 104 """ 105 catalog = self.get_current_catalog() 106 107 if isinstance(schema_name, exp.Table): 108 # Ensures we don't generate identifier quotes 109 schema_name = ".".join(part.name for part in schema_name.parts) 110 111 query = ( 112 exp.select( 113 exp.column("table_name").as_("name"), 114 exp.column("table_schema").as_("schema"), 115 exp.case(exp.column("table_type")) 116 .when( 117 exp.Literal.string("BASE TABLE"), 118 exp.Literal.string("table"), 119 ) 120 .when( 121 exp.Literal.string("VIEW"), 122 exp.Literal.string("view"), 123 ) 124 .when( 125 exp.Literal.string("LOCAL TEMPORARY"), 126 exp.Literal.string("table"), 127 ) 128 .as_("type"), 129 ) 130 .from_(exp.to_table("system.information_schema.tables")) 131 .where( 132 exp.column("table_catalog").eq(catalog), exp.column("table_schema").eq(schema_name) 133 ) 134 ) 135 if object_names: 136 query = query.where(exp.column("table_name").isin(*object_names)) 137 df = self.fetchdf(query) 138 return [ 139 DataObject( 140 catalog=catalog, # type: ignore 141 schema=row.schema, # type: ignore 142 name=row.name, # type: ignore 143 type=DataObjectType.from_str(row.type), # type: ignore 144 ) 145 for row in df.itertuples() 146 ] 147 148 def _normalize_decimal_value(self, col: exp.Expr, precision: int) -> exp.Expr: 149 """ 150 duckdb truncates instead of rounding when casting to decimal. 151 152 other databases: select cast(3.14159 as decimal(38,3)) -> 3.142 153 duckdb: select cast(3.14159 as decimal(38,3)) -> 3.141 154 155 however, we can get the behaviour of other databases by casting to double first: 156 select cast(cast(3.14159 as double) as decimal(38, 3)) -> 3.142 157 """ 158 return exp.cast( 159 exp.cast(col, "DOUBLE"), 160 f"DECIMAL(38, {precision})", 161 ) 162 163 def _create_table( 164 self, 165 table_name_or_schema: t.Union[exp.Schema, TableName], 166 expression: t.Optional[exp.Expr], 167 exists: bool = True, 168 replace: bool = False, 169 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 170 table_description: t.Optional[str] = None, 171 column_descriptions: t.Optional[t.Dict[str, str]] = None, 172 table_kind: t.Optional[str] = None, 173 track_rows_processed: bool = True, 174 **kwargs: t.Any, 175 ) -> None: 176 catalog = self.get_current_catalog() 177 catalog_type_tuple = self.fetchone( 178 exp.select("type") 179 .from_("duckdb_databases()") 180 .where(exp.column("database_name").eq(catalog)) 181 ) 182 catalog_type = catalog_type_tuple[0] if catalog_type_tuple else None 183 184 partitioned_by_exps = None 185 if catalog_type == "ducklake": 186 partitioned_by_exps = kwargs.pop("partitioned_by", None) 187 188 super()._create_table( 189 table_name_or_schema, 190 expression, 191 exists, 192 replace, 193 target_columns_to_types, 194 table_description, 195 column_descriptions, 196 table_kind, 197 track_rows_processed=track_rows_processed, 198 **kwargs, 199 ) 200 201 if partitioned_by_exps: 202 # Schema object contains column definitions, so we extract Table 203 table_name = ( 204 table_name_or_schema.this 205 if isinstance(table_name_or_schema, exp.Schema) 206 else table_name_or_schema 207 ) 208 table_name_str = ( 209 table_name.sql(dialect=self.dialect) 210 if isinstance(table_name, exp.Table) 211 else table_name 212 ) 213 partitioned_by_str = ", ".join( 214 expr.sql(dialect=self.dialect) for expr in partitioned_by_exps 215 ) 216 self.execute(f"ALTER TABLE {table_name_str} SET PARTITIONED BY ({partitioned_by_str});") 217 218 @property 219 def _is_motherduck(self) -> bool: 220 return self._extra_config.get("is_motherduck", False)
@set_catalog(override_mapping={'_get_data_objects': CatalogSupport.REQUIRES_SET_CATALOG})
class
DuckDBEngineAdapter28@set_catalog(override_mapping={"_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG}) 29class DuckDBEngineAdapter(LogicalMergeMixin, GetCurrentCatalogFromFunctionMixin, RowDiffMixin): 30 DIALECT = "duckdb" 31 SUPPORTS_TRANSACTIONS = False 32 SCHEMA_DIFFER_KWARGS = { 33 "parameterized_type_defaults": { 34 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 3), (0,)], 35 }, 36 } 37 COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY 38 COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY 39 SUPPORTS_CREATE_DROP_CATALOG = True 40 SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA", "TABLE", "VIEW"] 41 42 @property 43 def catalog_support(self) -> CatalogSupport: 44 return CatalogSupport.FULL_SUPPORT 45 46 def set_current_catalog(self, catalog: str) -> None: 47 """Sets the catalog name of the current connection.""" 48 self.execute(exp.Use(this=exp.to_identifier(catalog))) 49 50 def _create_catalog(self, catalog_name: exp.Identifier) -> None: 51 if not self._is_motherduck: 52 db_filename = f"{catalog_name.output_name}.db" 53 self.execute( 54 exp.Attach( 55 this=exp.alias_(exp.Literal.string(db_filename), catalog_name), exists=True 56 ) 57 ) 58 else: 59 self.execute( 60 exp.Create(this=exp.Table(this=catalog_name), kind="DATABASE", exists=True) 61 ) 62 63 def _drop_catalog(self, catalog_name: exp.Identifier) -> None: 64 if not self._is_motherduck: 65 db_file_path = Path(f"{catalog_name.output_name}.db") 66 self.execute(exp.Detach(this=catalog_name, exists=True)) 67 if db_file_path.exists(): 68 db_file_path.unlink() 69 else: 70 self.execute( 71 exp.Drop( 72 this=exp.Table(this=catalog_name), kind="DATABASE", cascade=True, exists=True 73 ) 74 ) 75 76 def _df_to_source_queries( 77 self, 78 df: DF, 79 target_columns_to_types: t.Dict[str, exp.DataType], 80 batch_size: int, 81 target_table: TableName, 82 source_columns: t.Optional[t.List[str]] = None, 83 ) -> t.List[SourceQuery]: 84 temp_table = self._get_temp_table(target_table) 85 temp_table_sql = ( 86 exp.select(*self._casted_columns(target_columns_to_types, source_columns)) 87 .from_("df") 88 .sql(dialect=self.dialect) 89 ) 90 self.cursor.sql(f"CREATE TABLE {temp_table} AS {temp_table_sql}") 91 return [ 92 SourceQuery( 93 query_factory=lambda: self._select_columns(target_columns_to_types).from_( 94 temp_table 95 ), # type: ignore 96 cleanup_func=lambda: self.drop_table(temp_table), 97 ) 98 ] 99 100 def _get_data_objects( 101 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 102 ) -> t.List[DataObject]: 103 """ 104 Returns all the data objects that exist in the given schema and optionally catalog. 105 """ 106 catalog = self.get_current_catalog() 107 108 if isinstance(schema_name, exp.Table): 109 # Ensures we don't generate identifier quotes 110 schema_name = ".".join(part.name for part in schema_name.parts) 111 112 query = ( 113 exp.select( 114 exp.column("table_name").as_("name"), 115 exp.column("table_schema").as_("schema"), 116 exp.case(exp.column("table_type")) 117 .when( 118 exp.Literal.string("BASE TABLE"), 119 exp.Literal.string("table"), 120 ) 121 .when( 122 exp.Literal.string("VIEW"), 123 exp.Literal.string("view"), 124 ) 125 .when( 126 exp.Literal.string("LOCAL TEMPORARY"), 127 exp.Literal.string("table"), 128 ) 129 .as_("type"), 130 ) 131 .from_(exp.to_table("system.information_schema.tables")) 132 .where( 133 exp.column("table_catalog").eq(catalog), exp.column("table_schema").eq(schema_name) 134 ) 135 ) 136 if object_names: 137 query = query.where(exp.column("table_name").isin(*object_names)) 138 df = self.fetchdf(query) 139 return [ 140 DataObject( 141 catalog=catalog, # type: ignore 142 schema=row.schema, # type: ignore 143 name=row.name, # type: ignore 144 type=DataObjectType.from_str(row.type), # type: ignore 145 ) 146 for row in df.itertuples() 147 ] 148 149 def _normalize_decimal_value(self, col: exp.Expr, precision: int) -> exp.Expr: 150 """ 151 duckdb truncates instead of rounding when casting to decimal. 152 153 other databases: select cast(3.14159 as decimal(38,3)) -> 3.142 154 duckdb: select cast(3.14159 as decimal(38,3)) -> 3.141 155 156 however, we can get the behaviour of other databases by casting to double first: 157 select cast(cast(3.14159 as double) as decimal(38, 3)) -> 3.142 158 """ 159 return exp.cast( 160 exp.cast(col, "DOUBLE"), 161 f"DECIMAL(38, {precision})", 162 ) 163 164 def _create_table( 165 self, 166 table_name_or_schema: t.Union[exp.Schema, TableName], 167 expression: t.Optional[exp.Expr], 168 exists: bool = True, 169 replace: bool = False, 170 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 171 table_description: t.Optional[str] = None, 172 column_descriptions: t.Optional[t.Dict[str, str]] = None, 173 table_kind: t.Optional[str] = None, 174 track_rows_processed: bool = True, 175 **kwargs: t.Any, 176 ) -> None: 177 catalog = self.get_current_catalog() 178 catalog_type_tuple = self.fetchone( 179 exp.select("type") 180 .from_("duckdb_databases()") 181 .where(exp.column("database_name").eq(catalog)) 182 ) 183 catalog_type = catalog_type_tuple[0] if catalog_type_tuple else None 184 185 partitioned_by_exps = None 186 if catalog_type == "ducklake": 187 partitioned_by_exps = kwargs.pop("partitioned_by", None) 188 189 super()._create_table( 190 table_name_or_schema, 191 expression, 192 exists, 193 replace, 194 target_columns_to_types, 195 table_description, 196 column_descriptions, 197 table_kind, 198 track_rows_processed=track_rows_processed, 199 **kwargs, 200 ) 201 202 if partitioned_by_exps: 203 # Schema object contains column definitions, so we extract Table 204 table_name = ( 205 table_name_or_schema.this 206 if isinstance(table_name_or_schema, exp.Schema) 207 else table_name_or_schema 208 ) 209 table_name_str = ( 210 table_name.sql(dialect=self.dialect) 211 if isinstance(table_name, exp.Table) 212 else table_name 213 ) 214 partitioned_by_str = ", ".join( 215 expr.sql(dialect=self.dialect) for expr in partitioned_by_exps 216 ) 217 self.execute(f"ALTER TABLE {table_name_str} SET PARTITIONED BY ({partitioned_by_str});") 218 219 @property 220 def _is_motherduck(self) -> bool: 221 return self._extra_config.get("is_motherduck", False)
Base class wrapping a Database API compliant connection.
The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.
Arguments:
- connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
- dialect: The dialect with which this adapter is associated.
- multithreaded: Indicates whether this adapter will be used by more than one thread.
SCHEMA_DIFFER_KWARGS =
{'parameterized_type_defaults': {<DType.DECIMAL: 'DECIMAL'>: [(18, 3), (0,)]}}
catalog_support: sqlmesh.core.engine_adapter.shared.CatalogSupport
def
set_current_catalog(self, catalog: str) -> None:
46 def set_current_catalog(self, catalog: str) -> None: 47 """Sets the catalog name of the current connection.""" 48 self.execute(exp.Use(this=exp.to_identifier(catalog)))
Sets the catalog name of the current connection.
def
merge( self, target_table: Union[str, sqlglot.expressions.query.Table], source_table: <MagicMock id='132726895264336'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]], unique_key: Sequence[sqlglot.expressions.core.Expr], when_matched: Optional[sqlglot.expressions.dml.Whens] = None, merge_filter: Optional[sqlglot.expressions.core.Expr] = None, source_columns: Optional[List[str]] = None, **kwargs: Any) -> None:
37 def merge( 38 self, 39 target_table: TableName, 40 source_table: QueryOrDF, 41 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], 42 unique_key: t.Sequence[exp.Expr], 43 when_matched: t.Optional[exp.Whens] = None, 44 merge_filter: t.Optional[exp.Expr] = None, 45 source_columns: t.Optional[t.List[str]] = None, 46 **kwargs: t.Any, 47 ) -> None: 48 logical_merge( 49 self, 50 target_table, 51 source_table, 52 target_columns_to_types, 53 unique_key, 54 when_matched=when_matched, 55 merge_filter=merge_filter, 56 source_columns=source_columns, 57 )
Inherited Members
- sqlmesh.core.engine_adapter.base.EngineAdapter
- EngineAdapter
- DEFAULT_BATCH_SIZE
- DATA_OBJECT_FILTER_BATCH_SIZE
- SUPPORTS_INDEXES
- MAX_TABLE_COMMENT_LENGTH
- MAX_COLUMN_COMMENT_LENGTH
- INSERT_OVERWRITE_STRATEGY
- SUPPORTS_MATERIALIZED_VIEWS
- SUPPORTS_MATERIALIZED_VIEW_SCHEMA
- SUPPORTS_VIEW_SCHEMA
- SUPPORTS_CLONING
- SUPPORTS_MANAGED_MODELS
- SUPPORTS_TUPLE_IN
- HAS_VIEW_BINDING
- SUPPORTS_REPLACE_TABLE
- SUPPORTS_GRANTS
- DEFAULT_CATALOG_TYPE
- QUOTE_IDENTIFIERS_IN_VIEWS
- MAX_IDENTIFIER_LENGTH
- ATTACH_CORRELATION_ID
- SUPPORTS_QUERY_EXECUTION_TRACKING
- SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
- dialect
- correlation_id
- with_settings
- cursor
- connection
- spark
- snowpark
- bigframe
- comments_enabled
- schema_differ
- default_catalog
- engine_run_mode
- recycle
- close
- get_catalog_type
- get_catalog_type_from_table
- current_catalog_type
- replace_query
- create_index
- create_table
- create_managed_table
- ctas
- create_state_table
- create_table_like
- clone_table
- drop_data_object
- drop_table
- drop_managed_table
- get_alter_operations
- alter_table
- create_view
- create_schema
- drop_schema
- drop_view
- create_catalog
- drop_catalog
- columns
- table_exists
- delete_from
- insert_append
- insert_overwrite_by_partition
- insert_overwrite_by_time_partition
- update_table
- scd_type_2_by_time
- scd_type_2_by_column
- rename_table
- get_data_object
- get_data_objects
- fetchone
- fetchall
- fetchdf
- fetch_pyspark_df
- wap_enabled
- wap_supported
- wap_table_name
- wap_prepare
- wap_publish
- sync_grants_config
- transaction
- session
- execute
- temp_table
- drop_data_object_on_type_mismatch
- ensure_nulls_for_unmatched_after_join
- use_server_nulls_for_unmatched_after_join
- ping
- get_table_last_modified_ts