sqlmesh.core.config.connection
1from __future__ import annotations 2 3import abc 4import base64 5import logging 6import os 7import importlib 8import pathlib 9import re 10import typing as t 11from enum import Enum 12from functools import partial 13 14import pydantic 15from pydantic import Field 16from pydantic_core import from_json 17from packaging import version 18from sqlglot import exp 19from sqlglot.helper import subclasses 20from sqlglot.errors import ParseError 21 22from sqlmesh.core import engine_adapter 23from sqlmesh.core.config.base import BaseConfig 24from sqlmesh.core.config.common import ( 25 concurrent_tasks_validator, 26 http_headers_validator, 27 compile_regex_mapping, 28) 29from sqlmesh.core.engine_adapter.shared import CatalogSupport 30from sqlmesh.core.engine_adapter import EngineAdapter 31from sqlmesh.utils import debug_mode_enabled, str_to_bool 32from sqlmesh.utils.errors import ConfigError 33from sqlmesh.utils.pydantic import ( 34 ValidationInfo, 35 field_validator, 36 model_validator, 37 validation_error_message, 38 get_concrete_types_from_typehint, 39) 40from sqlmesh.utils.aws import validate_s3_uri 41 42if t.TYPE_CHECKING: 43 from sqlmesh.core._typing import Self 44 45logger = logging.getLogger(__name__) 46 47RECOMMENDED_STATE_SYNC_ENGINES = { 48 "postgres", 49 "gcp_postgres", 50 "mysql", 51 "mssql", 52 "azuresql", 53} 54FORBIDDEN_STATE_SYNC_ENGINES = { 55 # Do not support row-level operations 56 "spark", 57 "trino", 58 # Nullable types are problematic 59 "clickhouse", 60} 61MOTHERDUCK_TOKEN_REGEX = re.compile(r"(\?|\&)(motherduck_token=)(\S*)") 62PASSWORD_REGEX = re.compile(r"(password=)(\S+)") 63 64 65def _get_engine_import_validator( 66 import_name: str, engine_type: str, extra_name: t.Optional[str] = None, decorate: bool = True 67) -> t.Callable: 68 extra_name = extra_name or engine_type 69 70 def validate(cls: t.Any, data: t.Any) -> t.Any: 71 check_import = ( 72 str_to_bool(str(data.pop("check_import", True))) if isinstance(data, dict) else True 73 ) 74 if not check_import: 75 return data 76 try: 77 importlib.import_module(import_name) 78 except ImportError: 79 if debug_mode_enabled(): 80 raise 81 82 logger.exception("Failed to import the engine library") 83 84 raise ConfigError( 85 f"Failed to import the '{engine_type}' engine library. This may be due to a missing " 86 "or incompatible installation. Please ensure the required dependency is installed by " 87 f'running: `pip install "sqlmesh[{extra_name}]"`. For more details, check the logs ' 88 "in the 'logs/' folder, or rerun the command with the '--debug' flag." 89 ) 90 91 return data 92 93 return model_validator(mode="before")(validate) if decorate else validate 94 95 96class ConnectionConfig(abc.ABC, BaseConfig): 97 type_: str 98 DIALECT: t.ClassVar[str] 99 DISPLAY_NAME: t.ClassVar[str] 100 DISPLAY_ORDER: t.ClassVar[int] 101 concurrent_tasks: int 102 register_comments: bool 103 pre_ping: bool 104 pretty_sql: bool = False 105 schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None 106 catalog_type_overrides: t.Optional[t.Dict[str, str]] = None 107 108 # Whether to share a single connection across threads or create a new connection per thread. 109 shared_connection: t.ClassVar[bool] = False 110 111 @property 112 @abc.abstractmethod 113 def _connection_kwargs_keys(self) -> t.Set[str]: 114 """keywords that should be passed into the connection""" 115 116 @property 117 @abc.abstractmethod 118 def _engine_adapter(self) -> t.Type[EngineAdapter]: 119 """The engine adapter for this connection""" 120 121 @property 122 @abc.abstractmethod 123 def _connection_factory(self) -> t.Callable: 124 """A function that is called to return a connection object for the given Engine Adapter""" 125 126 @property 127 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 128 """The static connection kwargs for this connection""" 129 return {} 130 131 @property 132 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 133 """kwargs that are for execution config only""" 134 return {} 135 136 @property 137 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 138 """A function that is called to initialize the cursor""" 139 return None 140 141 @property 142 def is_recommended_for_state_sync(self) -> bool: 143 """Whether this engine is recommended for being used as a state sync for production state syncs""" 144 return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES 145 146 @property 147 def is_forbidden_for_state_sync(self) -> bool: 148 """Whether this engine is forbidden from being used as a state sync""" 149 return self.type_ in FORBIDDEN_STATE_SYNC_ENGINES 150 151 @property 152 def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]: 153 """A function that is called to return a connection object for the given Engine Adapter""" 154 return partial( 155 self._connection_factory, 156 **{ 157 **self._static_connection_kwargs, 158 **{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys}, 159 }, 160 ) 161 162 def connection_validator(self) -> t.Callable[[], None]: 163 """A function that validates the connection configuration""" 164 return self.create_engine_adapter().ping 165 166 def create_engine_adapter( 167 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 168 ) -> EngineAdapter: 169 """Returns a new instance of the Engine Adapter.""" 170 171 concurrent_tasks = concurrent_tasks or self.concurrent_tasks 172 return self._engine_adapter( 173 self._connection_factory_with_kwargs, 174 multithreaded=concurrent_tasks > 1, 175 default_catalog=self.get_catalog(), 176 cursor_init=self._cursor_init, 177 register_comments=register_comments_override or self.register_comments, 178 pre_ping=self.pre_ping, 179 pretty_sql=self.pretty_sql, 180 shared_connection=self.shared_connection, 181 schema_differ_overrides=self.schema_differ_overrides, 182 catalog_type_overrides=self.catalog_type_overrides, 183 **self._extra_engine_config, 184 ) 185 186 def get_catalog(self) -> t.Optional[str]: 187 """The catalog for this connection""" 188 if hasattr(self, "catalog"): 189 return self.catalog 190 if hasattr(self, "database"): 191 return self.database 192 if hasattr(self, "db"): 193 return self.db 194 return None 195 196 @model_validator(mode="before") 197 @classmethod 198 def _expand_json_strings_to_concrete_types(cls, data: t.Any) -> t.Any: 199 """ 200 There are situations where a connection config class has a field that is some kind of complex type 201 (eg a list of strings or a dict) but the value is being supplied from a source such as an environment variable 202 203 When this happens, the value is supplied as a string rather than a Python object. We need some way 204 of turning this string into the corresponding Python list or dict. 205 206 Rather than doing this piecemeal on every config subclass, this provides a generic implementatation 207 to identify fields that may be be supplied as JSON strings and handle them transparently 208 """ 209 if data and isinstance(data, dict): 210 for maybe_json_field_name in cls._get_list_and_dict_field_names(): 211 if (value := data.get(maybe_json_field_name)) and isinstance(value, str): 212 # crude JSON check as we dont want to try and parse every string we get 213 value = value.strip() 214 if value.startswith("{") or value.startswith("["): 215 data[maybe_json_field_name] = from_json(value) 216 217 return data 218 219 @classmethod 220 def _get_list_and_dict_field_names(cls) -> t.Set[str]: 221 field_names = set() 222 for name, field in cls.model_fields.items(): 223 if field.annotation: 224 field_types = get_concrete_types_from_typehint(field.annotation) 225 226 # check if the field type is something that could concievably be supplied as a json string 227 if any(ft is t for t in (list, tuple, set, dict) for ft in field_types): 228 field_names.add(name) 229 230 return field_names 231 232 233class DuckDBAttachOptions(BaseConfig): 234 type: str 235 path: str 236 read_only: bool = False 237 238 # DuckLake specific options 239 data_path: t.Optional[str] = None 240 encrypted: bool = False 241 data_inlining_row_limit: t.Optional[int] = None 242 metadata_schema: t.Optional[str] = None 243 244 def to_sql(self, alias: str) -> str: 245 options = [] 246 # 'duckdb' is actually not a supported type, but we'd like to allow it for 247 # fully qualified attach options or integration testing, similar to duckdb-dbt 248 if self.type not in ("duckdb", "ducklake", "motherduck"): 249 options.append(f"TYPE {self.type.upper()}") 250 if self.read_only: 251 options.append("READ_ONLY") 252 253 # DuckLake specific options 254 path = self.path 255 if self.type == "ducklake": 256 if not path.startswith("ducklake:"): 257 path = f"ducklake:{path}" 258 if self.data_path is not None: 259 options.append(f"DATA_PATH '{self.data_path}'") 260 if self.encrypted: 261 options.append("ENCRYPTED") 262 if self.data_inlining_row_limit is not None: 263 options.append(f"DATA_INLINING_ROW_LIMIT {self.data_inlining_row_limit}") 264 if self.metadata_schema is not None: 265 options.append(f"METADATA_SCHEMA '{self.metadata_schema}'") 266 267 options_sql = f" ({', '.join(options)})" if options else "" 268 alias_sql = "" 269 # TODO: Add support for Postgres schema. Currently adding it blocks access to the information_schema 270 271 # MotherDuck does not support aliasing 272 alias_sql = ( 273 f" AS {alias}" if not (self.type == "motherduck" or self.path.startswith("md:")) else "" 274 ) 275 return f"ATTACH IF NOT EXISTS '{path}'{alias_sql}{options_sql}" 276 277 278class BaseDuckDBConnectionConfig(ConnectionConfig): 279 """Common configuration for the DuckDB-based connections. 280 281 Args: 282 database: The optional database name. If not specified, the in-memory database will be used. 283 catalogs: Key is the name of the catalog and value is the path. 284 extensions: A list of autoloadable extensions to load. 285 connector_config: A dictionary of configuration to pass into the duckdb connector. 286 secrets: A list of dictionaries used to generate DuckDB secrets for authenticating with external services (e.g. S3). 287 filesystems: A list of dictionaries used to register `fsspec` filesystems to the DuckDB cursor. 288 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 289 register_comments: Whether or not to register model comments with the SQL engine. 290 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 291 token: The optional MotherDuck token. If not specified and a MotherDuck path is in the catalog, the user will be prompted to login with their web browser. 292 """ 293 294 database: t.Optional[str] = None 295 catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None 296 extensions: t.List[t.Union[str, t.Dict[str, t.Any]]] = [] 297 connector_config: t.Dict[str, t.Any] = {} 298 secrets: t.Union[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Dict[str, t.Any]]] = [] 299 filesystems: t.List[t.Dict[str, t.Any]] = [] 300 301 concurrent_tasks: int = 1 302 register_comments: bool = True 303 pre_ping: t.Literal[False] = False 304 305 token: t.Optional[str] = None 306 307 shared_connection: t.ClassVar[bool] = True 308 309 _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} 310 311 @model_validator(mode="before") 312 def _validate_database_catalogs(cls, data: t.Any) -> t.Any: 313 if not isinstance(data, dict): 314 return data 315 316 db_path = data.get("database") 317 if db_path and data.get("catalogs"): 318 raise ConfigError( 319 "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog" 320 ) 321 if isinstance(db_path, str) and db_path.startswith("md:"): 322 raise ConfigError( 323 "Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`." 324 ) 325 326 return data 327 328 @property 329 def _engine_adapter(self) -> t.Type[EngineAdapter]: 330 return engine_adapter.DuckDBEngineAdapter 331 332 @property 333 def _connection_kwargs_keys(self) -> t.Set[str]: 334 return {"database"} 335 336 @property 337 def _connection_factory(self) -> t.Callable: 338 import duckdb 339 340 return duckdb.connect 341 342 @property 343 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 344 """A function that is called to initialize the cursor""" 345 import duckdb 346 from duckdb import BinderException 347 348 def init(cursor: duckdb.DuckDBPyConnection) -> None: 349 for extension in self.extensions: 350 extension = extension if isinstance(extension, dict) else {"name": extension} 351 352 install_command = f"INSTALL {extension['name']}" 353 354 if extension.get("repository"): 355 install_command = f"{install_command} FROM {extension['repository']}" 356 357 if extension.get("force_install"): 358 install_command = f"FORCE {install_command}" 359 360 try: 361 cursor.execute(install_command) 362 cursor.execute(f"LOAD {extension['name']}") 363 except Exception as e: 364 raise ConfigError(f"Failed to load extension {extension['name']}: {e}") 365 366 if self.connector_config: 367 option_names = list(self.connector_config) 368 in_part = ",".join("?" for _ in range(len(option_names))) 369 370 cursor.execute( 371 f"SELECT name, value FROM duckdb_settings() WHERE name IN ({in_part})", 372 option_names, 373 ) 374 375 existing_values = {field: setting for field, setting in cursor.fetchall()} 376 377 # only set connector_config items if the values differ from what is already set 378 # trying to set options like 'temp_directory' even to the same value can throw errors like: 379 # Not implemented Error: Cannot switch temporary directory after the current one has been used 380 for field, setting in self.connector_config.items(): 381 if existing_values.get(field) != setting: 382 try: 383 cursor.execute(f"SET {field} = '{setting}'") 384 except Exception as e: 385 raise ConfigError( 386 f"Failed to set connector config {field} to {setting}: {e}" 387 ) 388 389 if self.secrets: 390 duckdb_version = duckdb.__version__ 391 if version.parse(duckdb_version) < version.parse("0.10.0"): 392 from sqlmesh.core.console import get_console 393 394 get_console().log_warning( 395 f"DuckDB version {duckdb_version} does not support secrets-based authentication (requires 0.10.0 or later).\n" 396 "To use secrets, please upgrade DuckDB. For older versions, configure legacy authentication via `connector_config`.\n" 397 "More info: https://duckdb.org/docs/stable/extensions/httpfs/s3api_legacy_authentication.html" 398 ) 399 else: 400 if isinstance(self.secrets, list): 401 secrets_items = [(secret_dict, "") for secret_dict in self.secrets] 402 else: 403 secrets_items = [ 404 (secret_dict, secret_name) 405 for secret_name, secret_dict in self.secrets.items() 406 ] 407 408 for secret_dict, secret_name in secrets_items: 409 secret_settings: t.List[str] = [] 410 for field, setting in secret_dict.items(): 411 secret_settings.append(f"{field} '{setting}'") 412 if secret_settings: 413 secret_clause = ", ".join(secret_settings) 414 try: 415 cursor.execute( 416 f"CREATE OR REPLACE SECRET {secret_name} ({secret_clause});" 417 ) 418 except Exception as e: 419 raise ConfigError(f"Failed to create secret: {e}") 420 421 if self.filesystems: 422 from fsspec import filesystem # type: ignore 423 424 for file_system in self.filesystems: 425 options = file_system.copy() 426 fs = options.pop("fs") 427 fs = filesystem(fs, **options) 428 cursor.register_filesystem(fs) 429 430 for i, (alias, path_options) in enumerate( 431 (getattr(self, "catalogs", None) or {}).items() 432 ): 433 # we parse_identifier and generate to ensure that `alias` has exactly one set of quotes 434 # regardless of whether it comes in quoted or not 435 alias = exp.parse_identifier(alias, dialect="duckdb").sql( 436 identify=True, dialect="duckdb" 437 ) 438 try: 439 if isinstance(path_options, DuckDBAttachOptions): 440 query = path_options.to_sql(alias) 441 else: 442 query = f"ATTACH IF NOT EXISTS '{path_options}'" 443 if not path_options.startswith("md:"): 444 query += f" AS {alias}" 445 cursor.execute(query) 446 except BinderException as e: 447 # If a user tries to create a catalog pointing at `:memory:` and with the name `memory` 448 # then we don't want to raise since this happens by default. They are just doing this to 449 # set it as the default catalog. 450 # If a user tried to attach a MotherDuck database/share which has already by attached via 451 # `ATTACH 'md:'`, then we don't want to raise since this is expected. 452 if ( 453 not ( 454 'database with name "memory" already exists' in str(e) 455 and path_options == ":memory:" 456 ) 457 and f"""database with name "{path_options.path.replace("md:", "")}" already exists""" 458 not in str(e) 459 ): 460 raise e 461 if i == 0 and not getattr(self, "database", None): 462 cursor.execute(f"USE {alias}") 463 464 return init 465 466 def create_engine_adapter( 467 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 468 ) -> EngineAdapter: 469 """Checks if another engine adapter has already been created that shares a catalog that points to the same data 470 file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration 471 associated with the new adapter will be ignored.""" 472 data_files = set((self.catalogs or {}).values()) 473 if self.database: 474 if isinstance(self, MotherDuckConnectionConfig): 475 data_files.add( 476 f"md:{self.database}" 477 + (f"?motherduck_token={self.token}" if self.token else "") 478 ) 479 else: 480 data_files.add(self.database) 481 data_files.discard(":memory:") 482 for data_file in data_files: 483 key = data_file if isinstance(data_file, str) else data_file.path 484 adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key) 485 if adapter is not None: 486 logger.info( 487 f"Using existing DuckDB adapter due to overlapping data file: {self._mask_sensitive_data(key)}" 488 ) 489 return adapter 490 491 if data_files: 492 masked_files = { 493 self._mask_sensitive_data(file if isinstance(file, str) else file.path) 494 for file in data_files 495 } 496 logger.info(f"Creating new DuckDB adapter for data files: {masked_files}") 497 else: 498 logger.info("Creating new DuckDB adapter for in-memory database") 499 adapter = super().create_engine_adapter( 500 register_comments_override, concurrent_tasks=concurrent_tasks 501 ) 502 for data_file in data_files: 503 key = data_file if isinstance(data_file, str) else data_file.path 504 BaseDuckDBConnectionConfig._data_file_to_adapter[key] = adapter 505 return adapter 506 507 def get_catalog(self) -> t.Optional[str]: 508 if self.database: 509 # Remove `:` from the database name in order to handle if `:memory:` is passed in 510 return pathlib.Path(self.database.replace(":memory:", "memory")).stem 511 if self.catalogs: 512 return list(self.catalogs)[0] 513 return None 514 515 def _mask_sensitive_data(self, string: str) -> str: 516 # Mask MotherDuck tokens with fixed number of asterisks 517 result = MOTHERDUCK_TOKEN_REGEX.sub( 518 lambda m: f"{m.group(1)}{m.group(2)}{'*' * 8 if m.group(3) else ''}", string 519 ) 520 # Mask PostgreSQL/MySQL passwords with fixed number of asterisks 521 result = PASSWORD_REGEX.sub(lambda m: f"{m.group(1)}{'*' * 8}", result) 522 return result 523 524 525class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig): 526 """Configuration for the MotherDuck connection.""" 527 528 type_: t.Literal["motherduck"] = Field(alias="type", default="motherduck") 529 DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb" 530 DISPLAY_NAME: t.ClassVar[t.Literal["MotherDuck"]] = "MotherDuck" 531 DISPLAY_ORDER: t.ClassVar[t.Literal[5]] = 5 532 533 @property 534 def _connection_kwargs_keys(self) -> t.Set[str]: 535 return set() 536 537 @property 538 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 539 """kwargs that are for execution config only""" 540 from sqlmesh import __version__ 541 542 custom_user_agent_config = {"custom_user_agent": f"SQLMesh/{__version__}"} 543 connection_str = "md:" 544 if self.database: 545 # Attach single MD database instead of all databases on the account 546 connection_str += f"{self.database}?attach_mode=single" 547 if self.token: 548 connection_str += f"{'&' if self.database else '?'}motherduck_token={self.token}" 549 return {"database": connection_str, "config": custom_user_agent_config} 550 551 @property 552 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 553 return {"is_motherduck": True} 554 555 556class DuckDBConnectionConfig(BaseDuckDBConnectionConfig): 557 """Configuration for the DuckDB connection.""" 558 559 type_: t.Literal["duckdb"] = Field(alias="type", default="duckdb") 560 DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb" 561 DISPLAY_NAME: t.ClassVar[t.Literal["DuckDB"]] = "DuckDB" 562 DISPLAY_ORDER: t.ClassVar[t.Literal[1]] = 1 563 564 565class SnowflakeConnectionConfig(ConnectionConfig): 566 """Configuration for the Snowflake connection. 567 568 Args: 569 account: The Snowflake account name. 570 user: The Snowflake username. 571 password: The Snowflake password. 572 warehouse: The optional warehouse name. 573 database: The optional database name. 574 role: The optional role name. 575 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 576 authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake"). 577 Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183 578 token: The optional oauth access token to use for authentication when authenticator is set to "oauth". 579 private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation 580 private_key_path: The optional path to the private key to use for authentication. This would be used instead of `private_key`. 581 private_key_passphrase: The optional passphrase to use to decrypt `private_key` or `private_key_path`. Keys can be created without encryption so only provide this if needed. 582 register_comments: Whether or not to register model comments with the SQL engine. 583 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 584 session_parameters: The optional session parameters to set for the connection. 585 host: Host address for the connection. 586 port: Port for the connection. 587 """ 588 589 account: str 590 user: t.Optional[str] = None 591 password: t.Optional[str] = None 592 warehouse: t.Optional[str] = None 593 database: t.Optional[str] = None 594 role: t.Optional[str] = None 595 authenticator: t.Optional[str] = None 596 token: t.Optional[str] = None 597 host: t.Optional[str] = None 598 port: t.Optional[int] = None 599 application: t.Literal["Tobiko_SQLMesh"] = "Tobiko_SQLMesh" 600 601 # Private Key Auth 602 private_key: t.Optional[t.Union[str, bytes]] = None 603 private_key_path: t.Optional[str] = None 604 private_key_passphrase: t.Optional[str] = None 605 606 concurrent_tasks: int = 4 607 register_comments: bool = True 608 pre_ping: bool = False 609 610 session_parameters: t.Optional[dict] = None 611 612 type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake") 613 DIALECT: t.ClassVar[t.Literal["snowflake"]] = "snowflake" 614 DISPLAY_NAME: t.ClassVar[t.Literal["Snowflake"]] = "Snowflake" 615 DISPLAY_ORDER: t.ClassVar[t.Literal[2]] = 2 616 617 _concurrent_tasks_validator = concurrent_tasks_validator 618 619 @model_validator(mode="before") 620 def _validate_authenticator(cls, data: t.Any) -> t.Any: 621 if not isinstance(data, dict): 622 return data 623 624 from snowflake.connector.network import DEFAULT_AUTHENTICATOR, OAUTH_AUTHENTICATOR 625 626 auth = data.get("authenticator") 627 auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR 628 user = data.get("user") 629 password = data.get("password") 630 data["private_key"] = cls._get_private_key(data, auth) # type: ignore 631 632 if ( 633 auth == DEFAULT_AUTHENTICATOR 634 and not data.get("private_key") 635 and (not user or not password) 636 ): 637 raise ConfigError("User and password must be provided if using default authentication") 638 639 if auth == OAUTH_AUTHENTICATOR and not data.get("token"): 640 raise ConfigError("Token must be provided if using oauth authentication") 641 642 return data 643 644 _engine_import_validator = _get_engine_import_validator( 645 "snowflake.connector.network", "snowflake" 646 ) 647 648 @classmethod 649 def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]: 650 """ 651 source: https://github.com/dbt-labs/dbt-snowflake/blob/0374b4ec948982f2ac8ec0c95d53d672ad19e09c/dbt/adapters/snowflake/connections.py#L247C5-L285C1 652 653 Overall code change: Use local variables instead of class attributes + Validation 654 """ 655 # Start custom code 656 from cryptography.hazmat.backends import default_backend 657 from cryptography.hazmat.primitives import serialization 658 from snowflake.connector.network import ( 659 DEFAULT_AUTHENTICATOR, 660 KEY_PAIR_AUTHENTICATOR, 661 ) 662 663 private_key = values.get("private_key") 664 private_key_path = values.get("private_key_path") 665 private_key_passphrase = values.get("private_key_passphrase") 666 user = values.get("user") 667 password = values.get("password") 668 auth = auth if auth and auth != DEFAULT_AUTHENTICATOR else KEY_PAIR_AUTHENTICATOR 669 670 if not private_key and not private_key_path: 671 return None 672 if private_key and private_key_path: 673 raise ConfigError("Cannot specify both `private_key` and `private_key_path`") 674 if auth != KEY_PAIR_AUTHENTICATOR: 675 raise ConfigError( 676 f"Private key or private key path can only be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 677 ) 678 if not user: 679 raise ConfigError( 680 f"User must be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 681 ) 682 if password: 683 raise ConfigError( 684 f"Password cannot be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 685 ) 686 687 if isinstance(private_key, bytes): 688 return private_key 689 # End Custom Code 690 691 if private_key_passphrase: 692 encoded_passphrase = private_key_passphrase.encode() 693 else: 694 encoded_passphrase = None 695 696 if private_key: 697 if private_key.startswith("-"): 698 p_key = serialization.load_pem_private_key( 699 data=bytes(private_key, "utf-8"), 700 password=encoded_passphrase, 701 backend=default_backend(), 702 ) 703 704 else: 705 p_key = serialization.load_der_private_key( 706 data=base64.b64decode(private_key), 707 password=encoded_passphrase, 708 backend=default_backend(), 709 ) 710 711 elif private_key_path: 712 with open(private_key_path, "rb") as key: 713 p_key = serialization.load_pem_private_key( 714 key.read(), password=encoded_passphrase, backend=default_backend() 715 ) 716 else: 717 return None 718 719 return p_key.private_bytes( 720 encoding=serialization.Encoding.DER, 721 format=serialization.PrivateFormat.PKCS8, 722 encryption_algorithm=serialization.NoEncryption(), 723 ) 724 725 @property 726 def _connection_kwargs_keys(self) -> t.Set[str]: 727 return { 728 "user", 729 "password", 730 "account", 731 "warehouse", 732 "database", 733 "role", 734 "authenticator", 735 "token", 736 "private_key", 737 "session_parameters", 738 "application", 739 "host", 740 "port", 741 } 742 743 @property 744 def _engine_adapter(self) -> t.Type[EngineAdapter]: 745 return engine_adapter.SnowflakeEngineAdapter 746 747 @property 748 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 749 return {"autocommit": False} 750 751 @property 752 def _connection_factory(self) -> t.Callable: 753 from snowflake import connector 754 755 return connector.connect 756 757 758class DatabricksConnectionConfig(ConnectionConfig): 759 """ 760 Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations 761 762 Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39 763 OAuth ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication 764 765 Args: 766 server_hostname: Databricks instance host name. 767 http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) 768 or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) 769 access_token: Http Bearer access token, e.g. Databricks Personal Access Token. 770 auth_type: Set to 'databricks-oauth' or 'azure-oauth' to trigger OAuth (or dont set at all to use `access_token`) 771 oauth_client_id: Client ID to use when auth_type is set to one of the 'oauth' types 772 oauth_client_secret: Client Secret to use when auth_type is set to one of the 'oauth' types 773 catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in 774 the Databricks cluster (most likely `hive_metastore`). 775 http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request 776 session_configuration: An optional dictionary of Spark session parameters. 777 Execute the SQL command `SET -v` to get a full list of available commands. 778 databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect. 779 Defaults to the `server_hostname` value. 780 databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect. 781 Defaults to the `access_token` value. 782 databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect. 783 Defaults to deriving the cluster id from the `http_path` value. 784 force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector. 785 disable_databricks_connect: Even if databricks connect is installed, do not use it. 786 disable_spark_session: Do not use SparkSession if it is available (like when running in a notebook). 787 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 788 """ 789 790 server_hostname: t.Optional[str] = None 791 http_path: t.Optional[str] = None 792 access_token: t.Optional[str] = None 793 auth_type: t.Optional[str] = None 794 oauth_client_id: t.Optional[str] = None 795 oauth_client_secret: t.Optional[str] = None 796 catalog: t.Optional[str] = None 797 http_headers: t.Optional[t.List[t.Tuple[str, str]]] = None 798 session_configuration: t.Optional[t.Dict[str, t.Any]] = None 799 databricks_connect_server_hostname: t.Optional[str] = None 800 databricks_connect_access_token: t.Optional[str] = None 801 databricks_connect_cluster_id: t.Optional[str] = None 802 databricks_connect_use_serverless: bool = False 803 force_databricks_connect: bool = False 804 disable_databricks_connect: bool = False 805 disable_spark_session: bool = False 806 807 concurrent_tasks: int = 1 808 register_comments: bool = True 809 pre_ping: t.Literal[False] = False 810 811 type_: t.Literal["databricks"] = Field(alias="type", default="databricks") 812 DIALECT: t.ClassVar[t.Literal["databricks"]] = "databricks" 813 DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks" 814 DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3 815 816 _concurrent_tasks_validator = concurrent_tasks_validator 817 _http_headers_validator = http_headers_validator 818 819 @model_validator(mode="before") 820 def _databricks_connect_validator(cls, data: t.Any) -> t.Any: 821 # SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block. 822 # Disabling this allows SQLMesh to determine what should be shown to the user. 823 # Ex: We describe a table to see if it exists and therefore that execution can fail but we don't need to show 824 # the user since it is expected if the table doesn't exist. Without this change the user would see the error. 825 logging.getLogger("SQLQueryContextLogger").setLevel(logging.CRITICAL) 826 827 if not isinstance(data, dict): 828 return data 829 830 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 831 832 if DatabricksEngineAdapter.can_access_spark_session( 833 bool(data.get("disable_spark_session")) 834 ): 835 return data 836 837 databricks_connect_use_serverless = data.get("databricks_connect_use_serverless") 838 server_hostname, http_path, access_token, auth_type = ( 839 data.get("server_hostname"), 840 data.get("http_path"), 841 data.get("access_token"), 842 data.get("auth_type"), 843 ) 844 845 if (not server_hostname or not http_path or not access_token) and ( 846 not databricks_connect_use_serverless and not auth_type 847 ): 848 raise ValueError( 849 "`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook" 850 ) 851 if ( 852 databricks_connect_use_serverless 853 and not server_hostname 854 and not data.get("databricks_connect_server_hostname") 855 ): 856 raise ValueError( 857 "`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set" 858 ) 859 if DatabricksEngineAdapter.can_access_databricks_connect( 860 bool(data.get("disable_databricks_connect")) 861 ): 862 if not data.get("databricks_connect_access_token"): 863 data["databricks_connect_access_token"] = access_token 864 if not data.get("databricks_connect_server_hostname"): 865 data["databricks_connect_server_hostname"] = f"https://{server_hostname}" 866 if not databricks_connect_use_serverless and not data.get( 867 "databricks_connect_cluster_id" 868 ): 869 if t.TYPE_CHECKING: 870 assert http_path is not None 871 data["databricks_connect_cluster_id"] = http_path.split("/")[-1] 872 873 if auth_type: 874 from databricks.sql.auth.auth import AuthType 875 876 all_data = [m.value for m in AuthType] 877 if auth_type not in all_data: 878 raise ValueError( 879 f"`auth_type` {auth_type} does not match a valid option: {all_data}" 880 ) 881 882 client_id = data.get("oauth_client_id") 883 client_secret = data.get("oauth_client_secret") 884 885 if client_secret and not client_id: 886 raise ValueError( 887 "`oauth_client_id` is required when `oauth_client_secret` is specified" 888 ) 889 890 if not http_path: 891 raise ValueError("`http_path` is still required when using `auth_type`") 892 893 return data 894 895 _engine_import_validator = _get_engine_import_validator("databricks", "databricks") 896 897 @property 898 def _connection_kwargs_keys(self) -> t.Set[str]: 899 if self.use_spark_session_only: 900 return set() 901 return { 902 "server_hostname", 903 "http_path", 904 "access_token", 905 "http_headers", 906 "session_configuration", 907 "catalog", 908 } 909 910 @property 911 def _engine_adapter(self) -> t.Type[engine_adapter.DatabricksEngineAdapter]: 912 return engine_adapter.DatabricksEngineAdapter 913 914 @property 915 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 916 return { 917 k: v 918 for k, v in self.dict().items() 919 if k.startswith("databricks_connect_") 920 or k in ("catalog", "disable_databricks_connect", "disable_spark_session") 921 } 922 923 @property 924 def use_spark_session_only(self) -> bool: 925 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 926 927 return ( 928 DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session) 929 or self.force_databricks_connect 930 ) 931 932 @property 933 def _connection_factory(self) -> t.Callable: 934 if self.use_spark_session_only: 935 from sqlmesh.engines.spark.db_api.spark_session import connection 936 937 return connection 938 939 from databricks import sql # type: ignore 940 941 return sql.connect 942 943 @property 944 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 945 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 946 947 if not self.use_spark_session_only: 948 conn_kwargs: t.Dict[str, t.Any] = { 949 "_user_agent_entry": "sqlmesh", 950 } 951 952 if self.auth_type and "oauth" in self.auth_type: 953 # there are two types of oauth: User-to-Machine (U2M) and Machine-to-Machine (M2M) 954 if self.oauth_client_secret: 955 # if a client_secret exists, then a client_id also exists and we are using M2M 956 # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication 957 # ref: https://github.com/databricks/databricks-sql-python/blob/main/examples/m2m_oauth.py 958 from databricks.sdk.core import oauth_service_principal, Config 959 960 config = Config( 961 host=f"https://{self.server_hostname}", 962 client_id=self.oauth_client_id, 963 client_secret=self.oauth_client_secret, 964 ) 965 conn_kwargs["credentials_provider"] = lambda: oauth_service_principal(config) 966 else: 967 # if auth_type is set to an 'oauth' type but no client_id/secret are set, then we are using U2M 968 # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-user-to-machine-u2m-authentication 969 conn_kwargs["auth_type"] = self.auth_type 970 971 return conn_kwargs 972 973 if DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session): 974 from pyspark.sql import SparkSession 975 976 return dict( 977 spark=SparkSession.getActiveSession(), 978 catalog=self.catalog, 979 ) 980 981 from databricks.connect import DatabricksSession 982 983 if t.TYPE_CHECKING: 984 assert self.databricks_connect_server_hostname is not None 985 assert self.databricks_connect_access_token is not None 986 987 if self.databricks_connect_use_serverless: 988 builder = DatabricksSession.builder.remote( 989 host=self.databricks_connect_server_hostname, 990 token=self.databricks_connect_access_token, 991 serverless=True, 992 ) 993 else: 994 if t.TYPE_CHECKING: 995 assert self.databricks_connect_cluster_id is not None 996 builder = DatabricksSession.builder.remote( 997 host=self.databricks_connect_server_hostname, 998 token=self.databricks_connect_access_token, 999 cluster_id=self.databricks_connect_cluster_id, 1000 ) 1001 1002 return dict( 1003 spark=builder.userAgent("sqlmesh").getOrCreate(), 1004 catalog=self.catalog, 1005 ) 1006 1007 1008class BigQueryConnectionMethod(str, Enum): 1009 OAUTH = "oauth" 1010 OAUTH_SECRETS = "oauth-secrets" 1011 SERVICE_ACCOUNT = "service-account" 1012 SERVICE_ACCOUNT_JSON = "service-account-json" 1013 1014 1015class BigQueryPriority(str, Enum): 1016 BATCH = "batch" 1017 INTERACTIVE = "interactive" 1018 1019 @property 1020 def is_batch(self) -> bool: 1021 return self == self.BATCH 1022 1023 @property 1024 def is_interactive(self) -> bool: 1025 return self == self.INTERACTIVE 1026 1027 @property 1028 def bigquery_constant(self) -> str: 1029 from google.cloud.bigquery import QueryPriority 1030 1031 if self.is_batch: 1032 return QueryPriority.BATCH 1033 return QueryPriority.INTERACTIVE 1034 1035 1036class BigQueryConnectionConfig(ConnectionConfig): 1037 """ 1038 BigQuery Connection Configuration. 1039 """ 1040 1041 method: BigQueryConnectionMethod = BigQueryConnectionMethod.OAUTH 1042 1043 project: t.Optional[str] = None 1044 execution_project: t.Optional[str] = None 1045 quota_project: t.Optional[str] = None 1046 location: t.Optional[str] = None 1047 # Keyfile Auth 1048 keyfile: t.Optional[str] = None 1049 keyfile_json: t.Optional[t.Dict[str, t.Any]] = None 1050 # Oath Secret Auth 1051 token: t.Optional[str] = None 1052 refresh_token: t.Optional[str] = None 1053 client_id: t.Optional[str] = None 1054 client_secret: t.Optional[str] = None 1055 token_uri: t.Optional[str] = None 1056 scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/bigquery",) 1057 impersonated_service_account: t.Optional[str] = None 1058 # Extra Engine Config 1059 job_creation_timeout_seconds: t.Optional[int] = None 1060 job_execution_timeout_seconds: t.Optional[int] = None 1061 job_retries: t.Optional[int] = 1 1062 job_retry_deadline_seconds: t.Optional[int] = None 1063 priority: t.Optional[BigQueryPriority] = None 1064 maximum_bytes_billed: t.Optional[int] = None 1065 reservation: t.Optional[str] = None 1066 1067 concurrent_tasks: int = 1 1068 register_comments: bool = True 1069 pre_ping: t.Literal[False] = False 1070 1071 type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery") 1072 DIALECT: t.ClassVar[t.Literal["bigquery"]] = "bigquery" 1073 DISPLAY_NAME: t.ClassVar[t.Literal["BigQuery"]] = "BigQuery" 1074 DISPLAY_ORDER: t.ClassVar[t.Literal[4]] = 4 1075 1076 _engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery") 1077 1078 @field_validator("execution_project") 1079 def validate_execution_project( 1080 cls, 1081 v: t.Optional[str], 1082 info: ValidationInfo, 1083 ) -> t.Optional[str]: 1084 if v and not info.data.get("project"): 1085 raise ConfigError( 1086 "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location." 1087 ) 1088 return v 1089 1090 @field_validator("quota_project") 1091 def validate_quota_project( 1092 cls, 1093 v: t.Optional[str], 1094 info: ValidationInfo, 1095 ) -> t.Optional[str]: 1096 if v and not info.data.get("project"): 1097 raise ConfigError( 1098 "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location." 1099 ) 1100 return v 1101 1102 @property 1103 def _connection_kwargs_keys(self) -> t.Set[str]: 1104 return set() 1105 1106 @property 1107 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1108 return engine_adapter.BigQueryEngineAdapter 1109 1110 @property 1111 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1112 """The static connection kwargs for this connection""" 1113 import google.auth 1114 from google.auth import impersonated_credentials 1115 from google.api_core import client_info, client_options 1116 from google.oauth2 import credentials, service_account 1117 1118 if self.method == BigQueryConnectionMethod.OAUTH: 1119 creds, _ = google.auth.default(scopes=self.scopes) 1120 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT: 1121 creds = service_account.Credentials.from_service_account_file( 1122 self.keyfile, scopes=self.scopes 1123 ) 1124 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: 1125 creds = service_account.Credentials.from_service_account_info( 1126 self.keyfile_json, scopes=self.scopes 1127 ) 1128 elif self.method == BigQueryConnectionMethod.OAUTH_SECRETS: 1129 creds = credentials.Credentials( 1130 token=self.token, 1131 refresh_token=self.refresh_token, 1132 client_id=self.client_id, 1133 client_secret=self.client_secret, 1134 token_uri=self.token_uri, 1135 scopes=self.scopes, 1136 ) 1137 else: 1138 raise ConfigError("Invalid BigQuery Connection Method") 1139 1140 if self.impersonated_service_account: 1141 creds = impersonated_credentials.Credentials( 1142 source_credentials=creds, 1143 target_principal=self.impersonated_service_account, 1144 target_scopes=self.scopes, 1145 ) 1146 1147 options = client_options.ClientOptions(quota_project_id=self.quota_project) 1148 project = self.execution_project or self.project or None 1149 1150 client = google.cloud.bigquery.Client( 1151 project=project and exp.parse_identifier(project, dialect="bigquery").name, 1152 credentials=creds, 1153 location=self.location, 1154 client_info=client_info.ClientInfo(user_agent="sqlmesh"), 1155 client_options=options, 1156 ) 1157 1158 return { 1159 "client": client, 1160 } 1161 1162 @property 1163 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1164 return { 1165 k: v 1166 for k, v in self.dict().items() 1167 if k 1168 in { 1169 "job_creation_timeout_seconds", 1170 "job_execution_timeout_seconds", 1171 "job_retries", 1172 "job_retry_deadline_seconds", 1173 "priority", 1174 "maximum_bytes_billed", 1175 "reservation", 1176 } 1177 } 1178 1179 @property 1180 def _connection_factory(self) -> t.Callable: 1181 from google.cloud.bigquery.dbapi import connect 1182 1183 return connect 1184 1185 def get_catalog(self) -> t.Optional[str]: 1186 return self.project 1187 1188 1189class GCPPostgresConnectionConfig(ConnectionConfig): 1190 """ 1191 Postgres Connection Configuration for GCP. 1192 1193 Args: 1194 instance_connection_string: Connection name for the postgres instance. 1195 user: Postgres or IAM user's name 1196 password: The postgres user's password. Only needed when the user is a postgres user. 1197 enable_iam_auth: Set to True when user is an IAM user. 1198 db: Name of the db to connect to. 1199 keyfile: string path to json service account credentials file 1200 keyfile_json: dict service account credentials info 1201 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 1202 """ 1203 1204 instance_connection_string: str 1205 user: str 1206 password: t.Optional[str] = None 1207 enable_iam_auth: t.Optional[bool] = None 1208 db: str 1209 ip_type: t.Union[t.Literal["public"], t.Literal["private"], t.Literal["psc"]] = "public" 1210 # Keyfile Auth 1211 keyfile: t.Optional[str] = None 1212 keyfile_json: t.Optional[t.Dict[str, t.Any]] = None 1213 timeout: t.Optional[int] = None 1214 scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/sqlservice.admin",) 1215 driver: str = "pg8000" 1216 1217 type_: t.Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres") 1218 DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres" 1219 DISPLAY_NAME: t.ClassVar[t.Literal["GCP Postgres"]] = "GCP Postgres" 1220 DISPLAY_ORDER: t.ClassVar[t.Literal[13]] = 13 1221 1222 concurrent_tasks: int = 4 1223 register_comments: bool = True 1224 pre_ping: bool = True 1225 1226 _engine_import_validator = _get_engine_import_validator( 1227 "google.cloud.sql", "gcp_postgres", "gcppostgres" 1228 ) 1229 1230 @model_validator(mode="before") 1231 def _validate_auth_method(cls, data: t.Any) -> t.Any: 1232 if not isinstance(data, dict): 1233 return data 1234 1235 password = data.get("password") 1236 enable_iam_auth = data.get("enable_iam_auth") 1237 1238 if not password and not enable_iam_auth: 1239 raise ConfigError( 1240 "GCP Postgres connection configuration requires either password set" 1241 " for a postgres user account or enable_iam_auth set to 'True'" 1242 " for an IAM user account." 1243 ) 1244 1245 return data 1246 1247 @property 1248 def _connection_kwargs_keys(self) -> t.Set[str]: 1249 return { 1250 "instance_connection_string", 1251 "driver", 1252 "user", 1253 "password", 1254 "db", 1255 "enable_iam_auth", 1256 "timeout", 1257 } 1258 1259 @property 1260 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1261 return engine_adapter.PostgresEngineAdapter 1262 1263 @property 1264 def _connection_factory(self) -> t.Callable: 1265 from google.cloud.sql.connector import Connector 1266 from google.oauth2 import service_account 1267 1268 creds = None 1269 if self.keyfile: 1270 creds = service_account.Credentials.from_service_account_file( 1271 self.keyfile, scopes=self.scopes 1272 ) 1273 elif self.keyfile_json: 1274 creds = service_account.Credentials.from_service_account_info( 1275 self.keyfile_json, scopes=self.scopes 1276 ) 1277 1278 kwargs = { 1279 "credentials": creds, 1280 "ip_type": self.ip_type, 1281 } 1282 1283 if self.timeout: 1284 kwargs["timeout"] = self.timeout 1285 1286 return Connector(**kwargs).connect # type: ignore 1287 1288 1289class RedshiftConnectionConfig(ConnectionConfig): 1290 """ 1291 Redshift Connection Configuration. 1292 1293 Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146 1294 Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported. 1295 1296 Args: 1297 user: The username to use for authentication with the Amazon Redshift cluster. 1298 password: The password to use for authentication with the Amazon Redshift cluster. 1299 database: The name of the database instance to connect to. 1300 host: The hostname of the Amazon Redshift cluster. 1301 port: The port number of the Amazon Redshift cluster. Default value is 5439. 1302 source_address: No description provided 1303 unix_sock: No description provided 1304 ssl: Is SSL enabled. Default value is ``True``. SSL must be enabled when authenticating using IAM. 1305 sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported. 1306 timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout. 1307 tcp_keepalive: Is `TCP keepalive <https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive>`_ used. The default value is ``True``. 1308 application_name: Sets the application name. The default value is None. 1309 preferred_role: The IAM role preferred for the current connection. 1310 principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy. 1311 credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster. 1312 region: The AWS region where the Amazon Redshift cluster is located. 1313 cluster_identifier: The cluster identifier of the Amazon Redshift cluster. 1314 iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP. 1315 is_serverless: Redshift end-point is serverless or provisional. Default value false. 1316 serverless_acct_id: The account ID of the serverless. Default value None 1317 serverless_work_group: The name of work group for serverless end point. Default value None. 1318 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 1319 enable_merge: Whether to use the Redshift merge operation instead of the SQLMesh logical merge. 1320 """ 1321 1322 user: t.Optional[str] = None 1323 password: t.Optional[str] = None 1324 database: t.Optional[str] = None 1325 host: t.Optional[str] = None 1326 port: t.Optional[int] = None 1327 source_address: t.Optional[str] = None 1328 unix_sock: t.Optional[str] = None 1329 ssl: t.Optional[bool] = None 1330 sslmode: t.Optional[str] = None 1331 timeout: t.Optional[int] = None 1332 tcp_keepalive: t.Optional[bool] = None 1333 application_name: t.Optional[str] = None 1334 preferred_role: t.Optional[str] = None 1335 principal_arn: t.Optional[str] = None 1336 credentials_provider: t.Optional[str] = None 1337 region: t.Optional[str] = None 1338 cluster_identifier: t.Optional[str] = None 1339 iam: t.Optional[bool] = None 1340 is_serverless: t.Optional[bool] = None 1341 serverless_acct_id: t.Optional[str] = None 1342 serverless_work_group: t.Optional[str] = None 1343 enable_merge: t.Optional[bool] = None 1344 1345 concurrent_tasks: int = 4 1346 register_comments: bool = True 1347 pre_ping: bool = False 1348 1349 type_: t.Literal["redshift"] = Field(alias="type", default="redshift") 1350 DIALECT: t.ClassVar[t.Literal["redshift"]] = "redshift" 1351 DISPLAY_NAME: t.ClassVar[t.Literal["Redshift"]] = "Redshift" 1352 DISPLAY_ORDER: t.ClassVar[t.Literal[7]] = 7 1353 1354 _engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift") 1355 1356 @property 1357 def _connection_kwargs_keys(self) -> t.Set[str]: 1358 return { 1359 "user", 1360 "password", 1361 "database", 1362 "host", 1363 "port", 1364 "source_address", 1365 "unix_sock", 1366 "ssl", 1367 "sslmode", 1368 "timeout", 1369 "tcp_keepalive", 1370 "application_name", 1371 "preferred_role", 1372 "principal_arn", 1373 "credentials_provider", 1374 "region", 1375 "cluster_identifier", 1376 "iam", 1377 "is_serverless", 1378 "serverless_acct_id", 1379 "serverless_work_group", 1380 } 1381 1382 @property 1383 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1384 return engine_adapter.RedshiftEngineAdapter 1385 1386 @property 1387 def _connection_factory(self) -> t.Callable: 1388 from redshift_connector import connect 1389 1390 return connect 1391 1392 @property 1393 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1394 return {"enable_merge": self.enable_merge} 1395 1396 1397class PostgresConnectionConfig(ConnectionConfig): 1398 host: str 1399 user: str 1400 password: str 1401 port: int 1402 database: str 1403 keepalives_idle: t.Optional[int] = None 1404 connect_timeout: int = 10 1405 role: t.Optional[str] = None 1406 sslmode: t.Optional[str] = None 1407 application_name: t.Optional[str] = None 1408 1409 concurrent_tasks: int = 4 1410 register_comments: bool = True 1411 pre_ping: bool = True 1412 1413 type_: t.Literal["postgres"] = Field(alias="type", default="postgres") 1414 DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres" 1415 DISPLAY_NAME: t.ClassVar[t.Literal["Postgres"]] = "Postgres" 1416 DISPLAY_ORDER: t.ClassVar[t.Literal[12]] = 12 1417 1418 _engine_import_validator = _get_engine_import_validator("psycopg2", "postgres") 1419 1420 @property 1421 def _connection_kwargs_keys(self) -> t.Set[str]: 1422 return { 1423 "host", 1424 "user", 1425 "password", 1426 "port", 1427 "database", 1428 "keepalives_idle", 1429 "connect_timeout", 1430 "sslmode", 1431 "application_name", 1432 } 1433 1434 @property 1435 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1436 return engine_adapter.PostgresEngineAdapter 1437 1438 @property 1439 def _connection_factory(self) -> t.Callable: 1440 from psycopg2 import connect 1441 1442 return connect 1443 1444 @property 1445 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 1446 if not self.role: 1447 return None 1448 1449 def init(cursor: t.Any) -> None: 1450 cursor.execute(f"SET ROLE {self.role}") 1451 1452 return init 1453 1454 1455class MySQLConnectionConfig(ConnectionConfig): 1456 host: str 1457 user: str 1458 password: str 1459 port: t.Optional[int] = None 1460 database: t.Optional[str] = None 1461 charset: t.Optional[str] = None 1462 collation: t.Optional[str] = None 1463 ssl_disabled: t.Optional[bool] = None 1464 1465 concurrent_tasks: int = 4 1466 register_comments: bool = True 1467 pre_ping: bool = True 1468 1469 type_: t.Literal["mysql"] = Field(alias="type", default="mysql") 1470 DIALECT: t.ClassVar[t.Literal["mysql"]] = "mysql" 1471 DISPLAY_NAME: t.ClassVar[t.Literal["MySQL"]] = "MySQL" 1472 DISPLAY_ORDER: t.ClassVar[t.Literal[14]] = 14 1473 1474 _engine_import_validator = _get_engine_import_validator("pymysql", "mysql") 1475 1476 @property 1477 def _connection_kwargs_keys(self) -> t.Set[str]: 1478 connection_keys = { 1479 "host", 1480 "user", 1481 "password", 1482 } 1483 if self.port is not None: 1484 connection_keys.add("port") 1485 if self.database is not None: 1486 connection_keys.add("database") 1487 if self.charset is not None: 1488 connection_keys.add("charset") 1489 if self.collation is not None: 1490 connection_keys.add("collation") 1491 if self.ssl_disabled is not None: 1492 connection_keys.add("ssl_disabled") 1493 return connection_keys 1494 1495 @property 1496 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1497 return engine_adapter.MySQLEngineAdapter 1498 1499 @property 1500 def _connection_factory(self) -> t.Callable: 1501 from pymysql import connect 1502 1503 return connect 1504 1505 1506class MSSQLConnectionConfig(ConnectionConfig): 1507 host: str 1508 user: t.Optional[str] = None 1509 password: t.Optional[str] = None 1510 database: t.Optional[str] = "" 1511 timeout: t.Optional[int] = 0 1512 login_timeout: t.Optional[int] = 60 1513 charset: t.Optional[str] = "UTF-8" 1514 appname: t.Optional[str] = None 1515 port: t.Optional[int] = 1433 1516 conn_properties: t.Optional[t.Union[t.List[str], str]] = None 1517 autocommit: t.Optional[bool] = False 1518 tds_version: t.Optional[str] = None 1519 1520 # Driver options 1521 driver: t.Literal["pymssql", "pyodbc"] = "pymssql" 1522 # PyODBC specific options 1523 driver_name: t.Optional[str] = None # e.g. "ODBC Driver 18 for SQL Server" 1524 trust_server_certificate: t.Optional[bool] = None 1525 encrypt: t.Optional[bool] = None 1526 # Dictionary of arbitrary ODBC connection properties 1527 # See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute 1528 odbc_properties: t.Optional[t.Dict[str, t.Any]] = None 1529 1530 concurrent_tasks: int = 4 1531 register_comments: bool = True 1532 pre_ping: bool = True 1533 1534 type_: t.Literal["mssql"] = Field(alias="type", default="mssql") 1535 DIALECT: t.ClassVar[t.Literal["tsql"]] = "tsql" 1536 DISPLAY_NAME: t.ClassVar[t.Literal["MSSQL"]] = "MSSQL" 1537 DISPLAY_ORDER: t.ClassVar[t.Literal[11]] = 11 1538 1539 @model_validator(mode="before") 1540 @classmethod 1541 def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any: 1542 if not isinstance(data, dict): 1543 return data 1544 1545 driver = data.get("driver", "pymssql") 1546 1547 # Define the mapping of driver to import module and extra name 1548 driver_configs = {"pymssql": ("pymssql", "mssql"), "pyodbc": ("pyodbc", "mssql-odbc")} 1549 1550 if driver not in driver_configs: 1551 raise ValueError(f"Unsupported driver: {driver}") 1552 1553 import_module, extra_name = driver_configs[driver] 1554 1555 # Use _get_engine_import_validator with decorate=False to get the raw validation function 1556 # This avoids the __wrapped__ issue in Python 3.9 1557 validator_func = _get_engine_import_validator( 1558 import_module, driver, extra_name, decorate=False 1559 ) 1560 1561 # Call the raw validation function directly 1562 return validator_func(cls, data) 1563 1564 @property 1565 def _connection_kwargs_keys(self) -> t.Set[str]: 1566 base_keys = { 1567 "host", 1568 "user", 1569 "password", 1570 "database", 1571 "timeout", 1572 "login_timeout", 1573 "charset", 1574 "appname", 1575 "port", 1576 "conn_properties", 1577 "autocommit", 1578 "tds_version", 1579 } 1580 1581 if self.driver == "pyodbc": 1582 base_keys.update( 1583 { 1584 "driver_name", 1585 "trust_server_certificate", 1586 "encrypt", 1587 "odbc_properties", 1588 } 1589 ) 1590 # Remove pymssql-specific parameters 1591 base_keys.discard("tds_version") 1592 base_keys.discard("conn_properties") 1593 1594 return base_keys 1595 1596 @property 1597 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1598 return engine_adapter.MSSQLEngineAdapter 1599 1600 @property 1601 def _connection_factory(self) -> t.Callable: 1602 if self.driver == "pymssql": 1603 import pymssql 1604 1605 return pymssql.connect 1606 1607 import pyodbc 1608 1609 def connect(**kwargs: t.Any) -> t.Callable: 1610 # Extract parameters for connection string 1611 host = kwargs.pop("host") 1612 port = kwargs.pop("port", 1433) 1613 database = kwargs.pop("database", "") 1614 user = kwargs.pop("user", None) 1615 password = kwargs.pop("password", None) 1616 driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server") 1617 trust_server_certificate = kwargs.pop("trust_server_certificate", False) 1618 encrypt = kwargs.pop("encrypt", True) 1619 login_timeout = kwargs.pop("login_timeout", 60) 1620 1621 # Build connection string 1622 conn_str_parts = [ 1623 f"DRIVER={{{driver_name}}}", 1624 f"SERVER={host},{port}", 1625 ] 1626 1627 if database: 1628 conn_str_parts.append(f"DATABASE={database}") 1629 1630 # Add security options 1631 conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}") 1632 if trust_server_certificate: 1633 conn_str_parts.append("TrustServerCertificate=YES") 1634 1635 conn_str_parts.append(f"Connection Timeout={login_timeout}") 1636 1637 # Standard SQL Server authentication 1638 if user: 1639 conn_str_parts.append(f"UID={user}") 1640 if password: 1641 conn_str_parts.append(f"PWD={password}") 1642 1643 # Add any additional ODBC properties from the odbc_properties dictionary 1644 if self.odbc_properties: 1645 for key, value in self.odbc_properties.items(): 1646 # Skip properties that we've already set above 1647 if key.lower() in ( 1648 "driver", 1649 "server", 1650 "database", 1651 "uid", 1652 "pwd", 1653 "encrypt", 1654 "trustservercertificate", 1655 "connection timeout", 1656 ): 1657 continue 1658 1659 # Handle boolean values properly 1660 if isinstance(value, bool): 1661 conn_str_parts.append(f"{key}={'YES' if value else 'NO'}") 1662 else: 1663 conn_str_parts.append(f"{key}={value}") 1664 1665 # Create the connection string 1666 conn_str = ";".join(conn_str_parts) 1667 1668 conn = pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False)) 1669 1670 # Set up output converters for MSSQL-specific data types 1671 # Handle SQL type -155 (DATETIMEOFFSET) which is not yet supported by pyodbc 1672 # ref: https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794 1673 def handle_datetimeoffset(dto_value: t.Any) -> t.Any: 1674 from datetime import datetime, timedelta, timezone 1675 import struct 1676 1677 # Unpack the DATETIMEOFFSET binary format: 1678 # Format: <6hI2h = (year, month, day, hour, minute, second, nanoseconds, tz_hour_offset, tz_minute_offset) 1679 tup = struct.unpack("<6hI2h", dto_value) 1680 return datetime( 1681 tup[0], 1682 tup[1], 1683 tup[2], 1684 tup[3], 1685 tup[4], 1686 tup[5], 1687 tup[6] // 1000, 1688 timezone(timedelta(hours=tup[7], minutes=tup[8])), 1689 ) 1690 1691 conn.add_output_converter(-155, handle_datetimeoffset) 1692 1693 return conn 1694 1695 return connect 1696 1697 @property 1698 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1699 return {"catalog_support": CatalogSupport.REQUIRES_SET_CATALOG} 1700 1701 1702class AzureSQLConnectionConfig(MSSQLConnectionConfig): 1703 type_: t.Literal["azuresql"] = Field(alias="type", default="azuresql") # type: ignore 1704 DISPLAY_NAME: t.ClassVar[t.Literal["Azure SQL"]] = "Azure SQL" # type: ignore 1705 DISPLAY_ORDER: t.ClassVar[t.Literal[10]] = 10 # type: ignore 1706 1707 @property 1708 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1709 return {"catalog_support": CatalogSupport.SINGLE_CATALOG_ONLY} 1710 1711 1712class FabricConnectionConfig(MSSQLConnectionConfig): 1713 """ 1714 Fabric Connection Configuration. 1715 Inherits most settings from MSSQLConnectionConfig and sets the type to 'fabric'. 1716 It is recommended to use the 'pyodbc' driver for Fabric. 1717 """ 1718 1719 type_: t.Literal["fabric"] = Field(alias="type", default="fabric") # type: ignore 1720 DIALECT: t.ClassVar[t.Literal["fabric"]] = "fabric" # type: ignore 1721 DISPLAY_NAME: t.ClassVar[t.Literal["Fabric"]] = "Fabric" # type: ignore 1722 DISPLAY_ORDER: t.ClassVar[t.Literal[17]] = 17 # type: ignore 1723 driver: t.Literal["pyodbc"] = "pyodbc" 1724 workspace_id: str 1725 tenant_id: str 1726 autocommit: t.Optional[bool] = True 1727 1728 @property 1729 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1730 from sqlmesh.core.engine_adapter.fabric import FabricEngineAdapter 1731 1732 return FabricEngineAdapter 1733 1734 @property 1735 def _connection_factory(self) -> t.Callable: 1736 # Override to support catalog switching for Fabric 1737 base_factory = super()._connection_factory 1738 1739 def create_fabric_connection( 1740 target_catalog: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any 1741 ) -> t.Callable: 1742 kwargs["database"] = target_catalog or self.database 1743 return base_factory(*args, **kwargs) 1744 1745 return create_fabric_connection 1746 1747 @property 1748 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1749 return { 1750 "database": self.database, 1751 # more operations than not require a specific catalog to be already active 1752 # in particular, create/drop view, create/drop schema and querying information_schema 1753 "catalog_support": CatalogSupport.REQUIRES_SET_CATALOG, 1754 "workspace_id": self.workspace_id, 1755 "tenant_id": self.tenant_id, 1756 "user": self.user, 1757 "password": self.password, 1758 } 1759 1760 1761class SparkConnectionConfig(ConnectionConfig): 1762 """ 1763 Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks. 1764 """ 1765 1766 config_dir: t.Optional[str] = None 1767 catalog: t.Optional[str] = None 1768 config: t.Dict[str, t.Any] = {} 1769 wap_enabled: bool = False 1770 1771 concurrent_tasks: int = 4 1772 register_comments: bool = True 1773 pre_ping: t.Literal[False] = False 1774 1775 type_: t.Literal["spark"] = Field(alias="type", default="spark") 1776 DIALECT: t.ClassVar[t.Literal["spark"]] = "spark" 1777 DISPLAY_NAME: t.ClassVar[t.Literal["Spark"]] = "Spark" 1778 DISPLAY_ORDER: t.ClassVar[t.Literal[8]] = 8 1779 1780 _engine_import_validator = _get_engine_import_validator("pyspark", "spark") 1781 1782 @property 1783 def _connection_kwargs_keys(self) -> t.Set[str]: 1784 return { 1785 "catalog", 1786 } 1787 1788 @property 1789 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1790 return engine_adapter.SparkEngineAdapter 1791 1792 @property 1793 def _connection_factory(self) -> t.Callable: 1794 from sqlmesh.engines.spark.db_api.spark_session import connection 1795 1796 return connection 1797 1798 @property 1799 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1800 from pyspark.conf import SparkConf 1801 from pyspark.sql import SparkSession 1802 1803 spark_config = SparkConf() 1804 if self.config: 1805 for k, v in self.config.items(): 1806 spark_config.set(k, v) 1807 1808 if self.config_dir: 1809 os.environ["SPARK_CONF_DIR"] = self.config_dir 1810 return { 1811 "spark": SparkSession.builder.config(conf=spark_config) 1812 .enableHiveSupport() 1813 .getOrCreate(), 1814 } 1815 1816 @property 1817 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1818 return {"wap_enabled": self.wap_enabled} 1819 1820 1821class TrinoAuthenticationMethod(str, Enum): 1822 NO_AUTH = "no-auth" 1823 BASIC = "basic" 1824 LDAP = "ldap" 1825 KERBEROS = "kerberos" 1826 JWT = "jwt" 1827 CERTIFICATE = "certificate" 1828 OAUTH = "oauth" 1829 1830 @property 1831 def is_no_auth(self) -> bool: 1832 return self == self.NO_AUTH 1833 1834 @property 1835 def is_basic(self) -> bool: 1836 return self == self.BASIC 1837 1838 @property 1839 def is_ldap(self) -> bool: 1840 return self == self.LDAP 1841 1842 @property 1843 def is_kerberos(self) -> bool: 1844 return self == self.KERBEROS 1845 1846 @property 1847 def is_jwt(self) -> bool: 1848 return self == self.JWT 1849 1850 @property 1851 def is_certificate(self) -> bool: 1852 return self == self.CERTIFICATE 1853 1854 @property 1855 def is_oauth(self) -> bool: 1856 return self == self.OAUTH 1857 1858 1859class TrinoConnectionConfig(ConnectionConfig): 1860 method: TrinoAuthenticationMethod = TrinoAuthenticationMethod.NO_AUTH 1861 host: str 1862 user: str 1863 catalog: str 1864 port: t.Optional[int] = None 1865 http_scheme: t.Literal["http", "https"] = "https" 1866 # General Optional 1867 roles: t.Optional[t.Dict[str, str]] = None 1868 http_headers: t.Optional[t.Dict[str, str]] = None 1869 session_properties: t.Optional[t.Dict[str, str]] = None 1870 retries: int = 3 1871 timezone: t.Optional[str] = None 1872 # Basic/LDAP 1873 password: t.Optional[str] = None 1874 verify: t.Optional[bool] = None # disable SSL verification (ignored if `cert` is provided) 1875 # LDAP 1876 impersonation_user: t.Optional[str] = None 1877 # Kerberos 1878 keytab: t.Optional[str] = None 1879 krb5_config: t.Optional[str] = None 1880 principal: t.Optional[str] = None 1881 service_name: str = "trino" 1882 hostname_override: t.Optional[str] = None 1883 mutual_authentication: bool = False 1884 force_preemptive: bool = False 1885 sanitize_mutual_error_response: bool = True 1886 delegate: bool = False 1887 # JWT 1888 jwt_token: t.Optional[str] = None 1889 # Certificate 1890 client_certificate: t.Optional[str] = None 1891 client_private_key: t.Optional[str] = None 1892 cert: t.Optional[str] = None 1893 source: str = "sqlmesh" 1894 1895 # SQLMesh options 1896 schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None 1897 timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None 1898 concurrent_tasks: int = 4 1899 register_comments: bool = True 1900 pre_ping: t.Literal[False] = False 1901 1902 type_: t.Literal["trino"] = Field(alias="type", default="trino") 1903 DIALECT: t.ClassVar[t.Literal["trino"]] = "trino" 1904 DISPLAY_NAME: t.ClassVar[t.Literal["Trino"]] = "Trino" 1905 DISPLAY_ORDER: t.ClassVar[t.Literal[9]] = 9 1906 1907 _engine_import_validator = _get_engine_import_validator("trino", "trino") 1908 1909 @field_validator("schema_location_mapping", mode="before") 1910 @classmethod 1911 def _validate_regex_keys( 1912 cls, value: t.Dict[str | re.Pattern, str] 1913 ) -> t.Dict[re.Pattern, t.Any]: 1914 compiled = compile_regex_mapping(value) 1915 for replacement in compiled.values(): 1916 if "@{schema_name}" not in replacement: 1917 raise ConfigError( 1918 "schema_location_mapping needs to include the '@{schema_name}' placeholder in the value so SQLMesh knows where to substitute the schema name" 1919 ) 1920 return compiled 1921 1922 @field_validator("timestamp_mapping", mode="before") 1923 @classmethod 1924 def _validate_timestamp_mapping( 1925 cls, value: t.Optional[dict[str, str]] 1926 ) -> t.Optional[dict[exp.DataType, exp.DataType]]: 1927 if value is None: 1928 return value 1929 1930 result: dict[exp.DataType, exp.DataType] = {} 1931 for source_type, target_type in value.items(): 1932 try: 1933 source_datatype = exp.DataType.build(source_type) 1934 except ParseError: 1935 raise ConfigError( 1936 f"Invalid SQL type string in timestamp_mapping: " 1937 f"'{source_type}' is not a valid SQL data type." 1938 ) 1939 try: 1940 target_datatype = exp.DataType.build(target_type) 1941 except ParseError: 1942 raise ConfigError( 1943 f"Invalid SQL type string in timestamp_mapping: " 1944 f"'{target_type}' is not a valid SQL data type." 1945 ) 1946 result[source_datatype] = target_datatype 1947 1948 return result 1949 1950 @model_validator(mode="after") 1951 def _root_validator(self) -> Self: 1952 port = self.port 1953 if self.http_scheme == "http" and not self.method.is_no_auth and not self.method.is_basic: 1954 raise ConfigError("HTTP scheme can only be used with no-auth or basic method") 1955 1956 if port is None: 1957 self.port = 80 if self.http_scheme == "http" else 443 1958 1959 if (self.method.is_ldap or self.method.is_basic) and (not self.password or not self.user): 1960 raise ConfigError( 1961 f"Username and Password must be provided if using {self.method.value} authentication" 1962 ) 1963 1964 if self.method.is_kerberos and ( 1965 not self.principal or not self.keytab or not self.krb5_config 1966 ): 1967 raise ConfigError( 1968 "Kerberos requires the following fields: principal, keytab, and krb5_config" 1969 ) 1970 1971 if self.method.is_jwt and not self.jwt_token: 1972 raise ConfigError("JWT requires `jwt_token` to be set") 1973 1974 if self.method.is_certificate and ( 1975 not self.cert or not self.client_certificate or not self.client_private_key 1976 ): 1977 raise ConfigError( 1978 "Certificate requires the following fields: cert, client_certificate, and client_private_key" 1979 ) 1980 1981 return self 1982 1983 @property 1984 def _connection_kwargs_keys(self) -> t.Set[str]: 1985 kwargs = { 1986 "host", 1987 "port", 1988 "catalog", 1989 "roles", 1990 "source", 1991 "http_scheme", 1992 "http_headers", 1993 "session_properties", 1994 "timezone", 1995 } 1996 return kwargs 1997 1998 @property 1999 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2000 return engine_adapter.TrinoEngineAdapter 2001 2002 @property 2003 def _connection_factory(self) -> t.Callable: 2004 from trino.dbapi import connect 2005 2006 return connect 2007 2008 @property 2009 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 2010 from trino.auth import ( 2011 BasicAuthentication, 2012 CertificateAuthentication, 2013 JWTAuthentication, 2014 KerberosAuthentication, 2015 OAuth2Authentication, 2016 ) 2017 2018 auth: t.Optional[ 2019 t.Union[ 2020 BasicAuthentication, 2021 KerberosAuthentication, 2022 OAuth2Authentication, 2023 JWTAuthentication, 2024 CertificateAuthentication, 2025 ] 2026 ] = None 2027 if self.method.is_basic or self.method.is_ldap: 2028 assert self.password is not None # for mypy since validator already checks this 2029 auth = BasicAuthentication(self.user, self.password) 2030 elif self.method.is_kerberos: 2031 if self.keytab: 2032 os.environ["KRB5_CLIENT_KTNAME"] = self.keytab 2033 auth = KerberosAuthentication( 2034 config=self.krb5_config, 2035 service_name=self.service_name, 2036 principal=self.principal, 2037 mutual_authentication=self.mutual_authentication, 2038 ca_bundle=self.cert, 2039 force_preemptive=self.force_preemptive, 2040 hostname_override=self.hostname_override, 2041 sanitize_mutual_error_response=self.sanitize_mutual_error_response, 2042 delegate=self.delegate, 2043 ) 2044 elif self.method.is_oauth: 2045 auth = OAuth2Authentication() 2046 elif self.method.is_jwt: 2047 assert self.jwt_token is not None 2048 auth = JWTAuthentication(self.jwt_token) 2049 elif self.method.is_certificate: 2050 assert self.client_certificate is not None 2051 assert self.client_private_key is not None 2052 auth = CertificateAuthentication(self.client_certificate, self.client_private_key) 2053 2054 return { 2055 "auth": auth, 2056 "user": self.impersonation_user or self.user, 2057 "max_attempts": self.retries, 2058 "verify": self.cert if self.cert is not None else self.verify, 2059 "source": self.source, 2060 } 2061 2062 @property 2063 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2064 return { 2065 "schema_location_mapping": self.schema_location_mapping, 2066 "timestamp_mapping": self.timestamp_mapping, 2067 } 2068 2069 2070class ClickhouseConnectionConfig(ConnectionConfig): 2071 """ 2072 Clickhouse Connection Configuration. 2073 2074 Property reference: https://clickhouse.com/docs/en/integrations/python#client-initialization 2075 """ 2076 2077 host: str 2078 username: str 2079 password: t.Optional[str] = None 2080 port: t.Optional[int] = None 2081 cluster: t.Optional[str] = None 2082 connect_timeout: int = 10 2083 send_receive_timeout: int = 300 2084 query_limit: int = 0 2085 use_compression: bool = True 2086 compression_method: t.Optional[str] = None 2087 connection_settings: t.Optional[t.Dict[str, t.Any]] = None 2088 http_proxy: t.Optional[str] = None 2089 # HTTPS/TLS settings 2090 verify: bool = True 2091 ca_cert: t.Optional[str] = None 2092 client_cert: t.Optional[str] = None 2093 client_cert_key: t.Optional[str] = None 2094 https_proxy: t.Optional[str] = None 2095 server_host_name: t.Optional[str] = None 2096 tls_mode: t.Optional[str] = None 2097 2098 concurrent_tasks: int = 1 2099 register_comments: bool = True 2100 pre_ping: bool = False 2101 2102 # This object expects options from urllib3 and also from clickhouse-connect 2103 # See: 2104 # * https://urllib3.readthedocs.io/en/stable/advanced-usage.html 2105 # * https://clickhouse.com/docs/en/integrations/python#customizing-the-http-connection-pool 2106 connection_pool_options: t.Optional[t.Dict[str, t.Any]] = None 2107 2108 type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse") 2109 DIALECT: t.ClassVar[t.Literal["clickhouse"]] = "clickhouse" 2110 DISPLAY_NAME: t.ClassVar[t.Literal["ClickHouse"]] = "ClickHouse" 2111 DISPLAY_ORDER: t.ClassVar[t.Literal[6]] = 6 2112 2113 _engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse") 2114 2115 @property 2116 def _connection_kwargs_keys(self) -> t.Set[str]: 2117 kwargs = { 2118 "host", 2119 "username", 2120 "port", 2121 "password", 2122 "connect_timeout", 2123 "send_receive_timeout", 2124 "query_limit", 2125 "http_proxy", 2126 "verify", 2127 "ca_cert", 2128 "client_cert", 2129 "client_cert_key", 2130 "https_proxy", 2131 "server_host_name", 2132 "tls_mode", 2133 } 2134 return kwargs 2135 2136 @property 2137 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2138 return engine_adapter.ClickhouseEngineAdapter 2139 2140 @property 2141 def _connection_factory(self) -> t.Callable: 2142 from clickhouse_connect.dbapi import connect # type: ignore 2143 from clickhouse_connect.driver import httputil # type: ignore 2144 from functools import partial 2145 2146 pool_manager_options: t.Dict[str, t.Any] = dict( 2147 # Match the maxsize to the number of concurrent tasks 2148 maxsize=self.concurrent_tasks, 2149 # Block if there are no free connections 2150 block=True, 2151 verify=self.verify, 2152 ca_cert=self.ca_cert, 2153 client_cert=self.client_cert, 2154 client_cert_key=self.client_cert_key, 2155 https_proxy=self.https_proxy, 2156 ) 2157 # this doesn't happen automatically because we always supply our own pool manager to the connection 2158 # https://github.com/ClickHouse/clickhouse-connect/blob/3a7f4b04cad29c7c2536661b831fb744248e2ec0/clickhouse_connect/driver/httpclient.py#L109 2159 if self.server_host_name: 2160 pool_manager_options["server_hostname"] = self.server_host_name 2161 if self.verify: 2162 pool_manager_options["assert_hostname"] = self.server_host_name 2163 if self.connection_pool_options: 2164 pool_manager_options.update(self.connection_pool_options) 2165 pool_mgr = httputil.get_pool_manager(**pool_manager_options) 2166 2167 return partial(connect, pool_mgr=pool_mgr) 2168 2169 @property 2170 def cloud_mode(self) -> bool: 2171 return "clickhouse.cloud" in self.host 2172 2173 @property 2174 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2175 return {"cluster": self.cluster, "cloud_mode": self.cloud_mode} 2176 2177 @property 2178 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 2179 from sqlmesh import __version__ 2180 2181 # False = no compression 2182 # True = Clickhouse default compression method 2183 # string = specific compression method 2184 compress: bool | str = self.use_compression 2185 if compress and self.compression_method: 2186 compress = self.compression_method 2187 2188 # Clickhouse system settings passed to connection 2189 # https://clickhouse.com/docs/en/operations/settings/settings 2190 # - below are set to align with dbt-clickhouse 2191 # - https://github.com/ClickHouse/dbt-clickhouse/blob/44d26308ea6a3c8ead25c280164aa88191f05f47/dbt/adapters/clickhouse/dbclient.py#L77 2192 settings = self.connection_settings or {} 2193 # mutations_sync = 2: "The query waits for all mutations [ALTER statements] to complete on all replicas (if they exist)" 2194 settings["mutations_sync"] = "2" 2195 # insert_distributed_sync = 1: "INSERT operation succeeds only after all the data is saved on all shards" 2196 settings["insert_distributed_sync"] = "1" 2197 if self.cluster or self.cloud_mode: 2198 # database_replicated_enforce_synchronous_settings = 1: 2199 # - "Enforces synchronous waiting for some queries" 2200 # - https://github.com/ClickHouse/ClickHouse/blob/ccaa8d03a9351efc16625340268b9caffa8a22ba/src/Core/Settings.h#L709 2201 settings["database_replicated_enforce_synchronous_settings"] = "1" 2202 # insert_quorum = auto: 2203 # - "INSERT succeeds only when ClickHouse manages to correctly write data to the insert_quorum of replicas during 2204 # the insert_quorum_timeout" 2205 # - "use majority number (number_of_replicas / 2 + 1) as quorum number" 2206 settings["insert_quorum"] = "auto" 2207 2208 return { 2209 "compress": compress, 2210 "client_name": f"SQLMesh/{__version__}", 2211 **settings, 2212 } 2213 2214 2215class AthenaConnectionConfig(ConnectionConfig): 2216 # PyAthena connection options 2217 aws_access_key_id: t.Optional[str] = None 2218 aws_secret_access_key: t.Optional[str] = None 2219 role_arn: t.Optional[str] = None 2220 role_session_name: t.Optional[str] = None 2221 region_name: t.Optional[str] = None 2222 work_group: t.Optional[str] = None 2223 s3_staging_dir: t.Optional[str] = None 2224 schema_name: t.Optional[str] = None 2225 catalog_name: t.Optional[str] = None 2226 2227 # SQLMesh options 2228 s3_warehouse_location: t.Optional[str] = None 2229 concurrent_tasks: int = 4 2230 register_comments: t.Literal[False] = ( 2231 False # because Athena doesnt support comments in most cases 2232 ) 2233 pre_ping: t.Literal[False] = False 2234 2235 type_: t.Literal["athena"] = Field(alias="type", default="athena") 2236 DIALECT: t.ClassVar[t.Literal["athena"]] = "athena" 2237 DISPLAY_NAME: t.ClassVar[t.Literal["Athena"]] = "Athena" 2238 DISPLAY_ORDER: t.ClassVar[t.Literal[15]] = 15 2239 2240 _engine_import_validator = _get_engine_import_validator("pyathena", "athena") 2241 2242 @model_validator(mode="after") 2243 def _root_validator(self) -> Self: 2244 work_group = self.work_group 2245 s3_staging_dir = self.s3_staging_dir 2246 s3_warehouse_location = self.s3_warehouse_location 2247 2248 if not work_group and not s3_staging_dir: 2249 raise ConfigError("At least one of work_group or s3_staging_dir must be set") 2250 2251 if s3_staging_dir: 2252 self.s3_staging_dir = validate_s3_uri(s3_staging_dir, base=True, error_type=ConfigError) 2253 2254 if s3_warehouse_location: 2255 self.s3_warehouse_location = validate_s3_uri( 2256 s3_warehouse_location, base=True, error_type=ConfigError 2257 ) 2258 2259 return self 2260 2261 @property 2262 def _connection_kwargs_keys(self) -> t.Set[str]: 2263 return { 2264 "aws_access_key_id", 2265 "aws_secret_access_key", 2266 "role_arn", 2267 "role_session_name", 2268 "region_name", 2269 "work_group", 2270 "s3_staging_dir", 2271 "schema_name", 2272 "catalog_name", 2273 } 2274 2275 @property 2276 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2277 return engine_adapter.AthenaEngineAdapter 2278 2279 @property 2280 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2281 return {"s3_warehouse_location": self.s3_warehouse_location} 2282 2283 @property 2284 def _connection_factory(self) -> t.Callable: 2285 from pyathena import connect # type: ignore 2286 2287 return connect 2288 2289 def get_catalog(self) -> t.Optional[str]: 2290 return self.catalog_name 2291 2292 2293class RisingwaveConnectionConfig(ConnectionConfig): 2294 host: str 2295 user: str 2296 password: t.Optional[str] = None 2297 port: int 2298 database: str 2299 role: t.Optional[str] = None 2300 sslmode: t.Optional[str] = None 2301 2302 concurrent_tasks: int = 4 2303 register_comments: bool = True 2304 pre_ping: bool = True 2305 2306 type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave") 2307 DIALECT: t.ClassVar[t.Literal["risingwave"]] = "risingwave" 2308 DISPLAY_NAME: t.ClassVar[t.Literal["RisingWave"]] = "RisingWave" 2309 DISPLAY_ORDER: t.ClassVar[t.Literal[16]] = 16 2310 2311 _engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave") 2312 2313 @property 2314 def _connection_kwargs_keys(self) -> t.Set[str]: 2315 return { 2316 "host", 2317 "user", 2318 "password", 2319 "port", 2320 "database", 2321 "role", 2322 "sslmode", 2323 } 2324 2325 @property 2326 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2327 return engine_adapter.RisingwaveEngineAdapter 2328 2329 @property 2330 def _connection_factory(self) -> t.Callable: 2331 from psycopg2 import connect 2332 2333 return connect 2334 2335 @property 2336 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 2337 def init(cursor: t.Any) -> None: 2338 sql = "SET RW_IMPLICIT_FLUSH TO true;" 2339 cursor.execute(sql) 2340 2341 return init 2342 2343 2344CONNECTION_CONFIG_TO_TYPE = { 2345 # Map all subclasses of ConnectionConfig to the value of their `type_` field. 2346 tpe.all_field_infos()["type_"].default: tpe 2347 for tpe in subclasses( 2348 __name__, 2349 ConnectionConfig, 2350 exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, 2351 ) 2352} 2353 2354DIALECT_TO_TYPE = { 2355 tpe.all_field_infos()["type_"].default: tpe.DIALECT 2356 for tpe in subclasses( 2357 __name__, 2358 ConnectionConfig, 2359 exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, 2360 ) 2361} 2362 2363INIT_DISPLAY_INFO_TO_TYPE = { 2364 tpe.all_field_infos()["type_"].default: ( 2365 tpe.DISPLAY_ORDER, 2366 tpe.DISPLAY_NAME, 2367 ) 2368 for tpe in subclasses( 2369 __name__, 2370 ConnectionConfig, 2371 exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, 2372 ) 2373} 2374 2375 2376def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig: 2377 if "type" not in v: 2378 raise ConfigError("Missing connection type.") 2379 2380 connection_type = v["type"] 2381 if connection_type not in CONNECTION_CONFIG_TO_TYPE: 2382 raise ConfigError(f"Unknown connection type '{connection_type}'.") 2383 2384 return CONNECTION_CONFIG_TO_TYPE[connection_type](**v) 2385 2386 2387def _connection_config_validator( 2388 cls: t.Type, v: ConnectionConfig | t.Dict[str, t.Any] | None 2389) -> ConnectionConfig | None: 2390 if v is None or isinstance(v, ConnectionConfig): 2391 return v 2392 2393 check_config_and_vars_msg = "\n\nVerify your config.yaml and environment variables." 2394 2395 try: 2396 return parse_connection_config(v) 2397 except pydantic.ValidationError as e: 2398 raise ConfigError( 2399 validation_error_message(e, f"Invalid '{v['type']}' connection config:") 2400 + check_config_and_vars_msg 2401 ) 2402 except ConfigError as e: 2403 raise ConfigError(str(e) + check_config_and_vars_msg) 2404 2405 2406connection_config_validator: t.Callable = field_validator( 2407 "connection", 2408 "state_connection", 2409 "test_connection", 2410 "default_connection", 2411 "default_test_connection", 2412 mode="before", 2413 check_fields=False, 2414)(_connection_config_validator) 2415 2416 2417if t.TYPE_CHECKING: 2418 # TypeAlias hasn't been introduced until Python 3.10 which means that we can't use it 2419 # outside the TYPE_CHECKING guard. 2420 SerializableConnectionConfig: t.TypeAlias = ConnectionConfig # type: ignore 2421else: 2422 import pydantic 2423 2424 # Workaround for https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing 2425 SerializableConnectionConfig = pydantic.SerializeAsAny[ConnectionConfig] # type: ignore
97class ConnectionConfig(abc.ABC, BaseConfig): 98 type_: str 99 DIALECT: t.ClassVar[str] 100 DISPLAY_NAME: t.ClassVar[str] 101 DISPLAY_ORDER: t.ClassVar[int] 102 concurrent_tasks: int 103 register_comments: bool 104 pre_ping: bool 105 pretty_sql: bool = False 106 schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None 107 catalog_type_overrides: t.Optional[t.Dict[str, str]] = None 108 109 # Whether to share a single connection across threads or create a new connection per thread. 110 shared_connection: t.ClassVar[bool] = False 111 112 @property 113 @abc.abstractmethod 114 def _connection_kwargs_keys(self) -> t.Set[str]: 115 """keywords that should be passed into the connection""" 116 117 @property 118 @abc.abstractmethod 119 def _engine_adapter(self) -> t.Type[EngineAdapter]: 120 """The engine adapter for this connection""" 121 122 @property 123 @abc.abstractmethod 124 def _connection_factory(self) -> t.Callable: 125 """A function that is called to return a connection object for the given Engine Adapter""" 126 127 @property 128 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 129 """The static connection kwargs for this connection""" 130 return {} 131 132 @property 133 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 134 """kwargs that are for execution config only""" 135 return {} 136 137 @property 138 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 139 """A function that is called to initialize the cursor""" 140 return None 141 142 @property 143 def is_recommended_for_state_sync(self) -> bool: 144 """Whether this engine is recommended for being used as a state sync for production state syncs""" 145 return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES 146 147 @property 148 def is_forbidden_for_state_sync(self) -> bool: 149 """Whether this engine is forbidden from being used as a state sync""" 150 return self.type_ in FORBIDDEN_STATE_SYNC_ENGINES 151 152 @property 153 def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]: 154 """A function that is called to return a connection object for the given Engine Adapter""" 155 return partial( 156 self._connection_factory, 157 **{ 158 **self._static_connection_kwargs, 159 **{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys}, 160 }, 161 ) 162 163 def connection_validator(self) -> t.Callable[[], None]: 164 """A function that validates the connection configuration""" 165 return self.create_engine_adapter().ping 166 167 def create_engine_adapter( 168 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 169 ) -> EngineAdapter: 170 """Returns a new instance of the Engine Adapter.""" 171 172 concurrent_tasks = concurrent_tasks or self.concurrent_tasks 173 return self._engine_adapter( 174 self._connection_factory_with_kwargs, 175 multithreaded=concurrent_tasks > 1, 176 default_catalog=self.get_catalog(), 177 cursor_init=self._cursor_init, 178 register_comments=register_comments_override or self.register_comments, 179 pre_ping=self.pre_ping, 180 pretty_sql=self.pretty_sql, 181 shared_connection=self.shared_connection, 182 schema_differ_overrides=self.schema_differ_overrides, 183 catalog_type_overrides=self.catalog_type_overrides, 184 **self._extra_engine_config, 185 ) 186 187 def get_catalog(self) -> t.Optional[str]: 188 """The catalog for this connection""" 189 if hasattr(self, "catalog"): 190 return self.catalog 191 if hasattr(self, "database"): 192 return self.database 193 if hasattr(self, "db"): 194 return self.db 195 return None 196 197 @model_validator(mode="before") 198 @classmethod 199 def _expand_json_strings_to_concrete_types(cls, data: t.Any) -> t.Any: 200 """ 201 There are situations where a connection config class has a field that is some kind of complex type 202 (eg a list of strings or a dict) but the value is being supplied from a source such as an environment variable 203 204 When this happens, the value is supplied as a string rather than a Python object. We need some way 205 of turning this string into the corresponding Python list or dict. 206 207 Rather than doing this piecemeal on every config subclass, this provides a generic implementatation 208 to identify fields that may be be supplied as JSON strings and handle them transparently 209 """ 210 if data and isinstance(data, dict): 211 for maybe_json_field_name in cls._get_list_and_dict_field_names(): 212 if (value := data.get(maybe_json_field_name)) and isinstance(value, str): 213 # crude JSON check as we dont want to try and parse every string we get 214 value = value.strip() 215 if value.startswith("{") or value.startswith("["): 216 data[maybe_json_field_name] = from_json(value) 217 218 return data 219 220 @classmethod 221 def _get_list_and_dict_field_names(cls) -> t.Set[str]: 222 field_names = set() 223 for name, field in cls.model_fields.items(): 224 if field.annotation: 225 field_types = get_concrete_types_from_typehint(field.annotation) 226 227 # check if the field type is something that could concievably be supplied as a json string 228 if any(ft is t for t in (list, tuple, set, dict) for ft in field_types): 229 field_names.add(name) 230 231 return field_names
Helper class that provides a standard way to create an ABC using inheritance.
142 @property 143 def is_recommended_for_state_sync(self) -> bool: 144 """Whether this engine is recommended for being used as a state sync for production state syncs""" 145 return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES
Whether this engine is recommended for being used as a state sync for production state syncs
147 @property 148 def is_forbidden_for_state_sync(self) -> bool: 149 """Whether this engine is forbidden from being used as a state sync""" 150 return self.type_ in FORBIDDEN_STATE_SYNC_ENGINES
Whether this engine is forbidden from being used as a state sync
163 def connection_validator(self) -> t.Callable[[], None]: 164 """A function that validates the connection configuration""" 165 return self.create_engine_adapter().ping
A function that validates the connection configuration
167 def create_engine_adapter( 168 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 169 ) -> EngineAdapter: 170 """Returns a new instance of the Engine Adapter.""" 171 172 concurrent_tasks = concurrent_tasks or self.concurrent_tasks 173 return self._engine_adapter( 174 self._connection_factory_with_kwargs, 175 multithreaded=concurrent_tasks > 1, 176 default_catalog=self.get_catalog(), 177 cursor_init=self._cursor_init, 178 register_comments=register_comments_override or self.register_comments, 179 pre_ping=self.pre_ping, 180 pretty_sql=self.pretty_sql, 181 shared_connection=self.shared_connection, 182 schema_differ_overrides=self.schema_differ_overrides, 183 catalog_type_overrides=self.catalog_type_overrides, 184 **self._extra_engine_config, 185 )
Returns a new instance of the Engine Adapter.
187 def get_catalog(self) -> t.Optional[str]: 188 """The catalog for this connection""" 189 if hasattr(self, "catalog"): 190 return self.catalog 191 if hasattr(self, "database"): 192 return self.database 193 if hasattr(self, "db"): 194 return self.db 195 return None
The catalog for this connection
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
234class DuckDBAttachOptions(BaseConfig): 235 type: str 236 path: str 237 read_only: bool = False 238 239 # DuckLake specific options 240 data_path: t.Optional[str] = None 241 encrypted: bool = False 242 data_inlining_row_limit: t.Optional[int] = None 243 metadata_schema: t.Optional[str] = None 244 245 def to_sql(self, alias: str) -> str: 246 options = [] 247 # 'duckdb' is actually not a supported type, but we'd like to allow it for 248 # fully qualified attach options or integration testing, similar to duckdb-dbt 249 if self.type not in ("duckdb", "ducklake", "motherduck"): 250 options.append(f"TYPE {self.type.upper()}") 251 if self.read_only: 252 options.append("READ_ONLY") 253 254 # DuckLake specific options 255 path = self.path 256 if self.type == "ducklake": 257 if not path.startswith("ducklake:"): 258 path = f"ducklake:{path}" 259 if self.data_path is not None: 260 options.append(f"DATA_PATH '{self.data_path}'") 261 if self.encrypted: 262 options.append("ENCRYPTED") 263 if self.data_inlining_row_limit is not None: 264 options.append(f"DATA_INLINING_ROW_LIMIT {self.data_inlining_row_limit}") 265 if self.metadata_schema is not None: 266 options.append(f"METADATA_SCHEMA '{self.metadata_schema}'") 267 268 options_sql = f" ({', '.join(options)})" if options else "" 269 alias_sql = "" 270 # TODO: Add support for Postgres schema. Currently adding it blocks access to the information_schema 271 272 # MotherDuck does not support aliasing 273 alias_sql = ( 274 f" AS {alias}" if not (self.type == "motherduck" or self.path.startswith("md:")) else "" 275 ) 276 return f"ATTACH IF NOT EXISTS '{path}'{alias_sql}{options_sql}"
Base configuration functionality for configuration classes.
245 def to_sql(self, alias: str) -> str: 246 options = [] 247 # 'duckdb' is actually not a supported type, but we'd like to allow it for 248 # fully qualified attach options or integration testing, similar to duckdb-dbt 249 if self.type not in ("duckdb", "ducklake", "motherduck"): 250 options.append(f"TYPE {self.type.upper()}") 251 if self.read_only: 252 options.append("READ_ONLY") 253 254 # DuckLake specific options 255 path = self.path 256 if self.type == "ducklake": 257 if not path.startswith("ducklake:"): 258 path = f"ducklake:{path}" 259 if self.data_path is not None: 260 options.append(f"DATA_PATH '{self.data_path}'") 261 if self.encrypted: 262 options.append("ENCRYPTED") 263 if self.data_inlining_row_limit is not None: 264 options.append(f"DATA_INLINING_ROW_LIMIT {self.data_inlining_row_limit}") 265 if self.metadata_schema is not None: 266 options.append(f"METADATA_SCHEMA '{self.metadata_schema}'") 267 268 options_sql = f" ({', '.join(options)})" if options else "" 269 alias_sql = "" 270 # TODO: Add support for Postgres schema. Currently adding it blocks access to the information_schema 271 272 # MotherDuck does not support aliasing 273 alias_sql = ( 274 f" AS {alias}" if not (self.type == "motherduck" or self.path.startswith("md:")) else "" 275 ) 276 return f"ATTACH IF NOT EXISTS '{path}'{alias_sql}{options_sql}"
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
279class BaseDuckDBConnectionConfig(ConnectionConfig): 280 """Common configuration for the DuckDB-based connections. 281 282 Args: 283 database: The optional database name. If not specified, the in-memory database will be used. 284 catalogs: Key is the name of the catalog and value is the path. 285 extensions: A list of autoloadable extensions to load. 286 connector_config: A dictionary of configuration to pass into the duckdb connector. 287 secrets: A list of dictionaries used to generate DuckDB secrets for authenticating with external services (e.g. S3). 288 filesystems: A list of dictionaries used to register `fsspec` filesystems to the DuckDB cursor. 289 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 290 register_comments: Whether or not to register model comments with the SQL engine. 291 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 292 token: The optional MotherDuck token. If not specified and a MotherDuck path is in the catalog, the user will be prompted to login with their web browser. 293 """ 294 295 database: t.Optional[str] = None 296 catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None 297 extensions: t.List[t.Union[str, t.Dict[str, t.Any]]] = [] 298 connector_config: t.Dict[str, t.Any] = {} 299 secrets: t.Union[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Dict[str, t.Any]]] = [] 300 filesystems: t.List[t.Dict[str, t.Any]] = [] 301 302 concurrent_tasks: int = 1 303 register_comments: bool = True 304 pre_ping: t.Literal[False] = False 305 306 token: t.Optional[str] = None 307 308 shared_connection: t.ClassVar[bool] = True 309 310 _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} 311 312 @model_validator(mode="before") 313 def _validate_database_catalogs(cls, data: t.Any) -> t.Any: 314 if not isinstance(data, dict): 315 return data 316 317 db_path = data.get("database") 318 if db_path and data.get("catalogs"): 319 raise ConfigError( 320 "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog" 321 ) 322 if isinstance(db_path, str) and db_path.startswith("md:"): 323 raise ConfigError( 324 "Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`." 325 ) 326 327 return data 328 329 @property 330 def _engine_adapter(self) -> t.Type[EngineAdapter]: 331 return engine_adapter.DuckDBEngineAdapter 332 333 @property 334 def _connection_kwargs_keys(self) -> t.Set[str]: 335 return {"database"} 336 337 @property 338 def _connection_factory(self) -> t.Callable: 339 import duckdb 340 341 return duckdb.connect 342 343 @property 344 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 345 """A function that is called to initialize the cursor""" 346 import duckdb 347 from duckdb import BinderException 348 349 def init(cursor: duckdb.DuckDBPyConnection) -> None: 350 for extension in self.extensions: 351 extension = extension if isinstance(extension, dict) else {"name": extension} 352 353 install_command = f"INSTALL {extension['name']}" 354 355 if extension.get("repository"): 356 install_command = f"{install_command} FROM {extension['repository']}" 357 358 if extension.get("force_install"): 359 install_command = f"FORCE {install_command}" 360 361 try: 362 cursor.execute(install_command) 363 cursor.execute(f"LOAD {extension['name']}") 364 except Exception as e: 365 raise ConfigError(f"Failed to load extension {extension['name']}: {e}") 366 367 if self.connector_config: 368 option_names = list(self.connector_config) 369 in_part = ",".join("?" for _ in range(len(option_names))) 370 371 cursor.execute( 372 f"SELECT name, value FROM duckdb_settings() WHERE name IN ({in_part})", 373 option_names, 374 ) 375 376 existing_values = {field: setting for field, setting in cursor.fetchall()} 377 378 # only set connector_config items if the values differ from what is already set 379 # trying to set options like 'temp_directory' even to the same value can throw errors like: 380 # Not implemented Error: Cannot switch temporary directory after the current one has been used 381 for field, setting in self.connector_config.items(): 382 if existing_values.get(field) != setting: 383 try: 384 cursor.execute(f"SET {field} = '{setting}'") 385 except Exception as e: 386 raise ConfigError( 387 f"Failed to set connector config {field} to {setting}: {e}" 388 ) 389 390 if self.secrets: 391 duckdb_version = duckdb.__version__ 392 if version.parse(duckdb_version) < version.parse("0.10.0"): 393 from sqlmesh.core.console import get_console 394 395 get_console().log_warning( 396 f"DuckDB version {duckdb_version} does not support secrets-based authentication (requires 0.10.0 or later).\n" 397 "To use secrets, please upgrade DuckDB. For older versions, configure legacy authentication via `connector_config`.\n" 398 "More info: https://duckdb.org/docs/stable/extensions/httpfs/s3api_legacy_authentication.html" 399 ) 400 else: 401 if isinstance(self.secrets, list): 402 secrets_items = [(secret_dict, "") for secret_dict in self.secrets] 403 else: 404 secrets_items = [ 405 (secret_dict, secret_name) 406 for secret_name, secret_dict in self.secrets.items() 407 ] 408 409 for secret_dict, secret_name in secrets_items: 410 secret_settings: t.List[str] = [] 411 for field, setting in secret_dict.items(): 412 secret_settings.append(f"{field} '{setting}'") 413 if secret_settings: 414 secret_clause = ", ".join(secret_settings) 415 try: 416 cursor.execute( 417 f"CREATE OR REPLACE SECRET {secret_name} ({secret_clause});" 418 ) 419 except Exception as e: 420 raise ConfigError(f"Failed to create secret: {e}") 421 422 if self.filesystems: 423 from fsspec import filesystem # type: ignore 424 425 for file_system in self.filesystems: 426 options = file_system.copy() 427 fs = options.pop("fs") 428 fs = filesystem(fs, **options) 429 cursor.register_filesystem(fs) 430 431 for i, (alias, path_options) in enumerate( 432 (getattr(self, "catalogs", None) or {}).items() 433 ): 434 # we parse_identifier and generate to ensure that `alias` has exactly one set of quotes 435 # regardless of whether it comes in quoted or not 436 alias = exp.parse_identifier(alias, dialect="duckdb").sql( 437 identify=True, dialect="duckdb" 438 ) 439 try: 440 if isinstance(path_options, DuckDBAttachOptions): 441 query = path_options.to_sql(alias) 442 else: 443 query = f"ATTACH IF NOT EXISTS '{path_options}'" 444 if not path_options.startswith("md:"): 445 query += f" AS {alias}" 446 cursor.execute(query) 447 except BinderException as e: 448 # If a user tries to create a catalog pointing at `:memory:` and with the name `memory` 449 # then we don't want to raise since this happens by default. They are just doing this to 450 # set it as the default catalog. 451 # If a user tried to attach a MotherDuck database/share which has already by attached via 452 # `ATTACH 'md:'`, then we don't want to raise since this is expected. 453 if ( 454 not ( 455 'database with name "memory" already exists' in str(e) 456 and path_options == ":memory:" 457 ) 458 and f"""database with name "{path_options.path.replace("md:", "")}" already exists""" 459 not in str(e) 460 ): 461 raise e 462 if i == 0 and not getattr(self, "database", None): 463 cursor.execute(f"USE {alias}") 464 465 return init 466 467 def create_engine_adapter( 468 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 469 ) -> EngineAdapter: 470 """Checks if another engine adapter has already been created that shares a catalog that points to the same data 471 file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration 472 associated with the new adapter will be ignored.""" 473 data_files = set((self.catalogs or {}).values()) 474 if self.database: 475 if isinstance(self, MotherDuckConnectionConfig): 476 data_files.add( 477 f"md:{self.database}" 478 + (f"?motherduck_token={self.token}" if self.token else "") 479 ) 480 else: 481 data_files.add(self.database) 482 data_files.discard(":memory:") 483 for data_file in data_files: 484 key = data_file if isinstance(data_file, str) else data_file.path 485 adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key) 486 if adapter is not None: 487 logger.info( 488 f"Using existing DuckDB adapter due to overlapping data file: {self._mask_sensitive_data(key)}" 489 ) 490 return adapter 491 492 if data_files: 493 masked_files = { 494 self._mask_sensitive_data(file if isinstance(file, str) else file.path) 495 for file in data_files 496 } 497 logger.info(f"Creating new DuckDB adapter for data files: {masked_files}") 498 else: 499 logger.info("Creating new DuckDB adapter for in-memory database") 500 adapter = super().create_engine_adapter( 501 register_comments_override, concurrent_tasks=concurrent_tasks 502 ) 503 for data_file in data_files: 504 key = data_file if isinstance(data_file, str) else data_file.path 505 BaseDuckDBConnectionConfig._data_file_to_adapter[key] = adapter 506 return adapter 507 508 def get_catalog(self) -> t.Optional[str]: 509 if self.database: 510 # Remove `:` from the database name in order to handle if `:memory:` is passed in 511 return pathlib.Path(self.database.replace(":memory:", "memory")).stem 512 if self.catalogs: 513 return list(self.catalogs)[0] 514 return None 515 516 def _mask_sensitive_data(self, string: str) -> str: 517 # Mask MotherDuck tokens with fixed number of asterisks 518 result = MOTHERDUCK_TOKEN_REGEX.sub( 519 lambda m: f"{m.group(1)}{m.group(2)}{'*' * 8 if m.group(3) else ''}", string 520 ) 521 # Mask PostgreSQL/MySQL passwords with fixed number of asterisks 522 result = PASSWORD_REGEX.sub(lambda m: f"{m.group(1)}{'*' * 8}", result) 523 return result
Common configuration for the DuckDB-based connections.
Arguments:
- database: The optional database name. If not specified, the in-memory database will be used.
- catalogs: Key is the name of the catalog and value is the path.
- extensions: A list of autoloadable extensions to load.
- connector_config: A dictionary of configuration to pass into the duckdb connector.
- secrets: A list of dictionaries used to generate DuckDB secrets for authenticating with external services (e.g. S3).
- filesystems: A list of dictionaries used to register
fsspecfilesystems to the DuckDB cursor. - concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
- register_comments: Whether or not to register model comments with the SQL engine.
- pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
- token: The optional MotherDuck token. If not specified and a MotherDuck path is in the catalog, the user will be prompted to login with their web browser.
467 def create_engine_adapter( 468 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 469 ) -> EngineAdapter: 470 """Checks if another engine adapter has already been created that shares a catalog that points to the same data 471 file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration 472 associated with the new adapter will be ignored.""" 473 data_files = set((self.catalogs or {}).values()) 474 if self.database: 475 if isinstance(self, MotherDuckConnectionConfig): 476 data_files.add( 477 f"md:{self.database}" 478 + (f"?motherduck_token={self.token}" if self.token else "") 479 ) 480 else: 481 data_files.add(self.database) 482 data_files.discard(":memory:") 483 for data_file in data_files: 484 key = data_file if isinstance(data_file, str) else data_file.path 485 adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key) 486 if adapter is not None: 487 logger.info( 488 f"Using existing DuckDB adapter due to overlapping data file: {self._mask_sensitive_data(key)}" 489 ) 490 return adapter 491 492 if data_files: 493 masked_files = { 494 self._mask_sensitive_data(file if isinstance(file, str) else file.path) 495 for file in data_files 496 } 497 logger.info(f"Creating new DuckDB adapter for data files: {masked_files}") 498 else: 499 logger.info("Creating new DuckDB adapter for in-memory database") 500 adapter = super().create_engine_adapter( 501 register_comments_override, concurrent_tasks=concurrent_tasks 502 ) 503 for data_file in data_files: 504 key = data_file if isinstance(data_file, str) else data_file.path 505 BaseDuckDBConnectionConfig._data_file_to_adapter[key] = adapter 506 return adapter
Checks if another engine adapter has already been created that shares a catalog that points to the same data file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration associated with the new adapter will be ignored.
508 def get_catalog(self) -> t.Optional[str]: 509 if self.database: 510 # Remove `:` from the database name in order to handle if `:memory:` is passed in 511 return pathlib.Path(self.database.replace(":memory:", "memory")).stem 512 if self.catalogs: 513 return list(self.catalogs)[0] 514 return None
The catalog for this connection
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
526class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig): 527 """Configuration for the MotherDuck connection.""" 528 529 type_: t.Literal["motherduck"] = Field(alias="type", default="motherduck") 530 DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb" 531 DISPLAY_NAME: t.ClassVar[t.Literal["MotherDuck"]] = "MotherDuck" 532 DISPLAY_ORDER: t.ClassVar[t.Literal[5]] = 5 533 534 @property 535 def _connection_kwargs_keys(self) -> t.Set[str]: 536 return set() 537 538 @property 539 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 540 """kwargs that are for execution config only""" 541 from sqlmesh import __version__ 542 543 custom_user_agent_config = {"custom_user_agent": f"SQLMesh/{__version__}"} 544 connection_str = "md:" 545 if self.database: 546 # Attach single MD database instead of all databases on the account 547 connection_str += f"{self.database}?attach_mode=single" 548 if self.token: 549 connection_str += f"{'&' if self.database else '?'}motherduck_token={self.token}" 550 return {"database": connection_str, "config": custom_user_agent_config} 551 552 @property 553 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 554 return {"is_motherduck": True}
Configuration for the MotherDuck connection.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
- BaseDuckDBConnectionConfig
- database
- catalogs
- extensions
- connector_config
- secrets
- filesystems
- concurrent_tasks
- register_comments
- pre_ping
- token
- create_engine_adapter
- get_catalog
557class DuckDBConnectionConfig(BaseDuckDBConnectionConfig): 558 """Configuration for the DuckDB connection.""" 559 560 type_: t.Literal["duckdb"] = Field(alias="type", default="duckdb") 561 DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb" 562 DISPLAY_NAME: t.ClassVar[t.Literal["DuckDB"]] = "DuckDB" 563 DISPLAY_ORDER: t.ClassVar[t.Literal[1]] = 1
Configuration for the DuckDB connection.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
- BaseDuckDBConnectionConfig
- database
- catalogs
- extensions
- connector_config
- secrets
- filesystems
- concurrent_tasks
- register_comments
- pre_ping
- token
- create_engine_adapter
- get_catalog
566class SnowflakeConnectionConfig(ConnectionConfig): 567 """Configuration for the Snowflake connection. 568 569 Args: 570 account: The Snowflake account name. 571 user: The Snowflake username. 572 password: The Snowflake password. 573 warehouse: The optional warehouse name. 574 database: The optional database name. 575 role: The optional role name. 576 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 577 authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake"). 578 Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183 579 token: The optional oauth access token to use for authentication when authenticator is set to "oauth". 580 private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation 581 private_key_path: The optional path to the private key to use for authentication. This would be used instead of `private_key`. 582 private_key_passphrase: The optional passphrase to use to decrypt `private_key` or `private_key_path`. Keys can be created without encryption so only provide this if needed. 583 register_comments: Whether or not to register model comments with the SQL engine. 584 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 585 session_parameters: The optional session parameters to set for the connection. 586 host: Host address for the connection. 587 port: Port for the connection. 588 """ 589 590 account: str 591 user: t.Optional[str] = None 592 password: t.Optional[str] = None 593 warehouse: t.Optional[str] = None 594 database: t.Optional[str] = None 595 role: t.Optional[str] = None 596 authenticator: t.Optional[str] = None 597 token: t.Optional[str] = None 598 host: t.Optional[str] = None 599 port: t.Optional[int] = None 600 application: t.Literal["Tobiko_SQLMesh"] = "Tobiko_SQLMesh" 601 602 # Private Key Auth 603 private_key: t.Optional[t.Union[str, bytes]] = None 604 private_key_path: t.Optional[str] = None 605 private_key_passphrase: t.Optional[str] = None 606 607 concurrent_tasks: int = 4 608 register_comments: bool = True 609 pre_ping: bool = False 610 611 session_parameters: t.Optional[dict] = None 612 613 type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake") 614 DIALECT: t.ClassVar[t.Literal["snowflake"]] = "snowflake" 615 DISPLAY_NAME: t.ClassVar[t.Literal["Snowflake"]] = "Snowflake" 616 DISPLAY_ORDER: t.ClassVar[t.Literal[2]] = 2 617 618 _concurrent_tasks_validator = concurrent_tasks_validator 619 620 @model_validator(mode="before") 621 def _validate_authenticator(cls, data: t.Any) -> t.Any: 622 if not isinstance(data, dict): 623 return data 624 625 from snowflake.connector.network import DEFAULT_AUTHENTICATOR, OAUTH_AUTHENTICATOR 626 627 auth = data.get("authenticator") 628 auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR 629 user = data.get("user") 630 password = data.get("password") 631 data["private_key"] = cls._get_private_key(data, auth) # type: ignore 632 633 if ( 634 auth == DEFAULT_AUTHENTICATOR 635 and not data.get("private_key") 636 and (not user or not password) 637 ): 638 raise ConfigError("User and password must be provided if using default authentication") 639 640 if auth == OAUTH_AUTHENTICATOR and not data.get("token"): 641 raise ConfigError("Token must be provided if using oauth authentication") 642 643 return data 644 645 _engine_import_validator = _get_engine_import_validator( 646 "snowflake.connector.network", "snowflake" 647 ) 648 649 @classmethod 650 def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]: 651 """ 652 source: https://github.com/dbt-labs/dbt-snowflake/blob/0374b4ec948982f2ac8ec0c95d53d672ad19e09c/dbt/adapters/snowflake/connections.py#L247C5-L285C1 653 654 Overall code change: Use local variables instead of class attributes + Validation 655 """ 656 # Start custom code 657 from cryptography.hazmat.backends import default_backend 658 from cryptography.hazmat.primitives import serialization 659 from snowflake.connector.network import ( 660 DEFAULT_AUTHENTICATOR, 661 KEY_PAIR_AUTHENTICATOR, 662 ) 663 664 private_key = values.get("private_key") 665 private_key_path = values.get("private_key_path") 666 private_key_passphrase = values.get("private_key_passphrase") 667 user = values.get("user") 668 password = values.get("password") 669 auth = auth if auth and auth != DEFAULT_AUTHENTICATOR else KEY_PAIR_AUTHENTICATOR 670 671 if not private_key and not private_key_path: 672 return None 673 if private_key and private_key_path: 674 raise ConfigError("Cannot specify both `private_key` and `private_key_path`") 675 if auth != KEY_PAIR_AUTHENTICATOR: 676 raise ConfigError( 677 f"Private key or private key path can only be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 678 ) 679 if not user: 680 raise ConfigError( 681 f"User must be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 682 ) 683 if password: 684 raise ConfigError( 685 f"Password cannot be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 686 ) 687 688 if isinstance(private_key, bytes): 689 return private_key 690 # End Custom Code 691 692 if private_key_passphrase: 693 encoded_passphrase = private_key_passphrase.encode() 694 else: 695 encoded_passphrase = None 696 697 if private_key: 698 if private_key.startswith("-"): 699 p_key = serialization.load_pem_private_key( 700 data=bytes(private_key, "utf-8"), 701 password=encoded_passphrase, 702 backend=default_backend(), 703 ) 704 705 else: 706 p_key = serialization.load_der_private_key( 707 data=base64.b64decode(private_key), 708 password=encoded_passphrase, 709 backend=default_backend(), 710 ) 711 712 elif private_key_path: 713 with open(private_key_path, "rb") as key: 714 p_key = serialization.load_pem_private_key( 715 key.read(), password=encoded_passphrase, backend=default_backend() 716 ) 717 else: 718 return None 719 720 return p_key.private_bytes( 721 encoding=serialization.Encoding.DER, 722 format=serialization.PrivateFormat.PKCS8, 723 encryption_algorithm=serialization.NoEncryption(), 724 ) 725 726 @property 727 def _connection_kwargs_keys(self) -> t.Set[str]: 728 return { 729 "user", 730 "password", 731 "account", 732 "warehouse", 733 "database", 734 "role", 735 "authenticator", 736 "token", 737 "private_key", 738 "session_parameters", 739 "application", 740 "host", 741 "port", 742 } 743 744 @property 745 def _engine_adapter(self) -> t.Type[EngineAdapter]: 746 return engine_adapter.SnowflakeEngineAdapter 747 748 @property 749 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 750 return {"autocommit": False} 751 752 @property 753 def _connection_factory(self) -> t.Callable: 754 from snowflake import connector 755 756 return connector.connect
Configuration for the Snowflake connection.
Arguments:
- account: The Snowflake account name.
- user: The Snowflake username.
- password: The Snowflake password.
- warehouse: The optional warehouse name.
- database: The optional database name.
- role: The optional role name.
- concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
- authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake"). Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183
- token: The optional oauth access token to use for authentication when authenticator is set to "oauth".
- private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation
- private_key_path: The optional path to the private key to use for authentication. This would be used instead of
private_key. - private_key_passphrase: The optional passphrase to use to decrypt
private_keyorprivate_key_path. Keys can be created without encryption so only provide this if needed. - register_comments: Whether or not to register model comments with the SQL engine.
- pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
- session_parameters: The optional session parameters to set for the connection.
- host: Host address for the connection.
- port: Port for the connection.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
759class DatabricksConnectionConfig(ConnectionConfig): 760 """ 761 Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations 762 763 Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39 764 OAuth ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication 765 766 Args: 767 server_hostname: Databricks instance host name. 768 http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) 769 or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) 770 access_token: Http Bearer access token, e.g. Databricks Personal Access Token. 771 auth_type: Set to 'databricks-oauth' or 'azure-oauth' to trigger OAuth (or dont set at all to use `access_token`) 772 oauth_client_id: Client ID to use when auth_type is set to one of the 'oauth' types 773 oauth_client_secret: Client Secret to use when auth_type is set to one of the 'oauth' types 774 catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in 775 the Databricks cluster (most likely `hive_metastore`). 776 http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request 777 session_configuration: An optional dictionary of Spark session parameters. 778 Execute the SQL command `SET -v` to get a full list of available commands. 779 databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect. 780 Defaults to the `server_hostname` value. 781 databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect. 782 Defaults to the `access_token` value. 783 databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect. 784 Defaults to deriving the cluster id from the `http_path` value. 785 force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector. 786 disable_databricks_connect: Even if databricks connect is installed, do not use it. 787 disable_spark_session: Do not use SparkSession if it is available (like when running in a notebook). 788 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 789 """ 790 791 server_hostname: t.Optional[str] = None 792 http_path: t.Optional[str] = None 793 access_token: t.Optional[str] = None 794 auth_type: t.Optional[str] = None 795 oauth_client_id: t.Optional[str] = None 796 oauth_client_secret: t.Optional[str] = None 797 catalog: t.Optional[str] = None 798 http_headers: t.Optional[t.List[t.Tuple[str, str]]] = None 799 session_configuration: t.Optional[t.Dict[str, t.Any]] = None 800 databricks_connect_server_hostname: t.Optional[str] = None 801 databricks_connect_access_token: t.Optional[str] = None 802 databricks_connect_cluster_id: t.Optional[str] = None 803 databricks_connect_use_serverless: bool = False 804 force_databricks_connect: bool = False 805 disable_databricks_connect: bool = False 806 disable_spark_session: bool = False 807 808 concurrent_tasks: int = 1 809 register_comments: bool = True 810 pre_ping: t.Literal[False] = False 811 812 type_: t.Literal["databricks"] = Field(alias="type", default="databricks") 813 DIALECT: t.ClassVar[t.Literal["databricks"]] = "databricks" 814 DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks" 815 DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3 816 817 _concurrent_tasks_validator = concurrent_tasks_validator 818 _http_headers_validator = http_headers_validator 819 820 @model_validator(mode="before") 821 def _databricks_connect_validator(cls, data: t.Any) -> t.Any: 822 # SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block. 823 # Disabling this allows SQLMesh to determine what should be shown to the user. 824 # Ex: We describe a table to see if it exists and therefore that execution can fail but we don't need to show 825 # the user since it is expected if the table doesn't exist. Without this change the user would see the error. 826 logging.getLogger("SQLQueryContextLogger").setLevel(logging.CRITICAL) 827 828 if not isinstance(data, dict): 829 return data 830 831 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 832 833 if DatabricksEngineAdapter.can_access_spark_session( 834 bool(data.get("disable_spark_session")) 835 ): 836 return data 837 838 databricks_connect_use_serverless = data.get("databricks_connect_use_serverless") 839 server_hostname, http_path, access_token, auth_type = ( 840 data.get("server_hostname"), 841 data.get("http_path"), 842 data.get("access_token"), 843 data.get("auth_type"), 844 ) 845 846 if (not server_hostname or not http_path or not access_token) and ( 847 not databricks_connect_use_serverless and not auth_type 848 ): 849 raise ValueError( 850 "`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook" 851 ) 852 if ( 853 databricks_connect_use_serverless 854 and not server_hostname 855 and not data.get("databricks_connect_server_hostname") 856 ): 857 raise ValueError( 858 "`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set" 859 ) 860 if DatabricksEngineAdapter.can_access_databricks_connect( 861 bool(data.get("disable_databricks_connect")) 862 ): 863 if not data.get("databricks_connect_access_token"): 864 data["databricks_connect_access_token"] = access_token 865 if not data.get("databricks_connect_server_hostname"): 866 data["databricks_connect_server_hostname"] = f"https://{server_hostname}" 867 if not databricks_connect_use_serverless and not data.get( 868 "databricks_connect_cluster_id" 869 ): 870 if t.TYPE_CHECKING: 871 assert http_path is not None 872 data["databricks_connect_cluster_id"] = http_path.split("/")[-1] 873 874 if auth_type: 875 from databricks.sql.auth.auth import AuthType 876 877 all_data = [m.value for m in AuthType] 878 if auth_type not in all_data: 879 raise ValueError( 880 f"`auth_type` {auth_type} does not match a valid option: {all_data}" 881 ) 882 883 client_id = data.get("oauth_client_id") 884 client_secret = data.get("oauth_client_secret") 885 886 if client_secret and not client_id: 887 raise ValueError( 888 "`oauth_client_id` is required when `oauth_client_secret` is specified" 889 ) 890 891 if not http_path: 892 raise ValueError("`http_path` is still required when using `auth_type`") 893 894 return data 895 896 _engine_import_validator = _get_engine_import_validator("databricks", "databricks") 897 898 @property 899 def _connection_kwargs_keys(self) -> t.Set[str]: 900 if self.use_spark_session_only: 901 return set() 902 return { 903 "server_hostname", 904 "http_path", 905 "access_token", 906 "http_headers", 907 "session_configuration", 908 "catalog", 909 } 910 911 @property 912 def _engine_adapter(self) -> t.Type[engine_adapter.DatabricksEngineAdapter]: 913 return engine_adapter.DatabricksEngineAdapter 914 915 @property 916 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 917 return { 918 k: v 919 for k, v in self.dict().items() 920 if k.startswith("databricks_connect_") 921 or k in ("catalog", "disable_databricks_connect", "disable_spark_session") 922 } 923 924 @property 925 def use_spark_session_only(self) -> bool: 926 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 927 928 return ( 929 DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session) 930 or self.force_databricks_connect 931 ) 932 933 @property 934 def _connection_factory(self) -> t.Callable: 935 if self.use_spark_session_only: 936 from sqlmesh.engines.spark.db_api.spark_session import connection 937 938 return connection 939 940 from databricks import sql # type: ignore 941 942 return sql.connect 943 944 @property 945 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 946 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 947 948 if not self.use_spark_session_only: 949 conn_kwargs: t.Dict[str, t.Any] = { 950 "_user_agent_entry": "sqlmesh", 951 } 952 953 if self.auth_type and "oauth" in self.auth_type: 954 # there are two types of oauth: User-to-Machine (U2M) and Machine-to-Machine (M2M) 955 if self.oauth_client_secret: 956 # if a client_secret exists, then a client_id also exists and we are using M2M 957 # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication 958 # ref: https://github.com/databricks/databricks-sql-python/blob/main/examples/m2m_oauth.py 959 from databricks.sdk.core import oauth_service_principal, Config 960 961 config = Config( 962 host=f"https://{self.server_hostname}", 963 client_id=self.oauth_client_id, 964 client_secret=self.oauth_client_secret, 965 ) 966 conn_kwargs["credentials_provider"] = lambda: oauth_service_principal(config) 967 else: 968 # if auth_type is set to an 'oauth' type but no client_id/secret are set, then we are using U2M 969 # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-user-to-machine-u2m-authentication 970 conn_kwargs["auth_type"] = self.auth_type 971 972 return conn_kwargs 973 974 if DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session): 975 from pyspark.sql import SparkSession 976 977 return dict( 978 spark=SparkSession.getActiveSession(), 979 catalog=self.catalog, 980 ) 981 982 from databricks.connect import DatabricksSession 983 984 if t.TYPE_CHECKING: 985 assert self.databricks_connect_server_hostname is not None 986 assert self.databricks_connect_access_token is not None 987 988 if self.databricks_connect_use_serverless: 989 builder = DatabricksSession.builder.remote( 990 host=self.databricks_connect_server_hostname, 991 token=self.databricks_connect_access_token, 992 serverless=True, 993 ) 994 else: 995 if t.TYPE_CHECKING: 996 assert self.databricks_connect_cluster_id is not None 997 builder = DatabricksSession.builder.remote( 998 host=self.databricks_connect_server_hostname, 999 token=self.databricks_connect_access_token, 1000 cluster_id=self.databricks_connect_cluster_id, 1001 ) 1002 1003 return dict( 1004 spark=builder.userAgent("sqlmesh").getOrCreate(), 1005 catalog=self.catalog, 1006 )
Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations
Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39 OAuth ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication
Arguments:
- server_hostname: Databricks instance host name.
- http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
- access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
- auth_type: Set to 'databricks-oauth' or 'azure-oauth' to trigger OAuth (or dont set at all to use
access_token) - oauth_client_id: Client ID to use when auth_type is set to one of the 'oauth' types
- oauth_client_secret: Client Secret to use when auth_type is set to one of the 'oauth' types
- catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in
the Databricks cluster (most likely
hive_metastore). - http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request
- session_configuration: An optional dictionary of Spark session parameters.
Execute the SQL command
SET -vto get a full list of available commands. - databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect.
Defaults to the
server_hostnamevalue. - databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect.
Defaults to the
access_tokenvalue. - databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect.
Defaults to deriving the cluster id from the
http_pathvalue. - force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector.
- disable_databricks_connect: Even if databricks connect is installed, do not use it.
- disable_spark_session: Do not use SparkSession if it is available (like when running in a notebook).
- pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1009class BigQueryConnectionMethod(str, Enum): 1010 OAUTH = "oauth" 1011 OAUTH_SECRETS = "oauth-secrets" 1012 SERVICE_ACCOUNT = "service-account" 1013 SERVICE_ACCOUNT_JSON = "service-account-json"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
1016class BigQueryPriority(str, Enum): 1017 BATCH = "batch" 1018 INTERACTIVE = "interactive" 1019 1020 @property 1021 def is_batch(self) -> bool: 1022 return self == self.BATCH 1023 1024 @property 1025 def is_interactive(self) -> bool: 1026 return self == self.INTERACTIVE 1027 1028 @property 1029 def bigquery_constant(self) -> str: 1030 from google.cloud.bigquery import QueryPriority 1031 1032 if self.is_batch: 1033 return QueryPriority.BATCH 1034 return QueryPriority.INTERACTIVE
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
1037class BigQueryConnectionConfig(ConnectionConfig): 1038 """ 1039 BigQuery Connection Configuration. 1040 """ 1041 1042 method: BigQueryConnectionMethod = BigQueryConnectionMethod.OAUTH 1043 1044 project: t.Optional[str] = None 1045 execution_project: t.Optional[str] = None 1046 quota_project: t.Optional[str] = None 1047 location: t.Optional[str] = None 1048 # Keyfile Auth 1049 keyfile: t.Optional[str] = None 1050 keyfile_json: t.Optional[t.Dict[str, t.Any]] = None 1051 # Oath Secret Auth 1052 token: t.Optional[str] = None 1053 refresh_token: t.Optional[str] = None 1054 client_id: t.Optional[str] = None 1055 client_secret: t.Optional[str] = None 1056 token_uri: t.Optional[str] = None 1057 scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/bigquery",) 1058 impersonated_service_account: t.Optional[str] = None 1059 # Extra Engine Config 1060 job_creation_timeout_seconds: t.Optional[int] = None 1061 job_execution_timeout_seconds: t.Optional[int] = None 1062 job_retries: t.Optional[int] = 1 1063 job_retry_deadline_seconds: t.Optional[int] = None 1064 priority: t.Optional[BigQueryPriority] = None 1065 maximum_bytes_billed: t.Optional[int] = None 1066 reservation: t.Optional[str] = None 1067 1068 concurrent_tasks: int = 1 1069 register_comments: bool = True 1070 pre_ping: t.Literal[False] = False 1071 1072 type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery") 1073 DIALECT: t.ClassVar[t.Literal["bigquery"]] = "bigquery" 1074 DISPLAY_NAME: t.ClassVar[t.Literal["BigQuery"]] = "BigQuery" 1075 DISPLAY_ORDER: t.ClassVar[t.Literal[4]] = 4 1076 1077 _engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery") 1078 1079 @field_validator("execution_project") 1080 def validate_execution_project( 1081 cls, 1082 v: t.Optional[str], 1083 info: ValidationInfo, 1084 ) -> t.Optional[str]: 1085 if v and not info.data.get("project"): 1086 raise ConfigError( 1087 "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location." 1088 ) 1089 return v 1090 1091 @field_validator("quota_project") 1092 def validate_quota_project( 1093 cls, 1094 v: t.Optional[str], 1095 info: ValidationInfo, 1096 ) -> t.Optional[str]: 1097 if v and not info.data.get("project"): 1098 raise ConfigError( 1099 "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location." 1100 ) 1101 return v 1102 1103 @property 1104 def _connection_kwargs_keys(self) -> t.Set[str]: 1105 return set() 1106 1107 @property 1108 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1109 return engine_adapter.BigQueryEngineAdapter 1110 1111 @property 1112 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1113 """The static connection kwargs for this connection""" 1114 import google.auth 1115 from google.auth import impersonated_credentials 1116 from google.api_core import client_info, client_options 1117 from google.oauth2 import credentials, service_account 1118 1119 if self.method == BigQueryConnectionMethod.OAUTH: 1120 creds, _ = google.auth.default(scopes=self.scopes) 1121 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT: 1122 creds = service_account.Credentials.from_service_account_file( 1123 self.keyfile, scopes=self.scopes 1124 ) 1125 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: 1126 creds = service_account.Credentials.from_service_account_info( 1127 self.keyfile_json, scopes=self.scopes 1128 ) 1129 elif self.method == BigQueryConnectionMethod.OAUTH_SECRETS: 1130 creds = credentials.Credentials( 1131 token=self.token, 1132 refresh_token=self.refresh_token, 1133 client_id=self.client_id, 1134 client_secret=self.client_secret, 1135 token_uri=self.token_uri, 1136 scopes=self.scopes, 1137 ) 1138 else: 1139 raise ConfigError("Invalid BigQuery Connection Method") 1140 1141 if self.impersonated_service_account: 1142 creds = impersonated_credentials.Credentials( 1143 source_credentials=creds, 1144 target_principal=self.impersonated_service_account, 1145 target_scopes=self.scopes, 1146 ) 1147 1148 options = client_options.ClientOptions(quota_project_id=self.quota_project) 1149 project = self.execution_project or self.project or None 1150 1151 client = google.cloud.bigquery.Client( 1152 project=project and exp.parse_identifier(project, dialect="bigquery").name, 1153 credentials=creds, 1154 location=self.location, 1155 client_info=client_info.ClientInfo(user_agent="sqlmesh"), 1156 client_options=options, 1157 ) 1158 1159 return { 1160 "client": client, 1161 } 1162 1163 @property 1164 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1165 return { 1166 k: v 1167 for k, v in self.dict().items() 1168 if k 1169 in { 1170 "job_creation_timeout_seconds", 1171 "job_execution_timeout_seconds", 1172 "job_retries", 1173 "job_retry_deadline_seconds", 1174 "priority", 1175 "maximum_bytes_billed", 1176 "reservation", 1177 } 1178 } 1179 1180 @property 1181 def _connection_factory(self) -> t.Callable: 1182 from google.cloud.bigquery.dbapi import connect 1183 1184 return connect 1185 1186 def get_catalog(self) -> t.Optional[str]: 1187 return self.project
BigQuery Connection Configuration.
1079 @field_validator("execution_project") 1080 def validate_execution_project( 1081 cls, 1082 v: t.Optional[str], 1083 info: ValidationInfo, 1084 ) -> t.Optional[str]: 1085 if v and not info.data.get("project"): 1086 raise ConfigError( 1087 "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location." 1088 ) 1089 return v
1091 @field_validator("quota_project") 1092 def validate_quota_project( 1093 cls, 1094 v: t.Optional[str], 1095 info: ValidationInfo, 1096 ) -> t.Optional[str]: 1097 if v and not info.data.get("project"): 1098 raise ConfigError( 1099 "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location." 1100 ) 1101 return v
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1190class GCPPostgresConnectionConfig(ConnectionConfig): 1191 """ 1192 Postgres Connection Configuration for GCP. 1193 1194 Args: 1195 instance_connection_string: Connection name for the postgres instance. 1196 user: Postgres or IAM user's name 1197 password: The postgres user's password. Only needed when the user is a postgres user. 1198 enable_iam_auth: Set to True when user is an IAM user. 1199 db: Name of the db to connect to. 1200 keyfile: string path to json service account credentials file 1201 keyfile_json: dict service account credentials info 1202 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 1203 """ 1204 1205 instance_connection_string: str 1206 user: str 1207 password: t.Optional[str] = None 1208 enable_iam_auth: t.Optional[bool] = None 1209 db: str 1210 ip_type: t.Union[t.Literal["public"], t.Literal["private"], t.Literal["psc"]] = "public" 1211 # Keyfile Auth 1212 keyfile: t.Optional[str] = None 1213 keyfile_json: t.Optional[t.Dict[str, t.Any]] = None 1214 timeout: t.Optional[int] = None 1215 scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/sqlservice.admin",) 1216 driver: str = "pg8000" 1217 1218 type_: t.Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres") 1219 DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres" 1220 DISPLAY_NAME: t.ClassVar[t.Literal["GCP Postgres"]] = "GCP Postgres" 1221 DISPLAY_ORDER: t.ClassVar[t.Literal[13]] = 13 1222 1223 concurrent_tasks: int = 4 1224 register_comments: bool = True 1225 pre_ping: bool = True 1226 1227 _engine_import_validator = _get_engine_import_validator( 1228 "google.cloud.sql", "gcp_postgres", "gcppostgres" 1229 ) 1230 1231 @model_validator(mode="before") 1232 def _validate_auth_method(cls, data: t.Any) -> t.Any: 1233 if not isinstance(data, dict): 1234 return data 1235 1236 password = data.get("password") 1237 enable_iam_auth = data.get("enable_iam_auth") 1238 1239 if not password and not enable_iam_auth: 1240 raise ConfigError( 1241 "GCP Postgres connection configuration requires either password set" 1242 " for a postgres user account or enable_iam_auth set to 'True'" 1243 " for an IAM user account." 1244 ) 1245 1246 return data 1247 1248 @property 1249 def _connection_kwargs_keys(self) -> t.Set[str]: 1250 return { 1251 "instance_connection_string", 1252 "driver", 1253 "user", 1254 "password", 1255 "db", 1256 "enable_iam_auth", 1257 "timeout", 1258 } 1259 1260 @property 1261 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1262 return engine_adapter.PostgresEngineAdapter 1263 1264 @property 1265 def _connection_factory(self) -> t.Callable: 1266 from google.cloud.sql.connector import Connector 1267 from google.oauth2 import service_account 1268 1269 creds = None 1270 if self.keyfile: 1271 creds = service_account.Credentials.from_service_account_file( 1272 self.keyfile, scopes=self.scopes 1273 ) 1274 elif self.keyfile_json: 1275 creds = service_account.Credentials.from_service_account_info( 1276 self.keyfile_json, scopes=self.scopes 1277 ) 1278 1279 kwargs = { 1280 "credentials": creds, 1281 "ip_type": self.ip_type, 1282 } 1283 1284 if self.timeout: 1285 kwargs["timeout"] = self.timeout 1286 1287 return Connector(**kwargs).connect # type: ignore
Postgres Connection Configuration for GCP.
Arguments:
- instance_connection_string: Connection name for the postgres instance.
- user: Postgres or IAM user's name
- password: The postgres user's password. Only needed when the user is a postgres user.
- enable_iam_auth: Set to True when user is an IAM user.
- db: Name of the db to connect to.
- keyfile: string path to json service account credentials file
- keyfile_json: dict service account credentials info
- pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1290class RedshiftConnectionConfig(ConnectionConfig): 1291 """ 1292 Redshift Connection Configuration. 1293 1294 Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146 1295 Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported. 1296 1297 Args: 1298 user: The username to use for authentication with the Amazon Redshift cluster. 1299 password: The password to use for authentication with the Amazon Redshift cluster. 1300 database: The name of the database instance to connect to. 1301 host: The hostname of the Amazon Redshift cluster. 1302 port: The port number of the Amazon Redshift cluster. Default value is 5439. 1303 source_address: No description provided 1304 unix_sock: No description provided 1305 ssl: Is SSL enabled. Default value is ``True``. SSL must be enabled when authenticating using IAM. 1306 sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported. 1307 timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout. 1308 tcp_keepalive: Is `TCP keepalive <https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive>`_ used. The default value is ``True``. 1309 application_name: Sets the application name. The default value is None. 1310 preferred_role: The IAM role preferred for the current connection. 1311 principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy. 1312 credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster. 1313 region: The AWS region where the Amazon Redshift cluster is located. 1314 cluster_identifier: The cluster identifier of the Amazon Redshift cluster. 1315 iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP. 1316 is_serverless: Redshift end-point is serverless or provisional. Default value false. 1317 serverless_acct_id: The account ID of the serverless. Default value None 1318 serverless_work_group: The name of work group for serverless end point. Default value None. 1319 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 1320 enable_merge: Whether to use the Redshift merge operation instead of the SQLMesh logical merge. 1321 """ 1322 1323 user: t.Optional[str] = None 1324 password: t.Optional[str] = None 1325 database: t.Optional[str] = None 1326 host: t.Optional[str] = None 1327 port: t.Optional[int] = None 1328 source_address: t.Optional[str] = None 1329 unix_sock: t.Optional[str] = None 1330 ssl: t.Optional[bool] = None 1331 sslmode: t.Optional[str] = None 1332 timeout: t.Optional[int] = None 1333 tcp_keepalive: t.Optional[bool] = None 1334 application_name: t.Optional[str] = None 1335 preferred_role: t.Optional[str] = None 1336 principal_arn: t.Optional[str] = None 1337 credentials_provider: t.Optional[str] = None 1338 region: t.Optional[str] = None 1339 cluster_identifier: t.Optional[str] = None 1340 iam: t.Optional[bool] = None 1341 is_serverless: t.Optional[bool] = None 1342 serverless_acct_id: t.Optional[str] = None 1343 serverless_work_group: t.Optional[str] = None 1344 enable_merge: t.Optional[bool] = None 1345 1346 concurrent_tasks: int = 4 1347 register_comments: bool = True 1348 pre_ping: bool = False 1349 1350 type_: t.Literal["redshift"] = Field(alias="type", default="redshift") 1351 DIALECT: t.ClassVar[t.Literal["redshift"]] = "redshift" 1352 DISPLAY_NAME: t.ClassVar[t.Literal["Redshift"]] = "Redshift" 1353 DISPLAY_ORDER: t.ClassVar[t.Literal[7]] = 7 1354 1355 _engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift") 1356 1357 @property 1358 def _connection_kwargs_keys(self) -> t.Set[str]: 1359 return { 1360 "user", 1361 "password", 1362 "database", 1363 "host", 1364 "port", 1365 "source_address", 1366 "unix_sock", 1367 "ssl", 1368 "sslmode", 1369 "timeout", 1370 "tcp_keepalive", 1371 "application_name", 1372 "preferred_role", 1373 "principal_arn", 1374 "credentials_provider", 1375 "region", 1376 "cluster_identifier", 1377 "iam", 1378 "is_serverless", 1379 "serverless_acct_id", 1380 "serverless_work_group", 1381 } 1382 1383 @property 1384 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1385 return engine_adapter.RedshiftEngineAdapter 1386 1387 @property 1388 def _connection_factory(self) -> t.Callable: 1389 from redshift_connector import connect 1390 1391 return connect 1392 1393 @property 1394 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1395 return {"enable_merge": self.enable_merge}
Redshift Connection Configuration.
Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146 Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported.
Arguments:
- user: The username to use for authentication with the Amazon Redshift cluster.
- password: The password to use for authentication with the Amazon Redshift cluster.
- database: The name of the database instance to connect to.
- host: The hostname of the Amazon Redshift cluster.
- port: The port number of the Amazon Redshift cluster. Default value is 5439.
- source_address: No description provided
- unix_sock: No description provided
- ssl: Is SSL enabled. Default value is
True. SSL must be enabled when authenticating using IAM. - sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported.
- timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout.
- tcp_keepalive: Is TCP keepalive used. The default value is
True. - application_name: Sets the application name. The default value is None.
- preferred_role: The IAM role preferred for the current connection.
- principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy.
- credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster.
- region: The AWS region where the Amazon Redshift cluster is located.
- cluster_identifier: The cluster identifier of the Amazon Redshift cluster.
- iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP.
- is_serverless: Redshift end-point is serverless or provisional. Default value false.
- serverless_acct_id: The account ID of the serverless. Default value None
- serverless_work_group: The name of work group for serverless end point. Default value None.
- pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
- enable_merge: Whether to use the Redshift merge operation instead of the SQLMesh logical merge.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1398class PostgresConnectionConfig(ConnectionConfig): 1399 host: str 1400 user: str 1401 password: str 1402 port: int 1403 database: str 1404 keepalives_idle: t.Optional[int] = None 1405 connect_timeout: int = 10 1406 role: t.Optional[str] = None 1407 sslmode: t.Optional[str] = None 1408 application_name: t.Optional[str] = None 1409 1410 concurrent_tasks: int = 4 1411 register_comments: bool = True 1412 pre_ping: bool = True 1413 1414 type_: t.Literal["postgres"] = Field(alias="type", default="postgres") 1415 DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres" 1416 DISPLAY_NAME: t.ClassVar[t.Literal["Postgres"]] = "Postgres" 1417 DISPLAY_ORDER: t.ClassVar[t.Literal[12]] = 12 1418 1419 _engine_import_validator = _get_engine_import_validator("psycopg2", "postgres") 1420 1421 @property 1422 def _connection_kwargs_keys(self) -> t.Set[str]: 1423 return { 1424 "host", 1425 "user", 1426 "password", 1427 "port", 1428 "database", 1429 "keepalives_idle", 1430 "connect_timeout", 1431 "sslmode", 1432 "application_name", 1433 } 1434 1435 @property 1436 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1437 return engine_adapter.PostgresEngineAdapter 1438 1439 @property 1440 def _connection_factory(self) -> t.Callable: 1441 from psycopg2 import connect 1442 1443 return connect 1444 1445 @property 1446 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 1447 if not self.role: 1448 return None 1449 1450 def init(cursor: t.Any) -> None: 1451 cursor.execute(f"SET ROLE {self.role}") 1452 1453 return init
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1456class MySQLConnectionConfig(ConnectionConfig): 1457 host: str 1458 user: str 1459 password: str 1460 port: t.Optional[int] = None 1461 database: t.Optional[str] = None 1462 charset: t.Optional[str] = None 1463 collation: t.Optional[str] = None 1464 ssl_disabled: t.Optional[bool] = None 1465 1466 concurrent_tasks: int = 4 1467 register_comments: bool = True 1468 pre_ping: bool = True 1469 1470 type_: t.Literal["mysql"] = Field(alias="type", default="mysql") 1471 DIALECT: t.ClassVar[t.Literal["mysql"]] = "mysql" 1472 DISPLAY_NAME: t.ClassVar[t.Literal["MySQL"]] = "MySQL" 1473 DISPLAY_ORDER: t.ClassVar[t.Literal[14]] = 14 1474 1475 _engine_import_validator = _get_engine_import_validator("pymysql", "mysql") 1476 1477 @property 1478 def _connection_kwargs_keys(self) -> t.Set[str]: 1479 connection_keys = { 1480 "host", 1481 "user", 1482 "password", 1483 } 1484 if self.port is not None: 1485 connection_keys.add("port") 1486 if self.database is not None: 1487 connection_keys.add("database") 1488 if self.charset is not None: 1489 connection_keys.add("charset") 1490 if self.collation is not None: 1491 connection_keys.add("collation") 1492 if self.ssl_disabled is not None: 1493 connection_keys.add("ssl_disabled") 1494 return connection_keys 1495 1496 @property 1497 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1498 return engine_adapter.MySQLEngineAdapter 1499 1500 @property 1501 def _connection_factory(self) -> t.Callable: 1502 from pymysql import connect 1503 1504 return connect
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1507class MSSQLConnectionConfig(ConnectionConfig): 1508 host: str 1509 user: t.Optional[str] = None 1510 password: t.Optional[str] = None 1511 database: t.Optional[str] = "" 1512 timeout: t.Optional[int] = 0 1513 login_timeout: t.Optional[int] = 60 1514 charset: t.Optional[str] = "UTF-8" 1515 appname: t.Optional[str] = None 1516 port: t.Optional[int] = 1433 1517 conn_properties: t.Optional[t.Union[t.List[str], str]] = None 1518 autocommit: t.Optional[bool] = False 1519 tds_version: t.Optional[str] = None 1520 1521 # Driver options 1522 driver: t.Literal["pymssql", "pyodbc"] = "pymssql" 1523 # PyODBC specific options 1524 driver_name: t.Optional[str] = None # e.g. "ODBC Driver 18 for SQL Server" 1525 trust_server_certificate: t.Optional[bool] = None 1526 encrypt: t.Optional[bool] = None 1527 # Dictionary of arbitrary ODBC connection properties 1528 # See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute 1529 odbc_properties: t.Optional[t.Dict[str, t.Any]] = None 1530 1531 concurrent_tasks: int = 4 1532 register_comments: bool = True 1533 pre_ping: bool = True 1534 1535 type_: t.Literal["mssql"] = Field(alias="type", default="mssql") 1536 DIALECT: t.ClassVar[t.Literal["tsql"]] = "tsql" 1537 DISPLAY_NAME: t.ClassVar[t.Literal["MSSQL"]] = "MSSQL" 1538 DISPLAY_ORDER: t.ClassVar[t.Literal[11]] = 11 1539 1540 @model_validator(mode="before") 1541 @classmethod 1542 def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any: 1543 if not isinstance(data, dict): 1544 return data 1545 1546 driver = data.get("driver", "pymssql") 1547 1548 # Define the mapping of driver to import module and extra name 1549 driver_configs = {"pymssql": ("pymssql", "mssql"), "pyodbc": ("pyodbc", "mssql-odbc")} 1550 1551 if driver not in driver_configs: 1552 raise ValueError(f"Unsupported driver: {driver}") 1553 1554 import_module, extra_name = driver_configs[driver] 1555 1556 # Use _get_engine_import_validator with decorate=False to get the raw validation function 1557 # This avoids the __wrapped__ issue in Python 3.9 1558 validator_func = _get_engine_import_validator( 1559 import_module, driver, extra_name, decorate=False 1560 ) 1561 1562 # Call the raw validation function directly 1563 return validator_func(cls, data) 1564 1565 @property 1566 def _connection_kwargs_keys(self) -> t.Set[str]: 1567 base_keys = { 1568 "host", 1569 "user", 1570 "password", 1571 "database", 1572 "timeout", 1573 "login_timeout", 1574 "charset", 1575 "appname", 1576 "port", 1577 "conn_properties", 1578 "autocommit", 1579 "tds_version", 1580 } 1581 1582 if self.driver == "pyodbc": 1583 base_keys.update( 1584 { 1585 "driver_name", 1586 "trust_server_certificate", 1587 "encrypt", 1588 "odbc_properties", 1589 } 1590 ) 1591 # Remove pymssql-specific parameters 1592 base_keys.discard("tds_version") 1593 base_keys.discard("conn_properties") 1594 1595 return base_keys 1596 1597 @property 1598 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1599 return engine_adapter.MSSQLEngineAdapter 1600 1601 @property 1602 def _connection_factory(self) -> t.Callable: 1603 if self.driver == "pymssql": 1604 import pymssql 1605 1606 return pymssql.connect 1607 1608 import pyodbc 1609 1610 def connect(**kwargs: t.Any) -> t.Callable: 1611 # Extract parameters for connection string 1612 host = kwargs.pop("host") 1613 port = kwargs.pop("port", 1433) 1614 database = kwargs.pop("database", "") 1615 user = kwargs.pop("user", None) 1616 password = kwargs.pop("password", None) 1617 driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server") 1618 trust_server_certificate = kwargs.pop("trust_server_certificate", False) 1619 encrypt = kwargs.pop("encrypt", True) 1620 login_timeout = kwargs.pop("login_timeout", 60) 1621 1622 # Build connection string 1623 conn_str_parts = [ 1624 f"DRIVER={{{driver_name}}}", 1625 f"SERVER={host},{port}", 1626 ] 1627 1628 if database: 1629 conn_str_parts.append(f"DATABASE={database}") 1630 1631 # Add security options 1632 conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}") 1633 if trust_server_certificate: 1634 conn_str_parts.append("TrustServerCertificate=YES") 1635 1636 conn_str_parts.append(f"Connection Timeout={login_timeout}") 1637 1638 # Standard SQL Server authentication 1639 if user: 1640 conn_str_parts.append(f"UID={user}") 1641 if password: 1642 conn_str_parts.append(f"PWD={password}") 1643 1644 # Add any additional ODBC properties from the odbc_properties dictionary 1645 if self.odbc_properties: 1646 for key, value in self.odbc_properties.items(): 1647 # Skip properties that we've already set above 1648 if key.lower() in ( 1649 "driver", 1650 "server", 1651 "database", 1652 "uid", 1653 "pwd", 1654 "encrypt", 1655 "trustservercertificate", 1656 "connection timeout", 1657 ): 1658 continue 1659 1660 # Handle boolean values properly 1661 if isinstance(value, bool): 1662 conn_str_parts.append(f"{key}={'YES' if value else 'NO'}") 1663 else: 1664 conn_str_parts.append(f"{key}={value}") 1665 1666 # Create the connection string 1667 conn_str = ";".join(conn_str_parts) 1668 1669 conn = pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False)) 1670 1671 # Set up output converters for MSSQL-specific data types 1672 # Handle SQL type -155 (DATETIMEOFFSET) which is not yet supported by pyodbc 1673 # ref: https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794 1674 def handle_datetimeoffset(dto_value: t.Any) -> t.Any: 1675 from datetime import datetime, timedelta, timezone 1676 import struct 1677 1678 # Unpack the DATETIMEOFFSET binary format: 1679 # Format: <6hI2h = (year, month, day, hour, minute, second, nanoseconds, tz_hour_offset, tz_minute_offset) 1680 tup = struct.unpack("<6hI2h", dto_value) 1681 return datetime( 1682 tup[0], 1683 tup[1], 1684 tup[2], 1685 tup[3], 1686 tup[4], 1687 tup[5], 1688 tup[6] // 1000, 1689 timezone(timedelta(hours=tup[7], minutes=tup[8])), 1690 ) 1691 1692 conn.add_output_converter(-155, handle_datetimeoffset) 1693 1694 return conn 1695 1696 return connect 1697 1698 @property 1699 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1700 return {"catalog_support": CatalogSupport.REQUIRES_SET_CATALOG}
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1703class AzureSQLConnectionConfig(MSSQLConnectionConfig): 1704 type_: t.Literal["azuresql"] = Field(alias="type", default="azuresql") # type: ignore 1705 DISPLAY_NAME: t.ClassVar[t.Literal["Azure SQL"]] = "Azure SQL" # type: ignore 1706 DISPLAY_ORDER: t.ClassVar[t.Literal[10]] = 10 # type: ignore 1707 1708 @property 1709 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1710 return {"catalog_support": CatalogSupport.SINGLE_CATALOG_ONLY}
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
- MSSQLConnectionConfig
- host
- user
- password
- database
- timeout
- login_timeout
- charset
- appname
- port
- conn_properties
- autocommit
- tds_version
- driver
- driver_name
- trust_server_certificate
- encrypt
- odbc_properties
- concurrent_tasks
- register_comments
- pre_ping
- DIALECT
1713class FabricConnectionConfig(MSSQLConnectionConfig): 1714 """ 1715 Fabric Connection Configuration. 1716 Inherits most settings from MSSQLConnectionConfig and sets the type to 'fabric'. 1717 It is recommended to use the 'pyodbc' driver for Fabric. 1718 """ 1719 1720 type_: t.Literal["fabric"] = Field(alias="type", default="fabric") # type: ignore 1721 DIALECT: t.ClassVar[t.Literal["fabric"]] = "fabric" # type: ignore 1722 DISPLAY_NAME: t.ClassVar[t.Literal["Fabric"]] = "Fabric" # type: ignore 1723 DISPLAY_ORDER: t.ClassVar[t.Literal[17]] = 17 # type: ignore 1724 driver: t.Literal["pyodbc"] = "pyodbc" 1725 workspace_id: str 1726 tenant_id: str 1727 autocommit: t.Optional[bool] = True 1728 1729 @property 1730 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1731 from sqlmesh.core.engine_adapter.fabric import FabricEngineAdapter 1732 1733 return FabricEngineAdapter 1734 1735 @property 1736 def _connection_factory(self) -> t.Callable: 1737 # Override to support catalog switching for Fabric 1738 base_factory = super()._connection_factory 1739 1740 def create_fabric_connection( 1741 target_catalog: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any 1742 ) -> t.Callable: 1743 kwargs["database"] = target_catalog or self.database 1744 return base_factory(*args, **kwargs) 1745 1746 return create_fabric_connection 1747 1748 @property 1749 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1750 return { 1751 "database": self.database, 1752 # more operations than not require a specific catalog to be already active 1753 # in particular, create/drop view, create/drop schema and querying information_schema 1754 "catalog_support": CatalogSupport.REQUIRES_SET_CATALOG, 1755 "workspace_id": self.workspace_id, 1756 "tenant_id": self.tenant_id, 1757 "user": self.user, 1758 "password": self.password, 1759 }
Fabric Connection Configuration. Inherits most settings from MSSQLConnectionConfig and sets the type to 'fabric'. It is recommended to use the 'pyodbc' driver for Fabric.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
- MSSQLConnectionConfig
- host
- user
- password
- database
- timeout
- login_timeout
- charset
- appname
- port
- conn_properties
- tds_version
- driver_name
- trust_server_certificate
- encrypt
- odbc_properties
- concurrent_tasks
- register_comments
- pre_ping
1762class SparkConnectionConfig(ConnectionConfig): 1763 """ 1764 Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks. 1765 """ 1766 1767 config_dir: t.Optional[str] = None 1768 catalog: t.Optional[str] = None 1769 config: t.Dict[str, t.Any] = {} 1770 wap_enabled: bool = False 1771 1772 concurrent_tasks: int = 4 1773 register_comments: bool = True 1774 pre_ping: t.Literal[False] = False 1775 1776 type_: t.Literal["spark"] = Field(alias="type", default="spark") 1777 DIALECT: t.ClassVar[t.Literal["spark"]] = "spark" 1778 DISPLAY_NAME: t.ClassVar[t.Literal["Spark"]] = "Spark" 1779 DISPLAY_ORDER: t.ClassVar[t.Literal[8]] = 8 1780 1781 _engine_import_validator = _get_engine_import_validator("pyspark", "spark") 1782 1783 @property 1784 def _connection_kwargs_keys(self) -> t.Set[str]: 1785 return { 1786 "catalog", 1787 } 1788 1789 @property 1790 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1791 return engine_adapter.SparkEngineAdapter 1792 1793 @property 1794 def _connection_factory(self) -> t.Callable: 1795 from sqlmesh.engines.spark.db_api.spark_session import connection 1796 1797 return connection 1798 1799 @property 1800 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1801 from pyspark.conf import SparkConf 1802 from pyspark.sql import SparkSession 1803 1804 spark_config = SparkConf() 1805 if self.config: 1806 for k, v in self.config.items(): 1807 spark_config.set(k, v) 1808 1809 if self.config_dir: 1810 os.environ["SPARK_CONF_DIR"] = self.config_dir 1811 return { 1812 "spark": SparkSession.builder.config(conf=spark_config) 1813 .enableHiveSupport() 1814 .getOrCreate(), 1815 } 1816 1817 @property 1818 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1819 return {"wap_enabled": self.wap_enabled}
Vanilla Spark Connection Configuration. Use DatabricksConnectionConfig for Databricks.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1822class TrinoAuthenticationMethod(str, Enum): 1823 NO_AUTH = "no-auth" 1824 BASIC = "basic" 1825 LDAP = "ldap" 1826 KERBEROS = "kerberos" 1827 JWT = "jwt" 1828 CERTIFICATE = "certificate" 1829 OAUTH = "oauth" 1830 1831 @property 1832 def is_no_auth(self) -> bool: 1833 return self == self.NO_AUTH 1834 1835 @property 1836 def is_basic(self) -> bool: 1837 return self == self.BASIC 1838 1839 @property 1840 def is_ldap(self) -> bool: 1841 return self == self.LDAP 1842 1843 @property 1844 def is_kerberos(self) -> bool: 1845 return self == self.KERBEROS 1846 1847 @property 1848 def is_jwt(self) -> bool: 1849 return self == self.JWT 1850 1851 @property 1852 def is_certificate(self) -> bool: 1853 return self == self.CERTIFICATE 1854 1855 @property 1856 def is_oauth(self) -> bool: 1857 return self == self.OAUTH
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
1860class TrinoConnectionConfig(ConnectionConfig): 1861 method: TrinoAuthenticationMethod = TrinoAuthenticationMethod.NO_AUTH 1862 host: str 1863 user: str 1864 catalog: str 1865 port: t.Optional[int] = None 1866 http_scheme: t.Literal["http", "https"] = "https" 1867 # General Optional 1868 roles: t.Optional[t.Dict[str, str]] = None 1869 http_headers: t.Optional[t.Dict[str, str]] = None 1870 session_properties: t.Optional[t.Dict[str, str]] = None 1871 retries: int = 3 1872 timezone: t.Optional[str] = None 1873 # Basic/LDAP 1874 password: t.Optional[str] = None 1875 verify: t.Optional[bool] = None # disable SSL verification (ignored if `cert` is provided) 1876 # LDAP 1877 impersonation_user: t.Optional[str] = None 1878 # Kerberos 1879 keytab: t.Optional[str] = None 1880 krb5_config: t.Optional[str] = None 1881 principal: t.Optional[str] = None 1882 service_name: str = "trino" 1883 hostname_override: t.Optional[str] = None 1884 mutual_authentication: bool = False 1885 force_preemptive: bool = False 1886 sanitize_mutual_error_response: bool = True 1887 delegate: bool = False 1888 # JWT 1889 jwt_token: t.Optional[str] = None 1890 # Certificate 1891 client_certificate: t.Optional[str] = None 1892 client_private_key: t.Optional[str] = None 1893 cert: t.Optional[str] = None 1894 source: str = "sqlmesh" 1895 1896 # SQLMesh options 1897 schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None 1898 timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None 1899 concurrent_tasks: int = 4 1900 register_comments: bool = True 1901 pre_ping: t.Literal[False] = False 1902 1903 type_: t.Literal["trino"] = Field(alias="type", default="trino") 1904 DIALECT: t.ClassVar[t.Literal["trino"]] = "trino" 1905 DISPLAY_NAME: t.ClassVar[t.Literal["Trino"]] = "Trino" 1906 DISPLAY_ORDER: t.ClassVar[t.Literal[9]] = 9 1907 1908 _engine_import_validator = _get_engine_import_validator("trino", "trino") 1909 1910 @field_validator("schema_location_mapping", mode="before") 1911 @classmethod 1912 def _validate_regex_keys( 1913 cls, value: t.Dict[str | re.Pattern, str] 1914 ) -> t.Dict[re.Pattern, t.Any]: 1915 compiled = compile_regex_mapping(value) 1916 for replacement in compiled.values(): 1917 if "@{schema_name}" not in replacement: 1918 raise ConfigError( 1919 "schema_location_mapping needs to include the '@{schema_name}' placeholder in the value so SQLMesh knows where to substitute the schema name" 1920 ) 1921 return compiled 1922 1923 @field_validator("timestamp_mapping", mode="before") 1924 @classmethod 1925 def _validate_timestamp_mapping( 1926 cls, value: t.Optional[dict[str, str]] 1927 ) -> t.Optional[dict[exp.DataType, exp.DataType]]: 1928 if value is None: 1929 return value 1930 1931 result: dict[exp.DataType, exp.DataType] = {} 1932 for source_type, target_type in value.items(): 1933 try: 1934 source_datatype = exp.DataType.build(source_type) 1935 except ParseError: 1936 raise ConfigError( 1937 f"Invalid SQL type string in timestamp_mapping: " 1938 f"'{source_type}' is not a valid SQL data type." 1939 ) 1940 try: 1941 target_datatype = exp.DataType.build(target_type) 1942 except ParseError: 1943 raise ConfigError( 1944 f"Invalid SQL type string in timestamp_mapping: " 1945 f"'{target_type}' is not a valid SQL data type." 1946 ) 1947 result[source_datatype] = target_datatype 1948 1949 return result 1950 1951 @model_validator(mode="after") 1952 def _root_validator(self) -> Self: 1953 port = self.port 1954 if self.http_scheme == "http" and not self.method.is_no_auth and not self.method.is_basic: 1955 raise ConfigError("HTTP scheme can only be used with no-auth or basic method") 1956 1957 if port is None: 1958 self.port = 80 if self.http_scheme == "http" else 443 1959 1960 if (self.method.is_ldap or self.method.is_basic) and (not self.password or not self.user): 1961 raise ConfigError( 1962 f"Username and Password must be provided if using {self.method.value} authentication" 1963 ) 1964 1965 if self.method.is_kerberos and ( 1966 not self.principal or not self.keytab or not self.krb5_config 1967 ): 1968 raise ConfigError( 1969 "Kerberos requires the following fields: principal, keytab, and krb5_config" 1970 ) 1971 1972 if self.method.is_jwt and not self.jwt_token: 1973 raise ConfigError("JWT requires `jwt_token` to be set") 1974 1975 if self.method.is_certificate and ( 1976 not self.cert or not self.client_certificate or not self.client_private_key 1977 ): 1978 raise ConfigError( 1979 "Certificate requires the following fields: cert, client_certificate, and client_private_key" 1980 ) 1981 1982 return self 1983 1984 @property 1985 def _connection_kwargs_keys(self) -> t.Set[str]: 1986 kwargs = { 1987 "host", 1988 "port", 1989 "catalog", 1990 "roles", 1991 "source", 1992 "http_scheme", 1993 "http_headers", 1994 "session_properties", 1995 "timezone", 1996 } 1997 return kwargs 1998 1999 @property 2000 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2001 return engine_adapter.TrinoEngineAdapter 2002 2003 @property 2004 def _connection_factory(self) -> t.Callable: 2005 from trino.dbapi import connect 2006 2007 return connect 2008 2009 @property 2010 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 2011 from trino.auth import ( 2012 BasicAuthentication, 2013 CertificateAuthentication, 2014 JWTAuthentication, 2015 KerberosAuthentication, 2016 OAuth2Authentication, 2017 ) 2018 2019 auth: t.Optional[ 2020 t.Union[ 2021 BasicAuthentication, 2022 KerberosAuthentication, 2023 OAuth2Authentication, 2024 JWTAuthentication, 2025 CertificateAuthentication, 2026 ] 2027 ] = None 2028 if self.method.is_basic or self.method.is_ldap: 2029 assert self.password is not None # for mypy since validator already checks this 2030 auth = BasicAuthentication(self.user, self.password) 2031 elif self.method.is_kerberos: 2032 if self.keytab: 2033 os.environ["KRB5_CLIENT_KTNAME"] = self.keytab 2034 auth = KerberosAuthentication( 2035 config=self.krb5_config, 2036 service_name=self.service_name, 2037 principal=self.principal, 2038 mutual_authentication=self.mutual_authentication, 2039 ca_bundle=self.cert, 2040 force_preemptive=self.force_preemptive, 2041 hostname_override=self.hostname_override, 2042 sanitize_mutual_error_response=self.sanitize_mutual_error_response, 2043 delegate=self.delegate, 2044 ) 2045 elif self.method.is_oauth: 2046 auth = OAuth2Authentication() 2047 elif self.method.is_jwt: 2048 assert self.jwt_token is not None 2049 auth = JWTAuthentication(self.jwt_token) 2050 elif self.method.is_certificate: 2051 assert self.client_certificate is not None 2052 assert self.client_private_key is not None 2053 auth = CertificateAuthentication(self.client_certificate, self.client_private_key) 2054 2055 return { 2056 "auth": auth, 2057 "user": self.impersonation_user or self.user, 2058 "max_attempts": self.retries, 2059 "verify": self.cert if self.cert is not None else self.verify, 2060 "source": self.source, 2061 } 2062 2063 @property 2064 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2065 return { 2066 "schema_location_mapping": self.schema_location_mapping, 2067 "timestamp_mapping": self.timestamp_mapping, 2068 }
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
2071class ClickhouseConnectionConfig(ConnectionConfig): 2072 """ 2073 Clickhouse Connection Configuration. 2074 2075 Property reference: https://clickhouse.com/docs/en/integrations/python#client-initialization 2076 """ 2077 2078 host: str 2079 username: str 2080 password: t.Optional[str] = None 2081 port: t.Optional[int] = None 2082 cluster: t.Optional[str] = None 2083 connect_timeout: int = 10 2084 send_receive_timeout: int = 300 2085 query_limit: int = 0 2086 use_compression: bool = True 2087 compression_method: t.Optional[str] = None 2088 connection_settings: t.Optional[t.Dict[str, t.Any]] = None 2089 http_proxy: t.Optional[str] = None 2090 # HTTPS/TLS settings 2091 verify: bool = True 2092 ca_cert: t.Optional[str] = None 2093 client_cert: t.Optional[str] = None 2094 client_cert_key: t.Optional[str] = None 2095 https_proxy: t.Optional[str] = None 2096 server_host_name: t.Optional[str] = None 2097 tls_mode: t.Optional[str] = None 2098 2099 concurrent_tasks: int = 1 2100 register_comments: bool = True 2101 pre_ping: bool = False 2102 2103 # This object expects options from urllib3 and also from clickhouse-connect 2104 # See: 2105 # * https://urllib3.readthedocs.io/en/stable/advanced-usage.html 2106 # * https://clickhouse.com/docs/en/integrations/python#customizing-the-http-connection-pool 2107 connection_pool_options: t.Optional[t.Dict[str, t.Any]] = None 2108 2109 type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse") 2110 DIALECT: t.ClassVar[t.Literal["clickhouse"]] = "clickhouse" 2111 DISPLAY_NAME: t.ClassVar[t.Literal["ClickHouse"]] = "ClickHouse" 2112 DISPLAY_ORDER: t.ClassVar[t.Literal[6]] = 6 2113 2114 _engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse") 2115 2116 @property 2117 def _connection_kwargs_keys(self) -> t.Set[str]: 2118 kwargs = { 2119 "host", 2120 "username", 2121 "port", 2122 "password", 2123 "connect_timeout", 2124 "send_receive_timeout", 2125 "query_limit", 2126 "http_proxy", 2127 "verify", 2128 "ca_cert", 2129 "client_cert", 2130 "client_cert_key", 2131 "https_proxy", 2132 "server_host_name", 2133 "tls_mode", 2134 } 2135 return kwargs 2136 2137 @property 2138 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2139 return engine_adapter.ClickhouseEngineAdapter 2140 2141 @property 2142 def _connection_factory(self) -> t.Callable: 2143 from clickhouse_connect.dbapi import connect # type: ignore 2144 from clickhouse_connect.driver import httputil # type: ignore 2145 from functools import partial 2146 2147 pool_manager_options: t.Dict[str, t.Any] = dict( 2148 # Match the maxsize to the number of concurrent tasks 2149 maxsize=self.concurrent_tasks, 2150 # Block if there are no free connections 2151 block=True, 2152 verify=self.verify, 2153 ca_cert=self.ca_cert, 2154 client_cert=self.client_cert, 2155 client_cert_key=self.client_cert_key, 2156 https_proxy=self.https_proxy, 2157 ) 2158 # this doesn't happen automatically because we always supply our own pool manager to the connection 2159 # https://github.com/ClickHouse/clickhouse-connect/blob/3a7f4b04cad29c7c2536661b831fb744248e2ec0/clickhouse_connect/driver/httpclient.py#L109 2160 if self.server_host_name: 2161 pool_manager_options["server_hostname"] = self.server_host_name 2162 if self.verify: 2163 pool_manager_options["assert_hostname"] = self.server_host_name 2164 if self.connection_pool_options: 2165 pool_manager_options.update(self.connection_pool_options) 2166 pool_mgr = httputil.get_pool_manager(**pool_manager_options) 2167 2168 return partial(connect, pool_mgr=pool_mgr) 2169 2170 @property 2171 def cloud_mode(self) -> bool: 2172 return "clickhouse.cloud" in self.host 2173 2174 @property 2175 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2176 return {"cluster": self.cluster, "cloud_mode": self.cloud_mode} 2177 2178 @property 2179 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 2180 from sqlmesh import __version__ 2181 2182 # False = no compression 2183 # True = Clickhouse default compression method 2184 # string = specific compression method 2185 compress: bool | str = self.use_compression 2186 if compress and self.compression_method: 2187 compress = self.compression_method 2188 2189 # Clickhouse system settings passed to connection 2190 # https://clickhouse.com/docs/en/operations/settings/settings 2191 # - below are set to align with dbt-clickhouse 2192 # - https://github.com/ClickHouse/dbt-clickhouse/blob/44d26308ea6a3c8ead25c280164aa88191f05f47/dbt/adapters/clickhouse/dbclient.py#L77 2193 settings = self.connection_settings or {} 2194 # mutations_sync = 2: "The query waits for all mutations [ALTER statements] to complete on all replicas (if they exist)" 2195 settings["mutations_sync"] = "2" 2196 # insert_distributed_sync = 1: "INSERT operation succeeds only after all the data is saved on all shards" 2197 settings["insert_distributed_sync"] = "1" 2198 if self.cluster or self.cloud_mode: 2199 # database_replicated_enforce_synchronous_settings = 1: 2200 # - "Enforces synchronous waiting for some queries" 2201 # - https://github.com/ClickHouse/ClickHouse/blob/ccaa8d03a9351efc16625340268b9caffa8a22ba/src/Core/Settings.h#L709 2202 settings["database_replicated_enforce_synchronous_settings"] = "1" 2203 # insert_quorum = auto: 2204 # - "INSERT succeeds only when ClickHouse manages to correctly write data to the insert_quorum of replicas during 2205 # the insert_quorum_timeout" 2206 # - "use majority number (number_of_replicas / 2 + 1) as quorum number" 2207 settings["insert_quorum"] = "auto" 2208 2209 return { 2210 "compress": compress, 2211 "client_name": f"SQLMesh/{__version__}", 2212 **settings, 2213 }
Clickhouse Connection Configuration.
Property reference: https://clickhouse.com/docs/en/integrations/python#client-initialization
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
2216class AthenaConnectionConfig(ConnectionConfig): 2217 # PyAthena connection options 2218 aws_access_key_id: t.Optional[str] = None 2219 aws_secret_access_key: t.Optional[str] = None 2220 role_arn: t.Optional[str] = None 2221 role_session_name: t.Optional[str] = None 2222 region_name: t.Optional[str] = None 2223 work_group: t.Optional[str] = None 2224 s3_staging_dir: t.Optional[str] = None 2225 schema_name: t.Optional[str] = None 2226 catalog_name: t.Optional[str] = None 2227 2228 # SQLMesh options 2229 s3_warehouse_location: t.Optional[str] = None 2230 concurrent_tasks: int = 4 2231 register_comments: t.Literal[False] = ( 2232 False # because Athena doesnt support comments in most cases 2233 ) 2234 pre_ping: t.Literal[False] = False 2235 2236 type_: t.Literal["athena"] = Field(alias="type", default="athena") 2237 DIALECT: t.ClassVar[t.Literal["athena"]] = "athena" 2238 DISPLAY_NAME: t.ClassVar[t.Literal["Athena"]] = "Athena" 2239 DISPLAY_ORDER: t.ClassVar[t.Literal[15]] = 15 2240 2241 _engine_import_validator = _get_engine_import_validator("pyathena", "athena") 2242 2243 @model_validator(mode="after") 2244 def _root_validator(self) -> Self: 2245 work_group = self.work_group 2246 s3_staging_dir = self.s3_staging_dir 2247 s3_warehouse_location = self.s3_warehouse_location 2248 2249 if not work_group and not s3_staging_dir: 2250 raise ConfigError("At least one of work_group or s3_staging_dir must be set") 2251 2252 if s3_staging_dir: 2253 self.s3_staging_dir = validate_s3_uri(s3_staging_dir, base=True, error_type=ConfigError) 2254 2255 if s3_warehouse_location: 2256 self.s3_warehouse_location = validate_s3_uri( 2257 s3_warehouse_location, base=True, error_type=ConfigError 2258 ) 2259 2260 return self 2261 2262 @property 2263 def _connection_kwargs_keys(self) -> t.Set[str]: 2264 return { 2265 "aws_access_key_id", 2266 "aws_secret_access_key", 2267 "role_arn", 2268 "role_session_name", 2269 "region_name", 2270 "work_group", 2271 "s3_staging_dir", 2272 "schema_name", 2273 "catalog_name", 2274 } 2275 2276 @property 2277 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2278 return engine_adapter.AthenaEngineAdapter 2279 2280 @property 2281 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2282 return {"s3_warehouse_location": self.s3_warehouse_location} 2283 2284 @property 2285 def _connection_factory(self) -> t.Callable: 2286 from pyathena import connect # type: ignore 2287 2288 return connect 2289 2290 def get_catalog(self) -> t.Optional[str]: 2291 return self.catalog_name
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
2294class RisingwaveConnectionConfig(ConnectionConfig): 2295 host: str 2296 user: str 2297 password: t.Optional[str] = None 2298 port: int 2299 database: str 2300 role: t.Optional[str] = None 2301 sslmode: t.Optional[str] = None 2302 2303 concurrent_tasks: int = 4 2304 register_comments: bool = True 2305 pre_ping: bool = True 2306 2307 type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave") 2308 DIALECT: t.ClassVar[t.Literal["risingwave"]] = "risingwave" 2309 DISPLAY_NAME: t.ClassVar[t.Literal["RisingWave"]] = "RisingWave" 2310 DISPLAY_ORDER: t.ClassVar[t.Literal[16]] = 16 2311 2312 _engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave") 2313 2314 @property 2315 def _connection_kwargs_keys(self) -> t.Set[str]: 2316 return { 2317 "host", 2318 "user", 2319 "password", 2320 "port", 2321 "database", 2322 "role", 2323 "sslmode", 2324 } 2325 2326 @property 2327 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2328 return engine_adapter.RisingwaveEngineAdapter 2329 2330 @property 2331 def _connection_factory(self) -> t.Callable: 2332 from psycopg2 import connect 2333 2334 return connect 2335 2336 @property 2337 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 2338 def init(cursor: t.Any) -> None: 2339 sql = "SET RW_IMPLICIT_FLUSH TO true;" 2340 cursor.execute(sql) 2341 2342 return init
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
2377def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig: 2378 if "type" not in v: 2379 raise ConfigError("Missing connection type.") 2380 2381 connection_type = v["type"] 2382 if connection_type not in CONNECTION_CONFIG_TO_TYPE: 2383 raise ConfigError(f"Unknown connection type '{connection_type}'.") 2384 2385 return CONNECTION_CONFIG_TO_TYPE[connection_type](**v)
2388def _connection_config_validator( 2389 cls: t.Type, v: ConnectionConfig | t.Dict[str, t.Any] | None 2390) -> ConnectionConfig | None: 2391 if v is None or isinstance(v, ConnectionConfig): 2392 return v 2393 2394 check_config_and_vars_msg = "\n\nVerify your config.yaml and environment variables." 2395 2396 try: 2397 return parse_connection_config(v) 2398 except pydantic.ValidationError as e: 2399 raise ConfigError( 2400 validation_error_message(e, f"Invalid '{v['type']}' connection config:") 2401 + check_config_and_vars_msg 2402 ) 2403 except ConfigError as e: 2404 raise ConfigError(str(e) + check_config_and_vars_msg)
Wrap a classmethod, staticmethod, property or unbound function and act as a descriptor that allows us to detect decorated items from the class' attributes.
This class' __get__ returns the wrapped item's __get__ result, which makes it transparent for classmethods and staticmethods.
Attributes:
- wrapped: The decorator that has to be wrapped.
- decorator_info: The decorator info.
- shim: A wrapper function to wrap V1 style function.