sqlmesh.core.config.connection
1from __future__ import annotations 2 3import abc 4import base64 5import logging 6import os 7import importlib 8import pathlib 9import re 10import typing as t 11from enum import Enum 12from functools import partial 13 14import pydantic 15from pydantic import Field 16from pydantic_core import from_json 17from packaging import version 18from sqlglot import exp 19from sqlglot.helper import subclasses 20from sqlglot.errors import ParseError 21 22from sqlmesh.core import engine_adapter 23from sqlmesh.core.config.base import BaseConfig 24from sqlmesh.core.config.common import ( 25 concurrent_tasks_validator, 26 http_headers_validator, 27 compile_regex_mapping, 28) 29from sqlmesh.core.engine_adapter.shared import CatalogSupport 30from sqlmesh.core.engine_adapter import EngineAdapter 31from sqlmesh.utils import debug_mode_enabled, str_to_bool 32from sqlmesh.utils.errors import ConfigError 33from sqlmesh.utils.pydantic import ( 34 ValidationInfo, 35 field_validator, 36 model_validator, 37 validation_error_message, 38 get_concrete_types_from_typehint, 39) 40from sqlmesh.utils.aws import validate_s3_uri 41 42if t.TYPE_CHECKING: 43 from sqlmesh.core._typing import Self 44 45logger = logging.getLogger(__name__) 46 47RECOMMENDED_STATE_SYNC_ENGINES = { 48 "postgres", 49 "gcp_postgres", 50 "mysql", 51 "mssql", 52 "azuresql", 53} 54FORBIDDEN_STATE_SYNC_ENGINES = { 55 # Do not support row-level operations 56 "spark", 57 "trino", 58 # Nullable types are problematic 59 "clickhouse", 60} 61MOTHERDUCK_TOKEN_REGEX = re.compile(r"(\?|\&)(motherduck_token=)(\S*)") 62PASSWORD_REGEX = re.compile(r"(password=)(\S+)") 63 64 65def _get_engine_import_validator( 66 import_name: str, engine_type: str, extra_name: t.Optional[str] = None, decorate: bool = True 67) -> t.Callable: 68 extra_name = extra_name or engine_type 69 70 def validate(cls: t.Any, data: t.Any) -> t.Any: 71 check_import = ( 72 str_to_bool(str(data.pop("check_import", True))) if isinstance(data, dict) else True 73 ) 74 if not check_import: 75 return data 76 try: 77 importlib.import_module(import_name) 78 except ImportError: 79 if debug_mode_enabled(): 80 raise 81 82 logger.exception("Failed to import the engine library") 83 84 raise ConfigError( 85 f"Failed to import the '{engine_type}' engine library. This may be due to a missing " 86 "or incompatible installation. Please ensure the required dependency is installed by " 87 f'running: `pip install "sqlmesh[{extra_name}]"`. For more details, check the logs ' 88 "in the 'logs/' folder, or rerun the command with the '--debug' flag." 89 ) 90 91 return data 92 93 return model_validator(mode="before")(validate) if decorate else validate 94 95 96class ConnectionConfig(abc.ABC, BaseConfig): 97 type_: str 98 DIALECT: t.ClassVar[str] 99 DISPLAY_NAME: t.ClassVar[str] 100 DISPLAY_ORDER: t.ClassVar[int] 101 concurrent_tasks: int 102 register_comments: bool 103 pre_ping: bool 104 pretty_sql: bool = False 105 schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None 106 catalog_type_overrides: t.Optional[t.Dict[str, str]] = None 107 108 # Whether to share a single connection across threads or create a new connection per thread. 109 shared_connection: t.ClassVar[bool] = False 110 111 @property 112 @abc.abstractmethod 113 def _connection_kwargs_keys(self) -> t.Set[str]: 114 """keywords that should be passed into the connection""" 115 116 @property 117 @abc.abstractmethod 118 def _engine_adapter(self) -> t.Type[EngineAdapter]: 119 """The engine adapter for this connection""" 120 121 @property 122 @abc.abstractmethod 123 def _connection_factory(self) -> t.Callable: 124 """A function that is called to return a connection object for the given Engine Adapter""" 125 126 @property 127 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 128 """The static connection kwargs for this connection""" 129 return {} 130 131 @property 132 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 133 """kwargs that are for execution config only""" 134 return {} 135 136 @property 137 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 138 """A function that is called to initialize the cursor""" 139 return None 140 141 @property 142 def is_recommended_for_state_sync(self) -> bool: 143 """Whether this engine is recommended for being used as a state sync for production state syncs""" 144 return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES 145 146 @property 147 def is_forbidden_for_state_sync(self) -> bool: 148 """Whether this engine is forbidden from being used as a state sync""" 149 return self.type_ in FORBIDDEN_STATE_SYNC_ENGINES 150 151 @property 152 def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]: 153 """A function that is called to return a connection object for the given Engine Adapter""" 154 return partial( 155 self._connection_factory, 156 **{ 157 **self._static_connection_kwargs, 158 **{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys}, 159 }, 160 ) 161 162 def connection_validator(self) -> t.Callable[[], None]: 163 """A function that validates the connection configuration""" 164 return self.create_engine_adapter().ping 165 166 def create_engine_adapter( 167 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 168 ) -> EngineAdapter: 169 """Returns a new instance of the Engine Adapter.""" 170 171 concurrent_tasks = concurrent_tasks or self.concurrent_tasks 172 return self._engine_adapter( 173 self._connection_factory_with_kwargs, 174 multithreaded=concurrent_tasks > 1, 175 default_catalog=self.get_catalog(), 176 cursor_init=self._cursor_init, 177 register_comments=register_comments_override or self.register_comments, 178 pre_ping=self.pre_ping, 179 pretty_sql=self.pretty_sql, 180 shared_connection=self.shared_connection, 181 schema_differ_overrides=self.schema_differ_overrides, 182 catalog_type_overrides=self.catalog_type_overrides, 183 **self._extra_engine_config, 184 ) 185 186 def get_catalog(self) -> t.Optional[str]: 187 """The catalog for this connection""" 188 if hasattr(self, "catalog"): 189 return self.catalog 190 if hasattr(self, "database"): 191 return self.database 192 if hasattr(self, "db"): 193 return self.db 194 return None 195 196 @model_validator(mode="before") 197 @classmethod 198 def _expand_json_strings_to_concrete_types(cls, data: t.Any) -> t.Any: 199 """ 200 There are situations where a connection config class has a field that is some kind of complex type 201 (eg a list of strings or a dict) but the value is being supplied from a source such as an environment variable 202 203 When this happens, the value is supplied as a string rather than a Python object. We need some way 204 of turning this string into the corresponding Python list or dict. 205 206 Rather than doing this piecemeal on every config subclass, this provides a generic implementatation 207 to identify fields that may be be supplied as JSON strings and handle them transparently 208 """ 209 if data and isinstance(data, dict): 210 for maybe_json_field_name in cls._get_list_and_dict_field_names(): 211 if (value := data.get(maybe_json_field_name)) and isinstance(value, str): 212 # crude JSON check as we dont want to try and parse every string we get 213 value = value.strip() 214 if value.startswith("{") or value.startswith("["): 215 data[maybe_json_field_name] = from_json(value) 216 217 return data 218 219 @classmethod 220 def _get_list_and_dict_field_names(cls) -> t.Set[str]: 221 field_names = set() 222 for name, field in cls.model_fields.items(): 223 if field.annotation: 224 field_types = get_concrete_types_from_typehint(field.annotation) 225 226 # check if the field type is something that could concievably be supplied as a json string 227 if any(ft is t for t in (list, tuple, set, dict) for ft in field_types): 228 field_names.add(name) 229 230 return field_names 231 232 233class DuckDBAttachOptions(BaseConfig): 234 type: str 235 path: str 236 read_only: bool = False 237 238 # DuckLake specific options 239 data_path: t.Optional[str] = None 240 encrypted: bool = False 241 data_inlining_row_limit: t.Optional[int] = None 242 metadata_schema: t.Optional[str] = None 243 244 def to_sql(self, alias: str) -> str: 245 options = [] 246 # 'duckdb' is actually not a supported type, but we'd like to allow it for 247 # fully qualified attach options or integration testing, similar to duckdb-dbt 248 if self.type not in ("duckdb", "ducklake", "motherduck"): 249 options.append(f"TYPE {self.type.upper()}") 250 if self.read_only: 251 options.append("READ_ONLY") 252 253 # DuckLake specific options 254 path = self.path 255 if self.type == "ducklake": 256 if not path.startswith("ducklake:"): 257 path = f"ducklake:{path}" 258 if self.data_path is not None: 259 options.append(f"DATA_PATH '{self.data_path}'") 260 if self.encrypted: 261 options.append("ENCRYPTED") 262 if self.data_inlining_row_limit is not None: 263 options.append(f"DATA_INLINING_ROW_LIMIT {self.data_inlining_row_limit}") 264 if self.metadata_schema is not None: 265 options.append(f"METADATA_SCHEMA '{self.metadata_schema}'") 266 267 options_sql = f" ({', '.join(options)})" if options else "" 268 alias_sql = "" 269 # TODO: Add support for Postgres schema. Currently adding it blocks access to the information_schema 270 271 # MotherDuck does not support aliasing 272 alias_sql = ( 273 f" AS {alias}" if not (self.type == "motherduck" or self.path.startswith("md:")) else "" 274 ) 275 return f"ATTACH IF NOT EXISTS '{path}'{alias_sql}{options_sql}" 276 277 278class BaseDuckDBConnectionConfig(ConnectionConfig): 279 """Common configuration for the DuckDB-based connections. 280 281 Args: 282 database: The optional database name. If not specified, the in-memory database will be used. 283 catalogs: Key is the name of the catalog and value is the path. 284 extensions: A list of autoloadable extensions to load. 285 connector_config: A dictionary of configuration to pass into the duckdb connector. 286 secrets: A list of dictionaries used to generate DuckDB secrets for authenticating with external services (e.g. S3). 287 filesystems: A list of dictionaries used to register `fsspec` filesystems to the DuckDB cursor. 288 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 289 register_comments: Whether or not to register model comments with the SQL engine. 290 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 291 token: The optional MotherDuck token. If not specified and a MotherDuck path is in the catalog, the user will be prompted to login with their web browser. 292 """ 293 294 database: t.Optional[str] = None 295 catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None 296 extensions: t.List[t.Union[str, t.Dict[str, t.Any]]] = [] 297 connector_config: t.Dict[str, t.Any] = {} 298 secrets: t.Union[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Dict[str, t.Any]]] = [] 299 filesystems: t.List[t.Dict[str, t.Any]] = [] 300 301 concurrent_tasks: int = 1 302 register_comments: bool = True 303 pre_ping: t.Literal[False] = False 304 305 token: t.Optional[str] = None 306 307 shared_connection: t.ClassVar[bool] = True 308 309 _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} 310 311 @model_validator(mode="before") 312 def _validate_database_catalogs(cls, data: t.Any) -> t.Any: 313 if not isinstance(data, dict): 314 return data 315 316 db_path = data.get("database") 317 if db_path and data.get("catalogs"): 318 raise ConfigError( 319 "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog" 320 ) 321 if isinstance(db_path, str) and db_path.startswith("md:"): 322 raise ConfigError( 323 "Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`." 324 ) 325 326 return data 327 328 @property 329 def _engine_adapter(self) -> t.Type[EngineAdapter]: 330 return engine_adapter.DuckDBEngineAdapter 331 332 @property 333 def _connection_kwargs_keys(self) -> t.Set[str]: 334 return {"database"} 335 336 @property 337 def _connection_factory(self) -> t.Callable: 338 import duckdb 339 340 return duckdb.connect 341 342 @property 343 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 344 """A function that is called to initialize the cursor""" 345 import duckdb 346 from duckdb import BinderException 347 348 def init(cursor: duckdb.DuckDBPyConnection) -> None: 349 for extension in self.extensions: 350 extension = extension if isinstance(extension, dict) else {"name": extension} 351 352 install_command = f"INSTALL {extension['name']}" 353 354 if extension.get("repository"): 355 install_command = f"{install_command} FROM {extension['repository']}" 356 357 if extension.get("force_install"): 358 install_command = f"FORCE {install_command}" 359 360 try: 361 cursor.execute(install_command) 362 cursor.execute(f"LOAD {extension['name']}") 363 except Exception as e: 364 raise ConfigError(f"Failed to load extension {extension['name']}: {e}") 365 366 if self.connector_config: 367 option_names = list(self.connector_config) 368 in_part = ",".join("?" for _ in range(len(option_names))) 369 370 cursor.execute( 371 f"SELECT name, value FROM duckdb_settings() WHERE name IN ({in_part})", 372 option_names, 373 ) 374 375 existing_values = {field: setting for field, setting in cursor.fetchall()} 376 377 # only set connector_config items if the values differ from what is already set 378 # trying to set options like 'temp_directory' even to the same value can throw errors like: 379 # Not implemented Error: Cannot switch temporary directory after the current one has been used 380 for field, setting in self.connector_config.items(): 381 if existing_values.get(field) != setting: 382 try: 383 cursor.execute(f"SET {field} = '{setting}'") 384 except Exception as e: 385 raise ConfigError( 386 f"Failed to set connector config {field} to {setting}: {e}" 387 ) 388 389 if self.secrets: 390 duckdb_version = duckdb.__version__ 391 if version.parse(duckdb_version) < version.parse("0.10.0"): 392 from sqlmesh.core.console import get_console 393 394 get_console().log_warning( 395 f"DuckDB version {duckdb_version} does not support secrets-based authentication (requires 0.10.0 or later).\n" 396 "To use secrets, please upgrade DuckDB. For older versions, configure legacy authentication via `connector_config`.\n" 397 "More info: https://duckdb.org/docs/stable/extensions/httpfs/s3api_legacy_authentication.html" 398 ) 399 else: 400 if isinstance(self.secrets, list): 401 secrets_items = [(secret_dict, "") for secret_dict in self.secrets] 402 else: 403 secrets_items = [ 404 (secret_dict, secret_name) 405 for secret_name, secret_dict in self.secrets.items() 406 ] 407 408 for secret_dict, secret_name in secrets_items: 409 secret_settings: t.List[str] = [] 410 for field, setting in secret_dict.items(): 411 secret_settings.append(f"{field} '{setting}'") 412 if secret_settings: 413 secret_clause = ", ".join(secret_settings) 414 try: 415 cursor.execute( 416 f"CREATE OR REPLACE SECRET {secret_name} ({secret_clause});" 417 ) 418 except Exception as e: 419 raise ConfigError(f"Failed to create secret: {e}") 420 421 if self.filesystems: 422 from fsspec import filesystem # type: ignore 423 424 for file_system in self.filesystems: 425 options = file_system.copy() 426 fs = options.pop("fs") 427 fs = filesystem(fs, **options) 428 cursor.register_filesystem(fs) 429 430 for i, (alias, path_options) in enumerate( 431 (getattr(self, "catalogs", None) or {}).items() 432 ): 433 # we parse_identifier and generate to ensure that `alias` has exactly one set of quotes 434 # regardless of whether it comes in quoted or not 435 alias = exp.parse_identifier(alias, dialect="duckdb").sql( 436 identify=True, dialect="duckdb" 437 ) 438 try: 439 if isinstance(path_options, DuckDBAttachOptions): 440 query = path_options.to_sql(alias) 441 else: 442 query = f"ATTACH IF NOT EXISTS '{path_options}'" 443 if not path_options.startswith("md:"): 444 query += f" AS {alias}" 445 cursor.execute(query) 446 except BinderException as e: 447 # If a user tries to create a catalog pointing at `:memory:` and with the name `memory` 448 # then we don't want to raise since this happens by default. They are just doing this to 449 # set it as the default catalog. 450 # If a user tried to attach a MotherDuck database/share which has already by attached via 451 # `ATTACH 'md:'`, then we don't want to raise since this is expected. 452 if ( 453 not ( 454 'database with name "memory" already exists' in str(e) 455 and path_options == ":memory:" 456 ) 457 and f"""database with name "{path_options.path.replace("md:", "")}" already exists""" 458 not in str(e) 459 ): 460 raise e 461 if i == 0 and not getattr(self, "database", None): 462 cursor.execute(f"USE {alias}") 463 464 return init 465 466 def create_engine_adapter( 467 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 468 ) -> EngineAdapter: 469 """Checks if another engine adapter has already been created that shares a catalog that points to the same data 470 file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration 471 associated with the new adapter will be ignored.""" 472 data_files = set((self.catalogs or {}).values()) 473 if self.database: 474 if isinstance(self, MotherDuckConnectionConfig): 475 data_files.add( 476 f"md:{self.database}" 477 + (f"?motherduck_token={self.token}" if self.token else "") 478 ) 479 else: 480 data_files.add(self.database) 481 data_files.discard(":memory:") 482 for data_file in data_files: 483 key = data_file if isinstance(data_file, str) else data_file.path 484 adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key) 485 if adapter is not None: 486 logger.info( 487 f"Using existing DuckDB adapter due to overlapping data file: {self._mask_sensitive_data(key)}" 488 ) 489 return adapter 490 491 if data_files: 492 masked_files = { 493 self._mask_sensitive_data(file if isinstance(file, str) else file.path) 494 for file in data_files 495 } 496 logger.info(f"Creating new DuckDB adapter for data files: {masked_files}") 497 else: 498 logger.info("Creating new DuckDB adapter for in-memory database") 499 adapter = super().create_engine_adapter( 500 register_comments_override, concurrent_tasks=concurrent_tasks 501 ) 502 for data_file in data_files: 503 key = data_file if isinstance(data_file, str) else data_file.path 504 BaseDuckDBConnectionConfig._data_file_to_adapter[key] = adapter 505 return adapter 506 507 def get_catalog(self) -> t.Optional[str]: 508 if self.database: 509 # Remove `:` from the database name in order to handle if `:memory:` is passed in 510 return pathlib.Path(self.database.replace(":memory:", "memory")).stem 511 if self.catalogs: 512 return list(self.catalogs)[0] 513 return None 514 515 def _mask_sensitive_data(self, string: str) -> str: 516 # Mask MotherDuck tokens with fixed number of asterisks 517 result = MOTHERDUCK_TOKEN_REGEX.sub( 518 lambda m: f"{m.group(1)}{m.group(2)}{'*' * 8 if m.group(3) else ''}", string 519 ) 520 # Mask PostgreSQL/MySQL passwords with fixed number of asterisks 521 result = PASSWORD_REGEX.sub(lambda m: f"{m.group(1)}{'*' * 8}", result) 522 return result 523 524 525class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig): 526 """Configuration for the MotherDuck connection.""" 527 528 type_: t.Literal["motherduck"] = Field(alias="type", default="motherduck") 529 DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb" 530 DISPLAY_NAME: t.ClassVar[t.Literal["MotherDuck"]] = "MotherDuck" 531 DISPLAY_ORDER: t.ClassVar[t.Literal[5]] = 5 532 533 @property 534 def _connection_kwargs_keys(self) -> t.Set[str]: 535 return set() 536 537 @property 538 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 539 """kwargs that are for execution config only""" 540 from sqlmesh import __version__ 541 542 custom_user_agent_config = {"custom_user_agent": f"SQLMesh/{__version__}"} 543 connection_str = "md:" 544 if self.database: 545 # Attach single MD database instead of all databases on the account 546 connection_str += f"{self.database}?attach_mode=single" 547 if self.token: 548 connection_str += f"{'&' if self.database else '?'}motherduck_token={self.token}" 549 return {"database": connection_str, "config": custom_user_agent_config} 550 551 @property 552 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 553 return {"is_motherduck": True} 554 555 556class DuckDBConnectionConfig(BaseDuckDBConnectionConfig): 557 """Configuration for the DuckDB connection.""" 558 559 type_: t.Literal["duckdb"] = Field(alias="type", default="duckdb") 560 DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb" 561 DISPLAY_NAME: t.ClassVar[t.Literal["DuckDB"]] = "DuckDB" 562 DISPLAY_ORDER: t.ClassVar[t.Literal[1]] = 1 563 564 565class SnowflakeConnectionConfig(ConnectionConfig): 566 """Configuration for the Snowflake connection. 567 568 Args: 569 account: The Snowflake account name. 570 user: The Snowflake username. 571 password: The Snowflake password. 572 warehouse: The optional warehouse name. 573 database: The optional database name. 574 role: The optional role name. 575 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 576 authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake"). 577 Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183 578 token: The optional oauth access token to use for authentication when authenticator is set to "oauth". 579 private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation 580 private_key_path: The optional path to the private key to use for authentication. This would be used instead of `private_key`. 581 private_key_passphrase: The optional passphrase to use to decrypt `private_key` or `private_key_path`. Keys can be created without encryption so only provide this if needed. 582 register_comments: Whether or not to register model comments with the SQL engine. 583 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 584 session_parameters: The optional session parameters to set for the connection. 585 host: Host address for the connection. 586 port: Port for the connection. 587 """ 588 589 account: str 590 user: t.Optional[str] = None 591 password: t.Optional[str] = None 592 warehouse: t.Optional[str] = None 593 database: t.Optional[str] = None 594 role: t.Optional[str] = None 595 authenticator: t.Optional[str] = None 596 token: t.Optional[str] = None 597 host: t.Optional[str] = None 598 port: t.Optional[int] = None 599 application: t.Literal["Tobiko_SQLMesh"] = "Tobiko_SQLMesh" 600 601 # Private Key Auth 602 private_key: t.Optional[t.Union[str, bytes]] = None 603 private_key_path: t.Optional[str] = None 604 private_key_passphrase: t.Optional[str] = None 605 606 concurrent_tasks: int = 4 607 register_comments: bool = True 608 pre_ping: bool = False 609 610 session_parameters: t.Optional[dict] = None 611 612 type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake") 613 DIALECT: t.ClassVar[t.Literal["snowflake"]] = "snowflake" 614 DISPLAY_NAME: t.ClassVar[t.Literal["Snowflake"]] = "Snowflake" 615 DISPLAY_ORDER: t.ClassVar[t.Literal[2]] = 2 616 617 _concurrent_tasks_validator = concurrent_tasks_validator 618 619 @model_validator(mode="before") 620 def _validate_authenticator(cls, data: t.Any) -> t.Any: 621 if not isinstance(data, dict): 622 return data 623 624 from snowflake.connector.network import DEFAULT_AUTHENTICATOR, OAUTH_AUTHENTICATOR 625 626 auth = data.get("authenticator") 627 auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR 628 user = data.get("user") 629 password = data.get("password") 630 data["private_key"] = cls._get_private_key(data, auth) # type: ignore 631 632 if ( 633 auth == DEFAULT_AUTHENTICATOR 634 and not data.get("private_key") 635 and (not user or not password) 636 ): 637 raise ConfigError("User and password must be provided if using default authentication") 638 639 if auth == OAUTH_AUTHENTICATOR and not data.get("token"): 640 raise ConfigError("Token must be provided if using oauth authentication") 641 642 return data 643 644 _engine_import_validator = _get_engine_import_validator( 645 "snowflake.connector.network", "snowflake" 646 ) 647 648 @classmethod 649 def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]: 650 """ 651 source: https://github.com/dbt-labs/dbt-snowflake/blob/0374b4ec948982f2ac8ec0c95d53d672ad19e09c/dbt/adapters/snowflake/connections.py#L247C5-L285C1 652 653 Overall code change: Use local variables instead of class attributes + Validation 654 """ 655 # Start custom code 656 from cryptography.hazmat.backends import default_backend 657 from cryptography.hazmat.primitives import serialization 658 from snowflake.connector.network import ( 659 DEFAULT_AUTHENTICATOR, 660 KEY_PAIR_AUTHENTICATOR, 661 ) 662 663 private_key = values.get("private_key") 664 private_key_path = values.get("private_key_path") 665 private_key_passphrase = values.get("private_key_passphrase") 666 user = values.get("user") 667 password = values.get("password") 668 auth = auth if auth and auth != DEFAULT_AUTHENTICATOR else KEY_PAIR_AUTHENTICATOR 669 670 if not private_key and not private_key_path: 671 return None 672 if private_key and private_key_path: 673 raise ConfigError("Cannot specify both `private_key` and `private_key_path`") 674 if auth != KEY_PAIR_AUTHENTICATOR: 675 raise ConfigError( 676 f"Private key or private key path can only be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 677 ) 678 if not user: 679 raise ConfigError( 680 f"User must be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 681 ) 682 if password: 683 raise ConfigError( 684 f"Password cannot be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 685 ) 686 687 if isinstance(private_key, bytes): 688 return private_key 689 # End Custom Code 690 691 if private_key_passphrase: 692 encoded_passphrase = private_key_passphrase.encode() 693 else: 694 encoded_passphrase = None 695 696 if private_key: 697 if private_key.startswith("-"): 698 p_key = serialization.load_pem_private_key( 699 data=bytes(private_key, "utf-8"), 700 password=encoded_passphrase, 701 backend=default_backend(), 702 ) 703 704 else: 705 p_key = serialization.load_der_private_key( 706 data=base64.b64decode(private_key), 707 password=encoded_passphrase, 708 backend=default_backend(), 709 ) 710 711 elif private_key_path: 712 with open(private_key_path, "rb") as key: 713 p_key = serialization.load_pem_private_key( 714 key.read(), password=encoded_passphrase, backend=default_backend() 715 ) 716 else: 717 return None 718 719 return p_key.private_bytes( 720 encoding=serialization.Encoding.DER, 721 format=serialization.PrivateFormat.PKCS8, 722 encryption_algorithm=serialization.NoEncryption(), 723 ) 724 725 @property 726 def _connection_kwargs_keys(self) -> t.Set[str]: 727 return { 728 "user", 729 "password", 730 "account", 731 "warehouse", 732 "database", 733 "role", 734 "authenticator", 735 "token", 736 "private_key", 737 "session_parameters", 738 "application", 739 "host", 740 "port", 741 } 742 743 @property 744 def _engine_adapter(self) -> t.Type[EngineAdapter]: 745 return engine_adapter.SnowflakeEngineAdapter 746 747 @property 748 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 749 return {"autocommit": False} 750 751 @property 752 def _connection_factory(self) -> t.Callable: 753 from snowflake import connector 754 755 return connector.connect 756 757 758class DatabricksConnectionConfig(ConnectionConfig): 759 """ 760 Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations 761 762 Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39 763 OAuth ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication 764 765 Args: 766 server_hostname: Databricks instance host name. 767 http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) 768 or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) 769 access_token: Http Bearer access token, e.g. Databricks Personal Access Token. 770 auth_type: Set to 'databricks-oauth' or 'azure-oauth' to trigger OAuth (or dont set at all to use `access_token`) 771 oauth_client_id: Client ID to use when auth_type is set to one of the 'oauth' types 772 oauth_client_secret: Client Secret to use when auth_type is set to one of the 'oauth' types 773 catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in 774 the Databricks cluster (most likely `hive_metastore`). 775 http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request 776 session_configuration: An optional dictionary of Spark session parameters. 777 Execute the SQL command `SET -v` to get a full list of available commands. 778 databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect. 779 Defaults to the `server_hostname` value. 780 databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect. 781 Defaults to the `access_token` value. 782 databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect. 783 Defaults to deriving the cluster id from the `http_path` value. 784 force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector. 785 disable_databricks_connect: Even if databricks connect is installed, do not use it. 786 disable_spark_session: Do not use SparkSession if it is available (like when running in a notebook). 787 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 788 """ 789 790 server_hostname: t.Optional[str] = None 791 http_path: t.Optional[str] = None 792 access_token: t.Optional[str] = None 793 auth_type: t.Optional[str] = None 794 oauth_client_id: t.Optional[str] = None 795 oauth_client_secret: t.Optional[str] = None 796 catalog: t.Optional[str] = None 797 http_headers: t.Optional[t.List[t.Tuple[str, str]]] = None 798 session_configuration: t.Optional[t.Dict[str, t.Any]] = None 799 databricks_connect_server_hostname: t.Optional[str] = None 800 databricks_connect_access_token: t.Optional[str] = None 801 databricks_connect_cluster_id: t.Optional[str] = None 802 databricks_connect_use_serverless: bool = False 803 force_databricks_connect: bool = False 804 disable_databricks_connect: bool = False 805 disable_spark_session: bool = False 806 807 concurrent_tasks: int = 1 808 register_comments: bool = True 809 pre_ping: t.Literal[False] = False 810 811 type_: t.Literal["databricks"] = Field(alias="type", default="databricks") 812 DIALECT: t.ClassVar[t.Literal["databricks"]] = "databricks" 813 DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks" 814 DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3 815 816 _concurrent_tasks_validator = concurrent_tasks_validator 817 _http_headers_validator = http_headers_validator 818 819 @model_validator(mode="before") 820 def _databricks_connect_validator(cls, data: t.Any) -> t.Any: 821 # SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block. 822 # Disabling this allows SQLMesh to determine what should be shown to the user. 823 # Ex: We describe a table to see if it exists and therefore that execution can fail but we don't need to show 824 # the user since it is expected if the table doesn't exist. Without this change the user would see the error. 825 logging.getLogger("SQLQueryContextLogger").setLevel(logging.CRITICAL) 826 827 if not isinstance(data, dict): 828 return data 829 830 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 831 832 if DatabricksEngineAdapter.can_access_spark_session( 833 bool(data.get("disable_spark_session")) 834 ): 835 return data 836 837 databricks_connect_use_serverless = data.get("databricks_connect_use_serverless") 838 server_hostname, http_path, access_token, auth_type = ( 839 data.get("server_hostname"), 840 data.get("http_path"), 841 data.get("access_token"), 842 data.get("auth_type"), 843 ) 844 845 if (not server_hostname or not http_path or not access_token) and ( 846 not databricks_connect_use_serverless and not auth_type 847 ): 848 raise ValueError( 849 "`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook" 850 ) 851 if ( 852 databricks_connect_use_serverless 853 and not server_hostname 854 and not data.get("databricks_connect_server_hostname") 855 ): 856 raise ValueError( 857 "`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set" 858 ) 859 if DatabricksEngineAdapter.can_access_databricks_connect( 860 bool(data.get("disable_databricks_connect")) 861 ): 862 if not data.get("databricks_connect_access_token"): 863 data["databricks_connect_access_token"] = access_token 864 if not data.get("databricks_connect_server_hostname"): 865 data["databricks_connect_server_hostname"] = f"https://{server_hostname}" 866 if not databricks_connect_use_serverless and not data.get( 867 "databricks_connect_cluster_id" 868 ): 869 if t.TYPE_CHECKING: 870 assert http_path is not None 871 data["databricks_connect_cluster_id"] = http_path.split("/")[-1] 872 873 if auth_type: 874 from databricks.sql.auth.auth import AuthType 875 876 all_data = [m.value for m in AuthType] 877 if auth_type not in all_data: 878 raise ValueError( 879 f"`auth_type` {auth_type} does not match a valid option: {all_data}" 880 ) 881 882 client_id = data.get("oauth_client_id") 883 client_secret = data.get("oauth_client_secret") 884 885 if client_secret and not client_id: 886 raise ValueError( 887 "`oauth_client_id` is required when `oauth_client_secret` is specified" 888 ) 889 890 if not http_path: 891 raise ValueError("`http_path` is still required when using `auth_type`") 892 893 return data 894 895 _engine_import_validator = _get_engine_import_validator("databricks", "databricks") 896 897 @property 898 def _connection_kwargs_keys(self) -> t.Set[str]: 899 if self.use_spark_session_only: 900 return set() 901 return { 902 "server_hostname", 903 "http_path", 904 "access_token", 905 "http_headers", 906 "session_configuration", 907 "catalog", 908 } 909 910 @property 911 def _engine_adapter(self) -> t.Type[engine_adapter.DatabricksEngineAdapter]: 912 return engine_adapter.DatabricksEngineAdapter 913 914 @property 915 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 916 return { 917 k: v 918 for k, v in self.dict().items() 919 if k.startswith("databricks_connect_") 920 or k in ("catalog", "disable_databricks_connect", "disable_spark_session") 921 } 922 923 @property 924 def use_spark_session_only(self) -> bool: 925 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 926 927 return ( 928 DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session) 929 or self.force_databricks_connect 930 ) 931 932 @property 933 def _connection_factory(self) -> t.Callable: 934 if self.use_spark_session_only: 935 from sqlmesh.engines.spark.db_api.spark_session import connection 936 937 return connection 938 939 from databricks import sql # type: ignore 940 941 return sql.connect 942 943 @property 944 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 945 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 946 947 if not self.use_spark_session_only: 948 conn_kwargs: t.Dict[str, t.Any] = { 949 "_user_agent_entry": "sqlmesh", 950 } 951 952 if self.auth_type and "oauth" in self.auth_type: 953 # there are two types of oauth: User-to-Machine (U2M) and Machine-to-Machine (M2M) 954 if self.oauth_client_secret: 955 # if a client_secret exists, then a client_id also exists and we are using M2M 956 # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication 957 # ref: https://github.com/databricks/databricks-sql-python/blob/main/examples/m2m_oauth.py 958 from databricks.sdk.core import oauth_service_principal, Config 959 960 config = Config( 961 host=f"https://{self.server_hostname}", 962 client_id=self.oauth_client_id, 963 client_secret=self.oauth_client_secret, 964 ) 965 conn_kwargs["credentials_provider"] = lambda: oauth_service_principal(config) 966 else: 967 # if auth_type is set to an 'oauth' type but no client_id/secret are set, then we are using U2M 968 # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-user-to-machine-u2m-authentication 969 conn_kwargs["auth_type"] = self.auth_type 970 971 return conn_kwargs 972 973 if DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session): 974 from pyspark.sql import SparkSession 975 976 return dict( 977 spark=SparkSession.getActiveSession(), 978 catalog=self.catalog, 979 ) 980 981 from databricks.connect import DatabricksSession 982 983 if t.TYPE_CHECKING: 984 assert self.databricks_connect_server_hostname is not None 985 assert self.databricks_connect_access_token is not None 986 987 if self.databricks_connect_use_serverless: 988 builder = DatabricksSession.builder.remote( 989 host=self.databricks_connect_server_hostname, 990 token=self.databricks_connect_access_token, 991 serverless=True, 992 ) 993 else: 994 if t.TYPE_CHECKING: 995 assert self.databricks_connect_cluster_id is not None 996 builder = DatabricksSession.builder.remote( 997 host=self.databricks_connect_server_hostname, 998 token=self.databricks_connect_access_token, 999 cluster_id=self.databricks_connect_cluster_id, 1000 ) 1001 1002 return dict( 1003 spark=builder.userAgent("sqlmesh").getOrCreate(), 1004 catalog=self.catalog, 1005 ) 1006 1007 1008class BigQueryConnectionMethod(str, Enum): 1009 OAUTH = "oauth" 1010 OAUTH_SECRETS = "oauth-secrets" 1011 SERVICE_ACCOUNT = "service-account" 1012 SERVICE_ACCOUNT_JSON = "service-account-json" 1013 1014 1015class BigQueryPriority(str, Enum): 1016 BATCH = "batch" 1017 INTERACTIVE = "interactive" 1018 1019 @property 1020 def is_batch(self) -> bool: 1021 return self == self.BATCH 1022 1023 @property 1024 def is_interactive(self) -> bool: 1025 return self == self.INTERACTIVE 1026 1027 @property 1028 def bigquery_constant(self) -> str: 1029 from google.cloud.bigquery import QueryPriority 1030 1031 if self.is_batch: 1032 return QueryPriority.BATCH 1033 return QueryPriority.INTERACTIVE 1034 1035 1036class BigQueryConnectionConfig(ConnectionConfig): 1037 """ 1038 BigQuery Connection Configuration. 1039 """ 1040 1041 method: BigQueryConnectionMethod = BigQueryConnectionMethod.OAUTH 1042 1043 project: t.Optional[str] = None 1044 execution_project: t.Optional[str] = None 1045 quota_project: t.Optional[str] = None 1046 location: t.Optional[str] = None 1047 # Keyfile Auth 1048 keyfile: t.Optional[str] = None 1049 keyfile_json: t.Optional[t.Dict[str, t.Any]] = None 1050 # Oath Secret Auth 1051 token: t.Optional[str] = None 1052 refresh_token: t.Optional[str] = None 1053 client_id: t.Optional[str] = None 1054 client_secret: t.Optional[str] = None 1055 token_uri: t.Optional[str] = None 1056 scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/bigquery",) 1057 impersonated_service_account: t.Optional[str] = None 1058 # Extra Engine Config 1059 job_creation_timeout_seconds: t.Optional[int] = None 1060 job_execution_timeout_seconds: t.Optional[int] = None 1061 job_retries: t.Optional[int] = 1 1062 job_retry_deadline_seconds: t.Optional[int] = None 1063 priority: t.Optional[BigQueryPriority] = None 1064 maximum_bytes_billed: t.Optional[int] = None 1065 1066 concurrent_tasks: int = 1 1067 register_comments: bool = True 1068 pre_ping: t.Literal[False] = False 1069 1070 type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery") 1071 DIALECT: t.ClassVar[t.Literal["bigquery"]] = "bigquery" 1072 DISPLAY_NAME: t.ClassVar[t.Literal["BigQuery"]] = "BigQuery" 1073 DISPLAY_ORDER: t.ClassVar[t.Literal[4]] = 4 1074 1075 _engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery") 1076 1077 @field_validator("execution_project") 1078 def validate_execution_project( 1079 cls, 1080 v: t.Optional[str], 1081 info: ValidationInfo, 1082 ) -> t.Optional[str]: 1083 if v and not info.data.get("project"): 1084 raise ConfigError( 1085 "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location." 1086 ) 1087 return v 1088 1089 @field_validator("quota_project") 1090 def validate_quota_project( 1091 cls, 1092 v: t.Optional[str], 1093 info: ValidationInfo, 1094 ) -> t.Optional[str]: 1095 if v and not info.data.get("project"): 1096 raise ConfigError( 1097 "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location." 1098 ) 1099 return v 1100 1101 @property 1102 def _connection_kwargs_keys(self) -> t.Set[str]: 1103 return set() 1104 1105 @property 1106 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1107 return engine_adapter.BigQueryEngineAdapter 1108 1109 @property 1110 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1111 """The static connection kwargs for this connection""" 1112 import google.auth 1113 from google.auth import impersonated_credentials 1114 from google.api_core import client_info, client_options 1115 from google.oauth2 import credentials, service_account 1116 1117 if self.method == BigQueryConnectionMethod.OAUTH: 1118 creds, _ = google.auth.default(scopes=self.scopes) 1119 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT: 1120 creds = service_account.Credentials.from_service_account_file( 1121 self.keyfile, scopes=self.scopes 1122 ) 1123 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: 1124 creds = service_account.Credentials.from_service_account_info( 1125 self.keyfile_json, scopes=self.scopes 1126 ) 1127 elif self.method == BigQueryConnectionMethod.OAUTH_SECRETS: 1128 creds = credentials.Credentials( 1129 token=self.token, 1130 refresh_token=self.refresh_token, 1131 client_id=self.client_id, 1132 client_secret=self.client_secret, 1133 token_uri=self.token_uri, 1134 scopes=self.scopes, 1135 ) 1136 else: 1137 raise ConfigError("Invalid BigQuery Connection Method") 1138 1139 if self.impersonated_service_account: 1140 creds = impersonated_credentials.Credentials( 1141 source_credentials=creds, 1142 target_principal=self.impersonated_service_account, 1143 target_scopes=self.scopes, 1144 ) 1145 1146 options = client_options.ClientOptions(quota_project_id=self.quota_project) 1147 project = self.execution_project or self.project or None 1148 1149 client = google.cloud.bigquery.Client( 1150 project=project and exp.parse_identifier(project, dialect="bigquery").name, 1151 credentials=creds, 1152 location=self.location, 1153 client_info=client_info.ClientInfo(user_agent="sqlmesh"), 1154 client_options=options, 1155 ) 1156 1157 return { 1158 "client": client, 1159 } 1160 1161 @property 1162 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1163 return { 1164 k: v 1165 for k, v in self.dict().items() 1166 if k 1167 in { 1168 "job_creation_timeout_seconds", 1169 "job_execution_timeout_seconds", 1170 "job_retries", 1171 "job_retry_deadline_seconds", 1172 "priority", 1173 "maximum_bytes_billed", 1174 } 1175 } 1176 1177 @property 1178 def _connection_factory(self) -> t.Callable: 1179 from google.cloud.bigquery.dbapi import connect 1180 1181 return connect 1182 1183 def get_catalog(self) -> t.Optional[str]: 1184 return self.project 1185 1186 1187class GCPPostgresConnectionConfig(ConnectionConfig): 1188 """ 1189 Postgres Connection Configuration for GCP. 1190 1191 Args: 1192 instance_connection_string: Connection name for the postgres instance. 1193 user: Postgres or IAM user's name 1194 password: The postgres user's password. Only needed when the user is a postgres user. 1195 enable_iam_auth: Set to True when user is an IAM user. 1196 db: Name of the db to connect to. 1197 keyfile: string path to json service account credentials file 1198 keyfile_json: dict service account credentials info 1199 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 1200 """ 1201 1202 instance_connection_string: str 1203 user: str 1204 password: t.Optional[str] = None 1205 enable_iam_auth: t.Optional[bool] = None 1206 db: str 1207 ip_type: t.Union[t.Literal["public"], t.Literal["private"], t.Literal["psc"]] = "public" 1208 # Keyfile Auth 1209 keyfile: t.Optional[str] = None 1210 keyfile_json: t.Optional[t.Dict[str, t.Any]] = None 1211 timeout: t.Optional[int] = None 1212 scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/sqlservice.admin",) 1213 driver: str = "pg8000" 1214 1215 type_: t.Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres") 1216 DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres" 1217 DISPLAY_NAME: t.ClassVar[t.Literal["GCP Postgres"]] = "GCP Postgres" 1218 DISPLAY_ORDER: t.ClassVar[t.Literal[13]] = 13 1219 1220 concurrent_tasks: int = 4 1221 register_comments: bool = True 1222 pre_ping: bool = True 1223 1224 _engine_import_validator = _get_engine_import_validator( 1225 "google.cloud.sql", "gcp_postgres", "gcppostgres" 1226 ) 1227 1228 @model_validator(mode="before") 1229 def _validate_auth_method(cls, data: t.Any) -> t.Any: 1230 if not isinstance(data, dict): 1231 return data 1232 1233 password = data.get("password") 1234 enable_iam_auth = data.get("enable_iam_auth") 1235 1236 if not password and not enable_iam_auth: 1237 raise ConfigError( 1238 "GCP Postgres connection configuration requires either password set" 1239 " for a postgres user account or enable_iam_auth set to 'True'" 1240 " for an IAM user account." 1241 ) 1242 1243 return data 1244 1245 @property 1246 def _connection_kwargs_keys(self) -> t.Set[str]: 1247 return { 1248 "instance_connection_string", 1249 "driver", 1250 "user", 1251 "password", 1252 "db", 1253 "enable_iam_auth", 1254 "timeout", 1255 } 1256 1257 @property 1258 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1259 return engine_adapter.PostgresEngineAdapter 1260 1261 @property 1262 def _connection_factory(self) -> t.Callable: 1263 from google.cloud.sql.connector import Connector 1264 from google.oauth2 import service_account 1265 1266 creds = None 1267 if self.keyfile: 1268 creds = service_account.Credentials.from_service_account_file( 1269 self.keyfile, scopes=self.scopes 1270 ) 1271 elif self.keyfile_json: 1272 creds = service_account.Credentials.from_service_account_info( 1273 self.keyfile_json, scopes=self.scopes 1274 ) 1275 1276 kwargs = { 1277 "credentials": creds, 1278 "ip_type": self.ip_type, 1279 } 1280 1281 if self.timeout: 1282 kwargs["timeout"] = self.timeout 1283 1284 return Connector(**kwargs).connect # type: ignore 1285 1286 1287class RedshiftConnectionConfig(ConnectionConfig): 1288 """ 1289 Redshift Connection Configuration. 1290 1291 Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146 1292 Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported. 1293 1294 Args: 1295 user: The username to use for authentication with the Amazon Redshift cluster. 1296 password: The password to use for authentication with the Amazon Redshift cluster. 1297 database: The name of the database instance to connect to. 1298 host: The hostname of the Amazon Redshift cluster. 1299 port: The port number of the Amazon Redshift cluster. Default value is 5439. 1300 source_address: No description provided 1301 unix_sock: No description provided 1302 ssl: Is SSL enabled. Default value is ``True``. SSL must be enabled when authenticating using IAM. 1303 sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported. 1304 timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout. 1305 tcp_keepalive: Is `TCP keepalive <https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive>`_ used. The default value is ``True``. 1306 application_name: Sets the application name. The default value is None. 1307 preferred_role: The IAM role preferred for the current connection. 1308 principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy. 1309 credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster. 1310 region: The AWS region where the Amazon Redshift cluster is located. 1311 cluster_identifier: The cluster identifier of the Amazon Redshift cluster. 1312 iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP. 1313 is_serverless: Redshift end-point is serverless or provisional. Default value false. 1314 serverless_acct_id: The account ID of the serverless. Default value None 1315 serverless_work_group: The name of work group for serverless end point. Default value None. 1316 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 1317 enable_merge: Whether to use the Redshift merge operation instead of the SQLMesh logical merge. 1318 """ 1319 1320 user: t.Optional[str] = None 1321 password: t.Optional[str] = None 1322 database: t.Optional[str] = None 1323 host: t.Optional[str] = None 1324 port: t.Optional[int] = None 1325 source_address: t.Optional[str] = None 1326 unix_sock: t.Optional[str] = None 1327 ssl: t.Optional[bool] = None 1328 sslmode: t.Optional[str] = None 1329 timeout: t.Optional[int] = None 1330 tcp_keepalive: t.Optional[bool] = None 1331 application_name: t.Optional[str] = None 1332 preferred_role: t.Optional[str] = None 1333 principal_arn: t.Optional[str] = None 1334 credentials_provider: t.Optional[str] = None 1335 region: t.Optional[str] = None 1336 cluster_identifier: t.Optional[str] = None 1337 iam: t.Optional[bool] = None 1338 is_serverless: t.Optional[bool] = None 1339 serverless_acct_id: t.Optional[str] = None 1340 serverless_work_group: t.Optional[str] = None 1341 enable_merge: t.Optional[bool] = None 1342 1343 concurrent_tasks: int = 4 1344 register_comments: bool = True 1345 pre_ping: bool = False 1346 1347 type_: t.Literal["redshift"] = Field(alias="type", default="redshift") 1348 DIALECT: t.ClassVar[t.Literal["redshift"]] = "redshift" 1349 DISPLAY_NAME: t.ClassVar[t.Literal["Redshift"]] = "Redshift" 1350 DISPLAY_ORDER: t.ClassVar[t.Literal[7]] = 7 1351 1352 _engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift") 1353 1354 @property 1355 def _connection_kwargs_keys(self) -> t.Set[str]: 1356 return { 1357 "user", 1358 "password", 1359 "database", 1360 "host", 1361 "port", 1362 "source_address", 1363 "unix_sock", 1364 "ssl", 1365 "sslmode", 1366 "timeout", 1367 "tcp_keepalive", 1368 "application_name", 1369 "preferred_role", 1370 "principal_arn", 1371 "credentials_provider", 1372 "region", 1373 "cluster_identifier", 1374 "iam", 1375 "is_serverless", 1376 "serverless_acct_id", 1377 "serverless_work_group", 1378 } 1379 1380 @property 1381 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1382 return engine_adapter.RedshiftEngineAdapter 1383 1384 @property 1385 def _connection_factory(self) -> t.Callable: 1386 from redshift_connector import connect 1387 1388 return connect 1389 1390 @property 1391 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1392 return {"enable_merge": self.enable_merge} 1393 1394 1395class PostgresConnectionConfig(ConnectionConfig): 1396 host: str 1397 user: str 1398 password: str 1399 port: int 1400 database: str 1401 keepalives_idle: t.Optional[int] = None 1402 connect_timeout: int = 10 1403 role: t.Optional[str] = None 1404 sslmode: t.Optional[str] = None 1405 application_name: t.Optional[str] = None 1406 1407 concurrent_tasks: int = 4 1408 register_comments: bool = True 1409 pre_ping: bool = True 1410 1411 type_: t.Literal["postgres"] = Field(alias="type", default="postgres") 1412 DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres" 1413 DISPLAY_NAME: t.ClassVar[t.Literal["Postgres"]] = "Postgres" 1414 DISPLAY_ORDER: t.ClassVar[t.Literal[12]] = 12 1415 1416 _engine_import_validator = _get_engine_import_validator("psycopg2", "postgres") 1417 1418 @property 1419 def _connection_kwargs_keys(self) -> t.Set[str]: 1420 return { 1421 "host", 1422 "user", 1423 "password", 1424 "port", 1425 "database", 1426 "keepalives_idle", 1427 "connect_timeout", 1428 "sslmode", 1429 "application_name", 1430 } 1431 1432 @property 1433 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1434 return engine_adapter.PostgresEngineAdapter 1435 1436 @property 1437 def _connection_factory(self) -> t.Callable: 1438 from psycopg2 import connect 1439 1440 return connect 1441 1442 @property 1443 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 1444 if not self.role: 1445 return None 1446 1447 def init(cursor: t.Any) -> None: 1448 cursor.execute(f"SET ROLE {self.role}") 1449 1450 return init 1451 1452 1453class MySQLConnectionConfig(ConnectionConfig): 1454 host: str 1455 user: str 1456 password: str 1457 port: t.Optional[int] = None 1458 database: t.Optional[str] = None 1459 charset: t.Optional[str] = None 1460 collation: t.Optional[str] = None 1461 ssl_disabled: t.Optional[bool] = None 1462 1463 concurrent_tasks: int = 4 1464 register_comments: bool = True 1465 pre_ping: bool = True 1466 1467 type_: t.Literal["mysql"] = Field(alias="type", default="mysql") 1468 DIALECT: t.ClassVar[t.Literal["mysql"]] = "mysql" 1469 DISPLAY_NAME: t.ClassVar[t.Literal["MySQL"]] = "MySQL" 1470 DISPLAY_ORDER: t.ClassVar[t.Literal[14]] = 14 1471 1472 _engine_import_validator = _get_engine_import_validator("pymysql", "mysql") 1473 1474 @property 1475 def _connection_kwargs_keys(self) -> t.Set[str]: 1476 connection_keys = { 1477 "host", 1478 "user", 1479 "password", 1480 } 1481 if self.port is not None: 1482 connection_keys.add("port") 1483 if self.database is not None: 1484 connection_keys.add("database") 1485 if self.charset is not None: 1486 connection_keys.add("charset") 1487 if self.collation is not None: 1488 connection_keys.add("collation") 1489 if self.ssl_disabled is not None: 1490 connection_keys.add("ssl_disabled") 1491 return connection_keys 1492 1493 @property 1494 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1495 return engine_adapter.MySQLEngineAdapter 1496 1497 @property 1498 def _connection_factory(self) -> t.Callable: 1499 from pymysql import connect 1500 1501 return connect 1502 1503 1504class MSSQLConnectionConfig(ConnectionConfig): 1505 host: str 1506 user: t.Optional[str] = None 1507 password: t.Optional[str] = None 1508 database: t.Optional[str] = "" 1509 timeout: t.Optional[int] = 0 1510 login_timeout: t.Optional[int] = 60 1511 charset: t.Optional[str] = "UTF-8" 1512 appname: t.Optional[str] = None 1513 port: t.Optional[int] = 1433 1514 conn_properties: t.Optional[t.Union[t.List[str], str]] = None 1515 autocommit: t.Optional[bool] = False 1516 tds_version: t.Optional[str] = None 1517 1518 # Driver options 1519 driver: t.Literal["pymssql", "pyodbc"] = "pymssql" 1520 # PyODBC specific options 1521 driver_name: t.Optional[str] = None # e.g. "ODBC Driver 18 for SQL Server" 1522 trust_server_certificate: t.Optional[bool] = None 1523 encrypt: t.Optional[bool] = None 1524 # Dictionary of arbitrary ODBC connection properties 1525 # See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute 1526 odbc_properties: t.Optional[t.Dict[str, t.Any]] = None 1527 1528 concurrent_tasks: int = 4 1529 register_comments: bool = True 1530 pre_ping: bool = True 1531 1532 type_: t.Literal["mssql"] = Field(alias="type", default="mssql") 1533 DIALECT: t.ClassVar[t.Literal["tsql"]] = "tsql" 1534 DISPLAY_NAME: t.ClassVar[t.Literal["MSSQL"]] = "MSSQL" 1535 DISPLAY_ORDER: t.ClassVar[t.Literal[11]] = 11 1536 1537 @model_validator(mode="before") 1538 @classmethod 1539 def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any: 1540 if not isinstance(data, dict): 1541 return data 1542 1543 driver = data.get("driver", "pymssql") 1544 1545 # Define the mapping of driver to import module and extra name 1546 driver_configs = {"pymssql": ("pymssql", "mssql"), "pyodbc": ("pyodbc", "mssql-odbc")} 1547 1548 if driver not in driver_configs: 1549 raise ValueError(f"Unsupported driver: {driver}") 1550 1551 import_module, extra_name = driver_configs[driver] 1552 1553 # Use _get_engine_import_validator with decorate=False to get the raw validation function 1554 # This avoids the __wrapped__ issue in Python 3.9 1555 validator_func = _get_engine_import_validator( 1556 import_module, driver, extra_name, decorate=False 1557 ) 1558 1559 # Call the raw validation function directly 1560 return validator_func(cls, data) 1561 1562 @property 1563 def _connection_kwargs_keys(self) -> t.Set[str]: 1564 base_keys = { 1565 "host", 1566 "user", 1567 "password", 1568 "database", 1569 "timeout", 1570 "login_timeout", 1571 "charset", 1572 "appname", 1573 "port", 1574 "conn_properties", 1575 "autocommit", 1576 "tds_version", 1577 } 1578 1579 if self.driver == "pyodbc": 1580 base_keys.update( 1581 { 1582 "driver_name", 1583 "trust_server_certificate", 1584 "encrypt", 1585 "odbc_properties", 1586 } 1587 ) 1588 # Remove pymssql-specific parameters 1589 base_keys.discard("tds_version") 1590 base_keys.discard("conn_properties") 1591 1592 return base_keys 1593 1594 @property 1595 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1596 return engine_adapter.MSSQLEngineAdapter 1597 1598 @property 1599 def _connection_factory(self) -> t.Callable: 1600 if self.driver == "pymssql": 1601 import pymssql 1602 1603 return pymssql.connect 1604 1605 import pyodbc 1606 1607 def connect(**kwargs: t.Any) -> t.Callable: 1608 # Extract parameters for connection string 1609 host = kwargs.pop("host") 1610 port = kwargs.pop("port", 1433) 1611 database = kwargs.pop("database", "") 1612 user = kwargs.pop("user", None) 1613 password = kwargs.pop("password", None) 1614 driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server") 1615 trust_server_certificate = kwargs.pop("trust_server_certificate", False) 1616 encrypt = kwargs.pop("encrypt", True) 1617 login_timeout = kwargs.pop("login_timeout", 60) 1618 1619 # Build connection string 1620 conn_str_parts = [ 1621 f"DRIVER={{{driver_name}}}", 1622 f"SERVER={host},{port}", 1623 ] 1624 1625 if database: 1626 conn_str_parts.append(f"DATABASE={database}") 1627 1628 # Add security options 1629 conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}") 1630 if trust_server_certificate: 1631 conn_str_parts.append("TrustServerCertificate=YES") 1632 1633 conn_str_parts.append(f"Connection Timeout={login_timeout}") 1634 1635 # Standard SQL Server authentication 1636 if user: 1637 conn_str_parts.append(f"UID={user}") 1638 if password: 1639 conn_str_parts.append(f"PWD={password}") 1640 1641 # Add any additional ODBC properties from the odbc_properties dictionary 1642 if self.odbc_properties: 1643 for key, value in self.odbc_properties.items(): 1644 # Skip properties that we've already set above 1645 if key.lower() in ( 1646 "driver", 1647 "server", 1648 "database", 1649 "uid", 1650 "pwd", 1651 "encrypt", 1652 "trustservercertificate", 1653 "connection timeout", 1654 ): 1655 continue 1656 1657 # Handle boolean values properly 1658 if isinstance(value, bool): 1659 conn_str_parts.append(f"{key}={'YES' if value else 'NO'}") 1660 else: 1661 conn_str_parts.append(f"{key}={value}") 1662 1663 # Create the connection string 1664 conn_str = ";".join(conn_str_parts) 1665 1666 conn = pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False)) 1667 1668 # Set up output converters for MSSQL-specific data types 1669 # Handle SQL type -155 (DATETIMEOFFSET) which is not yet supported by pyodbc 1670 # ref: https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794 1671 def handle_datetimeoffset(dto_value: t.Any) -> t.Any: 1672 from datetime import datetime, timedelta, timezone 1673 import struct 1674 1675 # Unpack the DATETIMEOFFSET binary format: 1676 # Format: <6hI2h = (year, month, day, hour, minute, second, nanoseconds, tz_hour_offset, tz_minute_offset) 1677 tup = struct.unpack("<6hI2h", dto_value) 1678 return datetime( 1679 tup[0], 1680 tup[1], 1681 tup[2], 1682 tup[3], 1683 tup[4], 1684 tup[5], 1685 tup[6] // 1000, 1686 timezone(timedelta(hours=tup[7], minutes=tup[8])), 1687 ) 1688 1689 conn.add_output_converter(-155, handle_datetimeoffset) 1690 1691 return conn 1692 1693 return connect 1694 1695 @property 1696 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1697 return {"catalog_support": CatalogSupport.REQUIRES_SET_CATALOG} 1698 1699 1700class AzureSQLConnectionConfig(MSSQLConnectionConfig): 1701 type_: t.Literal["azuresql"] = Field(alias="type", default="azuresql") # type: ignore 1702 DISPLAY_NAME: t.ClassVar[t.Literal["Azure SQL"]] = "Azure SQL" # type: ignore 1703 DISPLAY_ORDER: t.ClassVar[t.Literal[10]] = 10 # type: ignore 1704 1705 @property 1706 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1707 return {"catalog_support": CatalogSupport.SINGLE_CATALOG_ONLY} 1708 1709 1710class FabricConnectionConfig(MSSQLConnectionConfig): 1711 """ 1712 Fabric Connection Configuration. 1713 Inherits most settings from MSSQLConnectionConfig and sets the type to 'fabric'. 1714 It is recommended to use the 'pyodbc' driver for Fabric. 1715 """ 1716 1717 type_: t.Literal["fabric"] = Field(alias="type", default="fabric") # type: ignore 1718 DIALECT: t.ClassVar[t.Literal["fabric"]] = "fabric" # type: ignore 1719 DISPLAY_NAME: t.ClassVar[t.Literal["Fabric"]] = "Fabric" # type: ignore 1720 DISPLAY_ORDER: t.ClassVar[t.Literal[17]] = 17 # type: ignore 1721 driver: t.Literal["pyodbc"] = "pyodbc" 1722 workspace_id: str 1723 tenant_id: str 1724 autocommit: t.Optional[bool] = True 1725 1726 @property 1727 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1728 from sqlmesh.core.engine_adapter.fabric import FabricEngineAdapter 1729 1730 return FabricEngineAdapter 1731 1732 @property 1733 def _connection_factory(self) -> t.Callable: 1734 # Override to support catalog switching for Fabric 1735 base_factory = super()._connection_factory 1736 1737 def create_fabric_connection( 1738 target_catalog: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any 1739 ) -> t.Callable: 1740 kwargs["database"] = target_catalog or self.database 1741 return base_factory(*args, **kwargs) 1742 1743 return create_fabric_connection 1744 1745 @property 1746 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1747 return { 1748 "database": self.database, 1749 # more operations than not require a specific catalog to be already active 1750 # in particular, create/drop view, create/drop schema and querying information_schema 1751 "catalog_support": CatalogSupport.REQUIRES_SET_CATALOG, 1752 "workspace_id": self.workspace_id, 1753 "tenant_id": self.tenant_id, 1754 "user": self.user, 1755 "password": self.password, 1756 } 1757 1758 1759class SparkConnectionConfig(ConnectionConfig): 1760 """ 1761 Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks. 1762 """ 1763 1764 config_dir: t.Optional[str] = None 1765 catalog: t.Optional[str] = None 1766 config: t.Dict[str, t.Any] = {} 1767 wap_enabled: bool = False 1768 1769 concurrent_tasks: int = 4 1770 register_comments: bool = True 1771 pre_ping: t.Literal[False] = False 1772 1773 type_: t.Literal["spark"] = Field(alias="type", default="spark") 1774 DIALECT: t.ClassVar[t.Literal["spark"]] = "spark" 1775 DISPLAY_NAME: t.ClassVar[t.Literal["Spark"]] = "Spark" 1776 DISPLAY_ORDER: t.ClassVar[t.Literal[8]] = 8 1777 1778 _engine_import_validator = _get_engine_import_validator("pyspark", "spark") 1779 1780 @property 1781 def _connection_kwargs_keys(self) -> t.Set[str]: 1782 return { 1783 "catalog", 1784 } 1785 1786 @property 1787 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1788 return engine_adapter.SparkEngineAdapter 1789 1790 @property 1791 def _connection_factory(self) -> t.Callable: 1792 from sqlmesh.engines.spark.db_api.spark_session import connection 1793 1794 return connection 1795 1796 @property 1797 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1798 from pyspark.conf import SparkConf 1799 from pyspark.sql import SparkSession 1800 1801 spark_config = SparkConf() 1802 if self.config: 1803 for k, v in self.config.items(): 1804 spark_config.set(k, v) 1805 1806 if self.config_dir: 1807 os.environ["SPARK_CONF_DIR"] = self.config_dir 1808 return { 1809 "spark": SparkSession.builder.config(conf=spark_config) 1810 .enableHiveSupport() 1811 .getOrCreate(), 1812 } 1813 1814 @property 1815 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1816 return {"wap_enabled": self.wap_enabled} 1817 1818 1819class TrinoAuthenticationMethod(str, Enum): 1820 NO_AUTH = "no-auth" 1821 BASIC = "basic" 1822 LDAP = "ldap" 1823 KERBEROS = "kerberos" 1824 JWT = "jwt" 1825 CERTIFICATE = "certificate" 1826 OAUTH = "oauth" 1827 1828 @property 1829 def is_no_auth(self) -> bool: 1830 return self == self.NO_AUTH 1831 1832 @property 1833 def is_basic(self) -> bool: 1834 return self == self.BASIC 1835 1836 @property 1837 def is_ldap(self) -> bool: 1838 return self == self.LDAP 1839 1840 @property 1841 def is_kerberos(self) -> bool: 1842 return self == self.KERBEROS 1843 1844 @property 1845 def is_jwt(self) -> bool: 1846 return self == self.JWT 1847 1848 @property 1849 def is_certificate(self) -> bool: 1850 return self == self.CERTIFICATE 1851 1852 @property 1853 def is_oauth(self) -> bool: 1854 return self == self.OAUTH 1855 1856 1857class TrinoConnectionConfig(ConnectionConfig): 1858 method: TrinoAuthenticationMethod = TrinoAuthenticationMethod.NO_AUTH 1859 host: str 1860 user: str 1861 catalog: str 1862 port: t.Optional[int] = None 1863 http_scheme: t.Literal["http", "https"] = "https" 1864 # General Optional 1865 roles: t.Optional[t.Dict[str, str]] = None 1866 http_headers: t.Optional[t.Dict[str, str]] = None 1867 session_properties: t.Optional[t.Dict[str, str]] = None 1868 retries: int = 3 1869 timezone: t.Optional[str] = None 1870 # Basic/LDAP 1871 password: t.Optional[str] = None 1872 verify: t.Optional[bool] = None # disable SSL verification (ignored if `cert` is provided) 1873 # LDAP 1874 impersonation_user: t.Optional[str] = None 1875 # Kerberos 1876 keytab: t.Optional[str] = None 1877 krb5_config: t.Optional[str] = None 1878 principal: t.Optional[str] = None 1879 service_name: str = "trino" 1880 hostname_override: t.Optional[str] = None 1881 mutual_authentication: bool = False 1882 force_preemptive: bool = False 1883 sanitize_mutual_error_response: bool = True 1884 delegate: bool = False 1885 # JWT 1886 jwt_token: t.Optional[str] = None 1887 # Certificate 1888 client_certificate: t.Optional[str] = None 1889 client_private_key: t.Optional[str] = None 1890 cert: t.Optional[str] = None 1891 source: str = "sqlmesh" 1892 1893 # SQLMesh options 1894 schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None 1895 timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None 1896 concurrent_tasks: int = 4 1897 register_comments: bool = True 1898 pre_ping: t.Literal[False] = False 1899 1900 type_: t.Literal["trino"] = Field(alias="type", default="trino") 1901 DIALECT: t.ClassVar[t.Literal["trino"]] = "trino" 1902 DISPLAY_NAME: t.ClassVar[t.Literal["Trino"]] = "Trino" 1903 DISPLAY_ORDER: t.ClassVar[t.Literal[9]] = 9 1904 1905 _engine_import_validator = _get_engine_import_validator("trino", "trino") 1906 1907 @field_validator("schema_location_mapping", mode="before") 1908 @classmethod 1909 def _validate_regex_keys( 1910 cls, value: t.Dict[str | re.Pattern, str] 1911 ) -> t.Dict[re.Pattern, t.Any]: 1912 compiled = compile_regex_mapping(value) 1913 for replacement in compiled.values(): 1914 if "@{schema_name}" not in replacement: 1915 raise ConfigError( 1916 "schema_location_mapping needs to include the '@{schema_name}' placeholder in the value so SQLMesh knows where to substitute the schema name" 1917 ) 1918 return compiled 1919 1920 @field_validator("timestamp_mapping", mode="before") 1921 @classmethod 1922 def _validate_timestamp_mapping( 1923 cls, value: t.Optional[dict[str, str]] 1924 ) -> t.Optional[dict[exp.DataType, exp.DataType]]: 1925 if value is None: 1926 return value 1927 1928 result: dict[exp.DataType, exp.DataType] = {} 1929 for source_type, target_type in value.items(): 1930 try: 1931 source_datatype = exp.DataType.build(source_type) 1932 except ParseError: 1933 raise ConfigError( 1934 f"Invalid SQL type string in timestamp_mapping: " 1935 f"'{source_type}' is not a valid SQL data type." 1936 ) 1937 try: 1938 target_datatype = exp.DataType.build(target_type) 1939 except ParseError: 1940 raise ConfigError( 1941 f"Invalid SQL type string in timestamp_mapping: " 1942 f"'{target_type}' is not a valid SQL data type." 1943 ) 1944 result[source_datatype] = target_datatype 1945 1946 return result 1947 1948 @model_validator(mode="after") 1949 def _root_validator(self) -> Self: 1950 port = self.port 1951 if self.http_scheme == "http" and not self.method.is_no_auth and not self.method.is_basic: 1952 raise ConfigError("HTTP scheme can only be used with no-auth or basic method") 1953 1954 if port is None: 1955 self.port = 80 if self.http_scheme == "http" else 443 1956 1957 if (self.method.is_ldap or self.method.is_basic) and (not self.password or not self.user): 1958 raise ConfigError( 1959 f"Username and Password must be provided if using {self.method.value} authentication" 1960 ) 1961 1962 if self.method.is_kerberos and ( 1963 not self.principal or not self.keytab or not self.krb5_config 1964 ): 1965 raise ConfigError( 1966 "Kerberos requires the following fields: principal, keytab, and krb5_config" 1967 ) 1968 1969 if self.method.is_jwt and not self.jwt_token: 1970 raise ConfigError("JWT requires `jwt_token` to be set") 1971 1972 if self.method.is_certificate and ( 1973 not self.cert or not self.client_certificate or not self.client_private_key 1974 ): 1975 raise ConfigError( 1976 "Certificate requires the following fields: cert, client_certificate, and client_private_key" 1977 ) 1978 1979 return self 1980 1981 @property 1982 def _connection_kwargs_keys(self) -> t.Set[str]: 1983 kwargs = { 1984 "host", 1985 "port", 1986 "catalog", 1987 "roles", 1988 "source", 1989 "http_scheme", 1990 "http_headers", 1991 "session_properties", 1992 "timezone", 1993 } 1994 return kwargs 1995 1996 @property 1997 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1998 return engine_adapter.TrinoEngineAdapter 1999 2000 @property 2001 def _connection_factory(self) -> t.Callable: 2002 from trino.dbapi import connect 2003 2004 return connect 2005 2006 @property 2007 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 2008 from trino.auth import ( 2009 BasicAuthentication, 2010 CertificateAuthentication, 2011 JWTAuthentication, 2012 KerberosAuthentication, 2013 OAuth2Authentication, 2014 ) 2015 2016 if self.method.is_basic or self.method.is_ldap: 2017 auth = BasicAuthentication(self.user, self.password) 2018 elif self.method.is_kerberos: 2019 if self.keytab: 2020 os.environ["KRB5_CLIENT_KTNAME"] = self.keytab 2021 auth = KerberosAuthentication( 2022 config=self.krb5_config, 2023 service_name=self.service_name, 2024 principal=self.principal, 2025 mutual_authentication=self.mutual_authentication, 2026 ca_bundle=self.cert, 2027 force_preemptive=self.force_preemptive, 2028 hostname_override=self.hostname_override, 2029 sanitize_mutual_error_response=self.sanitize_mutual_error_response, 2030 delegate=self.delegate, 2031 ) 2032 elif self.method.is_oauth: 2033 auth = OAuth2Authentication() 2034 elif self.method.is_jwt: 2035 auth = JWTAuthentication(self.jwt_token) 2036 elif self.method.is_certificate: 2037 auth = CertificateAuthentication(self.client_certificate, self.client_private_key) 2038 else: 2039 auth = None 2040 2041 return { 2042 "auth": auth, 2043 "user": self.impersonation_user or self.user, 2044 "max_attempts": self.retries, 2045 "verify": self.cert if self.cert is not None else self.verify, 2046 "source": self.source, 2047 } 2048 2049 @property 2050 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2051 return { 2052 "schema_location_mapping": self.schema_location_mapping, 2053 "timestamp_mapping": self.timestamp_mapping, 2054 } 2055 2056 2057class ClickhouseConnectionConfig(ConnectionConfig): 2058 """ 2059 Clickhouse Connection Configuration. 2060 2061 Property reference: https://clickhouse.com/docs/en/integrations/python#client-initialization 2062 """ 2063 2064 host: str 2065 username: str 2066 password: t.Optional[str] = None 2067 port: t.Optional[int] = None 2068 cluster: t.Optional[str] = None 2069 connect_timeout: int = 10 2070 send_receive_timeout: int = 300 2071 query_limit: int = 0 2072 use_compression: bool = True 2073 compression_method: t.Optional[str] = None 2074 connection_settings: t.Optional[t.Dict[str, t.Any]] = None 2075 http_proxy: t.Optional[str] = None 2076 # HTTPS/TLS settings 2077 verify: bool = True 2078 ca_cert: t.Optional[str] = None 2079 client_cert: t.Optional[str] = None 2080 client_cert_key: t.Optional[str] = None 2081 https_proxy: t.Optional[str] = None 2082 server_host_name: t.Optional[str] = None 2083 tls_mode: t.Optional[str] = None 2084 2085 concurrent_tasks: int = 1 2086 register_comments: bool = True 2087 pre_ping: bool = False 2088 2089 # This object expects options from urllib3 and also from clickhouse-connect 2090 # See: 2091 # * https://urllib3.readthedocs.io/en/stable/advanced-usage.html 2092 # * https://clickhouse.com/docs/en/integrations/python#customizing-the-http-connection-pool 2093 connection_pool_options: t.Optional[t.Dict[str, t.Any]] = None 2094 2095 type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse") 2096 DIALECT: t.ClassVar[t.Literal["clickhouse"]] = "clickhouse" 2097 DISPLAY_NAME: t.ClassVar[t.Literal["ClickHouse"]] = "ClickHouse" 2098 DISPLAY_ORDER: t.ClassVar[t.Literal[6]] = 6 2099 2100 _engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse") 2101 2102 @property 2103 def _connection_kwargs_keys(self) -> t.Set[str]: 2104 kwargs = { 2105 "host", 2106 "username", 2107 "port", 2108 "password", 2109 "connect_timeout", 2110 "send_receive_timeout", 2111 "query_limit", 2112 "http_proxy", 2113 "verify", 2114 "ca_cert", 2115 "client_cert", 2116 "client_cert_key", 2117 "https_proxy", 2118 "server_host_name", 2119 "tls_mode", 2120 } 2121 return kwargs 2122 2123 @property 2124 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2125 return engine_adapter.ClickhouseEngineAdapter 2126 2127 @property 2128 def _connection_factory(self) -> t.Callable: 2129 from clickhouse_connect.dbapi import connect # type: ignore 2130 from clickhouse_connect.driver import httputil # type: ignore 2131 from functools import partial 2132 2133 pool_manager_options: t.Dict[str, t.Any] = dict( 2134 # Match the maxsize to the number of concurrent tasks 2135 maxsize=self.concurrent_tasks, 2136 # Block if there are no free connections 2137 block=True, 2138 verify=self.verify, 2139 ca_cert=self.ca_cert, 2140 client_cert=self.client_cert, 2141 client_cert_key=self.client_cert_key, 2142 https_proxy=self.https_proxy, 2143 ) 2144 # this doesn't happen automatically because we always supply our own pool manager to the connection 2145 # https://github.com/ClickHouse/clickhouse-connect/blob/3a7f4b04cad29c7c2536661b831fb744248e2ec0/clickhouse_connect/driver/httpclient.py#L109 2146 if self.server_host_name: 2147 pool_manager_options["server_hostname"] = self.server_host_name 2148 if self.verify: 2149 pool_manager_options["assert_hostname"] = self.server_host_name 2150 if self.connection_pool_options: 2151 pool_manager_options.update(self.connection_pool_options) 2152 pool_mgr = httputil.get_pool_manager(**pool_manager_options) 2153 2154 return partial(connect, pool_mgr=pool_mgr) 2155 2156 @property 2157 def cloud_mode(self) -> bool: 2158 return "clickhouse.cloud" in self.host 2159 2160 @property 2161 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2162 return {"cluster": self.cluster, "cloud_mode": self.cloud_mode} 2163 2164 @property 2165 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 2166 from sqlmesh import __version__ 2167 2168 # False = no compression 2169 # True = Clickhouse default compression method 2170 # string = specific compression method 2171 compress: bool | str = self.use_compression 2172 if compress and self.compression_method: 2173 compress = self.compression_method 2174 2175 # Clickhouse system settings passed to connection 2176 # https://clickhouse.com/docs/en/operations/settings/settings 2177 # - below are set to align with dbt-clickhouse 2178 # - https://github.com/ClickHouse/dbt-clickhouse/blob/44d26308ea6a3c8ead25c280164aa88191f05f47/dbt/adapters/clickhouse/dbclient.py#L77 2179 settings = self.connection_settings or {} 2180 # mutations_sync = 2: "The query waits for all mutations [ALTER statements] to complete on all replicas (if they exist)" 2181 settings["mutations_sync"] = "2" 2182 # insert_distributed_sync = 1: "INSERT operation succeeds only after all the data is saved on all shards" 2183 settings["insert_distributed_sync"] = "1" 2184 if self.cluster or self.cloud_mode: 2185 # database_replicated_enforce_synchronous_settings = 1: 2186 # - "Enforces synchronous waiting for some queries" 2187 # - https://github.com/ClickHouse/ClickHouse/blob/ccaa8d03a9351efc16625340268b9caffa8a22ba/src/Core/Settings.h#L709 2188 settings["database_replicated_enforce_synchronous_settings"] = "1" 2189 # insert_quorum = auto: 2190 # - "INSERT succeeds only when ClickHouse manages to correctly write data to the insert_quorum of replicas during 2191 # the insert_quorum_timeout" 2192 # - "use majority number (number_of_replicas / 2 + 1) as quorum number" 2193 settings["insert_quorum"] = "auto" 2194 2195 return { 2196 "compress": compress, 2197 "client_name": f"SQLMesh/{__version__}", 2198 **settings, 2199 } 2200 2201 2202class AthenaConnectionConfig(ConnectionConfig): 2203 # PyAthena connection options 2204 aws_access_key_id: t.Optional[str] = None 2205 aws_secret_access_key: t.Optional[str] = None 2206 role_arn: t.Optional[str] = None 2207 role_session_name: t.Optional[str] = None 2208 region_name: t.Optional[str] = None 2209 work_group: t.Optional[str] = None 2210 s3_staging_dir: t.Optional[str] = None 2211 schema_name: t.Optional[str] = None 2212 catalog_name: t.Optional[str] = None 2213 2214 # SQLMesh options 2215 s3_warehouse_location: t.Optional[str] = None 2216 concurrent_tasks: int = 4 2217 register_comments: t.Literal[False] = ( 2218 False # because Athena doesnt support comments in most cases 2219 ) 2220 pre_ping: t.Literal[False] = False 2221 2222 type_: t.Literal["athena"] = Field(alias="type", default="athena") 2223 DIALECT: t.ClassVar[t.Literal["athena"]] = "athena" 2224 DISPLAY_NAME: t.ClassVar[t.Literal["Athena"]] = "Athena" 2225 DISPLAY_ORDER: t.ClassVar[t.Literal[15]] = 15 2226 2227 _engine_import_validator = _get_engine_import_validator("pyathena", "athena") 2228 2229 @model_validator(mode="after") 2230 def _root_validator(self) -> Self: 2231 work_group = self.work_group 2232 s3_staging_dir = self.s3_staging_dir 2233 s3_warehouse_location = self.s3_warehouse_location 2234 2235 if not work_group and not s3_staging_dir: 2236 raise ConfigError("At least one of work_group or s3_staging_dir must be set") 2237 2238 if s3_staging_dir: 2239 self.s3_staging_dir = validate_s3_uri(s3_staging_dir, base=True, error_type=ConfigError) 2240 2241 if s3_warehouse_location: 2242 self.s3_warehouse_location = validate_s3_uri( 2243 s3_warehouse_location, base=True, error_type=ConfigError 2244 ) 2245 2246 return self 2247 2248 @property 2249 def _connection_kwargs_keys(self) -> t.Set[str]: 2250 return { 2251 "aws_access_key_id", 2252 "aws_secret_access_key", 2253 "role_arn", 2254 "role_session_name", 2255 "region_name", 2256 "work_group", 2257 "s3_staging_dir", 2258 "schema_name", 2259 "catalog_name", 2260 } 2261 2262 @property 2263 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2264 return engine_adapter.AthenaEngineAdapter 2265 2266 @property 2267 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2268 return {"s3_warehouse_location": self.s3_warehouse_location} 2269 2270 @property 2271 def _connection_factory(self) -> t.Callable: 2272 from pyathena import connect # type: ignore 2273 2274 return connect 2275 2276 def get_catalog(self) -> t.Optional[str]: 2277 return self.catalog_name 2278 2279 2280class RisingwaveConnectionConfig(ConnectionConfig): 2281 host: str 2282 user: str 2283 password: t.Optional[str] = None 2284 port: int 2285 database: str 2286 role: t.Optional[str] = None 2287 sslmode: t.Optional[str] = None 2288 2289 concurrent_tasks: int = 4 2290 register_comments: bool = True 2291 pre_ping: bool = True 2292 2293 type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave") 2294 DIALECT: t.ClassVar[t.Literal["risingwave"]] = "risingwave" 2295 DISPLAY_NAME: t.ClassVar[t.Literal["RisingWave"]] = "RisingWave" 2296 DISPLAY_ORDER: t.ClassVar[t.Literal[16]] = 16 2297 2298 _engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave") 2299 2300 @property 2301 def _connection_kwargs_keys(self) -> t.Set[str]: 2302 return { 2303 "host", 2304 "user", 2305 "password", 2306 "port", 2307 "database", 2308 "role", 2309 "sslmode", 2310 } 2311 2312 @property 2313 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2314 return engine_adapter.RisingwaveEngineAdapter 2315 2316 @property 2317 def _connection_factory(self) -> t.Callable: 2318 from psycopg2 import connect 2319 2320 return connect 2321 2322 @property 2323 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 2324 def init(cursor: t.Any) -> None: 2325 sql = "SET RW_IMPLICIT_FLUSH TO true;" 2326 cursor.execute(sql) 2327 2328 return init 2329 2330 2331CONNECTION_CONFIG_TO_TYPE = { 2332 # Map all subclasses of ConnectionConfig to the value of their `type_` field. 2333 tpe.all_field_infos()["type_"].default: tpe 2334 for tpe in subclasses( 2335 __name__, 2336 ConnectionConfig, 2337 exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, 2338 ) 2339} 2340 2341DIALECT_TO_TYPE = { 2342 tpe.all_field_infos()["type_"].default: tpe.DIALECT 2343 for tpe in subclasses( 2344 __name__, 2345 ConnectionConfig, 2346 exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, 2347 ) 2348} 2349 2350INIT_DISPLAY_INFO_TO_TYPE = { 2351 tpe.all_field_infos()["type_"].default: ( 2352 tpe.DISPLAY_ORDER, 2353 tpe.DISPLAY_NAME, 2354 ) 2355 for tpe in subclasses( 2356 __name__, 2357 ConnectionConfig, 2358 exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, 2359 ) 2360} 2361 2362 2363def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig: 2364 if "type" not in v: 2365 raise ConfigError("Missing connection type.") 2366 2367 connection_type = v["type"] 2368 if connection_type not in CONNECTION_CONFIG_TO_TYPE: 2369 raise ConfigError(f"Unknown connection type '{connection_type}'.") 2370 2371 return CONNECTION_CONFIG_TO_TYPE[connection_type](**v) 2372 2373 2374def _connection_config_validator( 2375 cls: t.Type, v: ConnectionConfig | t.Dict[str, t.Any] | None 2376) -> ConnectionConfig | None: 2377 if v is None or isinstance(v, ConnectionConfig): 2378 return v 2379 2380 check_config_and_vars_msg = "\n\nVerify your config.yaml and environment variables." 2381 2382 try: 2383 return parse_connection_config(v) 2384 except pydantic.ValidationError as e: 2385 raise ConfigError( 2386 validation_error_message(e, f"Invalid '{v['type']}' connection config:") 2387 + check_config_and_vars_msg 2388 ) 2389 except ConfigError as e: 2390 raise ConfigError(str(e) + check_config_and_vars_msg) 2391 2392 2393connection_config_validator: t.Callable = field_validator( 2394 "connection", 2395 "state_connection", 2396 "test_connection", 2397 "default_connection", 2398 "default_test_connection", 2399 mode="before", 2400 check_fields=False, 2401)(_connection_config_validator) 2402 2403 2404if t.TYPE_CHECKING: 2405 # TypeAlias hasn't been introduced until Python 3.10 which means that we can't use it 2406 # outside the TYPE_CHECKING guard. 2407 SerializableConnectionConfig: t.TypeAlias = ConnectionConfig # type: ignore 2408else: 2409 import pydantic 2410 2411 # Workaround for https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing 2412 SerializableConnectionConfig = pydantic.SerializeAsAny[ConnectionConfig] # type: ignore
97class ConnectionConfig(abc.ABC, BaseConfig): 98 type_: str 99 DIALECT: t.ClassVar[str] 100 DISPLAY_NAME: t.ClassVar[str] 101 DISPLAY_ORDER: t.ClassVar[int] 102 concurrent_tasks: int 103 register_comments: bool 104 pre_ping: bool 105 pretty_sql: bool = False 106 schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None 107 catalog_type_overrides: t.Optional[t.Dict[str, str]] = None 108 109 # Whether to share a single connection across threads or create a new connection per thread. 110 shared_connection: t.ClassVar[bool] = False 111 112 @property 113 @abc.abstractmethod 114 def _connection_kwargs_keys(self) -> t.Set[str]: 115 """keywords that should be passed into the connection""" 116 117 @property 118 @abc.abstractmethod 119 def _engine_adapter(self) -> t.Type[EngineAdapter]: 120 """The engine adapter for this connection""" 121 122 @property 123 @abc.abstractmethod 124 def _connection_factory(self) -> t.Callable: 125 """A function that is called to return a connection object for the given Engine Adapter""" 126 127 @property 128 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 129 """The static connection kwargs for this connection""" 130 return {} 131 132 @property 133 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 134 """kwargs that are for execution config only""" 135 return {} 136 137 @property 138 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 139 """A function that is called to initialize the cursor""" 140 return None 141 142 @property 143 def is_recommended_for_state_sync(self) -> bool: 144 """Whether this engine is recommended for being used as a state sync for production state syncs""" 145 return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES 146 147 @property 148 def is_forbidden_for_state_sync(self) -> bool: 149 """Whether this engine is forbidden from being used as a state sync""" 150 return self.type_ in FORBIDDEN_STATE_SYNC_ENGINES 151 152 @property 153 def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]: 154 """A function that is called to return a connection object for the given Engine Adapter""" 155 return partial( 156 self._connection_factory, 157 **{ 158 **self._static_connection_kwargs, 159 **{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys}, 160 }, 161 ) 162 163 def connection_validator(self) -> t.Callable[[], None]: 164 """A function that validates the connection configuration""" 165 return self.create_engine_adapter().ping 166 167 def create_engine_adapter( 168 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 169 ) -> EngineAdapter: 170 """Returns a new instance of the Engine Adapter.""" 171 172 concurrent_tasks = concurrent_tasks or self.concurrent_tasks 173 return self._engine_adapter( 174 self._connection_factory_with_kwargs, 175 multithreaded=concurrent_tasks > 1, 176 default_catalog=self.get_catalog(), 177 cursor_init=self._cursor_init, 178 register_comments=register_comments_override or self.register_comments, 179 pre_ping=self.pre_ping, 180 pretty_sql=self.pretty_sql, 181 shared_connection=self.shared_connection, 182 schema_differ_overrides=self.schema_differ_overrides, 183 catalog_type_overrides=self.catalog_type_overrides, 184 **self._extra_engine_config, 185 ) 186 187 def get_catalog(self) -> t.Optional[str]: 188 """The catalog for this connection""" 189 if hasattr(self, "catalog"): 190 return self.catalog 191 if hasattr(self, "database"): 192 return self.database 193 if hasattr(self, "db"): 194 return self.db 195 return None 196 197 @model_validator(mode="before") 198 @classmethod 199 def _expand_json_strings_to_concrete_types(cls, data: t.Any) -> t.Any: 200 """ 201 There are situations where a connection config class has a field that is some kind of complex type 202 (eg a list of strings or a dict) but the value is being supplied from a source such as an environment variable 203 204 When this happens, the value is supplied as a string rather than a Python object. We need some way 205 of turning this string into the corresponding Python list or dict. 206 207 Rather than doing this piecemeal on every config subclass, this provides a generic implementatation 208 to identify fields that may be be supplied as JSON strings and handle them transparently 209 """ 210 if data and isinstance(data, dict): 211 for maybe_json_field_name in cls._get_list_and_dict_field_names(): 212 if (value := data.get(maybe_json_field_name)) and isinstance(value, str): 213 # crude JSON check as we dont want to try and parse every string we get 214 value = value.strip() 215 if value.startswith("{") or value.startswith("["): 216 data[maybe_json_field_name] = from_json(value) 217 218 return data 219 220 @classmethod 221 def _get_list_and_dict_field_names(cls) -> t.Set[str]: 222 field_names = set() 223 for name, field in cls.model_fields.items(): 224 if field.annotation: 225 field_types = get_concrete_types_from_typehint(field.annotation) 226 227 # check if the field type is something that could concievably be supplied as a json string 228 if any(ft is t for t in (list, tuple, set, dict) for ft in field_types): 229 field_names.add(name) 230 231 return field_names
Helper class that provides a standard way to create an ABC using inheritance.
142 @property 143 def is_recommended_for_state_sync(self) -> bool: 144 """Whether this engine is recommended for being used as a state sync for production state syncs""" 145 return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES
Whether this engine is recommended for being used as a state sync for production state syncs
147 @property 148 def is_forbidden_for_state_sync(self) -> bool: 149 """Whether this engine is forbidden from being used as a state sync""" 150 return self.type_ in FORBIDDEN_STATE_SYNC_ENGINES
Whether this engine is forbidden from being used as a state sync
163 def connection_validator(self) -> t.Callable[[], None]: 164 """A function that validates the connection configuration""" 165 return self.create_engine_adapter().ping
A function that validates the connection configuration
167 def create_engine_adapter( 168 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 169 ) -> EngineAdapter: 170 """Returns a new instance of the Engine Adapter.""" 171 172 concurrent_tasks = concurrent_tasks or self.concurrent_tasks 173 return self._engine_adapter( 174 self._connection_factory_with_kwargs, 175 multithreaded=concurrent_tasks > 1, 176 default_catalog=self.get_catalog(), 177 cursor_init=self._cursor_init, 178 register_comments=register_comments_override or self.register_comments, 179 pre_ping=self.pre_ping, 180 pretty_sql=self.pretty_sql, 181 shared_connection=self.shared_connection, 182 schema_differ_overrides=self.schema_differ_overrides, 183 catalog_type_overrides=self.catalog_type_overrides, 184 **self._extra_engine_config, 185 )
Returns a new instance of the Engine Adapter.
187 def get_catalog(self) -> t.Optional[str]: 188 """The catalog for this connection""" 189 if hasattr(self, "catalog"): 190 return self.catalog 191 if hasattr(self, "database"): 192 return self.database 193 if hasattr(self, "db"): 194 return self.db 195 return None
The catalog for this connection
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
234class DuckDBAttachOptions(BaseConfig): 235 type: str 236 path: str 237 read_only: bool = False 238 239 # DuckLake specific options 240 data_path: t.Optional[str] = None 241 encrypted: bool = False 242 data_inlining_row_limit: t.Optional[int] = None 243 metadata_schema: t.Optional[str] = None 244 245 def to_sql(self, alias: str) -> str: 246 options = [] 247 # 'duckdb' is actually not a supported type, but we'd like to allow it for 248 # fully qualified attach options or integration testing, similar to duckdb-dbt 249 if self.type not in ("duckdb", "ducklake", "motherduck"): 250 options.append(f"TYPE {self.type.upper()}") 251 if self.read_only: 252 options.append("READ_ONLY") 253 254 # DuckLake specific options 255 path = self.path 256 if self.type == "ducklake": 257 if not path.startswith("ducklake:"): 258 path = f"ducklake:{path}" 259 if self.data_path is not None: 260 options.append(f"DATA_PATH '{self.data_path}'") 261 if self.encrypted: 262 options.append("ENCRYPTED") 263 if self.data_inlining_row_limit is not None: 264 options.append(f"DATA_INLINING_ROW_LIMIT {self.data_inlining_row_limit}") 265 if self.metadata_schema is not None: 266 options.append(f"METADATA_SCHEMA '{self.metadata_schema}'") 267 268 options_sql = f" ({', '.join(options)})" if options else "" 269 alias_sql = "" 270 # TODO: Add support for Postgres schema. Currently adding it blocks access to the information_schema 271 272 # MotherDuck does not support aliasing 273 alias_sql = ( 274 f" AS {alias}" if not (self.type == "motherduck" or self.path.startswith("md:")) else "" 275 ) 276 return f"ATTACH IF NOT EXISTS '{path}'{alias_sql}{options_sql}"
Base configuration functionality for configuration classes.
245 def to_sql(self, alias: str) -> str: 246 options = [] 247 # 'duckdb' is actually not a supported type, but we'd like to allow it for 248 # fully qualified attach options or integration testing, similar to duckdb-dbt 249 if self.type not in ("duckdb", "ducklake", "motherduck"): 250 options.append(f"TYPE {self.type.upper()}") 251 if self.read_only: 252 options.append("READ_ONLY") 253 254 # DuckLake specific options 255 path = self.path 256 if self.type == "ducklake": 257 if not path.startswith("ducklake:"): 258 path = f"ducklake:{path}" 259 if self.data_path is not None: 260 options.append(f"DATA_PATH '{self.data_path}'") 261 if self.encrypted: 262 options.append("ENCRYPTED") 263 if self.data_inlining_row_limit is not None: 264 options.append(f"DATA_INLINING_ROW_LIMIT {self.data_inlining_row_limit}") 265 if self.metadata_schema is not None: 266 options.append(f"METADATA_SCHEMA '{self.metadata_schema}'") 267 268 options_sql = f" ({', '.join(options)})" if options else "" 269 alias_sql = "" 270 # TODO: Add support for Postgres schema. Currently adding it blocks access to the information_schema 271 272 # MotherDuck does not support aliasing 273 alias_sql = ( 274 f" AS {alias}" if not (self.type == "motherduck" or self.path.startswith("md:")) else "" 275 ) 276 return f"ATTACH IF NOT EXISTS '{path}'{alias_sql}{options_sql}"
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
279class BaseDuckDBConnectionConfig(ConnectionConfig): 280 """Common configuration for the DuckDB-based connections. 281 282 Args: 283 database: The optional database name. If not specified, the in-memory database will be used. 284 catalogs: Key is the name of the catalog and value is the path. 285 extensions: A list of autoloadable extensions to load. 286 connector_config: A dictionary of configuration to pass into the duckdb connector. 287 secrets: A list of dictionaries used to generate DuckDB secrets for authenticating with external services (e.g. S3). 288 filesystems: A list of dictionaries used to register `fsspec` filesystems to the DuckDB cursor. 289 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 290 register_comments: Whether or not to register model comments with the SQL engine. 291 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 292 token: The optional MotherDuck token. If not specified and a MotherDuck path is in the catalog, the user will be prompted to login with their web browser. 293 """ 294 295 database: t.Optional[str] = None 296 catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None 297 extensions: t.List[t.Union[str, t.Dict[str, t.Any]]] = [] 298 connector_config: t.Dict[str, t.Any] = {} 299 secrets: t.Union[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Dict[str, t.Any]]] = [] 300 filesystems: t.List[t.Dict[str, t.Any]] = [] 301 302 concurrent_tasks: int = 1 303 register_comments: bool = True 304 pre_ping: t.Literal[False] = False 305 306 token: t.Optional[str] = None 307 308 shared_connection: t.ClassVar[bool] = True 309 310 _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} 311 312 @model_validator(mode="before") 313 def _validate_database_catalogs(cls, data: t.Any) -> t.Any: 314 if not isinstance(data, dict): 315 return data 316 317 db_path = data.get("database") 318 if db_path and data.get("catalogs"): 319 raise ConfigError( 320 "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog" 321 ) 322 if isinstance(db_path, str) and db_path.startswith("md:"): 323 raise ConfigError( 324 "Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`." 325 ) 326 327 return data 328 329 @property 330 def _engine_adapter(self) -> t.Type[EngineAdapter]: 331 return engine_adapter.DuckDBEngineAdapter 332 333 @property 334 def _connection_kwargs_keys(self) -> t.Set[str]: 335 return {"database"} 336 337 @property 338 def _connection_factory(self) -> t.Callable: 339 import duckdb 340 341 return duckdb.connect 342 343 @property 344 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 345 """A function that is called to initialize the cursor""" 346 import duckdb 347 from duckdb import BinderException 348 349 def init(cursor: duckdb.DuckDBPyConnection) -> None: 350 for extension in self.extensions: 351 extension = extension if isinstance(extension, dict) else {"name": extension} 352 353 install_command = f"INSTALL {extension['name']}" 354 355 if extension.get("repository"): 356 install_command = f"{install_command} FROM {extension['repository']}" 357 358 if extension.get("force_install"): 359 install_command = f"FORCE {install_command}" 360 361 try: 362 cursor.execute(install_command) 363 cursor.execute(f"LOAD {extension['name']}") 364 except Exception as e: 365 raise ConfigError(f"Failed to load extension {extension['name']}: {e}") 366 367 if self.connector_config: 368 option_names = list(self.connector_config) 369 in_part = ",".join("?" for _ in range(len(option_names))) 370 371 cursor.execute( 372 f"SELECT name, value FROM duckdb_settings() WHERE name IN ({in_part})", 373 option_names, 374 ) 375 376 existing_values = {field: setting for field, setting in cursor.fetchall()} 377 378 # only set connector_config items if the values differ from what is already set 379 # trying to set options like 'temp_directory' even to the same value can throw errors like: 380 # Not implemented Error: Cannot switch temporary directory after the current one has been used 381 for field, setting in self.connector_config.items(): 382 if existing_values.get(field) != setting: 383 try: 384 cursor.execute(f"SET {field} = '{setting}'") 385 except Exception as e: 386 raise ConfigError( 387 f"Failed to set connector config {field} to {setting}: {e}" 388 ) 389 390 if self.secrets: 391 duckdb_version = duckdb.__version__ 392 if version.parse(duckdb_version) < version.parse("0.10.0"): 393 from sqlmesh.core.console import get_console 394 395 get_console().log_warning( 396 f"DuckDB version {duckdb_version} does not support secrets-based authentication (requires 0.10.0 or later).\n" 397 "To use secrets, please upgrade DuckDB. For older versions, configure legacy authentication via `connector_config`.\n" 398 "More info: https://duckdb.org/docs/stable/extensions/httpfs/s3api_legacy_authentication.html" 399 ) 400 else: 401 if isinstance(self.secrets, list): 402 secrets_items = [(secret_dict, "") for secret_dict in self.secrets] 403 else: 404 secrets_items = [ 405 (secret_dict, secret_name) 406 for secret_name, secret_dict in self.secrets.items() 407 ] 408 409 for secret_dict, secret_name in secrets_items: 410 secret_settings: t.List[str] = [] 411 for field, setting in secret_dict.items(): 412 secret_settings.append(f"{field} '{setting}'") 413 if secret_settings: 414 secret_clause = ", ".join(secret_settings) 415 try: 416 cursor.execute( 417 f"CREATE OR REPLACE SECRET {secret_name} ({secret_clause});" 418 ) 419 except Exception as e: 420 raise ConfigError(f"Failed to create secret: {e}") 421 422 if self.filesystems: 423 from fsspec import filesystem # type: ignore 424 425 for file_system in self.filesystems: 426 options = file_system.copy() 427 fs = options.pop("fs") 428 fs = filesystem(fs, **options) 429 cursor.register_filesystem(fs) 430 431 for i, (alias, path_options) in enumerate( 432 (getattr(self, "catalogs", None) or {}).items() 433 ): 434 # we parse_identifier and generate to ensure that `alias` has exactly one set of quotes 435 # regardless of whether it comes in quoted or not 436 alias = exp.parse_identifier(alias, dialect="duckdb").sql( 437 identify=True, dialect="duckdb" 438 ) 439 try: 440 if isinstance(path_options, DuckDBAttachOptions): 441 query = path_options.to_sql(alias) 442 else: 443 query = f"ATTACH IF NOT EXISTS '{path_options}'" 444 if not path_options.startswith("md:"): 445 query += f" AS {alias}" 446 cursor.execute(query) 447 except BinderException as e: 448 # If a user tries to create a catalog pointing at `:memory:` and with the name `memory` 449 # then we don't want to raise since this happens by default. They are just doing this to 450 # set it as the default catalog. 451 # If a user tried to attach a MotherDuck database/share which has already by attached via 452 # `ATTACH 'md:'`, then we don't want to raise since this is expected. 453 if ( 454 not ( 455 'database with name "memory" already exists' in str(e) 456 and path_options == ":memory:" 457 ) 458 and f"""database with name "{path_options.path.replace("md:", "")}" already exists""" 459 not in str(e) 460 ): 461 raise e 462 if i == 0 and not getattr(self, "database", None): 463 cursor.execute(f"USE {alias}") 464 465 return init 466 467 def create_engine_adapter( 468 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 469 ) -> EngineAdapter: 470 """Checks if another engine adapter has already been created that shares a catalog that points to the same data 471 file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration 472 associated with the new adapter will be ignored.""" 473 data_files = set((self.catalogs or {}).values()) 474 if self.database: 475 if isinstance(self, MotherDuckConnectionConfig): 476 data_files.add( 477 f"md:{self.database}" 478 + (f"?motherduck_token={self.token}" if self.token else "") 479 ) 480 else: 481 data_files.add(self.database) 482 data_files.discard(":memory:") 483 for data_file in data_files: 484 key = data_file if isinstance(data_file, str) else data_file.path 485 adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key) 486 if adapter is not None: 487 logger.info( 488 f"Using existing DuckDB adapter due to overlapping data file: {self._mask_sensitive_data(key)}" 489 ) 490 return adapter 491 492 if data_files: 493 masked_files = { 494 self._mask_sensitive_data(file if isinstance(file, str) else file.path) 495 for file in data_files 496 } 497 logger.info(f"Creating new DuckDB adapter for data files: {masked_files}") 498 else: 499 logger.info("Creating new DuckDB adapter for in-memory database") 500 adapter = super().create_engine_adapter( 501 register_comments_override, concurrent_tasks=concurrent_tasks 502 ) 503 for data_file in data_files: 504 key = data_file if isinstance(data_file, str) else data_file.path 505 BaseDuckDBConnectionConfig._data_file_to_adapter[key] = adapter 506 return adapter 507 508 def get_catalog(self) -> t.Optional[str]: 509 if self.database: 510 # Remove `:` from the database name in order to handle if `:memory:` is passed in 511 return pathlib.Path(self.database.replace(":memory:", "memory")).stem 512 if self.catalogs: 513 return list(self.catalogs)[0] 514 return None 515 516 def _mask_sensitive_data(self, string: str) -> str: 517 # Mask MotherDuck tokens with fixed number of asterisks 518 result = MOTHERDUCK_TOKEN_REGEX.sub( 519 lambda m: f"{m.group(1)}{m.group(2)}{'*' * 8 if m.group(3) else ''}", string 520 ) 521 # Mask PostgreSQL/MySQL passwords with fixed number of asterisks 522 result = PASSWORD_REGEX.sub(lambda m: f"{m.group(1)}{'*' * 8}", result) 523 return result
Common configuration for the DuckDB-based connections.
Arguments:
- database: The optional database name. If not specified, the in-memory database will be used.
- catalogs: Key is the name of the catalog and value is the path.
- extensions: A list of autoloadable extensions to load.
- connector_config: A dictionary of configuration to pass into the duckdb connector.
- secrets: A list of dictionaries used to generate DuckDB secrets for authenticating with external services (e.g. S3).
- filesystems: A list of dictionaries used to register
fsspecfilesystems to the DuckDB cursor. - concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
- register_comments: Whether or not to register model comments with the SQL engine.
- pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
- token: The optional MotherDuck token. If not specified and a MotherDuck path is in the catalog, the user will be prompted to login with their web browser.
467 def create_engine_adapter( 468 self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None 469 ) -> EngineAdapter: 470 """Checks if another engine adapter has already been created that shares a catalog that points to the same data 471 file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration 472 associated with the new adapter will be ignored.""" 473 data_files = set((self.catalogs or {}).values()) 474 if self.database: 475 if isinstance(self, MotherDuckConnectionConfig): 476 data_files.add( 477 f"md:{self.database}" 478 + (f"?motherduck_token={self.token}" if self.token else "") 479 ) 480 else: 481 data_files.add(self.database) 482 data_files.discard(":memory:") 483 for data_file in data_files: 484 key = data_file if isinstance(data_file, str) else data_file.path 485 adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key) 486 if adapter is not None: 487 logger.info( 488 f"Using existing DuckDB adapter due to overlapping data file: {self._mask_sensitive_data(key)}" 489 ) 490 return adapter 491 492 if data_files: 493 masked_files = { 494 self._mask_sensitive_data(file if isinstance(file, str) else file.path) 495 for file in data_files 496 } 497 logger.info(f"Creating new DuckDB adapter for data files: {masked_files}") 498 else: 499 logger.info("Creating new DuckDB adapter for in-memory database") 500 adapter = super().create_engine_adapter( 501 register_comments_override, concurrent_tasks=concurrent_tasks 502 ) 503 for data_file in data_files: 504 key = data_file if isinstance(data_file, str) else data_file.path 505 BaseDuckDBConnectionConfig._data_file_to_adapter[key] = adapter 506 return adapter
Checks if another engine adapter has already been created that shares a catalog that points to the same data file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration associated with the new adapter will be ignored.
508 def get_catalog(self) -> t.Optional[str]: 509 if self.database: 510 # Remove `:` from the database name in order to handle if `:memory:` is passed in 511 return pathlib.Path(self.database.replace(":memory:", "memory")).stem 512 if self.catalogs: 513 return list(self.catalogs)[0] 514 return None
The catalog for this connection
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
526class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig): 527 """Configuration for the MotherDuck connection.""" 528 529 type_: t.Literal["motherduck"] = Field(alias="type", default="motherduck") 530 DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb" 531 DISPLAY_NAME: t.ClassVar[t.Literal["MotherDuck"]] = "MotherDuck" 532 DISPLAY_ORDER: t.ClassVar[t.Literal[5]] = 5 533 534 @property 535 def _connection_kwargs_keys(self) -> t.Set[str]: 536 return set() 537 538 @property 539 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 540 """kwargs that are for execution config only""" 541 from sqlmesh import __version__ 542 543 custom_user_agent_config = {"custom_user_agent": f"SQLMesh/{__version__}"} 544 connection_str = "md:" 545 if self.database: 546 # Attach single MD database instead of all databases on the account 547 connection_str += f"{self.database}?attach_mode=single" 548 if self.token: 549 connection_str += f"{'&' if self.database else '?'}motherduck_token={self.token}" 550 return {"database": connection_str, "config": custom_user_agent_config} 551 552 @property 553 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 554 return {"is_motherduck": True}
Configuration for the MotherDuck connection.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
- BaseDuckDBConnectionConfig
- database
- catalogs
- extensions
- connector_config
- secrets
- filesystems
- concurrent_tasks
- register_comments
- pre_ping
- token
- create_engine_adapter
- get_catalog
557class DuckDBConnectionConfig(BaseDuckDBConnectionConfig): 558 """Configuration for the DuckDB connection.""" 559 560 type_: t.Literal["duckdb"] = Field(alias="type", default="duckdb") 561 DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb" 562 DISPLAY_NAME: t.ClassVar[t.Literal["DuckDB"]] = "DuckDB" 563 DISPLAY_ORDER: t.ClassVar[t.Literal[1]] = 1
Configuration for the DuckDB connection.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
- BaseDuckDBConnectionConfig
- database
- catalogs
- extensions
- connector_config
- secrets
- filesystems
- concurrent_tasks
- register_comments
- pre_ping
- token
- create_engine_adapter
- get_catalog
566class SnowflakeConnectionConfig(ConnectionConfig): 567 """Configuration for the Snowflake connection. 568 569 Args: 570 account: The Snowflake account name. 571 user: The Snowflake username. 572 password: The Snowflake password. 573 warehouse: The optional warehouse name. 574 database: The optional database name. 575 role: The optional role name. 576 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 577 authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake"). 578 Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183 579 token: The optional oauth access token to use for authentication when authenticator is set to "oauth". 580 private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation 581 private_key_path: The optional path to the private key to use for authentication. This would be used instead of `private_key`. 582 private_key_passphrase: The optional passphrase to use to decrypt `private_key` or `private_key_path`. Keys can be created without encryption so only provide this if needed. 583 register_comments: Whether or not to register model comments with the SQL engine. 584 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 585 session_parameters: The optional session parameters to set for the connection. 586 host: Host address for the connection. 587 port: Port for the connection. 588 """ 589 590 account: str 591 user: t.Optional[str] = None 592 password: t.Optional[str] = None 593 warehouse: t.Optional[str] = None 594 database: t.Optional[str] = None 595 role: t.Optional[str] = None 596 authenticator: t.Optional[str] = None 597 token: t.Optional[str] = None 598 host: t.Optional[str] = None 599 port: t.Optional[int] = None 600 application: t.Literal["Tobiko_SQLMesh"] = "Tobiko_SQLMesh" 601 602 # Private Key Auth 603 private_key: t.Optional[t.Union[str, bytes]] = None 604 private_key_path: t.Optional[str] = None 605 private_key_passphrase: t.Optional[str] = None 606 607 concurrent_tasks: int = 4 608 register_comments: bool = True 609 pre_ping: bool = False 610 611 session_parameters: t.Optional[dict] = None 612 613 type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake") 614 DIALECT: t.ClassVar[t.Literal["snowflake"]] = "snowflake" 615 DISPLAY_NAME: t.ClassVar[t.Literal["Snowflake"]] = "Snowflake" 616 DISPLAY_ORDER: t.ClassVar[t.Literal[2]] = 2 617 618 _concurrent_tasks_validator = concurrent_tasks_validator 619 620 @model_validator(mode="before") 621 def _validate_authenticator(cls, data: t.Any) -> t.Any: 622 if not isinstance(data, dict): 623 return data 624 625 from snowflake.connector.network import DEFAULT_AUTHENTICATOR, OAUTH_AUTHENTICATOR 626 627 auth = data.get("authenticator") 628 auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR 629 user = data.get("user") 630 password = data.get("password") 631 data["private_key"] = cls._get_private_key(data, auth) # type: ignore 632 633 if ( 634 auth == DEFAULT_AUTHENTICATOR 635 and not data.get("private_key") 636 and (not user or not password) 637 ): 638 raise ConfigError("User and password must be provided if using default authentication") 639 640 if auth == OAUTH_AUTHENTICATOR and not data.get("token"): 641 raise ConfigError("Token must be provided if using oauth authentication") 642 643 return data 644 645 _engine_import_validator = _get_engine_import_validator( 646 "snowflake.connector.network", "snowflake" 647 ) 648 649 @classmethod 650 def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]: 651 """ 652 source: https://github.com/dbt-labs/dbt-snowflake/blob/0374b4ec948982f2ac8ec0c95d53d672ad19e09c/dbt/adapters/snowflake/connections.py#L247C5-L285C1 653 654 Overall code change: Use local variables instead of class attributes + Validation 655 """ 656 # Start custom code 657 from cryptography.hazmat.backends import default_backend 658 from cryptography.hazmat.primitives import serialization 659 from snowflake.connector.network import ( 660 DEFAULT_AUTHENTICATOR, 661 KEY_PAIR_AUTHENTICATOR, 662 ) 663 664 private_key = values.get("private_key") 665 private_key_path = values.get("private_key_path") 666 private_key_passphrase = values.get("private_key_passphrase") 667 user = values.get("user") 668 password = values.get("password") 669 auth = auth if auth and auth != DEFAULT_AUTHENTICATOR else KEY_PAIR_AUTHENTICATOR 670 671 if not private_key and not private_key_path: 672 return None 673 if private_key and private_key_path: 674 raise ConfigError("Cannot specify both `private_key` and `private_key_path`") 675 if auth != KEY_PAIR_AUTHENTICATOR: 676 raise ConfigError( 677 f"Private key or private key path can only be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 678 ) 679 if not user: 680 raise ConfigError( 681 f"User must be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 682 ) 683 if password: 684 raise ConfigError( 685 f"Password cannot be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 686 ) 687 688 if isinstance(private_key, bytes): 689 return private_key 690 # End Custom Code 691 692 if private_key_passphrase: 693 encoded_passphrase = private_key_passphrase.encode() 694 else: 695 encoded_passphrase = None 696 697 if private_key: 698 if private_key.startswith("-"): 699 p_key = serialization.load_pem_private_key( 700 data=bytes(private_key, "utf-8"), 701 password=encoded_passphrase, 702 backend=default_backend(), 703 ) 704 705 else: 706 p_key = serialization.load_der_private_key( 707 data=base64.b64decode(private_key), 708 password=encoded_passphrase, 709 backend=default_backend(), 710 ) 711 712 elif private_key_path: 713 with open(private_key_path, "rb") as key: 714 p_key = serialization.load_pem_private_key( 715 key.read(), password=encoded_passphrase, backend=default_backend() 716 ) 717 else: 718 return None 719 720 return p_key.private_bytes( 721 encoding=serialization.Encoding.DER, 722 format=serialization.PrivateFormat.PKCS8, 723 encryption_algorithm=serialization.NoEncryption(), 724 ) 725 726 @property 727 def _connection_kwargs_keys(self) -> t.Set[str]: 728 return { 729 "user", 730 "password", 731 "account", 732 "warehouse", 733 "database", 734 "role", 735 "authenticator", 736 "token", 737 "private_key", 738 "session_parameters", 739 "application", 740 "host", 741 "port", 742 } 743 744 @property 745 def _engine_adapter(self) -> t.Type[EngineAdapter]: 746 return engine_adapter.SnowflakeEngineAdapter 747 748 @property 749 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 750 return {"autocommit": False} 751 752 @property 753 def _connection_factory(self) -> t.Callable: 754 from snowflake import connector 755 756 return connector.connect
Configuration for the Snowflake connection.
Arguments:
- account: The Snowflake account name.
- user: The Snowflake username.
- password: The Snowflake password.
- warehouse: The optional warehouse name.
- database: The optional database name.
- role: The optional role name.
- concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
- authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake"). Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183
- token: The optional oauth access token to use for authentication when authenticator is set to "oauth".
- private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation
- private_key_path: The optional path to the private key to use for authentication. This would be used instead of
private_key. - private_key_passphrase: The optional passphrase to use to decrypt
private_keyorprivate_key_path. Keys can be created without encryption so only provide this if needed. - register_comments: Whether or not to register model comments with the SQL engine.
- pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
- session_parameters: The optional session parameters to set for the connection.
- host: Host address for the connection.
- port: Port for the connection.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
759class DatabricksConnectionConfig(ConnectionConfig): 760 """ 761 Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations 762 763 Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39 764 OAuth ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication 765 766 Args: 767 server_hostname: Databricks instance host name. 768 http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) 769 or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) 770 access_token: Http Bearer access token, e.g. Databricks Personal Access Token. 771 auth_type: Set to 'databricks-oauth' or 'azure-oauth' to trigger OAuth (or dont set at all to use `access_token`) 772 oauth_client_id: Client ID to use when auth_type is set to one of the 'oauth' types 773 oauth_client_secret: Client Secret to use when auth_type is set to one of the 'oauth' types 774 catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in 775 the Databricks cluster (most likely `hive_metastore`). 776 http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request 777 session_configuration: An optional dictionary of Spark session parameters. 778 Execute the SQL command `SET -v` to get a full list of available commands. 779 databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect. 780 Defaults to the `server_hostname` value. 781 databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect. 782 Defaults to the `access_token` value. 783 databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect. 784 Defaults to deriving the cluster id from the `http_path` value. 785 force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector. 786 disable_databricks_connect: Even if databricks connect is installed, do not use it. 787 disable_spark_session: Do not use SparkSession if it is available (like when running in a notebook). 788 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 789 """ 790 791 server_hostname: t.Optional[str] = None 792 http_path: t.Optional[str] = None 793 access_token: t.Optional[str] = None 794 auth_type: t.Optional[str] = None 795 oauth_client_id: t.Optional[str] = None 796 oauth_client_secret: t.Optional[str] = None 797 catalog: t.Optional[str] = None 798 http_headers: t.Optional[t.List[t.Tuple[str, str]]] = None 799 session_configuration: t.Optional[t.Dict[str, t.Any]] = None 800 databricks_connect_server_hostname: t.Optional[str] = None 801 databricks_connect_access_token: t.Optional[str] = None 802 databricks_connect_cluster_id: t.Optional[str] = None 803 databricks_connect_use_serverless: bool = False 804 force_databricks_connect: bool = False 805 disable_databricks_connect: bool = False 806 disable_spark_session: bool = False 807 808 concurrent_tasks: int = 1 809 register_comments: bool = True 810 pre_ping: t.Literal[False] = False 811 812 type_: t.Literal["databricks"] = Field(alias="type", default="databricks") 813 DIALECT: t.ClassVar[t.Literal["databricks"]] = "databricks" 814 DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks" 815 DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3 816 817 _concurrent_tasks_validator = concurrent_tasks_validator 818 _http_headers_validator = http_headers_validator 819 820 @model_validator(mode="before") 821 def _databricks_connect_validator(cls, data: t.Any) -> t.Any: 822 # SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block. 823 # Disabling this allows SQLMesh to determine what should be shown to the user. 824 # Ex: We describe a table to see if it exists and therefore that execution can fail but we don't need to show 825 # the user since it is expected if the table doesn't exist. Without this change the user would see the error. 826 logging.getLogger("SQLQueryContextLogger").setLevel(logging.CRITICAL) 827 828 if not isinstance(data, dict): 829 return data 830 831 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 832 833 if DatabricksEngineAdapter.can_access_spark_session( 834 bool(data.get("disable_spark_session")) 835 ): 836 return data 837 838 databricks_connect_use_serverless = data.get("databricks_connect_use_serverless") 839 server_hostname, http_path, access_token, auth_type = ( 840 data.get("server_hostname"), 841 data.get("http_path"), 842 data.get("access_token"), 843 data.get("auth_type"), 844 ) 845 846 if (not server_hostname or not http_path or not access_token) and ( 847 not databricks_connect_use_serverless and not auth_type 848 ): 849 raise ValueError( 850 "`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook" 851 ) 852 if ( 853 databricks_connect_use_serverless 854 and not server_hostname 855 and not data.get("databricks_connect_server_hostname") 856 ): 857 raise ValueError( 858 "`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set" 859 ) 860 if DatabricksEngineAdapter.can_access_databricks_connect( 861 bool(data.get("disable_databricks_connect")) 862 ): 863 if not data.get("databricks_connect_access_token"): 864 data["databricks_connect_access_token"] = access_token 865 if not data.get("databricks_connect_server_hostname"): 866 data["databricks_connect_server_hostname"] = f"https://{server_hostname}" 867 if not databricks_connect_use_serverless and not data.get( 868 "databricks_connect_cluster_id" 869 ): 870 if t.TYPE_CHECKING: 871 assert http_path is not None 872 data["databricks_connect_cluster_id"] = http_path.split("/")[-1] 873 874 if auth_type: 875 from databricks.sql.auth.auth import AuthType 876 877 all_data = [m.value for m in AuthType] 878 if auth_type not in all_data: 879 raise ValueError( 880 f"`auth_type` {auth_type} does not match a valid option: {all_data}" 881 ) 882 883 client_id = data.get("oauth_client_id") 884 client_secret = data.get("oauth_client_secret") 885 886 if client_secret and not client_id: 887 raise ValueError( 888 "`oauth_client_id` is required when `oauth_client_secret` is specified" 889 ) 890 891 if not http_path: 892 raise ValueError("`http_path` is still required when using `auth_type`") 893 894 return data 895 896 _engine_import_validator = _get_engine_import_validator("databricks", "databricks") 897 898 @property 899 def _connection_kwargs_keys(self) -> t.Set[str]: 900 if self.use_spark_session_only: 901 return set() 902 return { 903 "server_hostname", 904 "http_path", 905 "access_token", 906 "http_headers", 907 "session_configuration", 908 "catalog", 909 } 910 911 @property 912 def _engine_adapter(self) -> t.Type[engine_adapter.DatabricksEngineAdapter]: 913 return engine_adapter.DatabricksEngineAdapter 914 915 @property 916 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 917 return { 918 k: v 919 for k, v in self.dict().items() 920 if k.startswith("databricks_connect_") 921 or k in ("catalog", "disable_databricks_connect", "disable_spark_session") 922 } 923 924 @property 925 def use_spark_session_only(self) -> bool: 926 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 927 928 return ( 929 DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session) 930 or self.force_databricks_connect 931 ) 932 933 @property 934 def _connection_factory(self) -> t.Callable: 935 if self.use_spark_session_only: 936 from sqlmesh.engines.spark.db_api.spark_session import connection 937 938 return connection 939 940 from databricks import sql # type: ignore 941 942 return sql.connect 943 944 @property 945 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 946 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 947 948 if not self.use_spark_session_only: 949 conn_kwargs: t.Dict[str, t.Any] = { 950 "_user_agent_entry": "sqlmesh", 951 } 952 953 if self.auth_type and "oauth" in self.auth_type: 954 # there are two types of oauth: User-to-Machine (U2M) and Machine-to-Machine (M2M) 955 if self.oauth_client_secret: 956 # if a client_secret exists, then a client_id also exists and we are using M2M 957 # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication 958 # ref: https://github.com/databricks/databricks-sql-python/blob/main/examples/m2m_oauth.py 959 from databricks.sdk.core import oauth_service_principal, Config 960 961 config = Config( 962 host=f"https://{self.server_hostname}", 963 client_id=self.oauth_client_id, 964 client_secret=self.oauth_client_secret, 965 ) 966 conn_kwargs["credentials_provider"] = lambda: oauth_service_principal(config) 967 else: 968 # if auth_type is set to an 'oauth' type but no client_id/secret are set, then we are using U2M 969 # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-user-to-machine-u2m-authentication 970 conn_kwargs["auth_type"] = self.auth_type 971 972 return conn_kwargs 973 974 if DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session): 975 from pyspark.sql import SparkSession 976 977 return dict( 978 spark=SparkSession.getActiveSession(), 979 catalog=self.catalog, 980 ) 981 982 from databricks.connect import DatabricksSession 983 984 if t.TYPE_CHECKING: 985 assert self.databricks_connect_server_hostname is not None 986 assert self.databricks_connect_access_token is not None 987 988 if self.databricks_connect_use_serverless: 989 builder = DatabricksSession.builder.remote( 990 host=self.databricks_connect_server_hostname, 991 token=self.databricks_connect_access_token, 992 serverless=True, 993 ) 994 else: 995 if t.TYPE_CHECKING: 996 assert self.databricks_connect_cluster_id is not None 997 builder = DatabricksSession.builder.remote( 998 host=self.databricks_connect_server_hostname, 999 token=self.databricks_connect_access_token, 1000 cluster_id=self.databricks_connect_cluster_id, 1001 ) 1002 1003 return dict( 1004 spark=builder.userAgent("sqlmesh").getOrCreate(), 1005 catalog=self.catalog, 1006 )
Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations
Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39 OAuth ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication
Arguments:
- server_hostname: Databricks instance host name.
- http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
- access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
- auth_type: Set to 'databricks-oauth' or 'azure-oauth' to trigger OAuth (or dont set at all to use
access_token) - oauth_client_id: Client ID to use when auth_type is set to one of the 'oauth' types
- oauth_client_secret: Client Secret to use when auth_type is set to one of the 'oauth' types
- catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in
the Databricks cluster (most likely
hive_metastore). - http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request
- session_configuration: An optional dictionary of Spark session parameters.
Execute the SQL command
SET -vto get a full list of available commands. - databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect.
Defaults to the
server_hostnamevalue. - databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect.
Defaults to the
access_tokenvalue. - databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect.
Defaults to deriving the cluster id from the
http_pathvalue. - force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector.
- disable_databricks_connect: Even if databricks connect is installed, do not use it.
- disable_spark_session: Do not use SparkSession if it is available (like when running in a notebook).
- pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1009class BigQueryConnectionMethod(str, Enum): 1010 OAUTH = "oauth" 1011 OAUTH_SECRETS = "oauth-secrets" 1012 SERVICE_ACCOUNT = "service-account" 1013 SERVICE_ACCOUNT_JSON = "service-account-json"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
1016class BigQueryPriority(str, Enum): 1017 BATCH = "batch" 1018 INTERACTIVE = "interactive" 1019 1020 @property 1021 def is_batch(self) -> bool: 1022 return self == self.BATCH 1023 1024 @property 1025 def is_interactive(self) -> bool: 1026 return self == self.INTERACTIVE 1027 1028 @property 1029 def bigquery_constant(self) -> str: 1030 from google.cloud.bigquery import QueryPriority 1031 1032 if self.is_batch: 1033 return QueryPriority.BATCH 1034 return QueryPriority.INTERACTIVE
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
1037class BigQueryConnectionConfig(ConnectionConfig): 1038 """ 1039 BigQuery Connection Configuration. 1040 """ 1041 1042 method: BigQueryConnectionMethod = BigQueryConnectionMethod.OAUTH 1043 1044 project: t.Optional[str] = None 1045 execution_project: t.Optional[str] = None 1046 quota_project: t.Optional[str] = None 1047 location: t.Optional[str] = None 1048 # Keyfile Auth 1049 keyfile: t.Optional[str] = None 1050 keyfile_json: t.Optional[t.Dict[str, t.Any]] = None 1051 # Oath Secret Auth 1052 token: t.Optional[str] = None 1053 refresh_token: t.Optional[str] = None 1054 client_id: t.Optional[str] = None 1055 client_secret: t.Optional[str] = None 1056 token_uri: t.Optional[str] = None 1057 scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/bigquery",) 1058 impersonated_service_account: t.Optional[str] = None 1059 # Extra Engine Config 1060 job_creation_timeout_seconds: t.Optional[int] = None 1061 job_execution_timeout_seconds: t.Optional[int] = None 1062 job_retries: t.Optional[int] = 1 1063 job_retry_deadline_seconds: t.Optional[int] = None 1064 priority: t.Optional[BigQueryPriority] = None 1065 maximum_bytes_billed: t.Optional[int] = None 1066 1067 concurrent_tasks: int = 1 1068 register_comments: bool = True 1069 pre_ping: t.Literal[False] = False 1070 1071 type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery") 1072 DIALECT: t.ClassVar[t.Literal["bigquery"]] = "bigquery" 1073 DISPLAY_NAME: t.ClassVar[t.Literal["BigQuery"]] = "BigQuery" 1074 DISPLAY_ORDER: t.ClassVar[t.Literal[4]] = 4 1075 1076 _engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery") 1077 1078 @field_validator("execution_project") 1079 def validate_execution_project( 1080 cls, 1081 v: t.Optional[str], 1082 info: ValidationInfo, 1083 ) -> t.Optional[str]: 1084 if v and not info.data.get("project"): 1085 raise ConfigError( 1086 "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location." 1087 ) 1088 return v 1089 1090 @field_validator("quota_project") 1091 def validate_quota_project( 1092 cls, 1093 v: t.Optional[str], 1094 info: ValidationInfo, 1095 ) -> t.Optional[str]: 1096 if v and not info.data.get("project"): 1097 raise ConfigError( 1098 "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location." 1099 ) 1100 return v 1101 1102 @property 1103 def _connection_kwargs_keys(self) -> t.Set[str]: 1104 return set() 1105 1106 @property 1107 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1108 return engine_adapter.BigQueryEngineAdapter 1109 1110 @property 1111 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1112 """The static connection kwargs for this connection""" 1113 import google.auth 1114 from google.auth import impersonated_credentials 1115 from google.api_core import client_info, client_options 1116 from google.oauth2 import credentials, service_account 1117 1118 if self.method == BigQueryConnectionMethod.OAUTH: 1119 creds, _ = google.auth.default(scopes=self.scopes) 1120 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT: 1121 creds = service_account.Credentials.from_service_account_file( 1122 self.keyfile, scopes=self.scopes 1123 ) 1124 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: 1125 creds = service_account.Credentials.from_service_account_info( 1126 self.keyfile_json, scopes=self.scopes 1127 ) 1128 elif self.method == BigQueryConnectionMethod.OAUTH_SECRETS: 1129 creds = credentials.Credentials( 1130 token=self.token, 1131 refresh_token=self.refresh_token, 1132 client_id=self.client_id, 1133 client_secret=self.client_secret, 1134 token_uri=self.token_uri, 1135 scopes=self.scopes, 1136 ) 1137 else: 1138 raise ConfigError("Invalid BigQuery Connection Method") 1139 1140 if self.impersonated_service_account: 1141 creds = impersonated_credentials.Credentials( 1142 source_credentials=creds, 1143 target_principal=self.impersonated_service_account, 1144 target_scopes=self.scopes, 1145 ) 1146 1147 options = client_options.ClientOptions(quota_project_id=self.quota_project) 1148 project = self.execution_project or self.project or None 1149 1150 client = google.cloud.bigquery.Client( 1151 project=project and exp.parse_identifier(project, dialect="bigquery").name, 1152 credentials=creds, 1153 location=self.location, 1154 client_info=client_info.ClientInfo(user_agent="sqlmesh"), 1155 client_options=options, 1156 ) 1157 1158 return { 1159 "client": client, 1160 } 1161 1162 @property 1163 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1164 return { 1165 k: v 1166 for k, v in self.dict().items() 1167 if k 1168 in { 1169 "job_creation_timeout_seconds", 1170 "job_execution_timeout_seconds", 1171 "job_retries", 1172 "job_retry_deadline_seconds", 1173 "priority", 1174 "maximum_bytes_billed", 1175 } 1176 } 1177 1178 @property 1179 def _connection_factory(self) -> t.Callable: 1180 from google.cloud.bigquery.dbapi import connect 1181 1182 return connect 1183 1184 def get_catalog(self) -> t.Optional[str]: 1185 return self.project
BigQuery Connection Configuration.
1078 @field_validator("execution_project") 1079 def validate_execution_project( 1080 cls, 1081 v: t.Optional[str], 1082 info: ValidationInfo, 1083 ) -> t.Optional[str]: 1084 if v and not info.data.get("project"): 1085 raise ConfigError( 1086 "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location." 1087 ) 1088 return v
1090 @field_validator("quota_project") 1091 def validate_quota_project( 1092 cls, 1093 v: t.Optional[str], 1094 info: ValidationInfo, 1095 ) -> t.Optional[str]: 1096 if v and not info.data.get("project"): 1097 raise ConfigError( 1098 "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location." 1099 ) 1100 return v
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1188class GCPPostgresConnectionConfig(ConnectionConfig): 1189 """ 1190 Postgres Connection Configuration for GCP. 1191 1192 Args: 1193 instance_connection_string: Connection name for the postgres instance. 1194 user: Postgres or IAM user's name 1195 password: The postgres user's password. Only needed when the user is a postgres user. 1196 enable_iam_auth: Set to True when user is an IAM user. 1197 db: Name of the db to connect to. 1198 keyfile: string path to json service account credentials file 1199 keyfile_json: dict service account credentials info 1200 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 1201 """ 1202 1203 instance_connection_string: str 1204 user: str 1205 password: t.Optional[str] = None 1206 enable_iam_auth: t.Optional[bool] = None 1207 db: str 1208 ip_type: t.Union[t.Literal["public"], t.Literal["private"], t.Literal["psc"]] = "public" 1209 # Keyfile Auth 1210 keyfile: t.Optional[str] = None 1211 keyfile_json: t.Optional[t.Dict[str, t.Any]] = None 1212 timeout: t.Optional[int] = None 1213 scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/sqlservice.admin",) 1214 driver: str = "pg8000" 1215 1216 type_: t.Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres") 1217 DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres" 1218 DISPLAY_NAME: t.ClassVar[t.Literal["GCP Postgres"]] = "GCP Postgres" 1219 DISPLAY_ORDER: t.ClassVar[t.Literal[13]] = 13 1220 1221 concurrent_tasks: int = 4 1222 register_comments: bool = True 1223 pre_ping: bool = True 1224 1225 _engine_import_validator = _get_engine_import_validator( 1226 "google.cloud.sql", "gcp_postgres", "gcppostgres" 1227 ) 1228 1229 @model_validator(mode="before") 1230 def _validate_auth_method(cls, data: t.Any) -> t.Any: 1231 if not isinstance(data, dict): 1232 return data 1233 1234 password = data.get("password") 1235 enable_iam_auth = data.get("enable_iam_auth") 1236 1237 if not password and not enable_iam_auth: 1238 raise ConfigError( 1239 "GCP Postgres connection configuration requires either password set" 1240 " for a postgres user account or enable_iam_auth set to 'True'" 1241 " for an IAM user account." 1242 ) 1243 1244 return data 1245 1246 @property 1247 def _connection_kwargs_keys(self) -> t.Set[str]: 1248 return { 1249 "instance_connection_string", 1250 "driver", 1251 "user", 1252 "password", 1253 "db", 1254 "enable_iam_auth", 1255 "timeout", 1256 } 1257 1258 @property 1259 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1260 return engine_adapter.PostgresEngineAdapter 1261 1262 @property 1263 def _connection_factory(self) -> t.Callable: 1264 from google.cloud.sql.connector import Connector 1265 from google.oauth2 import service_account 1266 1267 creds = None 1268 if self.keyfile: 1269 creds = service_account.Credentials.from_service_account_file( 1270 self.keyfile, scopes=self.scopes 1271 ) 1272 elif self.keyfile_json: 1273 creds = service_account.Credentials.from_service_account_info( 1274 self.keyfile_json, scopes=self.scopes 1275 ) 1276 1277 kwargs = { 1278 "credentials": creds, 1279 "ip_type": self.ip_type, 1280 } 1281 1282 if self.timeout: 1283 kwargs["timeout"] = self.timeout 1284 1285 return Connector(**kwargs).connect # type: ignore
Postgres Connection Configuration for GCP.
Arguments:
- instance_connection_string: Connection name for the postgres instance.
- user: Postgres or IAM user's name
- password: The postgres user's password. Only needed when the user is a postgres user.
- enable_iam_auth: Set to True when user is an IAM user.
- db: Name of the db to connect to.
- keyfile: string path to json service account credentials file
- keyfile_json: dict service account credentials info
- pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1288class RedshiftConnectionConfig(ConnectionConfig): 1289 """ 1290 Redshift Connection Configuration. 1291 1292 Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146 1293 Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported. 1294 1295 Args: 1296 user: The username to use for authentication with the Amazon Redshift cluster. 1297 password: The password to use for authentication with the Amazon Redshift cluster. 1298 database: The name of the database instance to connect to. 1299 host: The hostname of the Amazon Redshift cluster. 1300 port: The port number of the Amazon Redshift cluster. Default value is 5439. 1301 source_address: No description provided 1302 unix_sock: No description provided 1303 ssl: Is SSL enabled. Default value is ``True``. SSL must be enabled when authenticating using IAM. 1304 sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported. 1305 timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout. 1306 tcp_keepalive: Is `TCP keepalive <https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive>`_ used. The default value is ``True``. 1307 application_name: Sets the application name. The default value is None. 1308 preferred_role: The IAM role preferred for the current connection. 1309 principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy. 1310 credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster. 1311 region: The AWS region where the Amazon Redshift cluster is located. 1312 cluster_identifier: The cluster identifier of the Amazon Redshift cluster. 1313 iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP. 1314 is_serverless: Redshift end-point is serverless or provisional. Default value false. 1315 serverless_acct_id: The account ID of the serverless. Default value None 1316 serverless_work_group: The name of work group for serverless end point. Default value None. 1317 pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. 1318 enable_merge: Whether to use the Redshift merge operation instead of the SQLMesh logical merge. 1319 """ 1320 1321 user: t.Optional[str] = None 1322 password: t.Optional[str] = None 1323 database: t.Optional[str] = None 1324 host: t.Optional[str] = None 1325 port: t.Optional[int] = None 1326 source_address: t.Optional[str] = None 1327 unix_sock: t.Optional[str] = None 1328 ssl: t.Optional[bool] = None 1329 sslmode: t.Optional[str] = None 1330 timeout: t.Optional[int] = None 1331 tcp_keepalive: t.Optional[bool] = None 1332 application_name: t.Optional[str] = None 1333 preferred_role: t.Optional[str] = None 1334 principal_arn: t.Optional[str] = None 1335 credentials_provider: t.Optional[str] = None 1336 region: t.Optional[str] = None 1337 cluster_identifier: t.Optional[str] = None 1338 iam: t.Optional[bool] = None 1339 is_serverless: t.Optional[bool] = None 1340 serverless_acct_id: t.Optional[str] = None 1341 serverless_work_group: t.Optional[str] = None 1342 enable_merge: t.Optional[bool] = None 1343 1344 concurrent_tasks: int = 4 1345 register_comments: bool = True 1346 pre_ping: bool = False 1347 1348 type_: t.Literal["redshift"] = Field(alias="type", default="redshift") 1349 DIALECT: t.ClassVar[t.Literal["redshift"]] = "redshift" 1350 DISPLAY_NAME: t.ClassVar[t.Literal["Redshift"]] = "Redshift" 1351 DISPLAY_ORDER: t.ClassVar[t.Literal[7]] = 7 1352 1353 _engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift") 1354 1355 @property 1356 def _connection_kwargs_keys(self) -> t.Set[str]: 1357 return { 1358 "user", 1359 "password", 1360 "database", 1361 "host", 1362 "port", 1363 "source_address", 1364 "unix_sock", 1365 "ssl", 1366 "sslmode", 1367 "timeout", 1368 "tcp_keepalive", 1369 "application_name", 1370 "preferred_role", 1371 "principal_arn", 1372 "credentials_provider", 1373 "region", 1374 "cluster_identifier", 1375 "iam", 1376 "is_serverless", 1377 "serverless_acct_id", 1378 "serverless_work_group", 1379 } 1380 1381 @property 1382 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1383 return engine_adapter.RedshiftEngineAdapter 1384 1385 @property 1386 def _connection_factory(self) -> t.Callable: 1387 from redshift_connector import connect 1388 1389 return connect 1390 1391 @property 1392 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1393 return {"enable_merge": self.enable_merge}
Redshift Connection Configuration.
Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146 Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported.
Arguments:
- user: The username to use for authentication with the Amazon Redshift cluster.
- password: The password to use for authentication with the Amazon Redshift cluster.
- database: The name of the database instance to connect to.
- host: The hostname of the Amazon Redshift cluster.
- port: The port number of the Amazon Redshift cluster. Default value is 5439.
- source_address: No description provided
- unix_sock: No description provided
- ssl: Is SSL enabled. Default value is
True. SSL must be enabled when authenticating using IAM. - sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported.
- timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout.
- tcp_keepalive: Is TCP keepalive used. The default value is
True. - application_name: Sets the application name. The default value is None.
- preferred_role: The IAM role preferred for the current connection.
- principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy.
- credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster.
- region: The AWS region where the Amazon Redshift cluster is located.
- cluster_identifier: The cluster identifier of the Amazon Redshift cluster.
- iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP.
- is_serverless: Redshift end-point is serverless or provisional. Default value false.
- serverless_acct_id: The account ID of the serverless. Default value None
- serverless_work_group: The name of work group for serverless end point. Default value None.
- pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
- enable_merge: Whether to use the Redshift merge operation instead of the SQLMesh logical merge.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1396class PostgresConnectionConfig(ConnectionConfig): 1397 host: str 1398 user: str 1399 password: str 1400 port: int 1401 database: str 1402 keepalives_idle: t.Optional[int] = None 1403 connect_timeout: int = 10 1404 role: t.Optional[str] = None 1405 sslmode: t.Optional[str] = None 1406 application_name: t.Optional[str] = None 1407 1408 concurrent_tasks: int = 4 1409 register_comments: bool = True 1410 pre_ping: bool = True 1411 1412 type_: t.Literal["postgres"] = Field(alias="type", default="postgres") 1413 DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres" 1414 DISPLAY_NAME: t.ClassVar[t.Literal["Postgres"]] = "Postgres" 1415 DISPLAY_ORDER: t.ClassVar[t.Literal[12]] = 12 1416 1417 _engine_import_validator = _get_engine_import_validator("psycopg2", "postgres") 1418 1419 @property 1420 def _connection_kwargs_keys(self) -> t.Set[str]: 1421 return { 1422 "host", 1423 "user", 1424 "password", 1425 "port", 1426 "database", 1427 "keepalives_idle", 1428 "connect_timeout", 1429 "sslmode", 1430 "application_name", 1431 } 1432 1433 @property 1434 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1435 return engine_adapter.PostgresEngineAdapter 1436 1437 @property 1438 def _connection_factory(self) -> t.Callable: 1439 from psycopg2 import connect 1440 1441 return connect 1442 1443 @property 1444 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 1445 if not self.role: 1446 return None 1447 1448 def init(cursor: t.Any) -> None: 1449 cursor.execute(f"SET ROLE {self.role}") 1450 1451 return init
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1454class MySQLConnectionConfig(ConnectionConfig): 1455 host: str 1456 user: str 1457 password: str 1458 port: t.Optional[int] = None 1459 database: t.Optional[str] = None 1460 charset: t.Optional[str] = None 1461 collation: t.Optional[str] = None 1462 ssl_disabled: t.Optional[bool] = None 1463 1464 concurrent_tasks: int = 4 1465 register_comments: bool = True 1466 pre_ping: bool = True 1467 1468 type_: t.Literal["mysql"] = Field(alias="type", default="mysql") 1469 DIALECT: t.ClassVar[t.Literal["mysql"]] = "mysql" 1470 DISPLAY_NAME: t.ClassVar[t.Literal["MySQL"]] = "MySQL" 1471 DISPLAY_ORDER: t.ClassVar[t.Literal[14]] = 14 1472 1473 _engine_import_validator = _get_engine_import_validator("pymysql", "mysql") 1474 1475 @property 1476 def _connection_kwargs_keys(self) -> t.Set[str]: 1477 connection_keys = { 1478 "host", 1479 "user", 1480 "password", 1481 } 1482 if self.port is not None: 1483 connection_keys.add("port") 1484 if self.database is not None: 1485 connection_keys.add("database") 1486 if self.charset is not None: 1487 connection_keys.add("charset") 1488 if self.collation is not None: 1489 connection_keys.add("collation") 1490 if self.ssl_disabled is not None: 1491 connection_keys.add("ssl_disabled") 1492 return connection_keys 1493 1494 @property 1495 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1496 return engine_adapter.MySQLEngineAdapter 1497 1498 @property 1499 def _connection_factory(self) -> t.Callable: 1500 from pymysql import connect 1501 1502 return connect
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1505class MSSQLConnectionConfig(ConnectionConfig): 1506 host: str 1507 user: t.Optional[str] = None 1508 password: t.Optional[str] = None 1509 database: t.Optional[str] = "" 1510 timeout: t.Optional[int] = 0 1511 login_timeout: t.Optional[int] = 60 1512 charset: t.Optional[str] = "UTF-8" 1513 appname: t.Optional[str] = None 1514 port: t.Optional[int] = 1433 1515 conn_properties: t.Optional[t.Union[t.List[str], str]] = None 1516 autocommit: t.Optional[bool] = False 1517 tds_version: t.Optional[str] = None 1518 1519 # Driver options 1520 driver: t.Literal["pymssql", "pyodbc"] = "pymssql" 1521 # PyODBC specific options 1522 driver_name: t.Optional[str] = None # e.g. "ODBC Driver 18 for SQL Server" 1523 trust_server_certificate: t.Optional[bool] = None 1524 encrypt: t.Optional[bool] = None 1525 # Dictionary of arbitrary ODBC connection properties 1526 # See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute 1527 odbc_properties: t.Optional[t.Dict[str, t.Any]] = None 1528 1529 concurrent_tasks: int = 4 1530 register_comments: bool = True 1531 pre_ping: bool = True 1532 1533 type_: t.Literal["mssql"] = Field(alias="type", default="mssql") 1534 DIALECT: t.ClassVar[t.Literal["tsql"]] = "tsql" 1535 DISPLAY_NAME: t.ClassVar[t.Literal["MSSQL"]] = "MSSQL" 1536 DISPLAY_ORDER: t.ClassVar[t.Literal[11]] = 11 1537 1538 @model_validator(mode="before") 1539 @classmethod 1540 def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any: 1541 if not isinstance(data, dict): 1542 return data 1543 1544 driver = data.get("driver", "pymssql") 1545 1546 # Define the mapping of driver to import module and extra name 1547 driver_configs = {"pymssql": ("pymssql", "mssql"), "pyodbc": ("pyodbc", "mssql-odbc")} 1548 1549 if driver not in driver_configs: 1550 raise ValueError(f"Unsupported driver: {driver}") 1551 1552 import_module, extra_name = driver_configs[driver] 1553 1554 # Use _get_engine_import_validator with decorate=False to get the raw validation function 1555 # This avoids the __wrapped__ issue in Python 3.9 1556 validator_func = _get_engine_import_validator( 1557 import_module, driver, extra_name, decorate=False 1558 ) 1559 1560 # Call the raw validation function directly 1561 return validator_func(cls, data) 1562 1563 @property 1564 def _connection_kwargs_keys(self) -> t.Set[str]: 1565 base_keys = { 1566 "host", 1567 "user", 1568 "password", 1569 "database", 1570 "timeout", 1571 "login_timeout", 1572 "charset", 1573 "appname", 1574 "port", 1575 "conn_properties", 1576 "autocommit", 1577 "tds_version", 1578 } 1579 1580 if self.driver == "pyodbc": 1581 base_keys.update( 1582 { 1583 "driver_name", 1584 "trust_server_certificate", 1585 "encrypt", 1586 "odbc_properties", 1587 } 1588 ) 1589 # Remove pymssql-specific parameters 1590 base_keys.discard("tds_version") 1591 base_keys.discard("conn_properties") 1592 1593 return base_keys 1594 1595 @property 1596 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1597 return engine_adapter.MSSQLEngineAdapter 1598 1599 @property 1600 def _connection_factory(self) -> t.Callable: 1601 if self.driver == "pymssql": 1602 import pymssql 1603 1604 return pymssql.connect 1605 1606 import pyodbc 1607 1608 def connect(**kwargs: t.Any) -> t.Callable: 1609 # Extract parameters for connection string 1610 host = kwargs.pop("host") 1611 port = kwargs.pop("port", 1433) 1612 database = kwargs.pop("database", "") 1613 user = kwargs.pop("user", None) 1614 password = kwargs.pop("password", None) 1615 driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server") 1616 trust_server_certificate = kwargs.pop("trust_server_certificate", False) 1617 encrypt = kwargs.pop("encrypt", True) 1618 login_timeout = kwargs.pop("login_timeout", 60) 1619 1620 # Build connection string 1621 conn_str_parts = [ 1622 f"DRIVER={{{driver_name}}}", 1623 f"SERVER={host},{port}", 1624 ] 1625 1626 if database: 1627 conn_str_parts.append(f"DATABASE={database}") 1628 1629 # Add security options 1630 conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}") 1631 if trust_server_certificate: 1632 conn_str_parts.append("TrustServerCertificate=YES") 1633 1634 conn_str_parts.append(f"Connection Timeout={login_timeout}") 1635 1636 # Standard SQL Server authentication 1637 if user: 1638 conn_str_parts.append(f"UID={user}") 1639 if password: 1640 conn_str_parts.append(f"PWD={password}") 1641 1642 # Add any additional ODBC properties from the odbc_properties dictionary 1643 if self.odbc_properties: 1644 for key, value in self.odbc_properties.items(): 1645 # Skip properties that we've already set above 1646 if key.lower() in ( 1647 "driver", 1648 "server", 1649 "database", 1650 "uid", 1651 "pwd", 1652 "encrypt", 1653 "trustservercertificate", 1654 "connection timeout", 1655 ): 1656 continue 1657 1658 # Handle boolean values properly 1659 if isinstance(value, bool): 1660 conn_str_parts.append(f"{key}={'YES' if value else 'NO'}") 1661 else: 1662 conn_str_parts.append(f"{key}={value}") 1663 1664 # Create the connection string 1665 conn_str = ";".join(conn_str_parts) 1666 1667 conn = pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False)) 1668 1669 # Set up output converters for MSSQL-specific data types 1670 # Handle SQL type -155 (DATETIMEOFFSET) which is not yet supported by pyodbc 1671 # ref: https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794 1672 def handle_datetimeoffset(dto_value: t.Any) -> t.Any: 1673 from datetime import datetime, timedelta, timezone 1674 import struct 1675 1676 # Unpack the DATETIMEOFFSET binary format: 1677 # Format: <6hI2h = (year, month, day, hour, minute, second, nanoseconds, tz_hour_offset, tz_minute_offset) 1678 tup = struct.unpack("<6hI2h", dto_value) 1679 return datetime( 1680 tup[0], 1681 tup[1], 1682 tup[2], 1683 tup[3], 1684 tup[4], 1685 tup[5], 1686 tup[6] // 1000, 1687 timezone(timedelta(hours=tup[7], minutes=tup[8])), 1688 ) 1689 1690 conn.add_output_converter(-155, handle_datetimeoffset) 1691 1692 return conn 1693 1694 return connect 1695 1696 @property 1697 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1698 return {"catalog_support": CatalogSupport.REQUIRES_SET_CATALOG}
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1701class AzureSQLConnectionConfig(MSSQLConnectionConfig): 1702 type_: t.Literal["azuresql"] = Field(alias="type", default="azuresql") # type: ignore 1703 DISPLAY_NAME: t.ClassVar[t.Literal["Azure SQL"]] = "Azure SQL" # type: ignore 1704 DISPLAY_ORDER: t.ClassVar[t.Literal[10]] = 10 # type: ignore 1705 1706 @property 1707 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1708 return {"catalog_support": CatalogSupport.SINGLE_CATALOG_ONLY}
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
- MSSQLConnectionConfig
- host
- user
- password
- database
- timeout
- login_timeout
- charset
- appname
- port
- conn_properties
- autocommit
- tds_version
- driver
- driver_name
- trust_server_certificate
- encrypt
- odbc_properties
- concurrent_tasks
- register_comments
- pre_ping
- DIALECT
1711class FabricConnectionConfig(MSSQLConnectionConfig): 1712 """ 1713 Fabric Connection Configuration. 1714 Inherits most settings from MSSQLConnectionConfig and sets the type to 'fabric'. 1715 It is recommended to use the 'pyodbc' driver for Fabric. 1716 """ 1717 1718 type_: t.Literal["fabric"] = Field(alias="type", default="fabric") # type: ignore 1719 DIALECT: t.ClassVar[t.Literal["fabric"]] = "fabric" # type: ignore 1720 DISPLAY_NAME: t.ClassVar[t.Literal["Fabric"]] = "Fabric" # type: ignore 1721 DISPLAY_ORDER: t.ClassVar[t.Literal[17]] = 17 # type: ignore 1722 driver: t.Literal["pyodbc"] = "pyodbc" 1723 workspace_id: str 1724 tenant_id: str 1725 autocommit: t.Optional[bool] = True 1726 1727 @property 1728 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1729 from sqlmesh.core.engine_adapter.fabric import FabricEngineAdapter 1730 1731 return FabricEngineAdapter 1732 1733 @property 1734 def _connection_factory(self) -> t.Callable: 1735 # Override to support catalog switching for Fabric 1736 base_factory = super()._connection_factory 1737 1738 def create_fabric_connection( 1739 target_catalog: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any 1740 ) -> t.Callable: 1741 kwargs["database"] = target_catalog or self.database 1742 return base_factory(*args, **kwargs) 1743 1744 return create_fabric_connection 1745 1746 @property 1747 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1748 return { 1749 "database": self.database, 1750 # more operations than not require a specific catalog to be already active 1751 # in particular, create/drop view, create/drop schema and querying information_schema 1752 "catalog_support": CatalogSupport.REQUIRES_SET_CATALOG, 1753 "workspace_id": self.workspace_id, 1754 "tenant_id": self.tenant_id, 1755 "user": self.user, 1756 "password": self.password, 1757 }
Fabric Connection Configuration. Inherits most settings from MSSQLConnectionConfig and sets the type to 'fabric'. It is recommended to use the 'pyodbc' driver for Fabric.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
- MSSQLConnectionConfig
- host
- user
- password
- database
- timeout
- login_timeout
- charset
- appname
- port
- conn_properties
- tds_version
- driver_name
- trust_server_certificate
- encrypt
- odbc_properties
- concurrent_tasks
- register_comments
- pre_ping
1760class SparkConnectionConfig(ConnectionConfig): 1761 """ 1762 Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks. 1763 """ 1764 1765 config_dir: t.Optional[str] = None 1766 catalog: t.Optional[str] = None 1767 config: t.Dict[str, t.Any] = {} 1768 wap_enabled: bool = False 1769 1770 concurrent_tasks: int = 4 1771 register_comments: bool = True 1772 pre_ping: t.Literal[False] = False 1773 1774 type_: t.Literal["spark"] = Field(alias="type", default="spark") 1775 DIALECT: t.ClassVar[t.Literal["spark"]] = "spark" 1776 DISPLAY_NAME: t.ClassVar[t.Literal["Spark"]] = "Spark" 1777 DISPLAY_ORDER: t.ClassVar[t.Literal[8]] = 8 1778 1779 _engine_import_validator = _get_engine_import_validator("pyspark", "spark") 1780 1781 @property 1782 def _connection_kwargs_keys(self) -> t.Set[str]: 1783 return { 1784 "catalog", 1785 } 1786 1787 @property 1788 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1789 return engine_adapter.SparkEngineAdapter 1790 1791 @property 1792 def _connection_factory(self) -> t.Callable: 1793 from sqlmesh.engines.spark.db_api.spark_session import connection 1794 1795 return connection 1796 1797 @property 1798 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1799 from pyspark.conf import SparkConf 1800 from pyspark.sql import SparkSession 1801 1802 spark_config = SparkConf() 1803 if self.config: 1804 for k, v in self.config.items(): 1805 spark_config.set(k, v) 1806 1807 if self.config_dir: 1808 os.environ["SPARK_CONF_DIR"] = self.config_dir 1809 return { 1810 "spark": SparkSession.builder.config(conf=spark_config) 1811 .enableHiveSupport() 1812 .getOrCreate(), 1813 } 1814 1815 @property 1816 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 1817 return {"wap_enabled": self.wap_enabled}
Vanilla Spark Connection Configuration. Use DatabricksConnectionConfig for Databricks.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1820class TrinoAuthenticationMethod(str, Enum): 1821 NO_AUTH = "no-auth" 1822 BASIC = "basic" 1823 LDAP = "ldap" 1824 KERBEROS = "kerberos" 1825 JWT = "jwt" 1826 CERTIFICATE = "certificate" 1827 OAUTH = "oauth" 1828 1829 @property 1830 def is_no_auth(self) -> bool: 1831 return self == self.NO_AUTH 1832 1833 @property 1834 def is_basic(self) -> bool: 1835 return self == self.BASIC 1836 1837 @property 1838 def is_ldap(self) -> bool: 1839 return self == self.LDAP 1840 1841 @property 1842 def is_kerberos(self) -> bool: 1843 return self == self.KERBEROS 1844 1845 @property 1846 def is_jwt(self) -> bool: 1847 return self == self.JWT 1848 1849 @property 1850 def is_certificate(self) -> bool: 1851 return self == self.CERTIFICATE 1852 1853 @property 1854 def is_oauth(self) -> bool: 1855 return self == self.OAUTH
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
1858class TrinoConnectionConfig(ConnectionConfig): 1859 method: TrinoAuthenticationMethod = TrinoAuthenticationMethod.NO_AUTH 1860 host: str 1861 user: str 1862 catalog: str 1863 port: t.Optional[int] = None 1864 http_scheme: t.Literal["http", "https"] = "https" 1865 # General Optional 1866 roles: t.Optional[t.Dict[str, str]] = None 1867 http_headers: t.Optional[t.Dict[str, str]] = None 1868 session_properties: t.Optional[t.Dict[str, str]] = None 1869 retries: int = 3 1870 timezone: t.Optional[str] = None 1871 # Basic/LDAP 1872 password: t.Optional[str] = None 1873 verify: t.Optional[bool] = None # disable SSL verification (ignored if `cert` is provided) 1874 # LDAP 1875 impersonation_user: t.Optional[str] = None 1876 # Kerberos 1877 keytab: t.Optional[str] = None 1878 krb5_config: t.Optional[str] = None 1879 principal: t.Optional[str] = None 1880 service_name: str = "trino" 1881 hostname_override: t.Optional[str] = None 1882 mutual_authentication: bool = False 1883 force_preemptive: bool = False 1884 sanitize_mutual_error_response: bool = True 1885 delegate: bool = False 1886 # JWT 1887 jwt_token: t.Optional[str] = None 1888 # Certificate 1889 client_certificate: t.Optional[str] = None 1890 client_private_key: t.Optional[str] = None 1891 cert: t.Optional[str] = None 1892 source: str = "sqlmesh" 1893 1894 # SQLMesh options 1895 schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None 1896 timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None 1897 concurrent_tasks: int = 4 1898 register_comments: bool = True 1899 pre_ping: t.Literal[False] = False 1900 1901 type_: t.Literal["trino"] = Field(alias="type", default="trino") 1902 DIALECT: t.ClassVar[t.Literal["trino"]] = "trino" 1903 DISPLAY_NAME: t.ClassVar[t.Literal["Trino"]] = "Trino" 1904 DISPLAY_ORDER: t.ClassVar[t.Literal[9]] = 9 1905 1906 _engine_import_validator = _get_engine_import_validator("trino", "trino") 1907 1908 @field_validator("schema_location_mapping", mode="before") 1909 @classmethod 1910 def _validate_regex_keys( 1911 cls, value: t.Dict[str | re.Pattern, str] 1912 ) -> t.Dict[re.Pattern, t.Any]: 1913 compiled = compile_regex_mapping(value) 1914 for replacement in compiled.values(): 1915 if "@{schema_name}" not in replacement: 1916 raise ConfigError( 1917 "schema_location_mapping needs to include the '@{schema_name}' placeholder in the value so SQLMesh knows where to substitute the schema name" 1918 ) 1919 return compiled 1920 1921 @field_validator("timestamp_mapping", mode="before") 1922 @classmethod 1923 def _validate_timestamp_mapping( 1924 cls, value: t.Optional[dict[str, str]] 1925 ) -> t.Optional[dict[exp.DataType, exp.DataType]]: 1926 if value is None: 1927 return value 1928 1929 result: dict[exp.DataType, exp.DataType] = {} 1930 for source_type, target_type in value.items(): 1931 try: 1932 source_datatype = exp.DataType.build(source_type) 1933 except ParseError: 1934 raise ConfigError( 1935 f"Invalid SQL type string in timestamp_mapping: " 1936 f"'{source_type}' is not a valid SQL data type." 1937 ) 1938 try: 1939 target_datatype = exp.DataType.build(target_type) 1940 except ParseError: 1941 raise ConfigError( 1942 f"Invalid SQL type string in timestamp_mapping: " 1943 f"'{target_type}' is not a valid SQL data type." 1944 ) 1945 result[source_datatype] = target_datatype 1946 1947 return result 1948 1949 @model_validator(mode="after") 1950 def _root_validator(self) -> Self: 1951 port = self.port 1952 if self.http_scheme == "http" and not self.method.is_no_auth and not self.method.is_basic: 1953 raise ConfigError("HTTP scheme can only be used with no-auth or basic method") 1954 1955 if port is None: 1956 self.port = 80 if self.http_scheme == "http" else 443 1957 1958 if (self.method.is_ldap or self.method.is_basic) and (not self.password or not self.user): 1959 raise ConfigError( 1960 f"Username and Password must be provided if using {self.method.value} authentication" 1961 ) 1962 1963 if self.method.is_kerberos and ( 1964 not self.principal or not self.keytab or not self.krb5_config 1965 ): 1966 raise ConfigError( 1967 "Kerberos requires the following fields: principal, keytab, and krb5_config" 1968 ) 1969 1970 if self.method.is_jwt and not self.jwt_token: 1971 raise ConfigError("JWT requires `jwt_token` to be set") 1972 1973 if self.method.is_certificate and ( 1974 not self.cert or not self.client_certificate or not self.client_private_key 1975 ): 1976 raise ConfigError( 1977 "Certificate requires the following fields: cert, client_certificate, and client_private_key" 1978 ) 1979 1980 return self 1981 1982 @property 1983 def _connection_kwargs_keys(self) -> t.Set[str]: 1984 kwargs = { 1985 "host", 1986 "port", 1987 "catalog", 1988 "roles", 1989 "source", 1990 "http_scheme", 1991 "http_headers", 1992 "session_properties", 1993 "timezone", 1994 } 1995 return kwargs 1996 1997 @property 1998 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1999 return engine_adapter.TrinoEngineAdapter 2000 2001 @property 2002 def _connection_factory(self) -> t.Callable: 2003 from trino.dbapi import connect 2004 2005 return connect 2006 2007 @property 2008 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 2009 from trino.auth import ( 2010 BasicAuthentication, 2011 CertificateAuthentication, 2012 JWTAuthentication, 2013 KerberosAuthentication, 2014 OAuth2Authentication, 2015 ) 2016 2017 if self.method.is_basic or self.method.is_ldap: 2018 auth = BasicAuthentication(self.user, self.password) 2019 elif self.method.is_kerberos: 2020 if self.keytab: 2021 os.environ["KRB5_CLIENT_KTNAME"] = self.keytab 2022 auth = KerberosAuthentication( 2023 config=self.krb5_config, 2024 service_name=self.service_name, 2025 principal=self.principal, 2026 mutual_authentication=self.mutual_authentication, 2027 ca_bundle=self.cert, 2028 force_preemptive=self.force_preemptive, 2029 hostname_override=self.hostname_override, 2030 sanitize_mutual_error_response=self.sanitize_mutual_error_response, 2031 delegate=self.delegate, 2032 ) 2033 elif self.method.is_oauth: 2034 auth = OAuth2Authentication() 2035 elif self.method.is_jwt: 2036 auth = JWTAuthentication(self.jwt_token) 2037 elif self.method.is_certificate: 2038 auth = CertificateAuthentication(self.client_certificate, self.client_private_key) 2039 else: 2040 auth = None 2041 2042 return { 2043 "auth": auth, 2044 "user": self.impersonation_user or self.user, 2045 "max_attempts": self.retries, 2046 "verify": self.cert if self.cert is not None else self.verify, 2047 "source": self.source, 2048 } 2049 2050 @property 2051 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2052 return { 2053 "schema_location_mapping": self.schema_location_mapping, 2054 "timestamp_mapping": self.timestamp_mapping, 2055 }
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
2058class ClickhouseConnectionConfig(ConnectionConfig): 2059 """ 2060 Clickhouse Connection Configuration. 2061 2062 Property reference: https://clickhouse.com/docs/en/integrations/python#client-initialization 2063 """ 2064 2065 host: str 2066 username: str 2067 password: t.Optional[str] = None 2068 port: t.Optional[int] = None 2069 cluster: t.Optional[str] = None 2070 connect_timeout: int = 10 2071 send_receive_timeout: int = 300 2072 query_limit: int = 0 2073 use_compression: bool = True 2074 compression_method: t.Optional[str] = None 2075 connection_settings: t.Optional[t.Dict[str, t.Any]] = None 2076 http_proxy: t.Optional[str] = None 2077 # HTTPS/TLS settings 2078 verify: bool = True 2079 ca_cert: t.Optional[str] = None 2080 client_cert: t.Optional[str] = None 2081 client_cert_key: t.Optional[str] = None 2082 https_proxy: t.Optional[str] = None 2083 server_host_name: t.Optional[str] = None 2084 tls_mode: t.Optional[str] = None 2085 2086 concurrent_tasks: int = 1 2087 register_comments: bool = True 2088 pre_ping: bool = False 2089 2090 # This object expects options from urllib3 and also from clickhouse-connect 2091 # See: 2092 # * https://urllib3.readthedocs.io/en/stable/advanced-usage.html 2093 # * https://clickhouse.com/docs/en/integrations/python#customizing-the-http-connection-pool 2094 connection_pool_options: t.Optional[t.Dict[str, t.Any]] = None 2095 2096 type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse") 2097 DIALECT: t.ClassVar[t.Literal["clickhouse"]] = "clickhouse" 2098 DISPLAY_NAME: t.ClassVar[t.Literal["ClickHouse"]] = "ClickHouse" 2099 DISPLAY_ORDER: t.ClassVar[t.Literal[6]] = 6 2100 2101 _engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse") 2102 2103 @property 2104 def _connection_kwargs_keys(self) -> t.Set[str]: 2105 kwargs = { 2106 "host", 2107 "username", 2108 "port", 2109 "password", 2110 "connect_timeout", 2111 "send_receive_timeout", 2112 "query_limit", 2113 "http_proxy", 2114 "verify", 2115 "ca_cert", 2116 "client_cert", 2117 "client_cert_key", 2118 "https_proxy", 2119 "server_host_name", 2120 "tls_mode", 2121 } 2122 return kwargs 2123 2124 @property 2125 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2126 return engine_adapter.ClickhouseEngineAdapter 2127 2128 @property 2129 def _connection_factory(self) -> t.Callable: 2130 from clickhouse_connect.dbapi import connect # type: ignore 2131 from clickhouse_connect.driver import httputil # type: ignore 2132 from functools import partial 2133 2134 pool_manager_options: t.Dict[str, t.Any] = dict( 2135 # Match the maxsize to the number of concurrent tasks 2136 maxsize=self.concurrent_tasks, 2137 # Block if there are no free connections 2138 block=True, 2139 verify=self.verify, 2140 ca_cert=self.ca_cert, 2141 client_cert=self.client_cert, 2142 client_cert_key=self.client_cert_key, 2143 https_proxy=self.https_proxy, 2144 ) 2145 # this doesn't happen automatically because we always supply our own pool manager to the connection 2146 # https://github.com/ClickHouse/clickhouse-connect/blob/3a7f4b04cad29c7c2536661b831fb744248e2ec0/clickhouse_connect/driver/httpclient.py#L109 2147 if self.server_host_name: 2148 pool_manager_options["server_hostname"] = self.server_host_name 2149 if self.verify: 2150 pool_manager_options["assert_hostname"] = self.server_host_name 2151 if self.connection_pool_options: 2152 pool_manager_options.update(self.connection_pool_options) 2153 pool_mgr = httputil.get_pool_manager(**pool_manager_options) 2154 2155 return partial(connect, pool_mgr=pool_mgr) 2156 2157 @property 2158 def cloud_mode(self) -> bool: 2159 return "clickhouse.cloud" in self.host 2160 2161 @property 2162 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2163 return {"cluster": self.cluster, "cloud_mode": self.cloud_mode} 2164 2165 @property 2166 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 2167 from sqlmesh import __version__ 2168 2169 # False = no compression 2170 # True = Clickhouse default compression method 2171 # string = specific compression method 2172 compress: bool | str = self.use_compression 2173 if compress and self.compression_method: 2174 compress = self.compression_method 2175 2176 # Clickhouse system settings passed to connection 2177 # https://clickhouse.com/docs/en/operations/settings/settings 2178 # - below are set to align with dbt-clickhouse 2179 # - https://github.com/ClickHouse/dbt-clickhouse/blob/44d26308ea6a3c8ead25c280164aa88191f05f47/dbt/adapters/clickhouse/dbclient.py#L77 2180 settings = self.connection_settings or {} 2181 # mutations_sync = 2: "The query waits for all mutations [ALTER statements] to complete on all replicas (if they exist)" 2182 settings["mutations_sync"] = "2" 2183 # insert_distributed_sync = 1: "INSERT operation succeeds only after all the data is saved on all shards" 2184 settings["insert_distributed_sync"] = "1" 2185 if self.cluster or self.cloud_mode: 2186 # database_replicated_enforce_synchronous_settings = 1: 2187 # - "Enforces synchronous waiting for some queries" 2188 # - https://github.com/ClickHouse/ClickHouse/blob/ccaa8d03a9351efc16625340268b9caffa8a22ba/src/Core/Settings.h#L709 2189 settings["database_replicated_enforce_synchronous_settings"] = "1" 2190 # insert_quorum = auto: 2191 # - "INSERT succeeds only when ClickHouse manages to correctly write data to the insert_quorum of replicas during 2192 # the insert_quorum_timeout" 2193 # - "use majority number (number_of_replicas / 2 + 1) as quorum number" 2194 settings["insert_quorum"] = "auto" 2195 2196 return { 2197 "compress": compress, 2198 "client_name": f"SQLMesh/{__version__}", 2199 **settings, 2200 }
Clickhouse Connection Configuration.
Property reference: https://clickhouse.com/docs/en/integrations/python#client-initialization
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
2203class AthenaConnectionConfig(ConnectionConfig): 2204 # PyAthena connection options 2205 aws_access_key_id: t.Optional[str] = None 2206 aws_secret_access_key: t.Optional[str] = None 2207 role_arn: t.Optional[str] = None 2208 role_session_name: t.Optional[str] = None 2209 region_name: t.Optional[str] = None 2210 work_group: t.Optional[str] = None 2211 s3_staging_dir: t.Optional[str] = None 2212 schema_name: t.Optional[str] = None 2213 catalog_name: t.Optional[str] = None 2214 2215 # SQLMesh options 2216 s3_warehouse_location: t.Optional[str] = None 2217 concurrent_tasks: int = 4 2218 register_comments: t.Literal[False] = ( 2219 False # because Athena doesnt support comments in most cases 2220 ) 2221 pre_ping: t.Literal[False] = False 2222 2223 type_: t.Literal["athena"] = Field(alias="type", default="athena") 2224 DIALECT: t.ClassVar[t.Literal["athena"]] = "athena" 2225 DISPLAY_NAME: t.ClassVar[t.Literal["Athena"]] = "Athena" 2226 DISPLAY_ORDER: t.ClassVar[t.Literal[15]] = 15 2227 2228 _engine_import_validator = _get_engine_import_validator("pyathena", "athena") 2229 2230 @model_validator(mode="after") 2231 def _root_validator(self) -> Self: 2232 work_group = self.work_group 2233 s3_staging_dir = self.s3_staging_dir 2234 s3_warehouse_location = self.s3_warehouse_location 2235 2236 if not work_group and not s3_staging_dir: 2237 raise ConfigError("At least one of work_group or s3_staging_dir must be set") 2238 2239 if s3_staging_dir: 2240 self.s3_staging_dir = validate_s3_uri(s3_staging_dir, base=True, error_type=ConfigError) 2241 2242 if s3_warehouse_location: 2243 self.s3_warehouse_location = validate_s3_uri( 2244 s3_warehouse_location, base=True, error_type=ConfigError 2245 ) 2246 2247 return self 2248 2249 @property 2250 def _connection_kwargs_keys(self) -> t.Set[str]: 2251 return { 2252 "aws_access_key_id", 2253 "aws_secret_access_key", 2254 "role_arn", 2255 "role_session_name", 2256 "region_name", 2257 "work_group", 2258 "s3_staging_dir", 2259 "schema_name", 2260 "catalog_name", 2261 } 2262 2263 @property 2264 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2265 return engine_adapter.AthenaEngineAdapter 2266 2267 @property 2268 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 2269 return {"s3_warehouse_location": self.s3_warehouse_location} 2270 2271 @property 2272 def _connection_factory(self) -> t.Callable: 2273 from pyathena import connect # type: ignore 2274 2275 return connect 2276 2277 def get_catalog(self) -> t.Optional[str]: 2278 return self.catalog_name
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
2281class RisingwaveConnectionConfig(ConnectionConfig): 2282 host: str 2283 user: str 2284 password: t.Optional[str] = None 2285 port: int 2286 database: str 2287 role: t.Optional[str] = None 2288 sslmode: t.Optional[str] = None 2289 2290 concurrent_tasks: int = 4 2291 register_comments: bool = True 2292 pre_ping: bool = True 2293 2294 type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave") 2295 DIALECT: t.ClassVar[t.Literal["risingwave"]] = "risingwave" 2296 DISPLAY_NAME: t.ClassVar[t.Literal["RisingWave"]] = "RisingWave" 2297 DISPLAY_ORDER: t.ClassVar[t.Literal[16]] = 16 2298 2299 _engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave") 2300 2301 @property 2302 def _connection_kwargs_keys(self) -> t.Set[str]: 2303 return { 2304 "host", 2305 "user", 2306 "password", 2307 "port", 2308 "database", 2309 "role", 2310 "sslmode", 2311 } 2312 2313 @property 2314 def _engine_adapter(self) -> t.Type[EngineAdapter]: 2315 return engine_adapter.RisingwaveEngineAdapter 2316 2317 @property 2318 def _connection_factory(self) -> t.Callable: 2319 from psycopg2 import connect 2320 2321 return connect 2322 2323 @property 2324 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 2325 def init(cursor: t.Any) -> None: 2326 sql = "SET RW_IMPLICIT_FLUSH TO true;" 2327 cursor.execute(sql) 2328 2329 return init
Helper class that provides a standard way to create an ABC using inheritance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
2364def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig: 2365 if "type" not in v: 2366 raise ConfigError("Missing connection type.") 2367 2368 connection_type = v["type"] 2369 if connection_type not in CONNECTION_CONFIG_TO_TYPE: 2370 raise ConfigError(f"Unknown connection type '{connection_type}'.") 2371 2372 return CONNECTION_CONFIG_TO_TYPE[connection_type](**v)
2375def _connection_config_validator( 2376 cls: t.Type, v: ConnectionConfig | t.Dict[str, t.Any] | None 2377) -> ConnectionConfig | None: 2378 if v is None or isinstance(v, ConnectionConfig): 2379 return v 2380 2381 check_config_and_vars_msg = "\n\nVerify your config.yaml and environment variables." 2382 2383 try: 2384 return parse_connection_config(v) 2385 except pydantic.ValidationError as e: 2386 raise ConfigError( 2387 validation_error_message(e, f"Invalid '{v['type']}' connection config:") 2388 + check_config_and_vars_msg 2389 ) 2390 except ConfigError as e: 2391 raise ConfigError(str(e) + check_config_and_vars_msg)
Wrap a classmethod, staticmethod, property or unbound function and act as a descriptor that allows us to detect decorated items from the class' attributes.
This class' __get__ returns the wrapped item's __get__ result, which makes it transparent for classmethods and staticmethods.
Attributes:
- wrapped: The decorator that has to be wrapped.
- decorator_info: The decorator info.
- shim: A wrapper function to wrap V1 style function.