sqlmesh.core.config.connection
1from __future__ import annotations 2 3import abc 4import base64 5import logging 6import os 7import pathlib 8import sys 9import typing as t 10from enum import Enum 11from functools import partial 12 13from pydantic import Field 14from sqlglot import exp 15from sqlglot.helper import subclasses 16 17from sqlmesh.core import engine_adapter 18from sqlmesh.core.config.base import BaseConfig 19from sqlmesh.core.config.common import ( 20 concurrent_tasks_validator, 21 http_headers_validator, 22) 23from sqlmesh.core.engine_adapter import EngineAdapter 24from sqlmesh.utils.errors import ConfigError 25from sqlmesh.utils.pydantic import ( 26 PYDANTIC_MAJOR_VERSION, 27 field_validator, 28 model_validator, 29 model_validator_v1_args, 30) 31 32if sys.version_info >= (3, 9): 33 from typing import Literal 34else: 35 from typing_extensions import Literal 36 37 38logger = logging.getLogger(__name__) 39 40RECOMMENDED_STATE_SYNC_ENGINES = {"postgres", "gcp_postgres", "mysql", "duckdb"} 41 42 43class ConnectionConfig(abc.ABC, BaseConfig): 44 type_: str 45 concurrent_tasks: int 46 register_comments: bool 47 48 @property 49 @abc.abstractmethod 50 def _connection_kwargs_keys(self) -> t.Set[str]: 51 """keywords that should be passed into the connection""" 52 53 @property 54 @abc.abstractmethod 55 def _engine_adapter(self) -> t.Type[EngineAdapter]: 56 """The engine adapter for this connection""" 57 58 @property 59 @abc.abstractmethod 60 def _connection_factory(self) -> t.Callable: 61 """A function that is called to return a connection object for the given Engine Adapter""" 62 63 @property 64 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 65 """The static connection kwargs for this connection""" 66 return {} 67 68 @property 69 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 70 """kwargs that are for execution config only""" 71 return {} 72 73 @property 74 def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]: 75 """Key-value arguments that will be passed during cursor construction.""" 76 return None 77 78 @property 79 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 80 """A function that is called to initialize the cursor""" 81 return None 82 83 @property 84 def is_recommended_for_state_sync(self) -> bool: 85 """Whether this connection is recommended for being used as a state sync for production state syncs""" 86 return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES 87 88 @property 89 def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]: 90 """A function that is called to return a connection object for the given Engine Adapter""" 91 return partial( 92 self._connection_factory, 93 **{ 94 **self._static_connection_kwargs, 95 **{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys}, 96 }, 97 ) 98 99 def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter: 100 """Returns a new instance of the Engine Adapter.""" 101 return self._engine_adapter( 102 self._connection_factory_with_kwargs, 103 multithreaded=self.concurrent_tasks > 1, 104 cursor_kwargs=self._cursor_kwargs, 105 default_catalog=self.get_catalog(), 106 cursor_init=self._cursor_init, 107 register_comments=register_comments_override or self.register_comments, 108 **self._extra_engine_config, 109 ) 110 111 def get_catalog(self) -> t.Optional[str]: 112 """The catalog for this connection""" 113 if hasattr(self, "catalog"): 114 return self.catalog 115 if hasattr(self, "database"): 116 return self.database 117 if hasattr(self, "db"): 118 return self.db 119 return None 120 121 122class BaseDuckDBConnectionConfig(ConnectionConfig): 123 """Common configuration for the DuckDB-based connections. 124 125 Args: 126 extensions: A list of autoloadable extensions to load. 127 connector_config: A dictionary of configuration to pass into the duckdb connector. 128 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 129 register_comments: Whether or not to register model comments with the SQL engine. 130 """ 131 132 extensions: t.List[str] = [] 133 connector_config: t.Dict[str, t.Any] = {} 134 135 concurrent_tasks: Literal[1] = 1 136 register_comments: bool = True 137 138 @property 139 def _engine_adapter(self) -> t.Type[EngineAdapter]: 140 return engine_adapter.DuckDBEngineAdapter 141 142 @property 143 def _connection_factory(self) -> t.Callable: 144 import duckdb 145 146 return duckdb.connect 147 148 @property 149 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 150 """A function that is called to initialize the cursor""" 151 import duckdb 152 from duckdb import BinderException 153 154 def init(cursor: duckdb.DuckDBPyConnection) -> None: 155 for extension in self.extensions: 156 try: 157 cursor.execute(f"INSTALL {extension}") 158 cursor.execute(f"LOAD {extension}") 159 except Exception as e: 160 raise ConfigError(f"Failed to load extension {extension}: {e}") 161 162 for field, setting in self.connector_config.items(): 163 try: 164 cursor.execute(f"SET {field} = '{setting}'") 165 except Exception as e: 166 raise ConfigError(f"Failed to set connector config {field} to {setting}: {e}") 167 168 for i, (alias, path_options) in enumerate( 169 (getattr(self, "catalogs", None) or {}).items() 170 ): 171 # we parse_identifier and generate to ensure that `alias` has exactly one set of quotes 172 # regardless of whether it comes in quoted or not 173 alias = exp.parse_identifier(alias, dialect="duckdb").sql( 174 identify=True, dialect="duckdb" 175 ) 176 try: 177 query = ( 178 path_options.to_sql(alias) 179 if isinstance(path_options, DuckDBAttachOptions) 180 else f"ATTACH '{path_options}' AS {alias}" 181 ) 182 cursor.execute(query) 183 except BinderException as e: 184 # If a user tries to create a catalog pointing at `:memory:` and with the name `memory` 185 # then we don't want to raise since this happens by default. They are just doing this to 186 # set it as the default catalog. 187 if not ( 188 'database with name "memory" already exists' in str(e) 189 and path_options == ":memory:" 190 ): 191 raise e 192 if i == 0 and not getattr(self, "database", None): 193 cursor.execute(f"USE {alias}") 194 195 return init 196 197 198class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig): 199 """Configuration for the MotherDuck connection. 200 201 Args: 202 database: The database name. 203 token: The optional MotherDuck token. If not specified, the user will be prompted to login with their web browser. 204 """ 205 206 database: str 207 token: t.Optional[str] = None 208 209 type_: Literal["motherduck"] = Field(alias="type", default="motherduck") 210 211 @property 212 def _connection_kwargs_keys(self) -> t.Set[str]: 213 return set() 214 215 @property 216 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 217 """kwargs that are for execution config only""" 218 connection_str = f"md:{self.database}" 219 if self.token: 220 connection_str += f"?motherduck_token={self.token}" 221 return {"database": connection_str} 222 223 224class DuckDBAttachOptions(BaseConfig): 225 type: str 226 path: str 227 read_only: bool = False 228 229 def to_sql(self, alias: str) -> str: 230 options = [] 231 # 'duckdb' is actually not a supported type, but we'd like to allow it for 232 # fully qualified attach options or integration testing, similar to duckdb-dbt 233 if self.type != "duckdb": 234 options.append(f"TYPE {self.type.upper()}") 235 if self.read_only: 236 options.append("READ_ONLY") 237 options_sql = f" ({', '.join(options)})" if options else "" 238 return f"ATTACH '{self.path}' AS {alias}{options_sql}" 239 240 241class DuckDBConnectionConfig(BaseDuckDBConnectionConfig): 242 """Configuration for the DuckDB connection. 243 244 Args: 245 database: The optional database name. If not specified, the in-memory database will be used. 246 catalogs: Key is the name of the catalog and value is the path. 247 """ 248 249 database: t.Optional[str] = None 250 catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None 251 252 type_: Literal["duckdb"] = Field(alias="type", default="duckdb") 253 254 _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} 255 256 @model_validator(mode="before") 257 @model_validator_v1_args 258 def _validate_database_catalogs( 259 cls, values: t.Dict[str, t.Optional[str]] 260 ) -> t.Dict[str, t.Optional[str]]: 261 if values.get("database") and values.get("catalogs"): 262 raise ConfigError( 263 "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog" 264 ) 265 return values 266 267 @property 268 def _connection_kwargs_keys(self) -> t.Set[str]: 269 return {"database"} 270 271 def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter: 272 """Checks if another engine adapter has already been created that shares a catalog that points to the same data 273 file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration 274 associated with the new adapter will be ignored.""" 275 data_files = set((self.catalogs or {}).values()) 276 if self.database: 277 data_files.add(self.database) 278 data_files.discard(":memory:") 279 for data_file in data_files: 280 key = data_file if isinstance(data_file, str) else data_file.path 281 if adapter := DuckDBConnectionConfig._data_file_to_adapter.get(key): 282 logger.info(f"Using existing DuckDB adapter due to overlapping data file: {key}") 283 return adapter 284 285 if data_files: 286 logger.info(f"Creating new DuckDB adapter for data files: {data_files}") 287 else: 288 logger.info("Creating new DuckDB adapter for in-memory database") 289 adapter = super().create_engine_adapter(register_comments_override) 290 for data_file in data_files: 291 key = data_file if isinstance(data_file, str) else data_file.path 292 DuckDBConnectionConfig._data_file_to_adapter[key] = adapter 293 return adapter 294 295 def get_catalog(self) -> t.Optional[str]: 296 if self.database: 297 # Remove `:` from the database name in order to handle if `:memory:` is passed in 298 return pathlib.Path(self.database.replace(":", "")).stem 299 if self.catalogs: 300 return list(self.catalogs)[0] 301 return None 302 303 304class SnowflakeConnectionConfig(ConnectionConfig): 305 """Configuration for the Snowflake connection. 306 307 Args: 308 account: The Snowflake account name. 309 user: The Snowflake username. 310 password: The Snowflake password. 311 warehouse: The optional warehouse name. 312 database: The optional database name. 313 role: The optional role name. 314 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 315 authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake"). 316 Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183 317 token: The optional oauth access token to use for authentication when authenticator is set to "oauth". 318 private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation 319 private_key_path: The optional path to the private key to use for authentication. This would be used instead of `private_key`. 320 private_key_passphrase: The optional passphrase to use to decrypt `private_key` or `private_key_path`. Keys can be created without encryption so only provide this if needed. 321 register_comments: Whether or not to register model comments with the SQL engine. 322 """ 323 324 account: str 325 user: t.Optional[str] = None 326 password: t.Optional[str] = None 327 warehouse: t.Optional[str] = None 328 database: t.Optional[str] = None 329 role: t.Optional[str] = None 330 authenticator: t.Optional[str] = None 331 token: t.Optional[str] = None 332 333 # Private Key Auth 334 private_key: t.Optional[t.Union[str, bytes]] = None 335 private_key_path: t.Optional[str] = None 336 private_key_passphrase: t.Optional[str] = None 337 338 concurrent_tasks: int = 4 339 register_comments: bool = True 340 341 type_: Literal["snowflake"] = Field(alias="type", default="snowflake") 342 343 _concurrent_tasks_validator = concurrent_tasks_validator 344 345 @model_validator(mode="before") 346 @model_validator_v1_args 347 def _validate_authenticator( 348 cls, values: t.Dict[str, t.Optional[str]] 349 ) -> t.Dict[str, t.Optional[str]]: 350 from snowflake.connector.network import ( 351 DEFAULT_AUTHENTICATOR, 352 OAUTH_AUTHENTICATOR, 353 ) 354 355 auth = values.get("authenticator") 356 auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR 357 user = values.get("user") 358 password = values.get("password") 359 values["private_key"] = cls._get_private_key(values, auth) # type: ignore 360 if ( 361 auth == DEFAULT_AUTHENTICATOR 362 and not values.get("private_key") 363 and (not user or not password) 364 ): 365 raise ConfigError("User and password must be provided if using default authentication") 366 if auth == OAUTH_AUTHENTICATOR and not values.get("token"): 367 raise ConfigError("Token must be provided if using oauth authentication") 368 return values 369 370 @classmethod 371 def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]: 372 """ 373 source: https://github.com/dbt-labs/dbt-snowflake/blob/0374b4ec948982f2ac8ec0c95d53d672ad19e09c/dbt/adapters/snowflake/connections.py#L247C5-L285C1 374 375 Overall code change: Use local variables instead of class attributes + Validation 376 """ 377 # Start custom code 378 from cryptography.hazmat.backends import default_backend 379 from cryptography.hazmat.primitives import serialization 380 from snowflake.connector.network import ( 381 DEFAULT_AUTHENTICATOR, 382 KEY_PAIR_AUTHENTICATOR, 383 ) 384 385 private_key = values.get("private_key") 386 private_key_path = values.get("private_key_path") 387 private_key_passphrase = values.get("private_key_passphrase") 388 user = values.get("user") 389 password = values.get("password") 390 auth = auth if auth and auth != DEFAULT_AUTHENTICATOR else KEY_PAIR_AUTHENTICATOR 391 392 if not private_key and not private_key_path: 393 return None 394 if private_key and private_key_path: 395 raise ConfigError("Cannot specify both `private_key` and `private_key_path`") 396 if auth != KEY_PAIR_AUTHENTICATOR: 397 raise ConfigError( 398 f"Private key or private key path can only be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 399 ) 400 if not user: 401 raise ConfigError( 402 f"User must be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 403 ) 404 if password: 405 raise ConfigError( 406 f"Password cannot be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 407 ) 408 409 if isinstance(private_key, bytes): 410 return private_key 411 # End Custom Code 412 413 if private_key_passphrase: 414 encoded_passphrase = private_key_passphrase.encode() 415 else: 416 encoded_passphrase = None 417 418 if private_key: 419 if private_key.startswith("-"): 420 p_key = serialization.load_pem_private_key( 421 data=bytes(private_key, "utf-8"), 422 password=encoded_passphrase, 423 backend=default_backend(), 424 ) 425 426 else: 427 p_key = serialization.load_der_private_key( 428 data=base64.b64decode(private_key), 429 password=encoded_passphrase, 430 backend=default_backend(), 431 ) 432 433 elif private_key_path: 434 with open(private_key_path, "rb") as key: 435 p_key = serialization.load_pem_private_key( 436 key.read(), password=encoded_passphrase, backend=default_backend() 437 ) 438 else: 439 return None 440 441 return p_key.private_bytes( 442 encoding=serialization.Encoding.DER, 443 format=serialization.PrivateFormat.PKCS8, 444 encryption_algorithm=serialization.NoEncryption(), 445 ) 446 447 @property 448 def _connection_kwargs_keys(self) -> t.Set[str]: 449 return { 450 "user", 451 "password", 452 "account", 453 "warehouse", 454 "database", 455 "role", 456 "authenticator", 457 "token", 458 "private_key", 459 } 460 461 @property 462 def _engine_adapter(self) -> t.Type[EngineAdapter]: 463 return engine_adapter.SnowflakeEngineAdapter 464 465 @property 466 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 467 return {"autocommit": False} 468 469 @property 470 def _connection_factory(self) -> t.Callable: 471 from snowflake import connector 472 473 return connector.connect 474 475 476class DatabricksConnectionConfig(ConnectionConfig): 477 """ 478 Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations 479 480 Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39 481 Args: 482 server_hostname: Databricks instance host name. 483 http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) 484 or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) 485 access_token: Http Bearer access token, e.g. Databricks Personal Access Token. 486 catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in 487 the Databricks cluster (most likely `hive_metastore`). 488 http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request 489 session_configuration: An optional dictionary of Spark session parameters. 490 Execute the SQL command `SET -v` to get a full list of available commands. 491 databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect. 492 Defaults to the `server_hostname` value. 493 databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect. 494 Defaults to the `access_token` value. 495 databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect. 496 Defaults to deriving the cluster id from the `http_path` value. 497 force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector. 498 disable_databricks_connect: Even if databricks connect is installed, do not use it. 499 """ 500 501 server_hostname: t.Optional[str] = None 502 http_path: t.Optional[str] = None 503 access_token: t.Optional[str] = None 504 catalog: t.Optional[str] = None 505 http_headers: t.Optional[t.List[t.Tuple[str, str]]] = None 506 session_configuration: t.Optional[t.Dict[str, t.Any]] = None 507 databricks_connect_server_hostname: t.Optional[str] = None 508 databricks_connect_access_token: t.Optional[str] = None 509 databricks_connect_cluster_id: t.Optional[str] = None 510 force_databricks_connect: bool = False 511 disable_databricks_connect: bool = False 512 513 concurrent_tasks: int = 1 514 register_comments: bool = True 515 516 type_: Literal["databricks"] = Field(alias="type", default="databricks") 517 518 _concurrent_tasks_validator = concurrent_tasks_validator 519 _http_headers_validator = http_headers_validator 520 521 @model_validator(mode="before") 522 @model_validator_v1_args 523 def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 524 from sqlmesh import RuntimeEnv 525 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 526 527 runtime_env = RuntimeEnv.get() 528 529 if runtime_env.is_databricks: 530 return values 531 server_hostname, http_path, access_token = ( 532 values.get("server_hostname"), 533 values.get("http_path"), 534 values.get("access_token"), 535 ) 536 if not server_hostname or not http_path or not access_token: 537 raise ValueError( 538 "`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook" 539 ) 540 if DatabricksEngineAdapter.can_access_spark_session: 541 if not values.get("databricks_connect_server_hostname"): 542 values["databricks_connect_server_hostname"] = f"https://{server_hostname}" 543 if not values.get("databricks_connect_access_token"): 544 values["databricks_connect_access_token"] = access_token 545 if not values.get("databricks_connect_cluster_id"): 546 values["databricks_connect_cluster_id"] = http_path.split("/")[-1] 547 if not values.get("session_configuration"): 548 values["session_configuration"] = {} 549 values["session_configuration"]["spark.sql.sources.partitionOverwriteMode"] = "dynamic" 550 return values 551 552 @property 553 def _connection_kwargs_keys(self) -> t.Set[str]: 554 if self.use_spark_session_only: 555 return set() 556 return { 557 "server_hostname", 558 "http_path", 559 "access_token", 560 "http_headers", 561 "session_configuration", 562 "catalog", 563 } 564 565 @property 566 def _engine_adapter(self) -> t.Type[engine_adapter.DatabricksEngineAdapter]: 567 return engine_adapter.DatabricksEngineAdapter 568 569 @property 570 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 571 return { 572 k: v 573 for k, v in self.dict().items() 574 if k.startswith("databricks_connect_") or k in ("catalog", "disable_databricks_connect") 575 } 576 577 @property 578 def use_spark_session_only(self) -> bool: 579 from sqlmesh import RuntimeEnv 580 581 return RuntimeEnv.get().is_databricks or self.force_databricks_connect 582 583 @property 584 def _connection_factory(self) -> t.Callable: 585 if self.use_spark_session_only: 586 from sqlmesh.engines.spark.db_api.spark_session import connection 587 588 return connection 589 590 from databricks import sql 591 592 return sql.connect 593 594 @property 595 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 596 from sqlmesh import RuntimeEnv 597 598 if not self.use_spark_session_only: 599 return {} 600 601 if RuntimeEnv.get().is_databricks: 602 from pyspark.sql import SparkSession 603 604 return dict( 605 spark=SparkSession.getActiveSession(), 606 catalog=self.catalog, 607 ) 608 609 from databricks.connect import DatabricksSession 610 611 return dict( 612 spark=DatabricksSession.builder.remote( 613 host=self.databricks_connect_server_hostname, 614 token=self.databricks_connect_access_token, 615 cluster_id=self.databricks_connect_cluster_id, 616 ).getOrCreate(), 617 catalog=self.catalog, 618 ) 619 620 621class BigQueryConnectionMethod(str, Enum): 622 OAUTH = "oauth" 623 OAUTH_SECRETS = "oauth-secrets" 624 SERVICE_ACCOUNT = "service-account" 625 SERVICE_ACCOUNT_JSON = "service-account-json" 626 627 628class BigQueryPriority(str, Enum): 629 BATCH = "batch" 630 INTERACTIVE = "interactive" 631 632 @property 633 def is_batch(self) -> bool: 634 return self == self.BATCH 635 636 @property 637 def is_interactive(self) -> bool: 638 return self == self.INTERACTIVE 639 640 @property 641 def bigquery_constant(self) -> str: 642 from google.cloud.bigquery import QueryPriority 643 644 if self.is_batch: 645 return QueryPriority.BATCH 646 return QueryPriority.INTERACTIVE 647 648 649class BigQueryConnectionConfig(ConnectionConfig): 650 """ 651 BigQuery Connection Configuration. 652 """ 653 654 method: BigQueryConnectionMethod = BigQueryConnectionMethod.OAUTH 655 656 project: t.Optional[str] = None 657 execution_project: t.Optional[str] = None 658 location: t.Optional[str] = None 659 # Keyfile Auth 660 keyfile: t.Optional[str] = None 661 keyfile_json: t.Optional[t.Dict[str, t.Any]] = None 662 # Oath Secret Auth 663 token: t.Optional[str] = None 664 refresh_token: t.Optional[str] = None 665 client_id: t.Optional[str] = None 666 client_secret: t.Optional[str] = None 667 token_uri: t.Optional[str] = None 668 scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/bigquery",) 669 job_creation_timeout_seconds: t.Optional[int] = None 670 # Extra Engine Config 671 job_execution_timeout_seconds: t.Optional[int] = None 672 job_retries: t.Optional[int] = 1 673 job_retry_deadline_seconds: t.Optional[int] = None 674 priority: t.Optional[BigQueryPriority] = None 675 maximum_bytes_billed: t.Optional[int] = None 676 677 concurrent_tasks: int = 1 678 register_comments: bool = True 679 680 type_: Literal["bigquery"] = Field(alias="type", default="bigquery") 681 682 @property 683 def _connection_kwargs_keys(self) -> t.Set[str]: 684 return set() 685 686 @property 687 def _engine_adapter(self) -> t.Type[EngineAdapter]: 688 return engine_adapter.BigQueryEngineAdapter 689 690 @property 691 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 692 """The static connection kwargs for this connection""" 693 import google.auth 694 from google.api_core import client_info 695 from google.oauth2 import credentials, service_account 696 697 if self.method == BigQueryConnectionMethod.OAUTH: 698 creds, _ = google.auth.default(scopes=self.scopes) 699 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT: 700 creds = service_account.Credentials.from_service_account_file( 701 self.keyfile, scopes=self.scopes 702 ) 703 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: 704 creds = service_account.Credentials.from_service_account_info( 705 self.keyfile_json, scopes=self.scopes 706 ) 707 elif self.method == BigQueryConnectionMethod.OAUTH_SECRETS: 708 creds = credentials.Credentials( 709 token=self.token, 710 refresh_token=self.refresh_token, 711 client_id=self.client_id, 712 client_secret=self.client_secret, 713 token_uri=self.token_uri, 714 scopes=self.scopes, 715 ) 716 else: 717 raise ConfigError("Invalid BigQuery Connection Method") 718 client = google.cloud.bigquery.Client( 719 project=self.execution_project or self.project, 720 credentials=creds, 721 location=self.location, 722 client_info=client_info.ClientInfo(user_agent="sqlmesh"), 723 ) 724 725 return { 726 "client": client, 727 } 728 729 @property 730 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 731 return { 732 k: v 733 for k, v in self.dict().items() 734 if k 735 in { 736 "job_creation_timeout_seconds", 737 "job_execution_timeout_seconds", 738 "job_retries", 739 "job_retry_deadline_seconds", 740 "priority", 741 "maximum_bytes_billed", 742 } 743 } 744 745 @property 746 def _connection_factory(self) -> t.Callable: 747 from google.cloud.bigquery.dbapi import connect 748 749 return connect 750 751 def get_catalog(self) -> t.Optional[str]: 752 return self.project 753 754 755class GCPPostgresConnectionConfig(ConnectionConfig): 756 """ 757 Postgres Connection Configuration for GCP. 758 759 Args: 760 instance_connection_string: Connection name for the postgres instance. 761 user: Postgres or IAM user's name 762 password: The postgres user's password. Only needed when the user is a postgres user. 763 enable_iam_auth: Set to True when user is an IAM user. 764 db: Name of the db to connect to. 765 """ 766 767 instance_connection_string: str 768 user: str 769 password: t.Optional[str] = None 770 enable_iam_auth: t.Optional[bool] = None 771 db: str 772 timeout: t.Optional[int] = None 773 774 driver: str = "pg8000" 775 type_: Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres") 776 concurrent_tasks: int = 4 777 register_comments: bool = True 778 779 @model_validator(mode="before") 780 @model_validator_v1_args 781 def _validate_auth_method( 782 cls, values: t.Dict[str, t.Optional[str]] 783 ) -> t.Dict[str, t.Optional[str]]: 784 password = values.get("password") 785 enable_iam_auth = values.get("enable_iam_auth") 786 if password and enable_iam_auth: 787 raise ConfigError( 788 "Invalid GCP Postgres connection configuration - both password and" 789 " enable_iam_auth set. Use password when connecting to a postgres" 790 " user and enable_iam_auth 'True' when connecting to an IAM user." 791 ) 792 if not password and not enable_iam_auth: 793 raise ConfigError( 794 "GCP Postgres connection configuration requires either password set" 795 " for a postgres user account or enable_iam_auth set to 'True'" 796 " for an IAM user account." 797 ) 798 return values 799 800 @property 801 def _connection_kwargs_keys(self) -> t.Set[str]: 802 return { 803 "instance_connection_string", 804 "driver", 805 "user", 806 "password", 807 "db", 808 "enable_iam_auth", 809 "timeout", 810 } 811 812 @property 813 def _engine_adapter(self) -> t.Type[EngineAdapter]: 814 return engine_adapter.PostgresEngineAdapter 815 816 @property 817 def _connection_factory(self) -> t.Callable: 818 from google.cloud.sql.connector import Connector 819 820 return Connector().connect 821 822 823class RedshiftConnectionConfig(ConnectionConfig): 824 """ 825 Redshift Connection Configuration. 826 827 Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146 828 Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported. 829 830 Args: 831 user: The username to use for authentication with the Amazon Redshift cluster. 832 password: The password to use for authentication with the Amazon Redshift cluster. 833 database: The name of the database instance to connect to. 834 host: The hostname of the Amazon Redshift cluster. 835 port: The port number of the Amazon Redshift cluster. Default value is 5439. 836 source_address: No description provided 837 unix_sock: No description provided 838 ssl: Is SSL enabled. Default value is ``True``. SSL must be enabled when authenticating using IAM. 839 sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported. 840 timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout. 841 tcp_keepalive: Is `TCP keepalive <https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive>`_ used. The default value is ``True``. 842 application_name: Sets the application name. The default value is None. 843 preferred_role: The IAM role preferred for the current connection. 844 principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy. 845 credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster. 846 region: The AWS region where the Amazon Redshift cluster is located. 847 cluster_identifier: The cluster identifier of the Amazon Redshift cluster. 848 iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP. 849 is_serverless: Redshift end-point is serverless or provisional. Default value false. 850 serverless_acct_id: The account ID of the serverless. Default value None 851 serverless_work_group: The name of work group for serverless end point. Default value None. 852 """ 853 854 user: t.Optional[str] = None 855 password: t.Optional[str] = None 856 database: t.Optional[str] = None 857 host: t.Optional[str] = None 858 port: t.Optional[int] = None 859 source_address: t.Optional[str] = None 860 unix_sock: t.Optional[str] = None 861 ssl: t.Optional[bool] = None 862 sslmode: t.Optional[str] = None 863 timeout: t.Optional[int] = None 864 tcp_keepalive: t.Optional[bool] = None 865 application_name: t.Optional[str] = None 866 preferred_role: t.Optional[str] = None 867 principal_arn: t.Optional[str] = None 868 credentials_provider: t.Optional[str] = None 869 region: t.Optional[str] = None 870 cluster_identifier: t.Optional[str] = None 871 iam: t.Optional[bool] = None 872 is_serverless: t.Optional[bool] = None 873 serverless_acct_id: t.Optional[str] = None 874 serverless_work_group: t.Optional[str] = None 875 876 concurrent_tasks: int = 4 877 register_comments: bool = True 878 879 type_: Literal["redshift"] = Field(alias="type", default="redshift") 880 881 @property 882 def _connection_kwargs_keys(self) -> t.Set[str]: 883 return { 884 "user", 885 "password", 886 "database", 887 "host", 888 "port", 889 "source_address", 890 "unix_sock", 891 "ssl", 892 "sslmode", 893 "timeout", 894 "tcp_keepalive", 895 "application_name", 896 "preferred_role", 897 "principal_arn", 898 "credentials_provider", 899 "region", 900 "cluster_identifier", 901 "iam", 902 "is_serverless", 903 "serverless_acct_id", 904 "serverless_work_group", 905 } 906 907 @property 908 def _engine_adapter(self) -> t.Type[EngineAdapter]: 909 return engine_adapter.RedshiftEngineAdapter 910 911 @property 912 def _connection_factory(self) -> t.Callable: 913 from redshift_connector import connect 914 915 return connect 916 917 918class PostgresConnectionConfig(ConnectionConfig): 919 host: str 920 user: str 921 password: str 922 port: int 923 database: str 924 keepalives_idle: t.Optional[int] = None 925 connect_timeout: int = 10 926 role: t.Optional[str] = None 927 sslmode: t.Optional[str] = None 928 929 concurrent_tasks: int = 4 930 register_comments: bool = True 931 932 type_: Literal["postgres"] = Field(alias="type", default="postgres") 933 934 @property 935 def _connection_kwargs_keys(self) -> t.Set[str]: 936 return { 937 "host", 938 "user", 939 "password", 940 "port", 941 "database", 942 "keepalives_idle", 943 "connect_timeout", 944 "role", 945 "sslmode", 946 } 947 948 @property 949 def _engine_adapter(self) -> t.Type[EngineAdapter]: 950 return engine_adapter.PostgresEngineAdapter 951 952 @property 953 def _connection_factory(self) -> t.Callable: 954 from psycopg2 import connect 955 956 return connect 957 958 959class MySQLConnectionConfig(ConnectionConfig): 960 host: str 961 user: str 962 password: str 963 port: t.Optional[int] = None 964 charset: t.Optional[str] = None 965 ssl_disabled: t.Optional[bool] = None 966 967 concurrent_tasks: int = 4 968 register_comments: bool = True 969 970 type_: Literal["mysql"] = Field(alias="type", default="mysql") 971 972 @property 973 def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]: 974 """Key-value arguments that will be passed during cursor construction.""" 975 return {"buffered": True} 976 977 @property 978 def _connection_kwargs_keys(self) -> t.Set[str]: 979 connection_keys = { 980 "host", 981 "user", 982 "password", 983 "port", 984 "database", 985 } 986 if self.port is not None: 987 connection_keys.add("port") 988 if self.charset is not None: 989 connection_keys.add("charset") 990 if self.ssl_disabled is not None: 991 connection_keys.add("ssl_disabled") 992 return connection_keys 993 994 @property 995 def _engine_adapter(self) -> t.Type[EngineAdapter]: 996 return engine_adapter.MySQLEngineAdapter 997 998 @property 999 def _connection_factory(self) -> t.Callable: 1000 from mysql.connector import connect 1001 1002 return connect 1003 1004 1005class MSSQLConnectionConfig(ConnectionConfig): 1006 host: str 1007 user: t.Optional[str] = None 1008 password: t.Optional[str] = None 1009 database: t.Optional[str] = "" 1010 timeout: t.Optional[int] = 0 1011 login_timeout: t.Optional[int] = 60 1012 charset: t.Optional[str] = "UTF-8" 1013 appname: t.Optional[str] = None 1014 port: t.Optional[int] = 1433 1015 conn_properties: t.Optional[t.Union[t.Iterable[str], str]] = None 1016 autocommit: t.Optional[bool] = False 1017 tds_version: t.Optional[str] = None 1018 1019 concurrent_tasks: int = 4 1020 register_comments: bool = True 1021 1022 type_: Literal["mssql"] = Field(alias="type", default="mssql") 1023 1024 @property 1025 def _connection_kwargs_keys(self) -> t.Set[str]: 1026 return { 1027 "host", 1028 "user", 1029 "password", 1030 "database", 1031 "timeout", 1032 "login_timeout", 1033 "charset", 1034 "appname", 1035 "port", 1036 "conn_properties", 1037 "autocommit", 1038 "tds_version", 1039 } 1040 1041 @property 1042 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1043 return engine_adapter.MSSQLEngineAdapter 1044 1045 @property 1046 def _connection_factory(self) -> t.Callable: 1047 import pymssql 1048 1049 return pymssql.connect 1050 1051 1052class SparkConnectionConfig(ConnectionConfig): 1053 """ 1054 Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks. 1055 """ 1056 1057 config_dir: t.Optional[str] = None 1058 catalog: t.Optional[str] = None 1059 config: t.Dict[str, t.Any] = {} 1060 1061 concurrent_tasks: int = 4 1062 register_comments: bool = True 1063 1064 type_: Literal["spark"] = Field(alias="type", default="spark") 1065 1066 @property 1067 def _connection_kwargs_keys(self) -> t.Set[str]: 1068 return { 1069 "catalog", 1070 } 1071 1072 @property 1073 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1074 return engine_adapter.SparkEngineAdapter 1075 1076 @property 1077 def _connection_factory(self) -> t.Callable: 1078 from sqlmesh.engines.spark.db_api.spark_session import connection 1079 1080 return connection 1081 1082 @property 1083 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1084 from pyspark.conf import SparkConf 1085 from pyspark.sql import SparkSession 1086 1087 spark_config = SparkConf() 1088 if self.config: 1089 for k, v in self.config.items(): 1090 spark_config.set(k, v) 1091 1092 if self.config_dir: 1093 os.environ["SPARK_CONF_DIR"] = self.config_dir 1094 return { 1095 "spark": SparkSession.builder.config(conf=spark_config) 1096 .enableHiveSupport() 1097 .getOrCreate(), 1098 } 1099 1100 1101class TrinoAuthenticationMethod(str, Enum): 1102 NO_AUTH = "no-auth" 1103 BASIC = "basic" 1104 LDAP = "ldap" 1105 KERBEROS = "kerberos" 1106 JWT = "jwt" 1107 CERTIFICATE = "certificate" 1108 OAUTH = "oauth" 1109 1110 @property 1111 def is_no_auth(self) -> bool: 1112 return self == self.NO_AUTH 1113 1114 @property 1115 def is_basic(self) -> bool: 1116 return self == self.BASIC 1117 1118 @property 1119 def is_ldap(self) -> bool: 1120 return self == self.LDAP 1121 1122 @property 1123 def is_kerberos(self) -> bool: 1124 return self == self.KERBEROS 1125 1126 @property 1127 def is_jwt(self) -> bool: 1128 return self == self.JWT 1129 1130 @property 1131 def is_certificate(self) -> bool: 1132 return self == self.CERTIFICATE 1133 1134 @property 1135 def is_oauth(self) -> bool: 1136 return self == self.OAUTH 1137 1138 1139class TrinoConnectionConfig(ConnectionConfig): 1140 method: TrinoAuthenticationMethod = TrinoAuthenticationMethod.NO_AUTH 1141 host: str 1142 user: str 1143 catalog: str 1144 port: t.Optional[int] = None 1145 http_scheme: Literal["http", "https"] = "https" 1146 # General Optional 1147 roles: t.Optional[t.Dict[str, str]] = None 1148 http_headers: t.Optional[t.Dict[str, str]] = None 1149 session_properties: t.Optional[t.Dict[str, str]] = None 1150 retries: int = 3 1151 timezone: t.Optional[str] = None 1152 # Basic/LDAP 1153 password: t.Optional[str] = None 1154 # LDAP 1155 impersonation_user: t.Optional[str] = None 1156 # Kerberos 1157 keytab: t.Optional[str] = None 1158 krb5_config: t.Optional[str] = None 1159 principal: t.Optional[str] = None 1160 service_name: str = "trino" 1161 hostname_override: t.Optional[str] = None 1162 mutual_authentication: bool = False 1163 force_preemptive: bool = False 1164 sanitize_mutual_error_response: bool = True 1165 delegate: bool = False 1166 # JWT 1167 jwt_token: t.Optional[str] = None 1168 # Certificate 1169 client_certificate: t.Optional[str] = None 1170 client_private_key: t.Optional[str] = None 1171 cert: t.Optional[str] = None 1172 1173 concurrent_tasks: int = 4 1174 register_comments: bool = True 1175 1176 type_: Literal["trino"] = Field(alias="type", default="trino") 1177 1178 @model_validator(mode="after") 1179 @model_validator_v1_args 1180 def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 1181 port = values.get("port") 1182 if ( 1183 values["http_scheme"] == "http" 1184 and not values["method"].is_no_auth 1185 and not values["method"].is_basic 1186 ): 1187 raise ConfigError("HTTP scheme can only be used with no-auth or basic method") 1188 if port is None: 1189 values["port"] = 80 if values["http_scheme"] == "http" else 443 1190 if (values["method"].is_ldap or values["method"].is_basic) and ( 1191 not values["password"] or not values["user"] 1192 ): 1193 raise ConfigError( 1194 f"Username and Password must be provided if using {values['method'].value} authentication" 1195 ) 1196 if values["method"].is_kerberos and ( 1197 not values["principal"] or not values["keytab"] or not values["krb5_config"] 1198 ): 1199 raise ConfigError( 1200 "Kerberos requires the following fields: principal, keytab, and krb5_config" 1201 ) 1202 if values["method"].is_jwt and not values["jwt_token"]: 1203 raise ConfigError("JWT requires `jwt_token` to be set") 1204 if values["method"].is_certificate and ( 1205 not values["cert"] 1206 or not values["client_certificate"] 1207 or not values["client_private_key"] 1208 ): 1209 raise ConfigError( 1210 "Certificate requires the following fields: cert, client_certificate, and client_private_key" 1211 ) 1212 return values 1213 1214 @property 1215 def _connection_kwargs_keys(self) -> t.Set[str]: 1216 kwargs = { 1217 "host", 1218 "port", 1219 "catalog", 1220 "roles", 1221 "http_scheme", 1222 "http_headers", 1223 "session_properties", 1224 "timezone", 1225 } 1226 return kwargs 1227 1228 @property 1229 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1230 return engine_adapter.TrinoEngineAdapter 1231 1232 @property 1233 def _connection_factory(self) -> t.Callable: 1234 from trino.dbapi import connect 1235 1236 return connect 1237 1238 @property 1239 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1240 from trino.auth import ( 1241 BasicAuthentication, 1242 CertificateAuthentication, 1243 JWTAuthentication, 1244 KerberosAuthentication, 1245 OAuth2Authentication, 1246 ) 1247 1248 if self.method.is_basic or self.method.is_ldap: 1249 auth = BasicAuthentication(self.user, self.password) 1250 elif self.method.is_kerberos: 1251 if self.keytab: 1252 os.environ["KRB5_CLIENT_KTNAME"] = self.keytab 1253 auth = KerberosAuthentication( 1254 config=self.krb5_config, 1255 service_name=self.service_name, 1256 principal=self.principal, 1257 mutual_authentication=self.mutual_authentication, 1258 ca_bundle=self.cert, 1259 force_preemptive=self.force_preemptive, 1260 hostname_override=self.hostname_override, 1261 sanitize_mutual_error_response=self.sanitize_mutual_error_response, 1262 delegate=self.delegate, 1263 ) 1264 elif self.method.is_oauth: 1265 auth = OAuth2Authentication() 1266 elif self.method.is_jwt: 1267 auth = JWTAuthentication(self.jwt_token) 1268 elif self.method.is_certificate: 1269 auth = CertificateAuthentication(self.client_certificate, self.client_private_key) 1270 else: 1271 auth = None 1272 1273 return { 1274 "auth": auth, 1275 "user": self.impersonation_user or self.user, 1276 "max_attempts": self.retries, 1277 "verify": self.cert, 1278 "source": "sqlmesh", 1279 } 1280 1281 1282CONNECTION_CONFIG_TO_TYPE = { 1283 # Map all subclasses of ConnectionConfig to the value of their `type_` field. 1284 tpe.all_field_infos()["type_"].default: tpe 1285 for tpe in subclasses( 1286 __name__, ConnectionConfig, exclude=(ConnectionConfig, BaseDuckDBConnectionConfig) 1287 ) 1288} 1289 1290 1291def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig: 1292 if "type" not in v: 1293 raise ConfigError("Missing connection type.") 1294 1295 connection_type = v["type"] 1296 if connection_type not in CONNECTION_CONFIG_TO_TYPE: 1297 raise ConfigError(f"Unknown connection type '{connection_type}'.") 1298 1299 return CONNECTION_CONFIG_TO_TYPE[connection_type](**v) 1300 1301 1302def _connection_config_validator( 1303 cls: t.Type, v: ConnectionConfig | t.Dict[str, t.Any] | None 1304) -> ConnectionConfig | None: 1305 if v is None or isinstance(v, ConnectionConfig): 1306 return v 1307 return parse_connection_config(v) 1308 1309 1310connection_config_validator = field_validator( 1311 "connection", 1312 "state_connection", 1313 "test_connection", 1314 "default_connection", 1315 "default_test_connection", 1316 mode="before", 1317 check_fields=False, 1318)(_connection_config_validator) 1319 1320 1321if t.TYPE_CHECKING: 1322 # TypeAlias hasn't been introduced until Python 3.10 which means that we can't use it 1323 # outside the TYPE_CHECKING guard. 1324 SerializableConnectionConfig: t.TypeAlias = ConnectionConfig # type: ignore 1325elif PYDANTIC_MAJOR_VERSION >= 2: 1326 import pydantic 1327 1328 # Workaround for https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing 1329 SerializableConnectionConfig = pydantic.SerializeAsAny[ConnectionConfig] # type: ignore 1330else: 1331 SerializableConnectionConfig = ConnectionConfig # type: ignore
44class ConnectionConfig(abc.ABC, BaseConfig): 45 type_: str 46 concurrent_tasks: int 47 register_comments: bool 48 49 @property 50 @abc.abstractmethod 51 def _connection_kwargs_keys(self) -> t.Set[str]: 52 """keywords that should be passed into the connection""" 53 54 @property 55 @abc.abstractmethod 56 def _engine_adapter(self) -> t.Type[EngineAdapter]: 57 """The engine adapter for this connection""" 58 59 @property 60 @abc.abstractmethod 61 def _connection_factory(self) -> t.Callable: 62 """A function that is called to return a connection object for the given Engine Adapter""" 63 64 @property 65 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 66 """The static connection kwargs for this connection""" 67 return {} 68 69 @property 70 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 71 """kwargs that are for execution config only""" 72 return {} 73 74 @property 75 def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]: 76 """Key-value arguments that will be passed during cursor construction.""" 77 return None 78 79 @property 80 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 81 """A function that is called to initialize the cursor""" 82 return None 83 84 @property 85 def is_recommended_for_state_sync(self) -> bool: 86 """Whether this connection is recommended for being used as a state sync for production state syncs""" 87 return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES 88 89 @property 90 def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]: 91 """A function that is called to return a connection object for the given Engine Adapter""" 92 return partial( 93 self._connection_factory, 94 **{ 95 **self._static_connection_kwargs, 96 **{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys}, 97 }, 98 ) 99 100 def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter: 101 """Returns a new instance of the Engine Adapter.""" 102 return self._engine_adapter( 103 self._connection_factory_with_kwargs, 104 multithreaded=self.concurrent_tasks > 1, 105 cursor_kwargs=self._cursor_kwargs, 106 default_catalog=self.get_catalog(), 107 cursor_init=self._cursor_init, 108 register_comments=register_comments_override or self.register_comments, 109 **self._extra_engine_config, 110 ) 111 112 def get_catalog(self) -> t.Optional[str]: 113 """The catalog for this connection""" 114 if hasattr(self, "catalog"): 115 return self.catalog 116 if hasattr(self, "database"): 117 return self.database 118 if hasattr(self, "db"): 119 return self.db 120 return None
Helper class that provides a standard way to create an ABC using inheritance.
Whether this connection is recommended for being used as a state sync for production state syncs
100 def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter: 101 """Returns a new instance of the Engine Adapter.""" 102 return self._engine_adapter( 103 self._connection_factory_with_kwargs, 104 multithreaded=self.concurrent_tasks > 1, 105 cursor_kwargs=self._cursor_kwargs, 106 default_catalog=self.get_catalog(), 107 cursor_init=self._cursor_init, 108 register_comments=register_comments_override or self.register_comments, 109 **self._extra_engine_config, 110 )
Returns a new instance of the Engine Adapter.
112 def get_catalog(self) -> t.Optional[str]: 113 """The catalog for this connection""" 114 if hasattr(self, "catalog"): 115 return self.catalog 116 if hasattr(self, "database"): 117 return self.database 118 if hasattr(self, "db"): 119 return self.db 120 return None
The catalog for this connection
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
123class BaseDuckDBConnectionConfig(ConnectionConfig): 124 """Common configuration for the DuckDB-based connections. 125 126 Args: 127 extensions: A list of autoloadable extensions to load. 128 connector_config: A dictionary of configuration to pass into the duckdb connector. 129 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 130 register_comments: Whether or not to register model comments with the SQL engine. 131 """ 132 133 extensions: t.List[str] = [] 134 connector_config: t.Dict[str, t.Any] = {} 135 136 concurrent_tasks: Literal[1] = 1 137 register_comments: bool = True 138 139 @property 140 def _engine_adapter(self) -> t.Type[EngineAdapter]: 141 return engine_adapter.DuckDBEngineAdapter 142 143 @property 144 def _connection_factory(self) -> t.Callable: 145 import duckdb 146 147 return duckdb.connect 148 149 @property 150 def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: 151 """A function that is called to initialize the cursor""" 152 import duckdb 153 from duckdb import BinderException 154 155 def init(cursor: duckdb.DuckDBPyConnection) -> None: 156 for extension in self.extensions: 157 try: 158 cursor.execute(f"INSTALL {extension}") 159 cursor.execute(f"LOAD {extension}") 160 except Exception as e: 161 raise ConfigError(f"Failed to load extension {extension}: {e}") 162 163 for field, setting in self.connector_config.items(): 164 try: 165 cursor.execute(f"SET {field} = '{setting}'") 166 except Exception as e: 167 raise ConfigError(f"Failed to set connector config {field} to {setting}: {e}") 168 169 for i, (alias, path_options) in enumerate( 170 (getattr(self, "catalogs", None) or {}).items() 171 ): 172 # we parse_identifier and generate to ensure that `alias` has exactly one set of quotes 173 # regardless of whether it comes in quoted or not 174 alias = exp.parse_identifier(alias, dialect="duckdb").sql( 175 identify=True, dialect="duckdb" 176 ) 177 try: 178 query = ( 179 path_options.to_sql(alias) 180 if isinstance(path_options, DuckDBAttachOptions) 181 else f"ATTACH '{path_options}' AS {alias}" 182 ) 183 cursor.execute(query) 184 except BinderException as e: 185 # If a user tries to create a catalog pointing at `:memory:` and with the name `memory` 186 # then we don't want to raise since this happens by default. They are just doing this to 187 # set it as the default catalog. 188 if not ( 189 'database with name "memory" already exists' in str(e) 190 and path_options == ":memory:" 191 ): 192 raise e 193 if i == 0 and not getattr(self, "database", None): 194 cursor.execute(f"USE {alias}") 195 196 return init
Common configuration for the DuckDB-based connections.
Arguments:
- extensions: A list of autoloadable extensions to load.
- connector_config: A dictionary of configuration to pass into the duckdb connector.
- concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
- register_comments: Whether or not to register model comments with the SQL engine.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
199class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig): 200 """Configuration for the MotherDuck connection. 201 202 Args: 203 database: The database name. 204 token: The optional MotherDuck token. If not specified, the user will be prompted to login with their web browser. 205 """ 206 207 database: str 208 token: t.Optional[str] = None 209 210 type_: Literal["motherduck"] = Field(alias="type", default="motherduck") 211 212 @property 213 def _connection_kwargs_keys(self) -> t.Set[str]: 214 return set() 215 216 @property 217 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 218 """kwargs that are for execution config only""" 219 connection_str = f"md:{self.database}" 220 if self.token: 221 connection_str += f"?motherduck_token={self.token}" 222 return {"database": connection_str}
Configuration for the MotherDuck connection.
Arguments:
- database: The database name.
- token: The optional MotherDuck token. If not specified, the user will be prompted to login with their web browser.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
225class DuckDBAttachOptions(BaseConfig): 226 type: str 227 path: str 228 read_only: bool = False 229 230 def to_sql(self, alias: str) -> str: 231 options = [] 232 # 'duckdb' is actually not a supported type, but we'd like to allow it for 233 # fully qualified attach options or integration testing, similar to duckdb-dbt 234 if self.type != "duckdb": 235 options.append(f"TYPE {self.type.upper()}") 236 if self.read_only: 237 options.append("READ_ONLY") 238 options_sql = f" ({', '.join(options)})" if options else "" 239 return f"ATTACH '{self.path}' AS {alias}{options_sql}"
Base configuration functionality for configuration classes.
230 def to_sql(self, alias: str) -> str: 231 options = [] 232 # 'duckdb' is actually not a supported type, but we'd like to allow it for 233 # fully qualified attach options or integration testing, similar to duckdb-dbt 234 if self.type != "duckdb": 235 options.append(f"TYPE {self.type.upper()}") 236 if self.read_only: 237 options.append("READ_ONLY") 238 options_sql = f" ({', '.join(options)})" if options else "" 239 return f"ATTACH '{self.path}' AS {alias}{options_sql}"
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
242class DuckDBConnectionConfig(BaseDuckDBConnectionConfig): 243 """Configuration for the DuckDB connection. 244 245 Args: 246 database: The optional database name. If not specified, the in-memory database will be used. 247 catalogs: Key is the name of the catalog and value is the path. 248 """ 249 250 database: t.Optional[str] = None 251 catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None 252 253 type_: Literal["duckdb"] = Field(alias="type", default="duckdb") 254 255 _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} 256 257 @model_validator(mode="before") 258 @model_validator_v1_args 259 def _validate_database_catalogs( 260 cls, values: t.Dict[str, t.Optional[str]] 261 ) -> t.Dict[str, t.Optional[str]]: 262 if values.get("database") and values.get("catalogs"): 263 raise ConfigError( 264 "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog" 265 ) 266 return values 267 268 @property 269 def _connection_kwargs_keys(self) -> t.Set[str]: 270 return {"database"} 271 272 def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter: 273 """Checks if another engine adapter has already been created that shares a catalog that points to the same data 274 file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration 275 associated with the new adapter will be ignored.""" 276 data_files = set((self.catalogs or {}).values()) 277 if self.database: 278 data_files.add(self.database) 279 data_files.discard(":memory:") 280 for data_file in data_files: 281 key = data_file if isinstance(data_file, str) else data_file.path 282 if adapter := DuckDBConnectionConfig._data_file_to_adapter.get(key): 283 logger.info(f"Using existing DuckDB adapter due to overlapping data file: {key}") 284 return adapter 285 286 if data_files: 287 logger.info(f"Creating new DuckDB adapter for data files: {data_files}") 288 else: 289 logger.info("Creating new DuckDB adapter for in-memory database") 290 adapter = super().create_engine_adapter(register_comments_override) 291 for data_file in data_files: 292 key = data_file if isinstance(data_file, str) else data_file.path 293 DuckDBConnectionConfig._data_file_to_adapter[key] = adapter 294 return adapter 295 296 def get_catalog(self) -> t.Optional[str]: 297 if self.database: 298 # Remove `:` from the database name in order to handle if `:memory:` is passed in 299 return pathlib.Path(self.database.replace(":", "")).stem 300 if self.catalogs: 301 return list(self.catalogs)[0] 302 return None
Configuration for the DuckDB connection.
Arguments:
- database: The optional database name. If not specified, the in-memory database will be used.
- catalogs: Key is the name of the catalog and value is the path.
272 def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter: 273 """Checks if another engine adapter has already been created that shares a catalog that points to the same data 274 file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration 275 associated with the new adapter will be ignored.""" 276 data_files = set((self.catalogs or {}).values()) 277 if self.database: 278 data_files.add(self.database) 279 data_files.discard(":memory:") 280 for data_file in data_files: 281 key = data_file if isinstance(data_file, str) else data_file.path 282 if adapter := DuckDBConnectionConfig._data_file_to_adapter.get(key): 283 logger.info(f"Using existing DuckDB adapter due to overlapping data file: {key}") 284 return adapter 285 286 if data_files: 287 logger.info(f"Creating new DuckDB adapter for data files: {data_files}") 288 else: 289 logger.info("Creating new DuckDB adapter for in-memory database") 290 adapter = super().create_engine_adapter(register_comments_override) 291 for data_file in data_files: 292 key = data_file if isinstance(data_file, str) else data_file.path 293 DuckDBConnectionConfig._data_file_to_adapter[key] = adapter 294 return adapter
Checks if another engine adapter has already been created that shares a catalog that points to the same data file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration associated with the new adapter will be ignored.
296 def get_catalog(self) -> t.Optional[str]: 297 if self.database: 298 # Remove `:` from the database name in order to handle if `:memory:` is passed in 299 return pathlib.Path(self.database.replace(":", "")).stem 300 if self.catalogs: 301 return list(self.catalogs)[0] 302 return None
The catalog for this connection
102 def wrapped_model_post_init(self: BaseModel, __context: Any) -> None: 103 """We need to both initialize private attributes and call the user-defined model_post_init 104 method. 105 """ 106 init_private_attributes(self, __context) 107 original_model_post_init(self, __context)
Override this method to perform additional initialization after __init__
and model_construct
.
This is useful if you want to do some validation that requires the entire model to be initialized.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
305class SnowflakeConnectionConfig(ConnectionConfig): 306 """Configuration for the Snowflake connection. 307 308 Args: 309 account: The Snowflake account name. 310 user: The Snowflake username. 311 password: The Snowflake password. 312 warehouse: The optional warehouse name. 313 database: The optional database name. 314 role: The optional role name. 315 concurrent_tasks: The maximum number of tasks that can use this connection concurrently. 316 authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake"). 317 Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183 318 token: The optional oauth access token to use for authentication when authenticator is set to "oauth". 319 private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation 320 private_key_path: The optional path to the private key to use for authentication. This would be used instead of `private_key`. 321 private_key_passphrase: The optional passphrase to use to decrypt `private_key` or `private_key_path`. Keys can be created without encryption so only provide this if needed. 322 register_comments: Whether or not to register model comments with the SQL engine. 323 """ 324 325 account: str 326 user: t.Optional[str] = None 327 password: t.Optional[str] = None 328 warehouse: t.Optional[str] = None 329 database: t.Optional[str] = None 330 role: t.Optional[str] = None 331 authenticator: t.Optional[str] = None 332 token: t.Optional[str] = None 333 334 # Private Key Auth 335 private_key: t.Optional[t.Union[str, bytes]] = None 336 private_key_path: t.Optional[str] = None 337 private_key_passphrase: t.Optional[str] = None 338 339 concurrent_tasks: int = 4 340 register_comments: bool = True 341 342 type_: Literal["snowflake"] = Field(alias="type", default="snowflake") 343 344 _concurrent_tasks_validator = concurrent_tasks_validator 345 346 @model_validator(mode="before") 347 @model_validator_v1_args 348 def _validate_authenticator( 349 cls, values: t.Dict[str, t.Optional[str]] 350 ) -> t.Dict[str, t.Optional[str]]: 351 from snowflake.connector.network import ( 352 DEFAULT_AUTHENTICATOR, 353 OAUTH_AUTHENTICATOR, 354 ) 355 356 auth = values.get("authenticator") 357 auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR 358 user = values.get("user") 359 password = values.get("password") 360 values["private_key"] = cls._get_private_key(values, auth) # type: ignore 361 if ( 362 auth == DEFAULT_AUTHENTICATOR 363 and not values.get("private_key") 364 and (not user or not password) 365 ): 366 raise ConfigError("User and password must be provided if using default authentication") 367 if auth == OAUTH_AUTHENTICATOR and not values.get("token"): 368 raise ConfigError("Token must be provided if using oauth authentication") 369 return values 370 371 @classmethod 372 def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]: 373 """ 374 source: https://github.com/dbt-labs/dbt-snowflake/blob/0374b4ec948982f2ac8ec0c95d53d672ad19e09c/dbt/adapters/snowflake/connections.py#L247C5-L285C1 375 376 Overall code change: Use local variables instead of class attributes + Validation 377 """ 378 # Start custom code 379 from cryptography.hazmat.backends import default_backend 380 from cryptography.hazmat.primitives import serialization 381 from snowflake.connector.network import ( 382 DEFAULT_AUTHENTICATOR, 383 KEY_PAIR_AUTHENTICATOR, 384 ) 385 386 private_key = values.get("private_key") 387 private_key_path = values.get("private_key_path") 388 private_key_passphrase = values.get("private_key_passphrase") 389 user = values.get("user") 390 password = values.get("password") 391 auth = auth if auth and auth != DEFAULT_AUTHENTICATOR else KEY_PAIR_AUTHENTICATOR 392 393 if not private_key and not private_key_path: 394 return None 395 if private_key and private_key_path: 396 raise ConfigError("Cannot specify both `private_key` and `private_key_path`") 397 if auth != KEY_PAIR_AUTHENTICATOR: 398 raise ConfigError( 399 f"Private key or private key path can only be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 400 ) 401 if not user: 402 raise ConfigError( 403 f"User must be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 404 ) 405 if password: 406 raise ConfigError( 407 f"Password cannot be provided when using {KEY_PAIR_AUTHENTICATOR} authentication" 408 ) 409 410 if isinstance(private_key, bytes): 411 return private_key 412 # End Custom Code 413 414 if private_key_passphrase: 415 encoded_passphrase = private_key_passphrase.encode() 416 else: 417 encoded_passphrase = None 418 419 if private_key: 420 if private_key.startswith("-"): 421 p_key = serialization.load_pem_private_key( 422 data=bytes(private_key, "utf-8"), 423 password=encoded_passphrase, 424 backend=default_backend(), 425 ) 426 427 else: 428 p_key = serialization.load_der_private_key( 429 data=base64.b64decode(private_key), 430 password=encoded_passphrase, 431 backend=default_backend(), 432 ) 433 434 elif private_key_path: 435 with open(private_key_path, "rb") as key: 436 p_key = serialization.load_pem_private_key( 437 key.read(), password=encoded_passphrase, backend=default_backend() 438 ) 439 else: 440 return None 441 442 return p_key.private_bytes( 443 encoding=serialization.Encoding.DER, 444 format=serialization.PrivateFormat.PKCS8, 445 encryption_algorithm=serialization.NoEncryption(), 446 ) 447 448 @property 449 def _connection_kwargs_keys(self) -> t.Set[str]: 450 return { 451 "user", 452 "password", 453 "account", 454 "warehouse", 455 "database", 456 "role", 457 "authenticator", 458 "token", 459 "private_key", 460 } 461 462 @property 463 def _engine_adapter(self) -> t.Type[EngineAdapter]: 464 return engine_adapter.SnowflakeEngineAdapter 465 466 @property 467 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 468 return {"autocommit": False} 469 470 @property 471 def _connection_factory(self) -> t.Callable: 472 from snowflake import connector 473 474 return connector.connect
Configuration for the Snowflake connection.
Arguments:
- account: The Snowflake account name.
- user: The Snowflake username.
- password: The Snowflake password.
- warehouse: The optional warehouse name.
- database: The optional database name.
- role: The optional role name.
- concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
- authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake"). Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183
- token: The optional oauth access token to use for authentication when authenticator is set to "oauth".
- private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation
- private_key_path: The optional path to the private key to use for authentication. This would be used instead of
private_key
. - private_key_passphrase: The optional passphrase to use to decrypt
private_key
orprivate_key_path
. Keys can be created without encryption so only provide this if needed. - register_comments: Whether or not to register model comments with the SQL engine.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
477class DatabricksConnectionConfig(ConnectionConfig): 478 """ 479 Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations 480 481 Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39 482 Args: 483 server_hostname: Databricks instance host name. 484 http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) 485 or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) 486 access_token: Http Bearer access token, e.g. Databricks Personal Access Token. 487 catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in 488 the Databricks cluster (most likely `hive_metastore`). 489 http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request 490 session_configuration: An optional dictionary of Spark session parameters. 491 Execute the SQL command `SET -v` to get a full list of available commands. 492 databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect. 493 Defaults to the `server_hostname` value. 494 databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect. 495 Defaults to the `access_token` value. 496 databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect. 497 Defaults to deriving the cluster id from the `http_path` value. 498 force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector. 499 disable_databricks_connect: Even if databricks connect is installed, do not use it. 500 """ 501 502 server_hostname: t.Optional[str] = None 503 http_path: t.Optional[str] = None 504 access_token: t.Optional[str] = None 505 catalog: t.Optional[str] = None 506 http_headers: t.Optional[t.List[t.Tuple[str, str]]] = None 507 session_configuration: t.Optional[t.Dict[str, t.Any]] = None 508 databricks_connect_server_hostname: t.Optional[str] = None 509 databricks_connect_access_token: t.Optional[str] = None 510 databricks_connect_cluster_id: t.Optional[str] = None 511 force_databricks_connect: bool = False 512 disable_databricks_connect: bool = False 513 514 concurrent_tasks: int = 1 515 register_comments: bool = True 516 517 type_: Literal["databricks"] = Field(alias="type", default="databricks") 518 519 _concurrent_tasks_validator = concurrent_tasks_validator 520 _http_headers_validator = http_headers_validator 521 522 @model_validator(mode="before") 523 @model_validator_v1_args 524 def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 525 from sqlmesh import RuntimeEnv 526 from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter 527 528 runtime_env = RuntimeEnv.get() 529 530 if runtime_env.is_databricks: 531 return values 532 server_hostname, http_path, access_token = ( 533 values.get("server_hostname"), 534 values.get("http_path"), 535 values.get("access_token"), 536 ) 537 if not server_hostname or not http_path or not access_token: 538 raise ValueError( 539 "`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook" 540 ) 541 if DatabricksEngineAdapter.can_access_spark_session: 542 if not values.get("databricks_connect_server_hostname"): 543 values["databricks_connect_server_hostname"] = f"https://{server_hostname}" 544 if not values.get("databricks_connect_access_token"): 545 values["databricks_connect_access_token"] = access_token 546 if not values.get("databricks_connect_cluster_id"): 547 values["databricks_connect_cluster_id"] = http_path.split("/")[-1] 548 if not values.get("session_configuration"): 549 values["session_configuration"] = {} 550 values["session_configuration"]["spark.sql.sources.partitionOverwriteMode"] = "dynamic" 551 return values 552 553 @property 554 def _connection_kwargs_keys(self) -> t.Set[str]: 555 if self.use_spark_session_only: 556 return set() 557 return { 558 "server_hostname", 559 "http_path", 560 "access_token", 561 "http_headers", 562 "session_configuration", 563 "catalog", 564 } 565 566 @property 567 def _engine_adapter(self) -> t.Type[engine_adapter.DatabricksEngineAdapter]: 568 return engine_adapter.DatabricksEngineAdapter 569 570 @property 571 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 572 return { 573 k: v 574 for k, v in self.dict().items() 575 if k.startswith("databricks_connect_") or k in ("catalog", "disable_databricks_connect") 576 } 577 578 @property 579 def use_spark_session_only(self) -> bool: 580 from sqlmesh import RuntimeEnv 581 582 return RuntimeEnv.get().is_databricks or self.force_databricks_connect 583 584 @property 585 def _connection_factory(self) -> t.Callable: 586 if self.use_spark_session_only: 587 from sqlmesh.engines.spark.db_api.spark_session import connection 588 589 return connection 590 591 from databricks import sql 592 593 return sql.connect 594 595 @property 596 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 597 from sqlmesh import RuntimeEnv 598 599 if not self.use_spark_session_only: 600 return {} 601 602 if RuntimeEnv.get().is_databricks: 603 from pyspark.sql import SparkSession 604 605 return dict( 606 spark=SparkSession.getActiveSession(), 607 catalog=self.catalog, 608 ) 609 610 from databricks.connect import DatabricksSession 611 612 return dict( 613 spark=DatabricksSession.builder.remote( 614 host=self.databricks_connect_server_hostname, 615 token=self.databricks_connect_access_token, 616 cluster_id=self.databricks_connect_cluster_id, 617 ).getOrCreate(), 618 catalog=self.catalog, 619 )
Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations
Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39
Arguments:
- server_hostname: Databricks instance host name.
- http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
- access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
- catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in
the Databricks cluster (most likely
hive_metastore
). - http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request
- session_configuration: An optional dictionary of Spark session parameters.
Execute the SQL command
SET -v
to get a full list of available commands. - databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect.
Defaults to the
server_hostname
value. - databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect.
Defaults to the
access_token
value. - databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect.
Defaults to deriving the cluster id from the
http_path
value. - force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector.
- disable_databricks_connect: Even if databricks connect is installed, do not use it.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
622class BigQueryConnectionMethod(str, Enum): 623 OAUTH = "oauth" 624 OAUTH_SECRETS = "oauth-secrets" 625 SERVICE_ACCOUNT = "service-account" 626 SERVICE_ACCOUNT_JSON = "service-account-json"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
629class BigQueryPriority(str, Enum): 630 BATCH = "batch" 631 INTERACTIVE = "interactive" 632 633 @property 634 def is_batch(self) -> bool: 635 return self == self.BATCH 636 637 @property 638 def is_interactive(self) -> bool: 639 return self == self.INTERACTIVE 640 641 @property 642 def bigquery_constant(self) -> str: 643 from google.cloud.bigquery import QueryPriority 644 645 if self.is_batch: 646 return QueryPriority.BATCH 647 return QueryPriority.INTERACTIVE
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
650class BigQueryConnectionConfig(ConnectionConfig): 651 """ 652 BigQuery Connection Configuration. 653 """ 654 655 method: BigQueryConnectionMethod = BigQueryConnectionMethod.OAUTH 656 657 project: t.Optional[str] = None 658 execution_project: t.Optional[str] = None 659 location: t.Optional[str] = None 660 # Keyfile Auth 661 keyfile: t.Optional[str] = None 662 keyfile_json: t.Optional[t.Dict[str, t.Any]] = None 663 # Oath Secret Auth 664 token: t.Optional[str] = None 665 refresh_token: t.Optional[str] = None 666 client_id: t.Optional[str] = None 667 client_secret: t.Optional[str] = None 668 token_uri: t.Optional[str] = None 669 scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/bigquery",) 670 job_creation_timeout_seconds: t.Optional[int] = None 671 # Extra Engine Config 672 job_execution_timeout_seconds: t.Optional[int] = None 673 job_retries: t.Optional[int] = 1 674 job_retry_deadline_seconds: t.Optional[int] = None 675 priority: t.Optional[BigQueryPriority] = None 676 maximum_bytes_billed: t.Optional[int] = None 677 678 concurrent_tasks: int = 1 679 register_comments: bool = True 680 681 type_: Literal["bigquery"] = Field(alias="type", default="bigquery") 682 683 @property 684 def _connection_kwargs_keys(self) -> t.Set[str]: 685 return set() 686 687 @property 688 def _engine_adapter(self) -> t.Type[EngineAdapter]: 689 return engine_adapter.BigQueryEngineAdapter 690 691 @property 692 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 693 """The static connection kwargs for this connection""" 694 import google.auth 695 from google.api_core import client_info 696 from google.oauth2 import credentials, service_account 697 698 if self.method == BigQueryConnectionMethod.OAUTH: 699 creds, _ = google.auth.default(scopes=self.scopes) 700 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT: 701 creds = service_account.Credentials.from_service_account_file( 702 self.keyfile, scopes=self.scopes 703 ) 704 elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: 705 creds = service_account.Credentials.from_service_account_info( 706 self.keyfile_json, scopes=self.scopes 707 ) 708 elif self.method == BigQueryConnectionMethod.OAUTH_SECRETS: 709 creds = credentials.Credentials( 710 token=self.token, 711 refresh_token=self.refresh_token, 712 client_id=self.client_id, 713 client_secret=self.client_secret, 714 token_uri=self.token_uri, 715 scopes=self.scopes, 716 ) 717 else: 718 raise ConfigError("Invalid BigQuery Connection Method") 719 client = google.cloud.bigquery.Client( 720 project=self.execution_project or self.project, 721 credentials=creds, 722 location=self.location, 723 client_info=client_info.ClientInfo(user_agent="sqlmesh"), 724 ) 725 726 return { 727 "client": client, 728 } 729 730 @property 731 def _extra_engine_config(self) -> t.Dict[str, t.Any]: 732 return { 733 k: v 734 for k, v in self.dict().items() 735 if k 736 in { 737 "job_creation_timeout_seconds", 738 "job_execution_timeout_seconds", 739 "job_retries", 740 "job_retry_deadline_seconds", 741 "priority", 742 "maximum_bytes_billed", 743 } 744 } 745 746 @property 747 def _connection_factory(self) -> t.Callable: 748 from google.cloud.bigquery.dbapi import connect 749 750 return connect 751 752 def get_catalog(self) -> t.Optional[str]: 753 return self.project
BigQuery Connection Configuration.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
756class GCPPostgresConnectionConfig(ConnectionConfig): 757 """ 758 Postgres Connection Configuration for GCP. 759 760 Args: 761 instance_connection_string: Connection name for the postgres instance. 762 user: Postgres or IAM user's name 763 password: The postgres user's password. Only needed when the user is a postgres user. 764 enable_iam_auth: Set to True when user is an IAM user. 765 db: Name of the db to connect to. 766 """ 767 768 instance_connection_string: str 769 user: str 770 password: t.Optional[str] = None 771 enable_iam_auth: t.Optional[bool] = None 772 db: str 773 timeout: t.Optional[int] = None 774 775 driver: str = "pg8000" 776 type_: Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres") 777 concurrent_tasks: int = 4 778 register_comments: bool = True 779 780 @model_validator(mode="before") 781 @model_validator_v1_args 782 def _validate_auth_method( 783 cls, values: t.Dict[str, t.Optional[str]] 784 ) -> t.Dict[str, t.Optional[str]]: 785 password = values.get("password") 786 enable_iam_auth = values.get("enable_iam_auth") 787 if password and enable_iam_auth: 788 raise ConfigError( 789 "Invalid GCP Postgres connection configuration - both password and" 790 " enable_iam_auth set. Use password when connecting to a postgres" 791 " user and enable_iam_auth 'True' when connecting to an IAM user." 792 ) 793 if not password and not enable_iam_auth: 794 raise ConfigError( 795 "GCP Postgres connection configuration requires either password set" 796 " for a postgres user account or enable_iam_auth set to 'True'" 797 " for an IAM user account." 798 ) 799 return values 800 801 @property 802 def _connection_kwargs_keys(self) -> t.Set[str]: 803 return { 804 "instance_connection_string", 805 "driver", 806 "user", 807 "password", 808 "db", 809 "enable_iam_auth", 810 "timeout", 811 } 812 813 @property 814 def _engine_adapter(self) -> t.Type[EngineAdapter]: 815 return engine_adapter.PostgresEngineAdapter 816 817 @property 818 def _connection_factory(self) -> t.Callable: 819 from google.cloud.sql.connector import Connector 820 821 return Connector().connect
Postgres Connection Configuration for GCP.
Arguments:
- instance_connection_string: Connection name for the postgres instance.
- user: Postgres or IAM user's name
- password: The postgres user's password. Only needed when the user is a postgres user.
- enable_iam_auth: Set to True when user is an IAM user.
- db: Name of the db to connect to.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
824class RedshiftConnectionConfig(ConnectionConfig): 825 """ 826 Redshift Connection Configuration. 827 828 Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146 829 Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported. 830 831 Args: 832 user: The username to use for authentication with the Amazon Redshift cluster. 833 password: The password to use for authentication with the Amazon Redshift cluster. 834 database: The name of the database instance to connect to. 835 host: The hostname of the Amazon Redshift cluster. 836 port: The port number of the Amazon Redshift cluster. Default value is 5439. 837 source_address: No description provided 838 unix_sock: No description provided 839 ssl: Is SSL enabled. Default value is ``True``. SSL must be enabled when authenticating using IAM. 840 sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported. 841 timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout. 842 tcp_keepalive: Is `TCP keepalive <https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive>`_ used. The default value is ``True``. 843 application_name: Sets the application name. The default value is None. 844 preferred_role: The IAM role preferred for the current connection. 845 principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy. 846 credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster. 847 region: The AWS region where the Amazon Redshift cluster is located. 848 cluster_identifier: The cluster identifier of the Amazon Redshift cluster. 849 iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP. 850 is_serverless: Redshift end-point is serverless or provisional. Default value false. 851 serverless_acct_id: The account ID of the serverless. Default value None 852 serverless_work_group: The name of work group for serverless end point. Default value None. 853 """ 854 855 user: t.Optional[str] = None 856 password: t.Optional[str] = None 857 database: t.Optional[str] = None 858 host: t.Optional[str] = None 859 port: t.Optional[int] = None 860 source_address: t.Optional[str] = None 861 unix_sock: t.Optional[str] = None 862 ssl: t.Optional[bool] = None 863 sslmode: t.Optional[str] = None 864 timeout: t.Optional[int] = None 865 tcp_keepalive: t.Optional[bool] = None 866 application_name: t.Optional[str] = None 867 preferred_role: t.Optional[str] = None 868 principal_arn: t.Optional[str] = None 869 credentials_provider: t.Optional[str] = None 870 region: t.Optional[str] = None 871 cluster_identifier: t.Optional[str] = None 872 iam: t.Optional[bool] = None 873 is_serverless: t.Optional[bool] = None 874 serverless_acct_id: t.Optional[str] = None 875 serverless_work_group: t.Optional[str] = None 876 877 concurrent_tasks: int = 4 878 register_comments: bool = True 879 880 type_: Literal["redshift"] = Field(alias="type", default="redshift") 881 882 @property 883 def _connection_kwargs_keys(self) -> t.Set[str]: 884 return { 885 "user", 886 "password", 887 "database", 888 "host", 889 "port", 890 "source_address", 891 "unix_sock", 892 "ssl", 893 "sslmode", 894 "timeout", 895 "tcp_keepalive", 896 "application_name", 897 "preferred_role", 898 "principal_arn", 899 "credentials_provider", 900 "region", 901 "cluster_identifier", 902 "iam", 903 "is_serverless", 904 "serverless_acct_id", 905 "serverless_work_group", 906 } 907 908 @property 909 def _engine_adapter(self) -> t.Type[EngineAdapter]: 910 return engine_adapter.RedshiftEngineAdapter 911 912 @property 913 def _connection_factory(self) -> t.Callable: 914 from redshift_connector import connect 915 916 return connect
Redshift Connection Configuration.
Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146 Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported.
Arguments:
- user: The username to use for authentication with the Amazon Redshift cluster.
- password: The password to use for authentication with the Amazon Redshift cluster.
- database: The name of the database instance to connect to.
- host: The hostname of the Amazon Redshift cluster.
- port: The port number of the Amazon Redshift cluster. Default value is 5439.
- source_address: No description provided
- unix_sock: No description provided
- ssl: Is SSL enabled. Default value is
True
. SSL must be enabled when authenticating using IAM. - sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported.
- timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout.
- tcp_keepalive: Is TCP keepalive used. The default value is
True
. - application_name: Sets the application name. The default value is None.
- preferred_role: The IAM role preferred for the current connection.
- principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy.
- credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster.
- region: The AWS region where the Amazon Redshift cluster is located.
- cluster_identifier: The cluster identifier of the Amazon Redshift cluster.
- iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP.
- is_serverless: Redshift end-point is serverless or provisional. Default value false.
- serverless_acct_id: The account ID of the serverless. Default value None
- serverless_work_group: The name of work group for serverless end point. Default value None.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
919class PostgresConnectionConfig(ConnectionConfig): 920 host: str 921 user: str 922 password: str 923 port: int 924 database: str 925 keepalives_idle: t.Optional[int] = None 926 connect_timeout: int = 10 927 role: t.Optional[str] = None 928 sslmode: t.Optional[str] = None 929 930 concurrent_tasks: int = 4 931 register_comments: bool = True 932 933 type_: Literal["postgres"] = Field(alias="type", default="postgres") 934 935 @property 936 def _connection_kwargs_keys(self) -> t.Set[str]: 937 return { 938 "host", 939 "user", 940 "password", 941 "port", 942 "database", 943 "keepalives_idle", 944 "connect_timeout", 945 "role", 946 "sslmode", 947 } 948 949 @property 950 def _engine_adapter(self) -> t.Type[EngineAdapter]: 951 return engine_adapter.PostgresEngineAdapter 952 953 @property 954 def _connection_factory(self) -> t.Callable: 955 from psycopg2 import connect 956 957 return connect
Helper class that provides a standard way to create an ABC using inheritance.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
960class MySQLConnectionConfig(ConnectionConfig): 961 host: str 962 user: str 963 password: str 964 port: t.Optional[int] = None 965 charset: t.Optional[str] = None 966 ssl_disabled: t.Optional[bool] = None 967 968 concurrent_tasks: int = 4 969 register_comments: bool = True 970 971 type_: Literal["mysql"] = Field(alias="type", default="mysql") 972 973 @property 974 def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]: 975 """Key-value arguments that will be passed during cursor construction.""" 976 return {"buffered": True} 977 978 @property 979 def _connection_kwargs_keys(self) -> t.Set[str]: 980 connection_keys = { 981 "host", 982 "user", 983 "password", 984 "port", 985 "database", 986 } 987 if self.port is not None: 988 connection_keys.add("port") 989 if self.charset is not None: 990 connection_keys.add("charset") 991 if self.ssl_disabled is not None: 992 connection_keys.add("ssl_disabled") 993 return connection_keys 994 995 @property 996 def _engine_adapter(self) -> t.Type[EngineAdapter]: 997 return engine_adapter.MySQLEngineAdapter 998 999 @property 1000 def _connection_factory(self) -> t.Callable: 1001 from mysql.connector import connect 1002 1003 return connect
Helper class that provides a standard way to create an ABC using inheritance.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1006class MSSQLConnectionConfig(ConnectionConfig): 1007 host: str 1008 user: t.Optional[str] = None 1009 password: t.Optional[str] = None 1010 database: t.Optional[str] = "" 1011 timeout: t.Optional[int] = 0 1012 login_timeout: t.Optional[int] = 60 1013 charset: t.Optional[str] = "UTF-8" 1014 appname: t.Optional[str] = None 1015 port: t.Optional[int] = 1433 1016 conn_properties: t.Optional[t.Union[t.Iterable[str], str]] = None 1017 autocommit: t.Optional[bool] = False 1018 tds_version: t.Optional[str] = None 1019 1020 concurrent_tasks: int = 4 1021 register_comments: bool = True 1022 1023 type_: Literal["mssql"] = Field(alias="type", default="mssql") 1024 1025 @property 1026 def _connection_kwargs_keys(self) -> t.Set[str]: 1027 return { 1028 "host", 1029 "user", 1030 "password", 1031 "database", 1032 "timeout", 1033 "login_timeout", 1034 "charset", 1035 "appname", 1036 "port", 1037 "conn_properties", 1038 "autocommit", 1039 "tds_version", 1040 } 1041 1042 @property 1043 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1044 return engine_adapter.MSSQLEngineAdapter 1045 1046 @property 1047 def _connection_factory(self) -> t.Callable: 1048 import pymssql 1049 1050 return pymssql.connect
Helper class that provides a standard way to create an ABC using inheritance.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1053class SparkConnectionConfig(ConnectionConfig): 1054 """ 1055 Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks. 1056 """ 1057 1058 config_dir: t.Optional[str] = None 1059 catalog: t.Optional[str] = None 1060 config: t.Dict[str, t.Any] = {} 1061 1062 concurrent_tasks: int = 4 1063 register_comments: bool = True 1064 1065 type_: Literal["spark"] = Field(alias="type", default="spark") 1066 1067 @property 1068 def _connection_kwargs_keys(self) -> t.Set[str]: 1069 return { 1070 "catalog", 1071 } 1072 1073 @property 1074 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1075 return engine_adapter.SparkEngineAdapter 1076 1077 @property 1078 def _connection_factory(self) -> t.Callable: 1079 from sqlmesh.engines.spark.db_api.spark_session import connection 1080 1081 return connection 1082 1083 @property 1084 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1085 from pyspark.conf import SparkConf 1086 from pyspark.sql import SparkSession 1087 1088 spark_config = SparkConf() 1089 if self.config: 1090 for k, v in self.config.items(): 1091 spark_config.set(k, v) 1092 1093 if self.config_dir: 1094 os.environ["SPARK_CONF_DIR"] = self.config_dir 1095 return { 1096 "spark": SparkSession.builder.config(conf=spark_config) 1097 .enableHiveSupport() 1098 .getOrCreate(), 1099 }
Vanilla Spark Connection Configuration. Use DatabricksConnectionConfig
for Databricks.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1102class TrinoAuthenticationMethod(str, Enum): 1103 NO_AUTH = "no-auth" 1104 BASIC = "basic" 1105 LDAP = "ldap" 1106 KERBEROS = "kerberos" 1107 JWT = "jwt" 1108 CERTIFICATE = "certificate" 1109 OAUTH = "oauth" 1110 1111 @property 1112 def is_no_auth(self) -> bool: 1113 return self == self.NO_AUTH 1114 1115 @property 1116 def is_basic(self) -> bool: 1117 return self == self.BASIC 1118 1119 @property 1120 def is_ldap(self) -> bool: 1121 return self == self.LDAP 1122 1123 @property 1124 def is_kerberos(self) -> bool: 1125 return self == self.KERBEROS 1126 1127 @property 1128 def is_jwt(self) -> bool: 1129 return self == self.JWT 1130 1131 @property 1132 def is_certificate(self) -> bool: 1133 return self == self.CERTIFICATE 1134 1135 @property 1136 def is_oauth(self) -> bool: 1137 return self == self.OAUTH
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
1140class TrinoConnectionConfig(ConnectionConfig): 1141 method: TrinoAuthenticationMethod = TrinoAuthenticationMethod.NO_AUTH 1142 host: str 1143 user: str 1144 catalog: str 1145 port: t.Optional[int] = None 1146 http_scheme: Literal["http", "https"] = "https" 1147 # General Optional 1148 roles: t.Optional[t.Dict[str, str]] = None 1149 http_headers: t.Optional[t.Dict[str, str]] = None 1150 session_properties: t.Optional[t.Dict[str, str]] = None 1151 retries: int = 3 1152 timezone: t.Optional[str] = None 1153 # Basic/LDAP 1154 password: t.Optional[str] = None 1155 # LDAP 1156 impersonation_user: t.Optional[str] = None 1157 # Kerberos 1158 keytab: t.Optional[str] = None 1159 krb5_config: t.Optional[str] = None 1160 principal: t.Optional[str] = None 1161 service_name: str = "trino" 1162 hostname_override: t.Optional[str] = None 1163 mutual_authentication: bool = False 1164 force_preemptive: bool = False 1165 sanitize_mutual_error_response: bool = True 1166 delegate: bool = False 1167 # JWT 1168 jwt_token: t.Optional[str] = None 1169 # Certificate 1170 client_certificate: t.Optional[str] = None 1171 client_private_key: t.Optional[str] = None 1172 cert: t.Optional[str] = None 1173 1174 concurrent_tasks: int = 4 1175 register_comments: bool = True 1176 1177 type_: Literal["trino"] = Field(alias="type", default="trino") 1178 1179 @model_validator(mode="after") 1180 @model_validator_v1_args 1181 def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 1182 port = values.get("port") 1183 if ( 1184 values["http_scheme"] == "http" 1185 and not values["method"].is_no_auth 1186 and not values["method"].is_basic 1187 ): 1188 raise ConfigError("HTTP scheme can only be used with no-auth or basic method") 1189 if port is None: 1190 values["port"] = 80 if values["http_scheme"] == "http" else 443 1191 if (values["method"].is_ldap or values["method"].is_basic) and ( 1192 not values["password"] or not values["user"] 1193 ): 1194 raise ConfigError( 1195 f"Username and Password must be provided if using {values['method'].value} authentication" 1196 ) 1197 if values["method"].is_kerberos and ( 1198 not values["principal"] or not values["keytab"] or not values["krb5_config"] 1199 ): 1200 raise ConfigError( 1201 "Kerberos requires the following fields: principal, keytab, and krb5_config" 1202 ) 1203 if values["method"].is_jwt and not values["jwt_token"]: 1204 raise ConfigError("JWT requires `jwt_token` to be set") 1205 if values["method"].is_certificate and ( 1206 not values["cert"] 1207 or not values["client_certificate"] 1208 or not values["client_private_key"] 1209 ): 1210 raise ConfigError( 1211 "Certificate requires the following fields: cert, client_certificate, and client_private_key" 1212 ) 1213 return values 1214 1215 @property 1216 def _connection_kwargs_keys(self) -> t.Set[str]: 1217 kwargs = { 1218 "host", 1219 "port", 1220 "catalog", 1221 "roles", 1222 "http_scheme", 1223 "http_headers", 1224 "session_properties", 1225 "timezone", 1226 } 1227 return kwargs 1228 1229 @property 1230 def _engine_adapter(self) -> t.Type[EngineAdapter]: 1231 return engine_adapter.TrinoEngineAdapter 1232 1233 @property 1234 def _connection_factory(self) -> t.Callable: 1235 from trino.dbapi import connect 1236 1237 return connect 1238 1239 @property 1240 def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: 1241 from trino.auth import ( 1242 BasicAuthentication, 1243 CertificateAuthentication, 1244 JWTAuthentication, 1245 KerberosAuthentication, 1246 OAuth2Authentication, 1247 ) 1248 1249 if self.method.is_basic or self.method.is_ldap: 1250 auth = BasicAuthentication(self.user, self.password) 1251 elif self.method.is_kerberos: 1252 if self.keytab: 1253 os.environ["KRB5_CLIENT_KTNAME"] = self.keytab 1254 auth = KerberosAuthentication( 1255 config=self.krb5_config, 1256 service_name=self.service_name, 1257 principal=self.principal, 1258 mutual_authentication=self.mutual_authentication, 1259 ca_bundle=self.cert, 1260 force_preemptive=self.force_preemptive, 1261 hostname_override=self.hostname_override, 1262 sanitize_mutual_error_response=self.sanitize_mutual_error_response, 1263 delegate=self.delegate, 1264 ) 1265 elif self.method.is_oauth: 1266 auth = OAuth2Authentication() 1267 elif self.method.is_jwt: 1268 auth = JWTAuthentication(self.jwt_token) 1269 elif self.method.is_certificate: 1270 auth = CertificateAuthentication(self.client_certificate, self.client_private_key) 1271 else: 1272 auth = None 1273 1274 return { 1275 "auth": auth, 1276 "user": self.impersonation_user or self.user, 1277 "max_attempts": self.retries, 1278 "verify": self.cert, 1279 "source": "sqlmesh", 1280 }
Helper class that provides a standard way to create an ABC using inheritance.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
1292def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig: 1293 if "type" not in v: 1294 raise ConfigError("Missing connection type.") 1295 1296 connection_type = v["type"] 1297 if connection_type not in CONNECTION_CONFIG_TO_TYPE: 1298 raise ConfigError(f"Unknown connection type '{connection_type}'.") 1299 1300 return CONNECTION_CONFIG_TO_TYPE[connection_type](**v)
Wrap a classmethod, staticmethod, property or unbound function and act as a descriptor that allows us to detect decorated items from the class' attributes.
This class' __get__ returns the wrapped item's __get__ result, which makes it transparent for classmethods and staticmethods.
Attributes:
- wrapped: The decorator that has to be wrapped.
- decorator_info: The decorator info.
- shim: A wrapper function to wrap V1 style function.