Contains MSSQLEngineAdapter.
1"""Contains MSSQLEngineAdapter.""" 2 3from __future__ import annotations 4 5import typing as t 6import logging 7 8from sqlglot import exp 9 10from sqlmesh.core.dialect import to_schema, add_table 11from sqlmesh.core.engine_adapter.base import ( 12 EngineAdapterWithIndexSupport, 13 EngineAdapter, 14 InsertOverwriteStrategy, 15 MERGE_SOURCE_ALIAS, 16 MERGE_TARGET_ALIAS, 17 _get_data_object_cache_key, 18) 19from sqlmesh.core.engine_adapter.mixins import ( 20 GetCurrentCatalogFromFunctionMixin, 21 PandasNativeFetchDFSupportMixin, 22 VarcharSizeWorkaroundMixin, 23 RowDiffMixin, 24) 25from sqlmesh.core.engine_adapter.shared import ( 26 CatalogSupport, 27 CommentCreationTable, 28 CommentCreationView, 29 DataObject, 30 DataObjectType, 31 SourceQuery, 32 set_catalog, 33) 34from sqlmesh.utils import get_source_columns_to_types 35 36if t.TYPE_CHECKING: 37 from sqlmesh.core._typing import SchemaName, TableName 38 from sqlmesh.core.engine_adapter._typing import DF, Query, QueryOrDF 39 40 41logger = logging.getLogger(__name__) 42 43 44@set_catalog() 45class MSSQLEngineAdapter( 46 EngineAdapterWithIndexSupport, 47 PandasNativeFetchDFSupportMixin, 48 GetCurrentCatalogFromFunctionMixin, 49 VarcharSizeWorkaroundMixin, 50 RowDiffMixin, 51): 52 DIALECT: str = "tsql" 53 SUPPORTS_TUPLE_IN = False 54 SUPPORTS_MATERIALIZED_VIEWS = False 55 CURRENT_CATALOG_EXPRESSION = exp.func("db_name") 56 COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED 57 COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED 58 SUPPORTS_REPLACE_TABLE = False 59 MAX_IDENTIFIER_LENGTH = 128 60 SUPPORTS_QUERY_EXECUTION_TRACKING = True 61 SCHEMA_DIFFER_KWARGS = { 62 "parameterized_type_defaults": { 63 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)], 64 exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)], 65 exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(1,)], 66 exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], 67 exp.DataType.build("VARCHAR", dialect=DIALECT).this: [(1,)], 68 exp.DataType.build("NCHAR", dialect=DIALECT).this: [(1,)], 69 exp.DataType.build("NVARCHAR", dialect=DIALECT).this: [(1,)], 70 exp.DataType.build("TIME", dialect=DIALECT).this: [(7,)], 71 exp.DataType.build("DATETIME2", dialect=DIALECT).this: [(7,)], 72 exp.DataType.build("DATETIMEOFFSET", dialect=DIALECT).this: [(7,)], 73 }, 74 "max_parameter_length": { 75 exp.DataType.build("VARBINARY", dialect=DIALECT).this: 2147483647, # 2 GB 76 exp.DataType.build("VARCHAR", dialect=DIALECT).this: 2147483647, 77 exp.DataType.build("NVARCHAR", dialect=DIALECT).this: 2147483647, 78 }, 79 } 80 VARIABLE_LENGTH_DATA_TYPES = {"binary", "varbinary", "char", "varchar", "nchar", "nvarchar"} 81 INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE 82 83 @property 84 def catalog_support(self) -> CatalogSupport: 85 # MSSQL and AzureSQL both use this engine adapter, but they differ in catalog support. 86 # Therefore, we specify the catalog support in the connection config `_extra_engine_config` 87 # instead of in the adapter itself. 88 return self._extra_config["catalog_support"] 89 90 def columns( 91 self, 92 table_name: TableName, 93 include_pseudo_columns: bool = True, 94 ) -> t.Dict[str, exp.DataType]: 95 """MsSql doesn't support describe so we query information_schema.""" 96 97 table = exp.to_table(table_name) 98 99 sql = ( 100 exp.select( 101 "COLUMN_NAME", 102 "DATA_TYPE", 103 "CHARACTER_MAXIMUM_LENGTH", 104 "NUMERIC_PRECISION", 105 "NUMERIC_SCALE", 106 ) 107 .from_("INFORMATION_SCHEMA.COLUMNS") 108 .where(f"TABLE_NAME = '{table.name}'") 109 ) 110 database_name = table.db 111 if database_name: 112 sql = sql.where(f"TABLE_SCHEMA = '{database_name}'") 113 114 columns_raw = self.fetchall(sql, quote_identifiers=True) 115 116 def build_var_length_col( 117 column_name: str, 118 data_type: str, 119 character_maximum_length: t.Optional[int] = None, 120 numeric_precision: t.Optional[int] = None, 121 numeric_scale: t.Optional[int] = None, 122 ) -> tuple: 123 data_type = data_type.lower() 124 if ( 125 data_type in self.VARIABLE_LENGTH_DATA_TYPES 126 and character_maximum_length is not None 127 and character_maximum_length > 0 128 ): 129 return (column_name, f"{data_type}({character_maximum_length})") 130 if ( 131 data_type in ("varbinary", "varchar", "nvarchar") 132 and character_maximum_length is not None 133 and character_maximum_length == -1 134 ): 135 return (column_name, f"{data_type}(max)") 136 if data_type in ("decimal", "numeric"): 137 return (column_name, f"{data_type}({numeric_precision}, {numeric_scale})") 138 if data_type == "float": 139 return (column_name, f"{data_type}({numeric_precision})") 140 141 return (column_name, data_type) 142 143 columns = [build_var_length_col(*row) for row in columns_raw] 144 145 return { 146 column_name: exp.DataType.build(data_type, dialect=self.dialect) 147 for column_name, data_type in columns 148 } 149 150 def table_exists(self, table_name: TableName) -> bool: 151 """MsSql doesn't support describe so we query information_schema.""" 152 table = exp.to_table(table_name) 153 data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) 154 if data_object_cache_key in self._data_object_cache: 155 logger.debug("Table existence cache hit: %s", data_object_cache_key) 156 return self._data_object_cache[data_object_cache_key] is not None 157 158 sql = ( 159 exp.select("1") 160 .from_("INFORMATION_SCHEMA.TABLES") 161 .where(f"TABLE_NAME = '{table.alias_or_name}'") 162 ) 163 database_name = table.db 164 if database_name: 165 sql = sql.where(f"TABLE_SCHEMA = '{database_name}'") 166 167 result = self.fetchone(sql, quote_identifiers=True) 168 169 return result[0] == 1 if result else False 170 171 def set_current_catalog(self, catalog_name: str) -> None: 172 self.execute(exp.Use(this=exp.to_identifier(catalog_name))) 173 174 def drop_schema( 175 self, 176 schema_name: SchemaName, 177 ignore_if_not_exists: bool = True, 178 cascade: bool = False, 179 **drop_args: t.Dict[str, exp.Expr], 180 ) -> None: 181 """ 182 MsSql doesn't support CASCADE clause and drops schemas unconditionally. 183 """ 184 if cascade: 185 objects = self._get_data_objects(schema_name) 186 for obj in objects: 187 # Build properly quoted table for MSSQL using square brackets when needed 188 object_table = exp.table_(obj.name, obj.schema_name) 189 190 # _get_data_objects is catalog-specific, so these can't accidentally drop view/tables in another catalog 191 if obj.type == DataObjectType.VIEW: 192 self.drop_view( 193 object_table, 194 ignore_if_not_exists=ignore_if_not_exists, 195 ) 196 else: 197 self.drop_table( 198 object_table, 199 exists=ignore_if_not_exists, 200 ) 201 super().drop_schema(schema_name, ignore_if_not_exists=ignore_if_not_exists, cascade=False) 202 203 def merge( 204 self, 205 target_table: TableName, 206 source_table: QueryOrDF, 207 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], 208 unique_key: t.Sequence[exp.Expr], 209 when_matched: t.Optional[exp.Whens] = None, 210 merge_filter: t.Optional[exp.Expr] = None, 211 source_columns: t.Optional[t.List[str]] = None, 212 **kwargs: t.Any, 213 ) -> None: 214 mssql_merge_exists = kwargs.get("physical_properties", {}).get("mssql_merge_exists") 215 216 source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( 217 source_table, 218 target_columns_to_types, 219 target_table=target_table, 220 source_columns=source_columns, 221 ) 222 target_columns_to_types = target_columns_to_types or self.columns(target_table) 223 on = exp.and_( 224 *( 225 add_table(part, MERGE_TARGET_ALIAS).eq(add_table(part, MERGE_SOURCE_ALIAS)) 226 for part in unique_key 227 ) 228 ) 229 if merge_filter: 230 on = exp.and_(merge_filter, on) 231 232 match_expressions = [] 233 if not when_matched: 234 unique_key_names = [y.name for y in unique_key] 235 columns_to_types_no_keys = [ 236 c for c in target_columns_to_types if c not in unique_key_names 237 ] 238 239 target_columns_no_keys = [ 240 exp.column(c, MERGE_TARGET_ALIAS) for c in columns_to_types_no_keys 241 ] 242 source_columns_no_keys = [ 243 exp.column(c, MERGE_SOURCE_ALIAS) for c in columns_to_types_no_keys 244 ] 245 246 match_condition = ( 247 exp.Exists( 248 this=exp.select(*target_columns_no_keys).except_( 249 exp.select(*source_columns_no_keys) 250 ) 251 ) 252 if mssql_merge_exists 253 else None 254 ) 255 256 if target_columns_no_keys: 257 match_expressions.append( 258 exp.When( 259 matched=True, 260 source=False, 261 condition=match_condition, 262 then=exp.Update( 263 expressions=[ 264 exp.column(col, MERGE_TARGET_ALIAS).eq( 265 exp.column(col, MERGE_SOURCE_ALIAS) 266 ) 267 for col in columns_to_types_no_keys 268 ], 269 ), 270 ) 271 ) 272 else: 273 match_expressions.extend(when_matched.copy().expressions) 274 275 match_expressions.append( 276 exp.When( 277 matched=False, 278 source=False, 279 then=exp.Insert( 280 this=exp.Tuple( 281 expressions=[exp.column(col) for col in target_columns_to_types] 282 ), 283 expression=exp.Tuple( 284 expressions=[ 285 exp.column(col, MERGE_SOURCE_ALIAS) for col in target_columns_to_types 286 ] 287 ), 288 ), 289 ) 290 ) 291 for source_query in source_queries: 292 with source_query as query: 293 self._merge( 294 target_table=target_table, 295 query=query, 296 on=on, 297 whens=exp.Whens(expressions=match_expressions), 298 ) 299 300 def _convert_df_datetime(self, df: DF, columns_to_types: t.Dict[str, exp.DataType]) -> None: 301 import pandas as pd 302 from pandas.api.types import is_datetime64_any_dtype # type: ignore 303 304 # pymssql doesn't convert Pandas Timestamp (datetime64) types 305 # - this code is based on snowflake adapter implementation 306 for column, kind in columns_to_types.items(): 307 # pymssql errors if the column contains a datetime.date object 308 if kind.is_type("date"): # type: ignore 309 df[column] = pd.to_datetime(df[column]).dt.strftime("%Y-%m-%d") # type: ignore 310 elif is_datetime64_any_dtype(df.dtypes[column]): # type: ignore 311 if getattr(df.dtypes[column], "tz", None) is not None: # type: ignore 312 # MSSQL requires a colon in the offset (+00:00) so we use isoformat() instead of strftime() 313 df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" ")) # type: ignore 314 315 # bulk_copy() doesn't work with TZ timestamp, so load into string column and cast to 316 # timestamp in SELECT statement 317 columns_to_types[column] = exp.DataType.build("TEXT") 318 else: 319 df[column] = pd.to_datetime(df[column]).dt.strftime("%Y-%m-%d %H:%M:%S.%f") # type: ignore 320 321 def _df_to_source_queries( 322 self, 323 df: DF, 324 target_columns_to_types: t.Dict[str, exp.DataType], 325 batch_size: int, 326 target_table: TableName, 327 source_columns: t.Optional[t.List[str]] = None, 328 ) -> t.List[SourceQuery]: 329 import pandas as pd 330 import numpy as np 331 332 assert isinstance(df, pd.DataFrame) 333 temp_table = self._get_temp_table(target_table or "pandas") 334 335 # Return the superclass implementation if the connection pool doesn't support bulk_copy 336 if not hasattr(self._connection_pool.get(), "bulk_copy"): 337 return super()._df_to_source_queries( 338 df, target_columns_to_types, batch_size, target_table, source_columns=source_columns 339 ) 340 341 def query_factory() -> Query: 342 # It is possible for the factory to be called multiple times and if so then the temp table will already 343 # be created so we skip creating again. This means we are assuming the first call is the same result 344 # as later calls. 345 if not self.table_exists(temp_table): 346 source_columns_to_types = get_source_columns_to_types( 347 target_columns_to_types, source_columns 348 ) 349 ordered_df = df[ 350 list(source_columns_to_types) 351 ] # reorder DataFrame so it matches columns_to_types 352 self._convert_df_datetime(ordered_df, source_columns_to_types) 353 self.create_table(temp_table, source_columns_to_types) 354 rows: t.List[t.Tuple[t.Any, ...]] = list( 355 ordered_df.replace({np.nan: None}).itertuples(index=False, name=None) # type: ignore 356 ) 357 conn = self._connection_pool.get() 358 conn.bulk_copy(temp_table.sql(dialect=self.dialect), rows) 359 return exp.select( 360 *self._casted_columns(target_columns_to_types, source_columns=source_columns) 361 ).from_(temp_table) # type: ignore 362 363 return [ 364 SourceQuery( 365 query_factory=query_factory, 366 cleanup_func=lambda: self.drop_table(temp_table), 367 ) 368 ] 369 370 def _get_data_objects( 371 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 372 ) -> t.List[DataObject]: 373 """ 374 Returns all the data objects that exist in the given schema and catalog. 375 """ 376 import pandas as pd 377 378 catalog = self.get_current_catalog() 379 query = ( 380 exp.select( 381 exp.column("TABLE_NAME").as_("name"), 382 exp.column("TABLE_SCHEMA").as_("schema_name"), 383 exp.case() 384 .when(exp.column("TABLE_TYPE").eq("BASE TABLE"), exp.Literal.string("TABLE")) 385 .else_(exp.column("TABLE_TYPE")) 386 .as_("type"), 387 ) 388 .from_(exp.table_("TABLES", db="INFORMATION_SCHEMA")) 389 .where(exp.column("TABLE_SCHEMA").eq(to_schema(schema_name).db)) 390 ) 391 if object_names: 392 query = query.where(exp.column("TABLE_NAME").isin(*object_names)) 393 dataframe: pd.DataFrame = self.fetchdf(query) 394 return [ 395 DataObject( 396 catalog=catalog, # type: ignore 397 schema=row.schema_name, # type: ignore 398 name=row.name, # type: ignore 399 type=DataObjectType.from_str(row.type), # type: ignore 400 ) 401 for row in dataframe.itertuples() 402 ] 403 404 def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str: 405 sql = super()._to_sql(expression, quote=quote, **kwargs) 406 return f"{sql};" 407 408 def _rename_table( 409 self, 410 old_table_name: TableName, 411 new_table_name: TableName, 412 ) -> None: 413 # The function that renames tables in MSSQL takes string literals as arguments instead of identifiers, 414 # so we shouldn't quote the identifiers. 415 self.execute(exp.rename_table(old_table_name, new_table_name), quote_identifiers=False) 416 417 def _insert_overwrite_by_condition( 418 self, 419 table_name: TableName, 420 source_queries: t.List[SourceQuery], 421 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 422 where: t.Optional[exp.Condition] = None, 423 insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, 424 **kwargs: t.Any, 425 ) -> None: 426 # note that this is passed as table_properties here rather than physical_properties 427 use_merge_strategy = kwargs.get("table_properties", {}).get("mssql_merge_exists") 428 if (not where or where == exp.true()) and not use_merge_strategy: 429 # this is a full table replacement, call the base strategy to do DELETE+INSERT 430 # which will result in TRUNCATE+INSERT due to how we have overridden self.delete_from() 431 return EngineAdapter._insert_overwrite_by_condition( 432 self, 433 table_name=table_name, 434 source_queries=source_queries, 435 target_columns_to_types=target_columns_to_types, 436 where=where, 437 insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, 438 **kwargs, 439 ) 440 441 # For conditional overwrites or when mssql_merge_exists is set use MERGE 442 return super()._insert_overwrite_by_condition( 443 table_name=table_name, 444 source_queries=source_queries, 445 target_columns_to_types=target_columns_to_types, 446 where=where, 447 insert_overwrite_strategy_override=insert_overwrite_strategy_override, 448 **kwargs, 449 ) 450 451 def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None: 452 if where == exp.true(): 453 # "A TRUNCATE TABLE operation can be rolled back within a transaction." 454 # ref: https://learn.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15#remarks 455 return self.execute( 456 exp.TruncateTable(expressions=[exp.to_table(table_name, dialect=self.dialect)]) 457 ) 458 459 return super().delete_from(table_name, where)
logger =
<Logger sqlmesh.core.engine_adapter.mssql (WARNING)>
@set_catalog()
class
MSSQLEngineAdapter45@set_catalog() 46class MSSQLEngineAdapter( 47 EngineAdapterWithIndexSupport, 48 PandasNativeFetchDFSupportMixin, 49 GetCurrentCatalogFromFunctionMixin, 50 VarcharSizeWorkaroundMixin, 51 RowDiffMixin, 52): 53 DIALECT: str = "tsql" 54 SUPPORTS_TUPLE_IN = False 55 SUPPORTS_MATERIALIZED_VIEWS = False 56 CURRENT_CATALOG_EXPRESSION = exp.func("db_name") 57 COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED 58 COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED 59 SUPPORTS_REPLACE_TABLE = False 60 MAX_IDENTIFIER_LENGTH = 128 61 SUPPORTS_QUERY_EXECUTION_TRACKING = True 62 SCHEMA_DIFFER_KWARGS = { 63 "parameterized_type_defaults": { 64 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)], 65 exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)], 66 exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(1,)], 67 exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], 68 exp.DataType.build("VARCHAR", dialect=DIALECT).this: [(1,)], 69 exp.DataType.build("NCHAR", dialect=DIALECT).this: [(1,)], 70 exp.DataType.build("NVARCHAR", dialect=DIALECT).this: [(1,)], 71 exp.DataType.build("TIME", dialect=DIALECT).this: [(7,)], 72 exp.DataType.build("DATETIME2", dialect=DIALECT).this: [(7,)], 73 exp.DataType.build("DATETIMEOFFSET", dialect=DIALECT).this: [(7,)], 74 }, 75 "max_parameter_length": { 76 exp.DataType.build("VARBINARY", dialect=DIALECT).this: 2147483647, # 2 GB 77 exp.DataType.build("VARCHAR", dialect=DIALECT).this: 2147483647, 78 exp.DataType.build("NVARCHAR", dialect=DIALECT).this: 2147483647, 79 }, 80 } 81 VARIABLE_LENGTH_DATA_TYPES = {"binary", "varbinary", "char", "varchar", "nchar", "nvarchar"} 82 INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE 83 84 @property 85 def catalog_support(self) -> CatalogSupport: 86 # MSSQL and AzureSQL both use this engine adapter, but they differ in catalog support. 87 # Therefore, we specify the catalog support in the connection config `_extra_engine_config` 88 # instead of in the adapter itself. 89 return self._extra_config["catalog_support"] 90 91 def columns( 92 self, 93 table_name: TableName, 94 include_pseudo_columns: bool = True, 95 ) -> t.Dict[str, exp.DataType]: 96 """MsSql doesn't support describe so we query information_schema.""" 97 98 table = exp.to_table(table_name) 99 100 sql = ( 101 exp.select( 102 "COLUMN_NAME", 103 "DATA_TYPE", 104 "CHARACTER_MAXIMUM_LENGTH", 105 "NUMERIC_PRECISION", 106 "NUMERIC_SCALE", 107 ) 108 .from_("INFORMATION_SCHEMA.COLUMNS") 109 .where(f"TABLE_NAME = '{table.name}'") 110 ) 111 database_name = table.db 112 if database_name: 113 sql = sql.where(f"TABLE_SCHEMA = '{database_name}'") 114 115 columns_raw = self.fetchall(sql, quote_identifiers=True) 116 117 def build_var_length_col( 118 column_name: str, 119 data_type: str, 120 character_maximum_length: t.Optional[int] = None, 121 numeric_precision: t.Optional[int] = None, 122 numeric_scale: t.Optional[int] = None, 123 ) -> tuple: 124 data_type = data_type.lower() 125 if ( 126 data_type in self.VARIABLE_LENGTH_DATA_TYPES 127 and character_maximum_length is not None 128 and character_maximum_length > 0 129 ): 130 return (column_name, f"{data_type}({character_maximum_length})") 131 if ( 132 data_type in ("varbinary", "varchar", "nvarchar") 133 and character_maximum_length is not None 134 and character_maximum_length == -1 135 ): 136 return (column_name, f"{data_type}(max)") 137 if data_type in ("decimal", "numeric"): 138 return (column_name, f"{data_type}({numeric_precision}, {numeric_scale})") 139 if data_type == "float": 140 return (column_name, f"{data_type}({numeric_precision})") 141 142 return (column_name, data_type) 143 144 columns = [build_var_length_col(*row) for row in columns_raw] 145 146 return { 147 column_name: exp.DataType.build(data_type, dialect=self.dialect) 148 for column_name, data_type in columns 149 } 150 151 def table_exists(self, table_name: TableName) -> bool: 152 """MsSql doesn't support describe so we query information_schema.""" 153 table = exp.to_table(table_name) 154 data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) 155 if data_object_cache_key in self._data_object_cache: 156 logger.debug("Table existence cache hit: %s", data_object_cache_key) 157 return self._data_object_cache[data_object_cache_key] is not None 158 159 sql = ( 160 exp.select("1") 161 .from_("INFORMATION_SCHEMA.TABLES") 162 .where(f"TABLE_NAME = '{table.alias_or_name}'") 163 ) 164 database_name = table.db 165 if database_name: 166 sql = sql.where(f"TABLE_SCHEMA = '{database_name}'") 167 168 result = self.fetchone(sql, quote_identifiers=True) 169 170 return result[0] == 1 if result else False 171 172 def set_current_catalog(self, catalog_name: str) -> None: 173 self.execute(exp.Use(this=exp.to_identifier(catalog_name))) 174 175 def drop_schema( 176 self, 177 schema_name: SchemaName, 178 ignore_if_not_exists: bool = True, 179 cascade: bool = False, 180 **drop_args: t.Dict[str, exp.Expr], 181 ) -> None: 182 """ 183 MsSql doesn't support CASCADE clause and drops schemas unconditionally. 184 """ 185 if cascade: 186 objects = self._get_data_objects(schema_name) 187 for obj in objects: 188 # Build properly quoted table for MSSQL using square brackets when needed 189 object_table = exp.table_(obj.name, obj.schema_name) 190 191 # _get_data_objects is catalog-specific, so these can't accidentally drop view/tables in another catalog 192 if obj.type == DataObjectType.VIEW: 193 self.drop_view( 194 object_table, 195 ignore_if_not_exists=ignore_if_not_exists, 196 ) 197 else: 198 self.drop_table( 199 object_table, 200 exists=ignore_if_not_exists, 201 ) 202 super().drop_schema(schema_name, ignore_if_not_exists=ignore_if_not_exists, cascade=False) 203 204 def merge( 205 self, 206 target_table: TableName, 207 source_table: QueryOrDF, 208 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], 209 unique_key: t.Sequence[exp.Expr], 210 when_matched: t.Optional[exp.Whens] = None, 211 merge_filter: t.Optional[exp.Expr] = None, 212 source_columns: t.Optional[t.List[str]] = None, 213 **kwargs: t.Any, 214 ) -> None: 215 mssql_merge_exists = kwargs.get("physical_properties", {}).get("mssql_merge_exists") 216 217 source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( 218 source_table, 219 target_columns_to_types, 220 target_table=target_table, 221 source_columns=source_columns, 222 ) 223 target_columns_to_types = target_columns_to_types or self.columns(target_table) 224 on = exp.and_( 225 *( 226 add_table(part, MERGE_TARGET_ALIAS).eq(add_table(part, MERGE_SOURCE_ALIAS)) 227 for part in unique_key 228 ) 229 ) 230 if merge_filter: 231 on = exp.and_(merge_filter, on) 232 233 match_expressions = [] 234 if not when_matched: 235 unique_key_names = [y.name for y in unique_key] 236 columns_to_types_no_keys = [ 237 c for c in target_columns_to_types if c not in unique_key_names 238 ] 239 240 target_columns_no_keys = [ 241 exp.column(c, MERGE_TARGET_ALIAS) for c in columns_to_types_no_keys 242 ] 243 source_columns_no_keys = [ 244 exp.column(c, MERGE_SOURCE_ALIAS) for c in columns_to_types_no_keys 245 ] 246 247 match_condition = ( 248 exp.Exists( 249 this=exp.select(*target_columns_no_keys).except_( 250 exp.select(*source_columns_no_keys) 251 ) 252 ) 253 if mssql_merge_exists 254 else None 255 ) 256 257 if target_columns_no_keys: 258 match_expressions.append( 259 exp.When( 260 matched=True, 261 source=False, 262 condition=match_condition, 263 then=exp.Update( 264 expressions=[ 265 exp.column(col, MERGE_TARGET_ALIAS).eq( 266 exp.column(col, MERGE_SOURCE_ALIAS) 267 ) 268 for col in columns_to_types_no_keys 269 ], 270 ), 271 ) 272 ) 273 else: 274 match_expressions.extend(when_matched.copy().expressions) 275 276 match_expressions.append( 277 exp.When( 278 matched=False, 279 source=False, 280 then=exp.Insert( 281 this=exp.Tuple( 282 expressions=[exp.column(col) for col in target_columns_to_types] 283 ), 284 expression=exp.Tuple( 285 expressions=[ 286 exp.column(col, MERGE_SOURCE_ALIAS) for col in target_columns_to_types 287 ] 288 ), 289 ), 290 ) 291 ) 292 for source_query in source_queries: 293 with source_query as query: 294 self._merge( 295 target_table=target_table, 296 query=query, 297 on=on, 298 whens=exp.Whens(expressions=match_expressions), 299 ) 300 301 def _convert_df_datetime(self, df: DF, columns_to_types: t.Dict[str, exp.DataType]) -> None: 302 import pandas as pd 303 from pandas.api.types import is_datetime64_any_dtype # type: ignore 304 305 # pymssql doesn't convert Pandas Timestamp (datetime64) types 306 # - this code is based on snowflake adapter implementation 307 for column, kind in columns_to_types.items(): 308 # pymssql errors if the column contains a datetime.date object 309 if kind.is_type("date"): # type: ignore 310 df[column] = pd.to_datetime(df[column]).dt.strftime("%Y-%m-%d") # type: ignore 311 elif is_datetime64_any_dtype(df.dtypes[column]): # type: ignore 312 if getattr(df.dtypes[column], "tz", None) is not None: # type: ignore 313 # MSSQL requires a colon in the offset (+00:00) so we use isoformat() instead of strftime() 314 df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" ")) # type: ignore 315 316 # bulk_copy() doesn't work with TZ timestamp, so load into string column and cast to 317 # timestamp in SELECT statement 318 columns_to_types[column] = exp.DataType.build("TEXT") 319 else: 320 df[column] = pd.to_datetime(df[column]).dt.strftime("%Y-%m-%d %H:%M:%S.%f") # type: ignore 321 322 def _df_to_source_queries( 323 self, 324 df: DF, 325 target_columns_to_types: t.Dict[str, exp.DataType], 326 batch_size: int, 327 target_table: TableName, 328 source_columns: t.Optional[t.List[str]] = None, 329 ) -> t.List[SourceQuery]: 330 import pandas as pd 331 import numpy as np 332 333 assert isinstance(df, pd.DataFrame) 334 temp_table = self._get_temp_table(target_table or "pandas") 335 336 # Return the superclass implementation if the connection pool doesn't support bulk_copy 337 if not hasattr(self._connection_pool.get(), "bulk_copy"): 338 return super()._df_to_source_queries( 339 df, target_columns_to_types, batch_size, target_table, source_columns=source_columns 340 ) 341 342 def query_factory() -> Query: 343 # It is possible for the factory to be called multiple times and if so then the temp table will already 344 # be created so we skip creating again. This means we are assuming the first call is the same result 345 # as later calls. 346 if not self.table_exists(temp_table): 347 source_columns_to_types = get_source_columns_to_types( 348 target_columns_to_types, source_columns 349 ) 350 ordered_df = df[ 351 list(source_columns_to_types) 352 ] # reorder DataFrame so it matches columns_to_types 353 self._convert_df_datetime(ordered_df, source_columns_to_types) 354 self.create_table(temp_table, source_columns_to_types) 355 rows: t.List[t.Tuple[t.Any, ...]] = list( 356 ordered_df.replace({np.nan: None}).itertuples(index=False, name=None) # type: ignore 357 ) 358 conn = self._connection_pool.get() 359 conn.bulk_copy(temp_table.sql(dialect=self.dialect), rows) 360 return exp.select( 361 *self._casted_columns(target_columns_to_types, source_columns=source_columns) 362 ).from_(temp_table) # type: ignore 363 364 return [ 365 SourceQuery( 366 query_factory=query_factory, 367 cleanup_func=lambda: self.drop_table(temp_table), 368 ) 369 ] 370 371 def _get_data_objects( 372 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 373 ) -> t.List[DataObject]: 374 """ 375 Returns all the data objects that exist in the given schema and catalog. 376 """ 377 import pandas as pd 378 379 catalog = self.get_current_catalog() 380 query = ( 381 exp.select( 382 exp.column("TABLE_NAME").as_("name"), 383 exp.column("TABLE_SCHEMA").as_("schema_name"), 384 exp.case() 385 .when(exp.column("TABLE_TYPE").eq("BASE TABLE"), exp.Literal.string("TABLE")) 386 .else_(exp.column("TABLE_TYPE")) 387 .as_("type"), 388 ) 389 .from_(exp.table_("TABLES", db="INFORMATION_SCHEMA")) 390 .where(exp.column("TABLE_SCHEMA").eq(to_schema(schema_name).db)) 391 ) 392 if object_names: 393 query = query.where(exp.column("TABLE_NAME").isin(*object_names)) 394 dataframe: pd.DataFrame = self.fetchdf(query) 395 return [ 396 DataObject( 397 catalog=catalog, # type: ignore 398 schema=row.schema_name, # type: ignore 399 name=row.name, # type: ignore 400 type=DataObjectType.from_str(row.type), # type: ignore 401 ) 402 for row in dataframe.itertuples() 403 ] 404 405 def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str: 406 sql = super()._to_sql(expression, quote=quote, **kwargs) 407 return f"{sql};" 408 409 def _rename_table( 410 self, 411 old_table_name: TableName, 412 new_table_name: TableName, 413 ) -> None: 414 # The function that renames tables in MSSQL takes string literals as arguments instead of identifiers, 415 # so we shouldn't quote the identifiers. 416 self.execute(exp.rename_table(old_table_name, new_table_name), quote_identifiers=False) 417 418 def _insert_overwrite_by_condition( 419 self, 420 table_name: TableName, 421 source_queries: t.List[SourceQuery], 422 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 423 where: t.Optional[exp.Condition] = None, 424 insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, 425 **kwargs: t.Any, 426 ) -> None: 427 # note that this is passed as table_properties here rather than physical_properties 428 use_merge_strategy = kwargs.get("table_properties", {}).get("mssql_merge_exists") 429 if (not where or where == exp.true()) and not use_merge_strategy: 430 # this is a full table replacement, call the base strategy to do DELETE+INSERT 431 # which will result in TRUNCATE+INSERT due to how we have overridden self.delete_from() 432 return EngineAdapter._insert_overwrite_by_condition( 433 self, 434 table_name=table_name, 435 source_queries=source_queries, 436 target_columns_to_types=target_columns_to_types, 437 where=where, 438 insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, 439 **kwargs, 440 ) 441 442 # For conditional overwrites or when mssql_merge_exists is set use MERGE 443 return super()._insert_overwrite_by_condition( 444 table_name=table_name, 445 source_queries=source_queries, 446 target_columns_to_types=target_columns_to_types, 447 where=where, 448 insert_overwrite_strategy_override=insert_overwrite_strategy_override, 449 **kwargs, 450 ) 451 452 def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None: 453 if where == exp.true(): 454 # "A TRUNCATE TABLE operation can be rolled back within a transaction." 455 # ref: https://learn.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15#remarks 456 return self.execute( 457 exp.TruncateTable(expressions=[exp.to_table(table_name, dialect=self.dialect)]) 458 ) 459 460 return super().delete_from(table_name, where)
Base class wrapping a Database API compliant connection.
The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.
Arguments:
- connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
- dialect: The dialect with which this adapter is associated.
- multithreaded: Indicates whether this adapter will be used by more than one thread.
SCHEMA_DIFFER_KWARGS =
{'parameterized_type_defaults': {<DType.DECIMAL: 'DECIMAL'>: [(18, 0), (0,)], <DType.BINARY: 'BINARY'>: [(1,)], <DType.VARBINARY: 'VARBINARY'>: [(1,)], <DType.CHAR: 'CHAR'>: [(1,)], <DType.VARCHAR: 'VARCHAR'>: [(1,)], <DType.NCHAR: 'NCHAR'>: [(1,)], <DType.NVARCHAR: 'NVARCHAR'>: [(1,)], <DType.TIME: 'TIME'>: [(7,)], <DType.DATETIME2: 'DATETIME2'>: [(7,)], <DType.TIMESTAMPTZ: 'TIMESTAMPTZ'>: [(7,)]}, 'max_parameter_length': {<DType.VARBINARY: 'VARBINARY'>: 2147483647, <DType.VARCHAR: 'VARCHAR'>: 2147483647, <DType.NVARCHAR: 'NVARCHAR'>: 2147483647}}
catalog_support: sqlmesh.core.engine_adapter.shared.CatalogSupport
84 @property 85 def catalog_support(self) -> CatalogSupport: 86 # MSSQL and AzureSQL both use this engine adapter, but they differ in catalog support. 87 # Therefore, we specify the catalog support in the connection config `_extra_engine_config` 88 # instead of in the adapter itself. 89 return self._extra_config["catalog_support"]
def
columns( self, table_name: Union[str, sqlglot.expressions.query.Table], include_pseudo_columns: bool = True) -> Dict[str, sqlglot.expressions.datatypes.DataType]:
91 def columns( 92 self, 93 table_name: TableName, 94 include_pseudo_columns: bool = True, 95 ) -> t.Dict[str, exp.DataType]: 96 """MsSql doesn't support describe so we query information_schema.""" 97 98 table = exp.to_table(table_name) 99 100 sql = ( 101 exp.select( 102 "COLUMN_NAME", 103 "DATA_TYPE", 104 "CHARACTER_MAXIMUM_LENGTH", 105 "NUMERIC_PRECISION", 106 "NUMERIC_SCALE", 107 ) 108 .from_("INFORMATION_SCHEMA.COLUMNS") 109 .where(f"TABLE_NAME = '{table.name}'") 110 ) 111 database_name = table.db 112 if database_name: 113 sql = sql.where(f"TABLE_SCHEMA = '{database_name}'") 114 115 columns_raw = self.fetchall(sql, quote_identifiers=True) 116 117 def build_var_length_col( 118 column_name: str, 119 data_type: str, 120 character_maximum_length: t.Optional[int] = None, 121 numeric_precision: t.Optional[int] = None, 122 numeric_scale: t.Optional[int] = None, 123 ) -> tuple: 124 data_type = data_type.lower() 125 if ( 126 data_type in self.VARIABLE_LENGTH_DATA_TYPES 127 and character_maximum_length is not None 128 and character_maximum_length > 0 129 ): 130 return (column_name, f"{data_type}({character_maximum_length})") 131 if ( 132 data_type in ("varbinary", "varchar", "nvarchar") 133 and character_maximum_length is not None 134 and character_maximum_length == -1 135 ): 136 return (column_name, f"{data_type}(max)") 137 if data_type in ("decimal", "numeric"): 138 return (column_name, f"{data_type}({numeric_precision}, {numeric_scale})") 139 if data_type == "float": 140 return (column_name, f"{data_type}({numeric_precision})") 141 142 return (column_name, data_type) 143 144 columns = [build_var_length_col(*row) for row in columns_raw] 145 146 return { 147 column_name: exp.DataType.build(data_type, dialect=self.dialect) 148 for column_name, data_type in columns 149 }
MsSql doesn't support describe so we query information_schema.
def
table_exists(self, table_name: Union[str, sqlglot.expressions.query.Table]) -> bool:
151 def table_exists(self, table_name: TableName) -> bool: 152 """MsSql doesn't support describe so we query information_schema.""" 153 table = exp.to_table(table_name) 154 data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) 155 if data_object_cache_key in self._data_object_cache: 156 logger.debug("Table existence cache hit: %s", data_object_cache_key) 157 return self._data_object_cache[data_object_cache_key] is not None 158 159 sql = ( 160 exp.select("1") 161 .from_("INFORMATION_SCHEMA.TABLES") 162 .where(f"TABLE_NAME = '{table.alias_or_name}'") 163 ) 164 database_name = table.db 165 if database_name: 166 sql = sql.where(f"TABLE_SCHEMA = '{database_name}'") 167 168 result = self.fetchone(sql, quote_identifiers=True) 169 170 return result[0] == 1 if result else False
MsSql doesn't support describe so we query information_schema.
def
set_current_catalog(self, catalog_name: str) -> None:
172 def set_current_catalog(self, catalog_name: str) -> None: 173 self.execute(exp.Use(this=exp.to_identifier(catalog_name)))
Sets the catalog name of the current connection.
def
drop_schema( self, schema_name: Union[str, sqlglot.expressions.query.Table], ignore_if_not_exists: bool = True, cascade: bool = False, **drop_args: Dict[str, sqlglot.expressions.core.Expr]) -> None:
175 def drop_schema( 176 self, 177 schema_name: SchemaName, 178 ignore_if_not_exists: bool = True, 179 cascade: bool = False, 180 **drop_args: t.Dict[str, exp.Expr], 181 ) -> None: 182 """ 183 MsSql doesn't support CASCADE clause and drops schemas unconditionally. 184 """ 185 if cascade: 186 objects = self._get_data_objects(schema_name) 187 for obj in objects: 188 # Build properly quoted table for MSSQL using square brackets when needed 189 object_table = exp.table_(obj.name, obj.schema_name) 190 191 # _get_data_objects is catalog-specific, so these can't accidentally drop view/tables in another catalog 192 if obj.type == DataObjectType.VIEW: 193 self.drop_view( 194 object_table, 195 ignore_if_not_exists=ignore_if_not_exists, 196 ) 197 else: 198 self.drop_table( 199 object_table, 200 exists=ignore_if_not_exists, 201 ) 202 super().drop_schema(schema_name, ignore_if_not_exists=ignore_if_not_exists, cascade=False)
MsSql doesn't support CASCADE clause and drops schemas unconditionally.
def
merge( self, target_table: Union[str, sqlglot.expressions.query.Table], source_table: <MagicMock id='132726890437776'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]], unique_key: Sequence[sqlglot.expressions.core.Expr], when_matched: Optional[sqlglot.expressions.dml.Whens] = None, merge_filter: Optional[sqlglot.expressions.core.Expr] = None, source_columns: Optional[List[str]] = None, **kwargs: Any) -> None:
204 def merge( 205 self, 206 target_table: TableName, 207 source_table: QueryOrDF, 208 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], 209 unique_key: t.Sequence[exp.Expr], 210 when_matched: t.Optional[exp.Whens] = None, 211 merge_filter: t.Optional[exp.Expr] = None, 212 source_columns: t.Optional[t.List[str]] = None, 213 **kwargs: t.Any, 214 ) -> None: 215 mssql_merge_exists = kwargs.get("physical_properties", {}).get("mssql_merge_exists") 216 217 source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( 218 source_table, 219 target_columns_to_types, 220 target_table=target_table, 221 source_columns=source_columns, 222 ) 223 target_columns_to_types = target_columns_to_types or self.columns(target_table) 224 on = exp.and_( 225 *( 226 add_table(part, MERGE_TARGET_ALIAS).eq(add_table(part, MERGE_SOURCE_ALIAS)) 227 for part in unique_key 228 ) 229 ) 230 if merge_filter: 231 on = exp.and_(merge_filter, on) 232 233 match_expressions = [] 234 if not when_matched: 235 unique_key_names = [y.name for y in unique_key] 236 columns_to_types_no_keys = [ 237 c for c in target_columns_to_types if c not in unique_key_names 238 ] 239 240 target_columns_no_keys = [ 241 exp.column(c, MERGE_TARGET_ALIAS) for c in columns_to_types_no_keys 242 ] 243 source_columns_no_keys = [ 244 exp.column(c, MERGE_SOURCE_ALIAS) for c in columns_to_types_no_keys 245 ] 246 247 match_condition = ( 248 exp.Exists( 249 this=exp.select(*target_columns_no_keys).except_( 250 exp.select(*source_columns_no_keys) 251 ) 252 ) 253 if mssql_merge_exists 254 else None 255 ) 256 257 if target_columns_no_keys: 258 match_expressions.append( 259 exp.When( 260 matched=True, 261 source=False, 262 condition=match_condition, 263 then=exp.Update( 264 expressions=[ 265 exp.column(col, MERGE_TARGET_ALIAS).eq( 266 exp.column(col, MERGE_SOURCE_ALIAS) 267 ) 268 for col in columns_to_types_no_keys 269 ], 270 ), 271 ) 272 ) 273 else: 274 match_expressions.extend(when_matched.copy().expressions) 275 276 match_expressions.append( 277 exp.When( 278 matched=False, 279 source=False, 280 then=exp.Insert( 281 this=exp.Tuple( 282 expressions=[exp.column(col) for col in target_columns_to_types] 283 ), 284 expression=exp.Tuple( 285 expressions=[ 286 exp.column(col, MERGE_SOURCE_ALIAS) for col in target_columns_to_types 287 ] 288 ), 289 ), 290 ) 291 ) 292 for source_query in source_queries: 293 with source_query as query: 294 self._merge( 295 target_table=target_table, 296 query=query, 297 on=on, 298 whens=exp.Whens(expressions=match_expressions), 299 )
def
delete_from( self, table_name: Union[str, sqlglot.expressions.query.Table], where: Union[str, sqlglot.expressions.core.Expr]) -> None:
452 def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None: 453 if where == exp.true(): 454 # "A TRUNCATE TABLE operation can be rolled back within a transaction." 455 # ref: https://learn.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15#remarks 456 return self.execute( 457 exp.TruncateTable(expressions=[exp.to_table(table_name, dialect=self.dialect)]) 458 ) 459 460 return super().delete_from(table_name, where)
Inherited Members
- sqlmesh.core.engine_adapter.base.EngineAdapter
- EngineAdapter
- DEFAULT_BATCH_SIZE
- DATA_OBJECT_FILTER_BATCH_SIZE
- SUPPORTS_TRANSACTIONS
- MAX_TABLE_COMMENT_LENGTH
- MAX_COLUMN_COMMENT_LENGTH
- SUPPORTS_MATERIALIZED_VIEW_SCHEMA
- SUPPORTS_VIEW_SCHEMA
- SUPPORTS_CLONING
- SUPPORTS_MANAGED_MODELS
- SUPPORTS_CREATE_DROP_CATALOG
- SUPPORTED_DROP_CASCADE_OBJECT_KINDS
- HAS_VIEW_BINDING
- SUPPORTS_GRANTS
- DEFAULT_CATALOG_TYPE
- QUOTE_IDENTIFIERS_IN_VIEWS
- ATTACH_CORRELATION_ID
- SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
- dialect
- correlation_id
- with_settings
- cursor
- connection
- spark
- snowpark
- bigframe
- comments_enabled
- schema_differ
- default_catalog
- engine_run_mode
- recycle
- close
- get_catalog_type
- get_catalog_type_from_table
- current_catalog_type
- replace_query
- create_index
- create_table
- create_managed_table
- ctas
- create_state_table
- create_table_like
- clone_table
- drop_data_object
- drop_table
- drop_managed_table
- get_alter_operations
- alter_table
- create_view
- create_schema
- drop_view
- create_catalog
- drop_catalog
- insert_append
- insert_overwrite_by_partition
- insert_overwrite_by_time_partition
- update_table
- scd_type_2_by_time
- scd_type_2_by_column
- rename_table
- get_data_object
- get_data_objects
- fetchone
- fetchall
- fetchdf
- fetch_pyspark_df
- wap_enabled
- wap_supported
- wap_table_name
- wap_prepare
- wap_publish
- sync_grants_config
- transaction
- session
- execute
- temp_table
- drop_data_object_on_type_mismatch
- ensure_nulls_for_unmatched_after_join
- use_server_nulls_for_unmatched_after_join
- ping
- get_table_last_modified_ts