Edit on GitHub

sqlmesh.core.config.connection

   1from __future__ import annotations
   2
   3import abc
   4import base64
   5import logging
   6import os
   7import pathlib
   8import sys
   9import typing as t
  10from enum import Enum
  11from functools import partial
  12
  13from pydantic import Field
  14from sqlglot import exp
  15from sqlglot.helper import subclasses
  16
  17from sqlmesh.core import engine_adapter
  18from sqlmesh.core.config.base import BaseConfig
  19from sqlmesh.core.config.common import (
  20    concurrent_tasks_validator,
  21    http_headers_validator,
  22)
  23from sqlmesh.core.engine_adapter import EngineAdapter
  24from sqlmesh.utils.errors import ConfigError
  25from sqlmesh.utils.pydantic import (
  26    PYDANTIC_MAJOR_VERSION,
  27    field_validator,
  28    model_validator,
  29    model_validator_v1_args,
  30)
  31
  32if sys.version_info >= (3, 9):
  33    from typing import Literal
  34else:
  35    from typing_extensions import Literal
  36
  37
  38logger = logging.getLogger(__name__)
  39
  40RECOMMENDED_STATE_SYNC_ENGINES = {"postgres", "gcp_postgres", "mysql", "duckdb"}
  41
  42
  43class ConnectionConfig(abc.ABC, BaseConfig):
  44    type_: str
  45    concurrent_tasks: int
  46    register_comments: bool
  47
  48    @property
  49    @abc.abstractmethod
  50    def _connection_kwargs_keys(self) -> t.Set[str]:
  51        """keywords that should be passed into the connection"""
  52
  53    @property
  54    @abc.abstractmethod
  55    def _engine_adapter(self) -> t.Type[EngineAdapter]:
  56        """The engine adapter for this connection"""
  57
  58    @property
  59    @abc.abstractmethod
  60    def _connection_factory(self) -> t.Callable:
  61        """A function that is called to return a connection object for the given Engine Adapter"""
  62
  63    @property
  64    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
  65        """The static connection kwargs for this connection"""
  66        return {}
  67
  68    @property
  69    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
  70        """kwargs that are for execution config only"""
  71        return {}
  72
  73    @property
  74    def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]:
  75        """Key-value arguments that will be passed during cursor construction."""
  76        return None
  77
  78    @property
  79    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
  80        """A function that is called to initialize the cursor"""
  81        return None
  82
  83    @property
  84    def is_recommended_for_state_sync(self) -> bool:
  85        """Whether this connection is recommended for being used as a state sync for production state syncs"""
  86        return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES
  87
  88    @property
  89    def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]:
  90        """A function that is called to return a connection object for the given Engine Adapter"""
  91        return partial(
  92            self._connection_factory,
  93            **{
  94                **self._static_connection_kwargs,
  95                **{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys},
  96            },
  97        )
  98
  99    def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter:
 100        """Returns a new instance of the Engine Adapter."""
 101        return self._engine_adapter(
 102            self._connection_factory_with_kwargs,
 103            multithreaded=self.concurrent_tasks > 1,
 104            cursor_kwargs=self._cursor_kwargs,
 105            default_catalog=self.get_catalog(),
 106            cursor_init=self._cursor_init,
 107            register_comments=register_comments_override or self.register_comments,
 108            **self._extra_engine_config,
 109        )
 110
 111    def get_catalog(self) -> t.Optional[str]:
 112        """The catalog for this connection"""
 113        if hasattr(self, "catalog"):
 114            return self.catalog
 115        if hasattr(self, "database"):
 116            return self.database
 117        if hasattr(self, "db"):
 118            return self.db
 119        return None
 120
 121
 122class BaseDuckDBConnectionConfig(ConnectionConfig):
 123    """Common configuration for the DuckDB-based connections.
 124
 125    Args:
 126        extensions: A list of autoloadable extensions to load.
 127        connector_config: A dictionary of configuration to pass into the duckdb connector.
 128        concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
 129        register_comments: Whether or not to register model comments with the SQL engine.
 130    """
 131
 132    extensions: t.List[str] = []
 133    connector_config: t.Dict[str, t.Any] = {}
 134
 135    concurrent_tasks: Literal[1] = 1
 136    register_comments: bool = True
 137
 138    @property
 139    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 140        return engine_adapter.DuckDBEngineAdapter
 141
 142    @property
 143    def _connection_factory(self) -> t.Callable:
 144        import duckdb
 145
 146        return duckdb.connect
 147
 148    @property
 149    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
 150        """A function that is called to initialize the cursor"""
 151        import duckdb
 152        from duckdb import BinderException
 153
 154        def init(cursor: duckdb.DuckDBPyConnection) -> None:
 155            for extension in self.extensions:
 156                try:
 157                    cursor.execute(f"INSTALL {extension}")
 158                    cursor.execute(f"LOAD {extension}")
 159                except Exception as e:
 160                    raise ConfigError(f"Failed to load extension {extension}: {e}")
 161
 162            for field, setting in self.connector_config.items():
 163                try:
 164                    cursor.execute(f"SET {field} = '{setting}'")
 165                except Exception as e:
 166                    raise ConfigError(f"Failed to set connector config {field} to {setting}: {e}")
 167
 168            for i, (alias, path_options) in enumerate(
 169                (getattr(self, "catalogs", None) or {}).items()
 170            ):
 171                # we parse_identifier and generate to ensure that `alias` has exactly one set of quotes
 172                # regardless of whether it comes in quoted or not
 173                alias = exp.parse_identifier(alias, dialect="duckdb").sql(
 174                    identify=True, dialect="duckdb"
 175                )
 176                try:
 177                    query = (
 178                        path_options.to_sql(alias)
 179                        if isinstance(path_options, DuckDBAttachOptions)
 180                        else f"ATTACH '{path_options}' AS {alias}"
 181                    )
 182                    cursor.execute(query)
 183                except BinderException as e:
 184                    # If a user tries to create a catalog pointing at `:memory:` and with the name `memory`
 185                    # then we don't want to raise since this happens by default. They are just doing this to
 186                    # set it as the default catalog.
 187                    if not (
 188                        'database with name "memory" already exists' in str(e)
 189                        and path_options == ":memory:"
 190                    ):
 191                        raise e
 192                if i == 0 and not getattr(self, "database", None):
 193                    cursor.execute(f"USE {alias}")
 194
 195        return init
 196
 197
 198class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig):
 199    """Configuration for the MotherDuck connection.
 200
 201    Args:
 202        database: The database name.
 203        token: The optional MotherDuck token. If not specified, the user will be prompted to login with their web browser.
 204    """
 205
 206    database: str
 207    token: t.Optional[str] = None
 208
 209    type_: Literal["motherduck"] = Field(alias="type", default="motherduck")
 210
 211    @property
 212    def _connection_kwargs_keys(self) -> t.Set[str]:
 213        return set()
 214
 215    @property
 216    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
 217        """kwargs that are for execution config only"""
 218        connection_str = f"md:{self.database}"
 219        if self.token:
 220            connection_str += f"?motherduck_token={self.token}"
 221        return {"database": connection_str}
 222
 223
 224class DuckDBAttachOptions(BaseConfig):
 225    type: str
 226    path: str
 227    read_only: bool = False
 228
 229    def to_sql(self, alias: str) -> str:
 230        options = []
 231        # 'duckdb' is actually not a supported type, but we'd like to allow it for
 232        # fully qualified attach options or integration testing, similar to duckdb-dbt
 233        if self.type != "duckdb":
 234            options.append(f"TYPE {self.type.upper()}")
 235        if self.read_only:
 236            options.append("READ_ONLY")
 237        options_sql = f" ({', '.join(options)})" if options else ""
 238        return f"ATTACH '{self.path}' AS {alias}{options_sql}"
 239
 240
 241class DuckDBConnectionConfig(BaseDuckDBConnectionConfig):
 242    """Configuration for the DuckDB connection.
 243
 244    Args:
 245        database: The optional database name. If not specified, the in-memory database will be used.
 246        catalogs: Key is the name of the catalog and value is the path.
 247    """
 248
 249    database: t.Optional[str] = None
 250    catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None
 251
 252    type_: Literal["duckdb"] = Field(alias="type", default="duckdb")
 253
 254    _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {}
 255
 256    @model_validator(mode="before")
 257    @model_validator_v1_args
 258    def _validate_database_catalogs(
 259        cls, values: t.Dict[str, t.Optional[str]]
 260    ) -> t.Dict[str, t.Optional[str]]:
 261        if values.get("database") and values.get("catalogs"):
 262            raise ConfigError(
 263                "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog"
 264            )
 265        return values
 266
 267    @property
 268    def _connection_kwargs_keys(self) -> t.Set[str]:
 269        return {"database"}
 270
 271    def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter:
 272        """Checks if another engine adapter has already been created that shares a catalog that points to the same data
 273        file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration
 274        associated with the new adapter will be ignored."""
 275        data_files = set((self.catalogs or {}).values())
 276        if self.database:
 277            data_files.add(self.database)
 278        data_files.discard(":memory:")
 279        for data_file in data_files:
 280            key = data_file if isinstance(data_file, str) else data_file.path
 281            if adapter := DuckDBConnectionConfig._data_file_to_adapter.get(key):
 282                logger.info(f"Using existing DuckDB adapter due to overlapping data file: {key}")
 283                return adapter
 284
 285        if data_files:
 286            logger.info(f"Creating new DuckDB adapter for data files: {data_files}")
 287        else:
 288            logger.info("Creating new DuckDB adapter for in-memory database")
 289        adapter = super().create_engine_adapter(register_comments_override)
 290        for data_file in data_files:
 291            key = data_file if isinstance(data_file, str) else data_file.path
 292            DuckDBConnectionConfig._data_file_to_adapter[key] = adapter
 293        return adapter
 294
 295    def get_catalog(self) -> t.Optional[str]:
 296        if self.database:
 297            # Remove `:` from the database name in order to handle if `:memory:` is passed in
 298            return pathlib.Path(self.database.replace(":", "")).stem
 299        if self.catalogs:
 300            return list(self.catalogs)[0]
 301        return None
 302
 303
 304class SnowflakeConnectionConfig(ConnectionConfig):
 305    """Configuration for the Snowflake connection.
 306
 307    Args:
 308        account: The Snowflake account name.
 309        user: The Snowflake username.
 310        password: The Snowflake password.
 311        warehouse: The optional warehouse name.
 312        database: The optional database name.
 313        role: The optional role name.
 314        concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
 315        authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake").
 316                       Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183
 317        token: The optional oauth access token to use for authentication when authenticator is set to "oauth".
 318        private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation
 319        private_key_path: The optional path to the private key to use for authentication. This would be used instead of `private_key`.
 320        private_key_passphrase: The optional passphrase to use to decrypt `private_key` or `private_key_path`. Keys can be created without encryption so only provide this if needed.
 321        register_comments: Whether or not to register model comments with the SQL engine.
 322    """
 323
 324    account: str
 325    user: t.Optional[str] = None
 326    password: t.Optional[str] = None
 327    warehouse: t.Optional[str] = None
 328    database: t.Optional[str] = None
 329    role: t.Optional[str] = None
 330    authenticator: t.Optional[str] = None
 331    token: t.Optional[str] = None
 332
 333    # Private Key Auth
 334    private_key: t.Optional[t.Union[str, bytes]] = None
 335    private_key_path: t.Optional[str] = None
 336    private_key_passphrase: t.Optional[str] = None
 337
 338    concurrent_tasks: int = 4
 339    register_comments: bool = True
 340
 341    type_: Literal["snowflake"] = Field(alias="type", default="snowflake")
 342
 343    _concurrent_tasks_validator = concurrent_tasks_validator
 344
 345    @model_validator(mode="before")
 346    @model_validator_v1_args
 347    def _validate_authenticator(
 348        cls, values: t.Dict[str, t.Optional[str]]
 349    ) -> t.Dict[str, t.Optional[str]]:
 350        from snowflake.connector.network import (
 351            DEFAULT_AUTHENTICATOR,
 352            OAUTH_AUTHENTICATOR,
 353        )
 354
 355        auth = values.get("authenticator")
 356        auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR
 357        user = values.get("user")
 358        password = values.get("password")
 359        values["private_key"] = cls._get_private_key(values, auth)  # type: ignore
 360        if (
 361            auth == DEFAULT_AUTHENTICATOR
 362            and not values.get("private_key")
 363            and (not user or not password)
 364        ):
 365            raise ConfigError("User and password must be provided if using default authentication")
 366        if auth == OAUTH_AUTHENTICATOR and not values.get("token"):
 367            raise ConfigError("Token must be provided if using oauth authentication")
 368        return values
 369
 370    @classmethod
 371    def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]:
 372        """
 373        source: https://github.com/dbt-labs/dbt-snowflake/blob/0374b4ec948982f2ac8ec0c95d53d672ad19e09c/dbt/adapters/snowflake/connections.py#L247C5-L285C1
 374
 375        Overall code change: Use local variables instead of class attributes + Validation
 376        """
 377        # Start custom code
 378        from cryptography.hazmat.backends import default_backend
 379        from cryptography.hazmat.primitives import serialization
 380        from snowflake.connector.network import (
 381            DEFAULT_AUTHENTICATOR,
 382            KEY_PAIR_AUTHENTICATOR,
 383        )
 384
 385        private_key = values.get("private_key")
 386        private_key_path = values.get("private_key_path")
 387        private_key_passphrase = values.get("private_key_passphrase")
 388        user = values.get("user")
 389        password = values.get("password")
 390        auth = auth if auth and auth != DEFAULT_AUTHENTICATOR else KEY_PAIR_AUTHENTICATOR
 391
 392        if not private_key and not private_key_path:
 393            return None
 394        if private_key and private_key_path:
 395            raise ConfigError("Cannot specify both `private_key` and `private_key_path`")
 396        if auth != KEY_PAIR_AUTHENTICATOR:
 397            raise ConfigError(
 398                f"Private key or private key path can only be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
 399            )
 400        if not user:
 401            raise ConfigError(
 402                f"User must be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
 403            )
 404        if password:
 405            raise ConfigError(
 406                f"Password cannot be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
 407            )
 408
 409        if isinstance(private_key, bytes):
 410            return private_key
 411        # End Custom Code
 412
 413        if private_key_passphrase:
 414            encoded_passphrase = private_key_passphrase.encode()
 415        else:
 416            encoded_passphrase = None
 417
 418        if private_key:
 419            if private_key.startswith("-"):
 420                p_key = serialization.load_pem_private_key(
 421                    data=bytes(private_key, "utf-8"),
 422                    password=encoded_passphrase,
 423                    backend=default_backend(),
 424                )
 425
 426            else:
 427                p_key = serialization.load_der_private_key(
 428                    data=base64.b64decode(private_key),
 429                    password=encoded_passphrase,
 430                    backend=default_backend(),
 431                )
 432
 433        elif private_key_path:
 434            with open(private_key_path, "rb") as key:
 435                p_key = serialization.load_pem_private_key(
 436                    key.read(), password=encoded_passphrase, backend=default_backend()
 437                )
 438        else:
 439            return None
 440
 441        return p_key.private_bytes(
 442            encoding=serialization.Encoding.DER,
 443            format=serialization.PrivateFormat.PKCS8,
 444            encryption_algorithm=serialization.NoEncryption(),
 445        )
 446
 447    @property
 448    def _connection_kwargs_keys(self) -> t.Set[str]:
 449        return {
 450            "user",
 451            "password",
 452            "account",
 453            "warehouse",
 454            "database",
 455            "role",
 456            "authenticator",
 457            "token",
 458            "private_key",
 459        }
 460
 461    @property
 462    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 463        return engine_adapter.SnowflakeEngineAdapter
 464
 465    @property
 466    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
 467        return {"autocommit": False}
 468
 469    @property
 470    def _connection_factory(self) -> t.Callable:
 471        from snowflake import connector
 472
 473        return connector.connect
 474
 475
 476class DatabricksConnectionConfig(ConnectionConfig):
 477    """
 478    Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations
 479
 480    Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39
 481    Args:
 482        server_hostname: Databricks instance host name.
 483        http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef)
 484            or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
 485        access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
 486        catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in
 487            the Databricks cluster (most likely `hive_metastore`).
 488        http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request
 489        session_configuration: An optional dictionary of Spark session parameters.
 490            Execute the SQL command `SET -v` to get a full list of available commands.
 491        databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect.
 492            Defaults to the `server_hostname` value.
 493        databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect.
 494            Defaults to the `access_token` value.
 495        databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect.
 496            Defaults to deriving the cluster id from the `http_path` value.
 497        force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector.
 498        disable_databricks_connect: Even if databricks connect is installed, do not use it.
 499    """
 500
 501    server_hostname: t.Optional[str] = None
 502    http_path: t.Optional[str] = None
 503    access_token: t.Optional[str] = None
 504    catalog: t.Optional[str] = None
 505    http_headers: t.Optional[t.List[t.Tuple[str, str]]] = None
 506    session_configuration: t.Optional[t.Dict[str, t.Any]] = None
 507    databricks_connect_server_hostname: t.Optional[str] = None
 508    databricks_connect_access_token: t.Optional[str] = None
 509    databricks_connect_cluster_id: t.Optional[str] = None
 510    force_databricks_connect: bool = False
 511    disable_databricks_connect: bool = False
 512
 513    concurrent_tasks: int = 1
 514    register_comments: bool = True
 515
 516    type_: Literal["databricks"] = Field(alias="type", default="databricks")
 517
 518    _concurrent_tasks_validator = concurrent_tasks_validator
 519    _http_headers_validator = http_headers_validator
 520
 521    @model_validator(mode="before")
 522    @model_validator_v1_args
 523    def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
 524        from sqlmesh import RuntimeEnv
 525        from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
 526
 527        runtime_env = RuntimeEnv.get()
 528
 529        if runtime_env.is_databricks:
 530            return values
 531        server_hostname, http_path, access_token = (
 532            values.get("server_hostname"),
 533            values.get("http_path"),
 534            values.get("access_token"),
 535        )
 536        if not server_hostname or not http_path or not access_token:
 537            raise ValueError(
 538                "`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook"
 539            )
 540        if DatabricksEngineAdapter.can_access_spark_session:
 541            if not values.get("databricks_connect_server_hostname"):
 542                values["databricks_connect_server_hostname"] = f"https://{server_hostname}"
 543            if not values.get("databricks_connect_access_token"):
 544                values["databricks_connect_access_token"] = access_token
 545            if not values.get("databricks_connect_cluster_id"):
 546                values["databricks_connect_cluster_id"] = http_path.split("/")[-1]
 547        if not values.get("session_configuration"):
 548            values["session_configuration"] = {}
 549        values["session_configuration"]["spark.sql.sources.partitionOverwriteMode"] = "dynamic"
 550        return values
 551
 552    @property
 553    def _connection_kwargs_keys(self) -> t.Set[str]:
 554        if self.use_spark_session_only:
 555            return set()
 556        return {
 557            "server_hostname",
 558            "http_path",
 559            "access_token",
 560            "http_headers",
 561            "session_configuration",
 562            "catalog",
 563        }
 564
 565    @property
 566    def _engine_adapter(self) -> t.Type[engine_adapter.DatabricksEngineAdapter]:
 567        return engine_adapter.DatabricksEngineAdapter
 568
 569    @property
 570    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
 571        return {
 572            k: v
 573            for k, v in self.dict().items()
 574            if k.startswith("databricks_connect_") or k in ("catalog", "disable_databricks_connect")
 575        }
 576
 577    @property
 578    def use_spark_session_only(self) -> bool:
 579        from sqlmesh import RuntimeEnv
 580
 581        return RuntimeEnv.get().is_databricks or self.force_databricks_connect
 582
 583    @property
 584    def _connection_factory(self) -> t.Callable:
 585        if self.use_spark_session_only:
 586            from sqlmesh.engines.spark.db_api.spark_session import connection
 587
 588            return connection
 589
 590        from databricks import sql
 591
 592        return sql.connect
 593
 594    @property
 595    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
 596        from sqlmesh import RuntimeEnv
 597
 598        if not self.use_spark_session_only:
 599            return {}
 600
 601        if RuntimeEnv.get().is_databricks:
 602            from pyspark.sql import SparkSession
 603
 604            return dict(
 605                spark=SparkSession.getActiveSession(),
 606                catalog=self.catalog,
 607            )
 608
 609        from databricks.connect import DatabricksSession
 610
 611        return dict(
 612            spark=DatabricksSession.builder.remote(
 613                host=self.databricks_connect_server_hostname,
 614                token=self.databricks_connect_access_token,
 615                cluster_id=self.databricks_connect_cluster_id,
 616            ).getOrCreate(),
 617            catalog=self.catalog,
 618        )
 619
 620
 621class BigQueryConnectionMethod(str, Enum):
 622    OAUTH = "oauth"
 623    OAUTH_SECRETS = "oauth-secrets"
 624    SERVICE_ACCOUNT = "service-account"
 625    SERVICE_ACCOUNT_JSON = "service-account-json"
 626
 627
 628class BigQueryPriority(str, Enum):
 629    BATCH = "batch"
 630    INTERACTIVE = "interactive"
 631
 632    @property
 633    def is_batch(self) -> bool:
 634        return self == self.BATCH
 635
 636    @property
 637    def is_interactive(self) -> bool:
 638        return self == self.INTERACTIVE
 639
 640    @property
 641    def bigquery_constant(self) -> str:
 642        from google.cloud.bigquery import QueryPriority
 643
 644        if self.is_batch:
 645            return QueryPriority.BATCH
 646        return QueryPriority.INTERACTIVE
 647
 648
 649class BigQueryConnectionConfig(ConnectionConfig):
 650    """
 651    BigQuery Connection Configuration.
 652    """
 653
 654    method: BigQueryConnectionMethod = BigQueryConnectionMethod.OAUTH
 655
 656    project: t.Optional[str] = None
 657    execution_project: t.Optional[str] = None
 658    location: t.Optional[str] = None
 659    # Keyfile Auth
 660    keyfile: t.Optional[str] = None
 661    keyfile_json: t.Optional[t.Dict[str, t.Any]] = None
 662    # Oath Secret Auth
 663    token: t.Optional[str] = None
 664    refresh_token: t.Optional[str] = None
 665    client_id: t.Optional[str] = None
 666    client_secret: t.Optional[str] = None
 667    token_uri: t.Optional[str] = None
 668    scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/bigquery",)
 669    job_creation_timeout_seconds: t.Optional[int] = None
 670    # Extra Engine Config
 671    job_execution_timeout_seconds: t.Optional[int] = None
 672    job_retries: t.Optional[int] = 1
 673    job_retry_deadline_seconds: t.Optional[int] = None
 674    priority: t.Optional[BigQueryPriority] = None
 675    maximum_bytes_billed: t.Optional[int] = None
 676
 677    concurrent_tasks: int = 1
 678    register_comments: bool = True
 679
 680    type_: Literal["bigquery"] = Field(alias="type", default="bigquery")
 681
 682    @property
 683    def _connection_kwargs_keys(self) -> t.Set[str]:
 684        return set()
 685
 686    @property
 687    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 688        return engine_adapter.BigQueryEngineAdapter
 689
 690    @property
 691    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
 692        """The static connection kwargs for this connection"""
 693        import google.auth
 694        from google.api_core import client_info
 695        from google.oauth2 import credentials, service_account
 696
 697        if self.method == BigQueryConnectionMethod.OAUTH:
 698            creds, _ = google.auth.default(scopes=self.scopes)
 699        elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT:
 700            creds = service_account.Credentials.from_service_account_file(
 701                self.keyfile, scopes=self.scopes
 702            )
 703        elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON:
 704            creds = service_account.Credentials.from_service_account_info(
 705                self.keyfile_json, scopes=self.scopes
 706            )
 707        elif self.method == BigQueryConnectionMethod.OAUTH_SECRETS:
 708            creds = credentials.Credentials(
 709                token=self.token,
 710                refresh_token=self.refresh_token,
 711                client_id=self.client_id,
 712                client_secret=self.client_secret,
 713                token_uri=self.token_uri,
 714                scopes=self.scopes,
 715            )
 716        else:
 717            raise ConfigError("Invalid BigQuery Connection Method")
 718        client = google.cloud.bigquery.Client(
 719            project=self.execution_project or self.project,
 720            credentials=creds,
 721            location=self.location,
 722            client_info=client_info.ClientInfo(user_agent="sqlmesh"),
 723        )
 724
 725        return {
 726            "client": client,
 727        }
 728
 729    @property
 730    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
 731        return {
 732            k: v
 733            for k, v in self.dict().items()
 734            if k
 735            in {
 736                "job_creation_timeout_seconds",
 737                "job_execution_timeout_seconds",
 738                "job_retries",
 739                "job_retry_deadline_seconds",
 740                "priority",
 741                "maximum_bytes_billed",
 742            }
 743        }
 744
 745    @property
 746    def _connection_factory(self) -> t.Callable:
 747        from google.cloud.bigquery.dbapi import connect
 748
 749        return connect
 750
 751    def get_catalog(self) -> t.Optional[str]:
 752        return self.project
 753
 754
 755class GCPPostgresConnectionConfig(ConnectionConfig):
 756    """
 757    Postgres Connection Configuration for GCP.
 758
 759    Args:
 760        instance_connection_string: Connection name for the postgres instance.
 761        user: Postgres or IAM user's name
 762        password: The postgres user's password. Only needed when the user is a postgres user.
 763        enable_iam_auth: Set to True when user is an IAM user.
 764        db: Name of the db to connect to.
 765    """
 766
 767    instance_connection_string: str
 768    user: str
 769    password: t.Optional[str] = None
 770    enable_iam_auth: t.Optional[bool] = None
 771    db: str
 772    timeout: t.Optional[int] = None
 773
 774    driver: str = "pg8000"
 775    type_: Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres")
 776    concurrent_tasks: int = 4
 777    register_comments: bool = True
 778
 779    @model_validator(mode="before")
 780    @model_validator_v1_args
 781    def _validate_auth_method(
 782        cls, values: t.Dict[str, t.Optional[str]]
 783    ) -> t.Dict[str, t.Optional[str]]:
 784        password = values.get("password")
 785        enable_iam_auth = values.get("enable_iam_auth")
 786        if password and enable_iam_auth:
 787            raise ConfigError(
 788                "Invalid GCP Postgres connection configuration - both password and"
 789                " enable_iam_auth set. Use password when connecting to a postgres"
 790                " user and enable_iam_auth 'True' when connecting to an IAM user."
 791            )
 792        if not password and not enable_iam_auth:
 793            raise ConfigError(
 794                "GCP Postgres connection configuration requires either password set"
 795                " for a postgres user account or enable_iam_auth set to 'True'"
 796                " for an IAM user account."
 797            )
 798        return values
 799
 800    @property
 801    def _connection_kwargs_keys(self) -> t.Set[str]:
 802        return {
 803            "instance_connection_string",
 804            "driver",
 805            "user",
 806            "password",
 807            "db",
 808            "enable_iam_auth",
 809            "timeout",
 810        }
 811
 812    @property
 813    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 814        return engine_adapter.PostgresEngineAdapter
 815
 816    @property
 817    def _connection_factory(self) -> t.Callable:
 818        from google.cloud.sql.connector import Connector
 819
 820        return Connector().connect
 821
 822
 823class RedshiftConnectionConfig(ConnectionConfig):
 824    """
 825    Redshift Connection Configuration.
 826
 827    Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146
 828    Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported.
 829
 830    Args:
 831        user: The username to use for authentication with the Amazon Redshift cluster.
 832        password: The password to use for authentication with the Amazon Redshift cluster.
 833        database: The name of the database instance to connect to.
 834        host: The hostname of the Amazon Redshift cluster.
 835        port: The port number of the Amazon Redshift cluster. Default value is 5439.
 836        source_address: No description provided
 837        unix_sock: No description provided
 838        ssl: Is SSL enabled. Default value is ``True``. SSL must be enabled when authenticating using IAM.
 839        sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported.
 840        timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout.
 841        tcp_keepalive: Is `TCP keepalive <https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive>`_ used. The default value is ``True``.
 842        application_name: Sets the application name. The default value is None.
 843        preferred_role: The IAM role preferred for the current connection.
 844        principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy.
 845        credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster.
 846        region: The AWS region where the Amazon Redshift cluster is located.
 847        cluster_identifier: The cluster identifier of the Amazon Redshift cluster.
 848        iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP.
 849        is_serverless: Redshift end-point is serverless or provisional. Default value false.
 850        serverless_acct_id: The account ID of the serverless. Default value None
 851        serverless_work_group: The name of work group for serverless end point. Default value None.
 852    """
 853
 854    user: t.Optional[str] = None
 855    password: t.Optional[str] = None
 856    database: t.Optional[str] = None
 857    host: t.Optional[str] = None
 858    port: t.Optional[int] = None
 859    source_address: t.Optional[str] = None
 860    unix_sock: t.Optional[str] = None
 861    ssl: t.Optional[bool] = None
 862    sslmode: t.Optional[str] = None
 863    timeout: t.Optional[int] = None
 864    tcp_keepalive: t.Optional[bool] = None
 865    application_name: t.Optional[str] = None
 866    preferred_role: t.Optional[str] = None
 867    principal_arn: t.Optional[str] = None
 868    credentials_provider: t.Optional[str] = None
 869    region: t.Optional[str] = None
 870    cluster_identifier: t.Optional[str] = None
 871    iam: t.Optional[bool] = None
 872    is_serverless: t.Optional[bool] = None
 873    serverless_acct_id: t.Optional[str] = None
 874    serverless_work_group: t.Optional[str] = None
 875
 876    concurrent_tasks: int = 4
 877    register_comments: bool = True
 878
 879    type_: Literal["redshift"] = Field(alias="type", default="redshift")
 880
 881    @property
 882    def _connection_kwargs_keys(self) -> t.Set[str]:
 883        return {
 884            "user",
 885            "password",
 886            "database",
 887            "host",
 888            "port",
 889            "source_address",
 890            "unix_sock",
 891            "ssl",
 892            "sslmode",
 893            "timeout",
 894            "tcp_keepalive",
 895            "application_name",
 896            "preferred_role",
 897            "principal_arn",
 898            "credentials_provider",
 899            "region",
 900            "cluster_identifier",
 901            "iam",
 902            "is_serverless",
 903            "serverless_acct_id",
 904            "serverless_work_group",
 905        }
 906
 907    @property
 908    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 909        return engine_adapter.RedshiftEngineAdapter
 910
 911    @property
 912    def _connection_factory(self) -> t.Callable:
 913        from redshift_connector import connect
 914
 915        return connect
 916
 917
 918class PostgresConnectionConfig(ConnectionConfig):
 919    host: str
 920    user: str
 921    password: str
 922    port: int
 923    database: str
 924    keepalives_idle: t.Optional[int] = None
 925    connect_timeout: int = 10
 926    role: t.Optional[str] = None
 927    sslmode: t.Optional[str] = None
 928
 929    concurrent_tasks: int = 4
 930    register_comments: bool = True
 931
 932    type_: Literal["postgres"] = Field(alias="type", default="postgres")
 933
 934    @property
 935    def _connection_kwargs_keys(self) -> t.Set[str]:
 936        return {
 937            "host",
 938            "user",
 939            "password",
 940            "port",
 941            "database",
 942            "keepalives_idle",
 943            "connect_timeout",
 944            "role",
 945            "sslmode",
 946        }
 947
 948    @property
 949    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 950        return engine_adapter.PostgresEngineAdapter
 951
 952    @property
 953    def _connection_factory(self) -> t.Callable:
 954        from psycopg2 import connect
 955
 956        return connect
 957
 958
 959class MySQLConnectionConfig(ConnectionConfig):
 960    host: str
 961    user: str
 962    password: str
 963    port: t.Optional[int] = None
 964    charset: t.Optional[str] = None
 965    ssl_disabled: t.Optional[bool] = None
 966
 967    concurrent_tasks: int = 4
 968    register_comments: bool = True
 969
 970    type_: Literal["mysql"] = Field(alias="type", default="mysql")
 971
 972    @property
 973    def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]:
 974        """Key-value arguments that will be passed during cursor construction."""
 975        return {"buffered": True}
 976
 977    @property
 978    def _connection_kwargs_keys(self) -> t.Set[str]:
 979        connection_keys = {
 980            "host",
 981            "user",
 982            "password",
 983            "port",
 984            "database",
 985        }
 986        if self.port is not None:
 987            connection_keys.add("port")
 988        if self.charset is not None:
 989            connection_keys.add("charset")
 990        if self.ssl_disabled is not None:
 991            connection_keys.add("ssl_disabled")
 992        return connection_keys
 993
 994    @property
 995    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 996        return engine_adapter.MySQLEngineAdapter
 997
 998    @property
 999    def _connection_factory(self) -> t.Callable:
1000        from mysql.connector import connect
1001
1002        return connect
1003
1004
1005class MSSQLConnectionConfig(ConnectionConfig):
1006    host: str
1007    user: t.Optional[str] = None
1008    password: t.Optional[str] = None
1009    database: t.Optional[str] = ""
1010    timeout: t.Optional[int] = 0
1011    login_timeout: t.Optional[int] = 60
1012    charset: t.Optional[str] = "UTF-8"
1013    appname: t.Optional[str] = None
1014    port: t.Optional[int] = 1433
1015    conn_properties: t.Optional[t.Union[t.Iterable[str], str]] = None
1016    autocommit: t.Optional[bool] = False
1017    tds_version: t.Optional[str] = None
1018
1019    concurrent_tasks: int = 4
1020    register_comments: bool = True
1021
1022    type_: Literal["mssql"] = Field(alias="type", default="mssql")
1023
1024    @property
1025    def _connection_kwargs_keys(self) -> t.Set[str]:
1026        return {
1027            "host",
1028            "user",
1029            "password",
1030            "database",
1031            "timeout",
1032            "login_timeout",
1033            "charset",
1034            "appname",
1035            "port",
1036            "conn_properties",
1037            "autocommit",
1038            "tds_version",
1039        }
1040
1041    @property
1042    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1043        return engine_adapter.MSSQLEngineAdapter
1044
1045    @property
1046    def _connection_factory(self) -> t.Callable:
1047        import pymssql
1048
1049        return pymssql.connect
1050
1051
1052class SparkConnectionConfig(ConnectionConfig):
1053    """
1054    Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks.
1055    """
1056
1057    config_dir: t.Optional[str] = None
1058    catalog: t.Optional[str] = None
1059    config: t.Dict[str, t.Any] = {}
1060
1061    concurrent_tasks: int = 4
1062    register_comments: bool = True
1063
1064    type_: Literal["spark"] = Field(alias="type", default="spark")
1065
1066    @property
1067    def _connection_kwargs_keys(self) -> t.Set[str]:
1068        return {
1069            "catalog",
1070        }
1071
1072    @property
1073    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1074        return engine_adapter.SparkEngineAdapter
1075
1076    @property
1077    def _connection_factory(self) -> t.Callable:
1078        from sqlmesh.engines.spark.db_api.spark_session import connection
1079
1080        return connection
1081
1082    @property
1083    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
1084        from pyspark.conf import SparkConf
1085        from pyspark.sql import SparkSession
1086
1087        spark_config = SparkConf()
1088        if self.config:
1089            for k, v in self.config.items():
1090                spark_config.set(k, v)
1091
1092        if self.config_dir:
1093            os.environ["SPARK_CONF_DIR"] = self.config_dir
1094        return {
1095            "spark": SparkSession.builder.config(conf=spark_config)
1096            .enableHiveSupport()
1097            .getOrCreate(),
1098        }
1099
1100
1101class TrinoAuthenticationMethod(str, Enum):
1102    NO_AUTH = "no-auth"
1103    BASIC = "basic"
1104    LDAP = "ldap"
1105    KERBEROS = "kerberos"
1106    JWT = "jwt"
1107    CERTIFICATE = "certificate"
1108    OAUTH = "oauth"
1109
1110    @property
1111    def is_no_auth(self) -> bool:
1112        return self == self.NO_AUTH
1113
1114    @property
1115    def is_basic(self) -> bool:
1116        return self == self.BASIC
1117
1118    @property
1119    def is_ldap(self) -> bool:
1120        return self == self.LDAP
1121
1122    @property
1123    def is_kerberos(self) -> bool:
1124        return self == self.KERBEROS
1125
1126    @property
1127    def is_jwt(self) -> bool:
1128        return self == self.JWT
1129
1130    @property
1131    def is_certificate(self) -> bool:
1132        return self == self.CERTIFICATE
1133
1134    @property
1135    def is_oauth(self) -> bool:
1136        return self == self.OAUTH
1137
1138
1139class TrinoConnectionConfig(ConnectionConfig):
1140    method: TrinoAuthenticationMethod = TrinoAuthenticationMethod.NO_AUTH
1141    host: str
1142    user: str
1143    catalog: str
1144    port: t.Optional[int] = None
1145    http_scheme: Literal["http", "https"] = "https"
1146    # General Optional
1147    roles: t.Optional[t.Dict[str, str]] = None
1148    http_headers: t.Optional[t.Dict[str, str]] = None
1149    session_properties: t.Optional[t.Dict[str, str]] = None
1150    retries: int = 3
1151    timezone: t.Optional[str] = None
1152    # Basic/LDAP
1153    password: t.Optional[str] = None
1154    # LDAP
1155    impersonation_user: t.Optional[str] = None
1156    # Kerberos
1157    keytab: t.Optional[str] = None
1158    krb5_config: t.Optional[str] = None
1159    principal: t.Optional[str] = None
1160    service_name: str = "trino"
1161    hostname_override: t.Optional[str] = None
1162    mutual_authentication: bool = False
1163    force_preemptive: bool = False
1164    sanitize_mutual_error_response: bool = True
1165    delegate: bool = False
1166    # JWT
1167    jwt_token: t.Optional[str] = None
1168    # Certificate
1169    client_certificate: t.Optional[str] = None
1170    client_private_key: t.Optional[str] = None
1171    cert: t.Optional[str] = None
1172
1173    concurrent_tasks: int = 4
1174    register_comments: bool = True
1175
1176    type_: Literal["trino"] = Field(alias="type", default="trino")
1177
1178    @model_validator(mode="after")
1179    @model_validator_v1_args
1180    def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
1181        port = values.get("port")
1182        if (
1183            values["http_scheme"] == "http"
1184            and not values["method"].is_no_auth
1185            and not values["method"].is_basic
1186        ):
1187            raise ConfigError("HTTP scheme can only be used with no-auth or basic method")
1188        if port is None:
1189            values["port"] = 80 if values["http_scheme"] == "http" else 443
1190        if (values["method"].is_ldap or values["method"].is_basic) and (
1191            not values["password"] or not values["user"]
1192        ):
1193            raise ConfigError(
1194                f"Username and Password must be provided if using {values['method'].value} authentication"
1195            )
1196        if values["method"].is_kerberos and (
1197            not values["principal"] or not values["keytab"] or not values["krb5_config"]
1198        ):
1199            raise ConfigError(
1200                "Kerberos requires the following fields: principal, keytab, and krb5_config"
1201            )
1202        if values["method"].is_jwt and not values["jwt_token"]:
1203            raise ConfigError("JWT requires `jwt_token` to be set")
1204        if values["method"].is_certificate and (
1205            not values["cert"]
1206            or not values["client_certificate"]
1207            or not values["client_private_key"]
1208        ):
1209            raise ConfigError(
1210                "Certificate requires the following fields: cert, client_certificate, and client_private_key"
1211            )
1212        return values
1213
1214    @property
1215    def _connection_kwargs_keys(self) -> t.Set[str]:
1216        kwargs = {
1217            "host",
1218            "port",
1219            "catalog",
1220            "roles",
1221            "http_scheme",
1222            "http_headers",
1223            "session_properties",
1224            "timezone",
1225        }
1226        return kwargs
1227
1228    @property
1229    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1230        return engine_adapter.TrinoEngineAdapter
1231
1232    @property
1233    def _connection_factory(self) -> t.Callable:
1234        from trino.dbapi import connect
1235
1236        return connect
1237
1238    @property
1239    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
1240        from trino.auth import (
1241            BasicAuthentication,
1242            CertificateAuthentication,
1243            JWTAuthentication,
1244            KerberosAuthentication,
1245            OAuth2Authentication,
1246        )
1247
1248        if self.method.is_basic or self.method.is_ldap:
1249            auth = BasicAuthentication(self.user, self.password)
1250        elif self.method.is_kerberos:
1251            if self.keytab:
1252                os.environ["KRB5_CLIENT_KTNAME"] = self.keytab
1253            auth = KerberosAuthentication(
1254                config=self.krb5_config,
1255                service_name=self.service_name,
1256                principal=self.principal,
1257                mutual_authentication=self.mutual_authentication,
1258                ca_bundle=self.cert,
1259                force_preemptive=self.force_preemptive,
1260                hostname_override=self.hostname_override,
1261                sanitize_mutual_error_response=self.sanitize_mutual_error_response,
1262                delegate=self.delegate,
1263            )
1264        elif self.method.is_oauth:
1265            auth = OAuth2Authentication()
1266        elif self.method.is_jwt:
1267            auth = JWTAuthentication(self.jwt_token)
1268        elif self.method.is_certificate:
1269            auth = CertificateAuthentication(self.client_certificate, self.client_private_key)
1270        else:
1271            auth = None
1272
1273        return {
1274            "auth": auth,
1275            "user": self.impersonation_user or self.user,
1276            "max_attempts": self.retries,
1277            "verify": self.cert,
1278            "source": "sqlmesh",
1279        }
1280
1281
1282CONNECTION_CONFIG_TO_TYPE = {
1283    # Map all subclasses of ConnectionConfig to the value of their `type_` field.
1284    tpe.all_field_infos()["type_"].default: tpe
1285    for tpe in subclasses(
1286        __name__, ConnectionConfig, exclude=(ConnectionConfig, BaseDuckDBConnectionConfig)
1287    )
1288}
1289
1290
1291def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig:
1292    if "type" not in v:
1293        raise ConfigError("Missing connection type.")
1294
1295    connection_type = v["type"]
1296    if connection_type not in CONNECTION_CONFIG_TO_TYPE:
1297        raise ConfigError(f"Unknown connection type '{connection_type}'.")
1298
1299    return CONNECTION_CONFIG_TO_TYPE[connection_type](**v)
1300
1301
1302def _connection_config_validator(
1303    cls: t.Type, v: ConnectionConfig | t.Dict[str, t.Any] | None
1304) -> ConnectionConfig | None:
1305    if v is None or isinstance(v, ConnectionConfig):
1306        return v
1307    return parse_connection_config(v)
1308
1309
1310connection_config_validator = field_validator(
1311    "connection",
1312    "state_connection",
1313    "test_connection",
1314    "default_connection",
1315    "default_test_connection",
1316    mode="before",
1317    check_fields=False,
1318)(_connection_config_validator)
1319
1320
1321if t.TYPE_CHECKING:
1322    # TypeAlias hasn't been introduced until Python 3.10 which means that we can't use it
1323    # outside the TYPE_CHECKING guard.
1324    SerializableConnectionConfig: t.TypeAlias = ConnectionConfig  # type: ignore
1325elif PYDANTIC_MAJOR_VERSION >= 2:
1326    import pydantic
1327
1328    # Workaround for https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
1329    SerializableConnectionConfig = pydantic.SerializeAsAny[ConnectionConfig]  # type: ignore
1330else:
1331    SerializableConnectionConfig = ConnectionConfig  # type: ignore
class ConnectionConfig(abc.ABC, sqlmesh.core.config.base.BaseConfig):
 44class ConnectionConfig(abc.ABC, BaseConfig):
 45    type_: str
 46    concurrent_tasks: int
 47    register_comments: bool
 48
 49    @property
 50    @abc.abstractmethod
 51    def _connection_kwargs_keys(self) -> t.Set[str]:
 52        """keywords that should be passed into the connection"""
 53
 54    @property
 55    @abc.abstractmethod
 56    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 57        """The engine adapter for this connection"""
 58
 59    @property
 60    @abc.abstractmethod
 61    def _connection_factory(self) -> t.Callable:
 62        """A function that is called to return a connection object for the given Engine Adapter"""
 63
 64    @property
 65    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
 66        """The static connection kwargs for this connection"""
 67        return {}
 68
 69    @property
 70    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
 71        """kwargs that are for execution config only"""
 72        return {}
 73
 74    @property
 75    def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]:
 76        """Key-value arguments that will be passed during cursor construction."""
 77        return None
 78
 79    @property
 80    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
 81        """A function that is called to initialize the cursor"""
 82        return None
 83
 84    @property
 85    def is_recommended_for_state_sync(self) -> bool:
 86        """Whether this connection is recommended for being used as a state sync for production state syncs"""
 87        return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES
 88
 89    @property
 90    def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]:
 91        """A function that is called to return a connection object for the given Engine Adapter"""
 92        return partial(
 93            self._connection_factory,
 94            **{
 95                **self._static_connection_kwargs,
 96                **{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys},
 97            },
 98        )
 99
100    def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter:
101        """Returns a new instance of the Engine Adapter."""
102        return self._engine_adapter(
103            self._connection_factory_with_kwargs,
104            multithreaded=self.concurrent_tasks > 1,
105            cursor_kwargs=self._cursor_kwargs,
106            default_catalog=self.get_catalog(),
107            cursor_init=self._cursor_init,
108            register_comments=register_comments_override or self.register_comments,
109            **self._extra_engine_config,
110        )
111
112    def get_catalog(self) -> t.Optional[str]:
113        """The catalog for this connection"""
114        if hasattr(self, "catalog"):
115            return self.catalog
116        if hasattr(self, "database"):
117            return self.database
118        if hasattr(self, "db"):
119            return self.db
120        return None

Helper class that provides a standard way to create an ABC using inheritance.

def create_engine_adapter( self, register_comments_override: bool = False) -> sqlmesh.core.engine_adapter.base.EngineAdapter:
100    def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter:
101        """Returns a new instance of the Engine Adapter."""
102        return self._engine_adapter(
103            self._connection_factory_with_kwargs,
104            multithreaded=self.concurrent_tasks > 1,
105            cursor_kwargs=self._cursor_kwargs,
106            default_catalog=self.get_catalog(),
107            cursor_init=self._cursor_init,
108            register_comments=register_comments_override or self.register_comments,
109            **self._extra_engine_config,
110        )

Returns a new instance of the Engine Adapter.

def get_catalog(self) -> Union[str, NoneType]:
112    def get_catalog(self) -> t.Optional[str]:
113        """The catalog for this connection"""
114        if hasattr(self, "catalog"):
115            return self.catalog
116        if hasattr(self, "database"):
117            return self.database
118        if hasattr(self, "db"):
119            return self.db
120        return None

The catalog for this connection

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class BaseDuckDBConnectionConfig(ConnectionConfig):
123class BaseDuckDBConnectionConfig(ConnectionConfig):
124    """Common configuration for the DuckDB-based connections.
125
126    Args:
127        extensions: A list of autoloadable extensions to load.
128        connector_config: A dictionary of configuration to pass into the duckdb connector.
129        concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
130        register_comments: Whether or not to register model comments with the SQL engine.
131    """
132
133    extensions: t.List[str] = []
134    connector_config: t.Dict[str, t.Any] = {}
135
136    concurrent_tasks: Literal[1] = 1
137    register_comments: bool = True
138
139    @property
140    def _engine_adapter(self) -> t.Type[EngineAdapter]:
141        return engine_adapter.DuckDBEngineAdapter
142
143    @property
144    def _connection_factory(self) -> t.Callable:
145        import duckdb
146
147        return duckdb.connect
148
149    @property
150    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
151        """A function that is called to initialize the cursor"""
152        import duckdb
153        from duckdb import BinderException
154
155        def init(cursor: duckdb.DuckDBPyConnection) -> None:
156            for extension in self.extensions:
157                try:
158                    cursor.execute(f"INSTALL {extension}")
159                    cursor.execute(f"LOAD {extension}")
160                except Exception as e:
161                    raise ConfigError(f"Failed to load extension {extension}: {e}")
162
163            for field, setting in self.connector_config.items():
164                try:
165                    cursor.execute(f"SET {field} = '{setting}'")
166                except Exception as e:
167                    raise ConfigError(f"Failed to set connector config {field} to {setting}: {e}")
168
169            for i, (alias, path_options) in enumerate(
170                (getattr(self, "catalogs", None) or {}).items()
171            ):
172                # we parse_identifier and generate to ensure that `alias` has exactly one set of quotes
173                # regardless of whether it comes in quoted or not
174                alias = exp.parse_identifier(alias, dialect="duckdb").sql(
175                    identify=True, dialect="duckdb"
176                )
177                try:
178                    query = (
179                        path_options.to_sql(alias)
180                        if isinstance(path_options, DuckDBAttachOptions)
181                        else f"ATTACH '{path_options}' AS {alias}"
182                    )
183                    cursor.execute(query)
184                except BinderException as e:
185                    # If a user tries to create a catalog pointing at `:memory:` and with the name `memory`
186                    # then we don't want to raise since this happens by default. They are just doing this to
187                    # set it as the default catalog.
188                    if not (
189                        'database with name "memory" already exists' in str(e)
190                        and path_options == ":memory:"
191                    ):
192                        raise e
193                if i == 0 and not getattr(self, "database", None):
194                    cursor.execute(f"USE {alias}")
195
196        return init

Common configuration for the DuckDB-based connections.

Arguments:
  • extensions: A list of autoloadable extensions to load.
  • connector_config: A dictionary of configuration to pass into the duckdb connector.
  • concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
  • register_comments: Whether or not to register model comments with the SQL engine.
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig):
199class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig):
200    """Configuration for the MotherDuck connection.
201
202    Args:
203        database: The database name.
204        token: The optional MotherDuck token. If not specified, the user will be prompted to login with their web browser.
205    """
206
207    database: str
208    token: t.Optional[str] = None
209
210    type_: Literal["motherduck"] = Field(alias="type", default="motherduck")
211
212    @property
213    def _connection_kwargs_keys(self) -> t.Set[str]:
214        return set()
215
216    @property
217    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
218        """kwargs that are for execution config only"""
219        connection_str = f"md:{self.database}"
220        if self.token:
221            connection_str += f"?motherduck_token={self.token}"
222        return {"database": connection_str}

Configuration for the MotherDuck connection.

Arguments:
  • database: The database name.
  • token: The optional MotherDuck token. If not specified, the user will be prompted to login with their web browser.
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class DuckDBAttachOptions(sqlmesh.core.config.base.BaseConfig):
225class DuckDBAttachOptions(BaseConfig):
226    type: str
227    path: str
228    read_only: bool = False
229
230    def to_sql(self, alias: str) -> str:
231        options = []
232        # 'duckdb' is actually not a supported type, but we'd like to allow it for
233        # fully qualified attach options or integration testing, similar to duckdb-dbt
234        if self.type != "duckdb":
235            options.append(f"TYPE {self.type.upper()}")
236        if self.read_only:
237            options.append("READ_ONLY")
238        options_sql = f" ({', '.join(options)})" if options else ""
239        return f"ATTACH '{self.path}' AS {alias}{options_sql}"

Base configuration functionality for configuration classes.

def to_sql(self, alias: str) -> str:
230    def to_sql(self, alias: str) -> str:
231        options = []
232        # 'duckdb' is actually not a supported type, but we'd like to allow it for
233        # fully qualified attach options or integration testing, similar to duckdb-dbt
234        if self.type != "duckdb":
235            options.append(f"TYPE {self.type.upper()}")
236        if self.read_only:
237            options.append("READ_ONLY")
238        options_sql = f" ({', '.join(options)})" if options else ""
239        return f"ATTACH '{self.path}' AS {alias}{options_sql}"
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class DuckDBConnectionConfig(BaseDuckDBConnectionConfig):
242class DuckDBConnectionConfig(BaseDuckDBConnectionConfig):
243    """Configuration for the DuckDB connection.
244
245    Args:
246        database: The optional database name. If not specified, the in-memory database will be used.
247        catalogs: Key is the name of the catalog and value is the path.
248    """
249
250    database: t.Optional[str] = None
251    catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None
252
253    type_: Literal["duckdb"] = Field(alias="type", default="duckdb")
254
255    _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {}
256
257    @model_validator(mode="before")
258    @model_validator_v1_args
259    def _validate_database_catalogs(
260        cls, values: t.Dict[str, t.Optional[str]]
261    ) -> t.Dict[str, t.Optional[str]]:
262        if values.get("database") and values.get("catalogs"):
263            raise ConfigError(
264                "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog"
265            )
266        return values
267
268    @property
269    def _connection_kwargs_keys(self) -> t.Set[str]:
270        return {"database"}
271
272    def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter:
273        """Checks if another engine adapter has already been created that shares a catalog that points to the same data
274        file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration
275        associated with the new adapter will be ignored."""
276        data_files = set((self.catalogs or {}).values())
277        if self.database:
278            data_files.add(self.database)
279        data_files.discard(":memory:")
280        for data_file in data_files:
281            key = data_file if isinstance(data_file, str) else data_file.path
282            if adapter := DuckDBConnectionConfig._data_file_to_adapter.get(key):
283                logger.info(f"Using existing DuckDB adapter due to overlapping data file: {key}")
284                return adapter
285
286        if data_files:
287            logger.info(f"Creating new DuckDB adapter for data files: {data_files}")
288        else:
289            logger.info("Creating new DuckDB adapter for in-memory database")
290        adapter = super().create_engine_adapter(register_comments_override)
291        for data_file in data_files:
292            key = data_file if isinstance(data_file, str) else data_file.path
293            DuckDBConnectionConfig._data_file_to_adapter[key] = adapter
294        return adapter
295
296    def get_catalog(self) -> t.Optional[str]:
297        if self.database:
298            # Remove `:` from the database name in order to handle if `:memory:` is passed in
299            return pathlib.Path(self.database.replace(":", "")).stem
300        if self.catalogs:
301            return list(self.catalogs)[0]
302        return None

Configuration for the DuckDB connection.

Arguments:
  • database: The optional database name. If not specified, the in-memory database will be used.
  • catalogs: Key is the name of the catalog and value is the path.
def create_engine_adapter( self, register_comments_override: bool = False) -> sqlmesh.core.engine_adapter.base.EngineAdapter:
272    def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter:
273        """Checks if another engine adapter has already been created that shares a catalog that points to the same data
274        file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration
275        associated with the new adapter will be ignored."""
276        data_files = set((self.catalogs or {}).values())
277        if self.database:
278            data_files.add(self.database)
279        data_files.discard(":memory:")
280        for data_file in data_files:
281            key = data_file if isinstance(data_file, str) else data_file.path
282            if adapter := DuckDBConnectionConfig._data_file_to_adapter.get(key):
283                logger.info(f"Using existing DuckDB adapter due to overlapping data file: {key}")
284                return adapter
285
286        if data_files:
287            logger.info(f"Creating new DuckDB adapter for data files: {data_files}")
288        else:
289            logger.info("Creating new DuckDB adapter for in-memory database")
290        adapter = super().create_engine_adapter(register_comments_override)
291        for data_file in data_files:
292            key = data_file if isinstance(data_file, str) else data_file.path
293            DuckDBConnectionConfig._data_file_to_adapter[key] = adapter
294        return adapter

Checks if another engine adapter has already been created that shares a catalog that points to the same data file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration associated with the new adapter will be ignored.

def get_catalog(self) -> Union[str, NoneType]:
296    def get_catalog(self) -> t.Optional[str]:
297        if self.database:
298            # Remove `:` from the database name in order to handle if `:memory:` is passed in
299            return pathlib.Path(self.database.replace(":", "")).stem
300        if self.catalogs:
301            return list(self.catalogs)[0]
302        return None

The catalog for this connection

def model_post_init(self: pydantic.main.BaseModel, _ModelMetaclass__context: Any) -> None:
102                    def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
103                        """We need to both initialize private attributes and call the user-defined model_post_init
104                        method.
105                        """
106                        init_private_attributes(self, __context)
107                        original_model_post_init(self, __context)

Override this method to perform additional initialization after __init__ and model_construct. This is useful if you want to do some validation that requires the entire model to be initialized.

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class SnowflakeConnectionConfig(ConnectionConfig):
305class SnowflakeConnectionConfig(ConnectionConfig):
306    """Configuration for the Snowflake connection.
307
308    Args:
309        account: The Snowflake account name.
310        user: The Snowflake username.
311        password: The Snowflake password.
312        warehouse: The optional warehouse name.
313        database: The optional database name.
314        role: The optional role name.
315        concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
316        authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake").
317                       Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183
318        token: The optional oauth access token to use for authentication when authenticator is set to "oauth".
319        private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation
320        private_key_path: The optional path to the private key to use for authentication. This would be used instead of `private_key`.
321        private_key_passphrase: The optional passphrase to use to decrypt `private_key` or `private_key_path`. Keys can be created without encryption so only provide this if needed.
322        register_comments: Whether or not to register model comments with the SQL engine.
323    """
324
325    account: str
326    user: t.Optional[str] = None
327    password: t.Optional[str] = None
328    warehouse: t.Optional[str] = None
329    database: t.Optional[str] = None
330    role: t.Optional[str] = None
331    authenticator: t.Optional[str] = None
332    token: t.Optional[str] = None
333
334    # Private Key Auth
335    private_key: t.Optional[t.Union[str, bytes]] = None
336    private_key_path: t.Optional[str] = None
337    private_key_passphrase: t.Optional[str] = None
338
339    concurrent_tasks: int = 4
340    register_comments: bool = True
341
342    type_: Literal["snowflake"] = Field(alias="type", default="snowflake")
343
344    _concurrent_tasks_validator = concurrent_tasks_validator
345
346    @model_validator(mode="before")
347    @model_validator_v1_args
348    def _validate_authenticator(
349        cls, values: t.Dict[str, t.Optional[str]]
350    ) -> t.Dict[str, t.Optional[str]]:
351        from snowflake.connector.network import (
352            DEFAULT_AUTHENTICATOR,
353            OAUTH_AUTHENTICATOR,
354        )
355
356        auth = values.get("authenticator")
357        auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR
358        user = values.get("user")
359        password = values.get("password")
360        values["private_key"] = cls._get_private_key(values, auth)  # type: ignore
361        if (
362            auth == DEFAULT_AUTHENTICATOR
363            and not values.get("private_key")
364            and (not user or not password)
365        ):
366            raise ConfigError("User and password must be provided if using default authentication")
367        if auth == OAUTH_AUTHENTICATOR and not values.get("token"):
368            raise ConfigError("Token must be provided if using oauth authentication")
369        return values
370
371    @classmethod
372    def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]:
373        """
374        source: https://github.com/dbt-labs/dbt-snowflake/blob/0374b4ec948982f2ac8ec0c95d53d672ad19e09c/dbt/adapters/snowflake/connections.py#L247C5-L285C1
375
376        Overall code change: Use local variables instead of class attributes + Validation
377        """
378        # Start custom code
379        from cryptography.hazmat.backends import default_backend
380        from cryptography.hazmat.primitives import serialization
381        from snowflake.connector.network import (
382            DEFAULT_AUTHENTICATOR,
383            KEY_PAIR_AUTHENTICATOR,
384        )
385
386        private_key = values.get("private_key")
387        private_key_path = values.get("private_key_path")
388        private_key_passphrase = values.get("private_key_passphrase")
389        user = values.get("user")
390        password = values.get("password")
391        auth = auth if auth and auth != DEFAULT_AUTHENTICATOR else KEY_PAIR_AUTHENTICATOR
392
393        if not private_key and not private_key_path:
394            return None
395        if private_key and private_key_path:
396            raise ConfigError("Cannot specify both `private_key` and `private_key_path`")
397        if auth != KEY_PAIR_AUTHENTICATOR:
398            raise ConfigError(
399                f"Private key or private key path can only be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
400            )
401        if not user:
402            raise ConfigError(
403                f"User must be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
404            )
405        if password:
406            raise ConfigError(
407                f"Password cannot be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
408            )
409
410        if isinstance(private_key, bytes):
411            return private_key
412        # End Custom Code
413
414        if private_key_passphrase:
415            encoded_passphrase = private_key_passphrase.encode()
416        else:
417            encoded_passphrase = None
418
419        if private_key:
420            if private_key.startswith("-"):
421                p_key = serialization.load_pem_private_key(
422                    data=bytes(private_key, "utf-8"),
423                    password=encoded_passphrase,
424                    backend=default_backend(),
425                )
426
427            else:
428                p_key = serialization.load_der_private_key(
429                    data=base64.b64decode(private_key),
430                    password=encoded_passphrase,
431                    backend=default_backend(),
432                )
433
434        elif private_key_path:
435            with open(private_key_path, "rb") as key:
436                p_key = serialization.load_pem_private_key(
437                    key.read(), password=encoded_passphrase, backend=default_backend()
438                )
439        else:
440            return None
441
442        return p_key.private_bytes(
443            encoding=serialization.Encoding.DER,
444            format=serialization.PrivateFormat.PKCS8,
445            encryption_algorithm=serialization.NoEncryption(),
446        )
447
448    @property
449    def _connection_kwargs_keys(self) -> t.Set[str]:
450        return {
451            "user",
452            "password",
453            "account",
454            "warehouse",
455            "database",
456            "role",
457            "authenticator",
458            "token",
459            "private_key",
460        }
461
462    @property
463    def _engine_adapter(self) -> t.Type[EngineAdapter]:
464        return engine_adapter.SnowflakeEngineAdapter
465
466    @property
467    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
468        return {"autocommit": False}
469
470    @property
471    def _connection_factory(self) -> t.Callable:
472        from snowflake import connector
473
474        return connector.connect

Configuration for the Snowflake connection.

Arguments:
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class DatabricksConnectionConfig(ConnectionConfig):
477class DatabricksConnectionConfig(ConnectionConfig):
478    """
479    Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations
480
481    Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39
482    Args:
483        server_hostname: Databricks instance host name.
484        http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef)
485            or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
486        access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
487        catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in
488            the Databricks cluster (most likely `hive_metastore`).
489        http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request
490        session_configuration: An optional dictionary of Spark session parameters.
491            Execute the SQL command `SET -v` to get a full list of available commands.
492        databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect.
493            Defaults to the `server_hostname` value.
494        databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect.
495            Defaults to the `access_token` value.
496        databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect.
497            Defaults to deriving the cluster id from the `http_path` value.
498        force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector.
499        disable_databricks_connect: Even if databricks connect is installed, do not use it.
500    """
501
502    server_hostname: t.Optional[str] = None
503    http_path: t.Optional[str] = None
504    access_token: t.Optional[str] = None
505    catalog: t.Optional[str] = None
506    http_headers: t.Optional[t.List[t.Tuple[str, str]]] = None
507    session_configuration: t.Optional[t.Dict[str, t.Any]] = None
508    databricks_connect_server_hostname: t.Optional[str] = None
509    databricks_connect_access_token: t.Optional[str] = None
510    databricks_connect_cluster_id: t.Optional[str] = None
511    force_databricks_connect: bool = False
512    disable_databricks_connect: bool = False
513
514    concurrent_tasks: int = 1
515    register_comments: bool = True
516
517    type_: Literal["databricks"] = Field(alias="type", default="databricks")
518
519    _concurrent_tasks_validator = concurrent_tasks_validator
520    _http_headers_validator = http_headers_validator
521
522    @model_validator(mode="before")
523    @model_validator_v1_args
524    def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
525        from sqlmesh import RuntimeEnv
526        from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
527
528        runtime_env = RuntimeEnv.get()
529
530        if runtime_env.is_databricks:
531            return values
532        server_hostname, http_path, access_token = (
533            values.get("server_hostname"),
534            values.get("http_path"),
535            values.get("access_token"),
536        )
537        if not server_hostname or not http_path or not access_token:
538            raise ValueError(
539                "`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook"
540            )
541        if DatabricksEngineAdapter.can_access_spark_session:
542            if not values.get("databricks_connect_server_hostname"):
543                values["databricks_connect_server_hostname"] = f"https://{server_hostname}"
544            if not values.get("databricks_connect_access_token"):
545                values["databricks_connect_access_token"] = access_token
546            if not values.get("databricks_connect_cluster_id"):
547                values["databricks_connect_cluster_id"] = http_path.split("/")[-1]
548        if not values.get("session_configuration"):
549            values["session_configuration"] = {}
550        values["session_configuration"]["spark.sql.sources.partitionOverwriteMode"] = "dynamic"
551        return values
552
553    @property
554    def _connection_kwargs_keys(self) -> t.Set[str]:
555        if self.use_spark_session_only:
556            return set()
557        return {
558            "server_hostname",
559            "http_path",
560            "access_token",
561            "http_headers",
562            "session_configuration",
563            "catalog",
564        }
565
566    @property
567    def _engine_adapter(self) -> t.Type[engine_adapter.DatabricksEngineAdapter]:
568        return engine_adapter.DatabricksEngineAdapter
569
570    @property
571    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
572        return {
573            k: v
574            for k, v in self.dict().items()
575            if k.startswith("databricks_connect_") or k in ("catalog", "disable_databricks_connect")
576        }
577
578    @property
579    def use_spark_session_only(self) -> bool:
580        from sqlmesh import RuntimeEnv
581
582        return RuntimeEnv.get().is_databricks or self.force_databricks_connect
583
584    @property
585    def _connection_factory(self) -> t.Callable:
586        if self.use_spark_session_only:
587            from sqlmesh.engines.spark.db_api.spark_session import connection
588
589            return connection
590
591        from databricks import sql
592
593        return sql.connect
594
595    @property
596    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
597        from sqlmesh import RuntimeEnv
598
599        if not self.use_spark_session_only:
600            return {}
601
602        if RuntimeEnv.get().is_databricks:
603            from pyspark.sql import SparkSession
604
605            return dict(
606                spark=SparkSession.getActiveSession(),
607                catalog=self.catalog,
608            )
609
610        from databricks.connect import DatabricksSession
611
612        return dict(
613            spark=DatabricksSession.builder.remote(
614                host=self.databricks_connect_server_hostname,
615                token=self.databricks_connect_access_token,
616                cluster_id=self.databricks_connect_cluster_id,
617            ).getOrCreate(),
618            catalog=self.catalog,
619        )

Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations

Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39

Arguments:
  • server_hostname: Databricks instance host name.
  • http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
  • access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
  • catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in the Databricks cluster (most likely hive_metastore).
  • http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request
  • session_configuration: An optional dictionary of Spark session parameters. Execute the SQL command SET -v to get a full list of available commands.
  • databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect. Defaults to the server_hostname value.
  • databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect. Defaults to the access_token value.
  • databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect. Defaults to deriving the cluster id from the http_path value.
  • force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector.
  • disable_databricks_connect: Even if databricks connect is installed, do not use it.
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class BigQueryConnectionMethod(builtins.str, enum.Enum):
622class BigQueryConnectionMethod(str, Enum):
623    OAUTH = "oauth"
624    OAUTH_SECRETS = "oauth-secrets"
625    SERVICE_ACCOUNT = "service-account"
626    SERVICE_ACCOUNT_JSON = "service-account-json"

An enumeration.

OAUTH = <BigQueryConnectionMethod.OAUTH: 'oauth'>
OAUTH_SECRETS = <BigQueryConnectionMethod.OAUTH_SECRETS: 'oauth-secrets'>
SERVICE_ACCOUNT = <BigQueryConnectionMethod.SERVICE_ACCOUNT: 'service-account'>
SERVICE_ACCOUNT_JSON = <BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: 'service-account-json'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class BigQueryPriority(builtins.str, enum.Enum):
629class BigQueryPriority(str, Enum):
630    BATCH = "batch"
631    INTERACTIVE = "interactive"
632
633    @property
634    def is_batch(self) -> bool:
635        return self == self.BATCH
636
637    @property
638    def is_interactive(self) -> bool:
639        return self == self.INTERACTIVE
640
641    @property
642    def bigquery_constant(self) -> str:
643        from google.cloud.bigquery import QueryPriority
644
645        if self.is_batch:
646            return QueryPriority.BATCH
647        return QueryPriority.INTERACTIVE

An enumeration.

BATCH = <BigQueryPriority.BATCH: 'batch'>
INTERACTIVE = <BigQueryPriority.INTERACTIVE: 'interactive'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class BigQueryConnectionConfig(ConnectionConfig):
650class BigQueryConnectionConfig(ConnectionConfig):
651    """
652    BigQuery Connection Configuration.
653    """
654
655    method: BigQueryConnectionMethod = BigQueryConnectionMethod.OAUTH
656
657    project: t.Optional[str] = None
658    execution_project: t.Optional[str] = None
659    location: t.Optional[str] = None
660    # Keyfile Auth
661    keyfile: t.Optional[str] = None
662    keyfile_json: t.Optional[t.Dict[str, t.Any]] = None
663    # Oath Secret Auth
664    token: t.Optional[str] = None
665    refresh_token: t.Optional[str] = None
666    client_id: t.Optional[str] = None
667    client_secret: t.Optional[str] = None
668    token_uri: t.Optional[str] = None
669    scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/bigquery",)
670    job_creation_timeout_seconds: t.Optional[int] = None
671    # Extra Engine Config
672    job_execution_timeout_seconds: t.Optional[int] = None
673    job_retries: t.Optional[int] = 1
674    job_retry_deadline_seconds: t.Optional[int] = None
675    priority: t.Optional[BigQueryPriority] = None
676    maximum_bytes_billed: t.Optional[int] = None
677
678    concurrent_tasks: int = 1
679    register_comments: bool = True
680
681    type_: Literal["bigquery"] = Field(alias="type", default="bigquery")
682
683    @property
684    def _connection_kwargs_keys(self) -> t.Set[str]:
685        return set()
686
687    @property
688    def _engine_adapter(self) -> t.Type[EngineAdapter]:
689        return engine_adapter.BigQueryEngineAdapter
690
691    @property
692    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
693        """The static connection kwargs for this connection"""
694        import google.auth
695        from google.api_core import client_info
696        from google.oauth2 import credentials, service_account
697
698        if self.method == BigQueryConnectionMethod.OAUTH:
699            creds, _ = google.auth.default(scopes=self.scopes)
700        elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT:
701            creds = service_account.Credentials.from_service_account_file(
702                self.keyfile, scopes=self.scopes
703            )
704        elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON:
705            creds = service_account.Credentials.from_service_account_info(
706                self.keyfile_json, scopes=self.scopes
707            )
708        elif self.method == BigQueryConnectionMethod.OAUTH_SECRETS:
709            creds = credentials.Credentials(
710                token=self.token,
711                refresh_token=self.refresh_token,
712                client_id=self.client_id,
713                client_secret=self.client_secret,
714                token_uri=self.token_uri,
715                scopes=self.scopes,
716            )
717        else:
718            raise ConfigError("Invalid BigQuery Connection Method")
719        client = google.cloud.bigquery.Client(
720            project=self.execution_project or self.project,
721            credentials=creds,
722            location=self.location,
723            client_info=client_info.ClientInfo(user_agent="sqlmesh"),
724        )
725
726        return {
727            "client": client,
728        }
729
730    @property
731    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
732        return {
733            k: v
734            for k, v in self.dict().items()
735            if k
736            in {
737                "job_creation_timeout_seconds",
738                "job_execution_timeout_seconds",
739                "job_retries",
740                "job_retry_deadline_seconds",
741                "priority",
742                "maximum_bytes_billed",
743            }
744        }
745
746    @property
747    def _connection_factory(self) -> t.Callable:
748        from google.cloud.bigquery.dbapi import connect
749
750        return connect
751
752    def get_catalog(self) -> t.Optional[str]:
753        return self.project

BigQuery Connection Configuration.

def get_catalog(self) -> Union[str, NoneType]:
752    def get_catalog(self) -> t.Optional[str]:
753        return self.project

The catalog for this connection

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class GCPPostgresConnectionConfig(ConnectionConfig):
756class GCPPostgresConnectionConfig(ConnectionConfig):
757    """
758    Postgres Connection Configuration for GCP.
759
760    Args:
761        instance_connection_string: Connection name for the postgres instance.
762        user: Postgres or IAM user's name
763        password: The postgres user's password. Only needed when the user is a postgres user.
764        enable_iam_auth: Set to True when user is an IAM user.
765        db: Name of the db to connect to.
766    """
767
768    instance_connection_string: str
769    user: str
770    password: t.Optional[str] = None
771    enable_iam_auth: t.Optional[bool] = None
772    db: str
773    timeout: t.Optional[int] = None
774
775    driver: str = "pg8000"
776    type_: Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres")
777    concurrent_tasks: int = 4
778    register_comments: bool = True
779
780    @model_validator(mode="before")
781    @model_validator_v1_args
782    def _validate_auth_method(
783        cls, values: t.Dict[str, t.Optional[str]]
784    ) -> t.Dict[str, t.Optional[str]]:
785        password = values.get("password")
786        enable_iam_auth = values.get("enable_iam_auth")
787        if password and enable_iam_auth:
788            raise ConfigError(
789                "Invalid GCP Postgres connection configuration - both password and"
790                " enable_iam_auth set. Use password when connecting to a postgres"
791                " user and enable_iam_auth 'True' when connecting to an IAM user."
792            )
793        if not password and not enable_iam_auth:
794            raise ConfigError(
795                "GCP Postgres connection configuration requires either password set"
796                " for a postgres user account or enable_iam_auth set to 'True'"
797                " for an IAM user account."
798            )
799        return values
800
801    @property
802    def _connection_kwargs_keys(self) -> t.Set[str]:
803        return {
804            "instance_connection_string",
805            "driver",
806            "user",
807            "password",
808            "db",
809            "enable_iam_auth",
810            "timeout",
811        }
812
813    @property
814    def _engine_adapter(self) -> t.Type[EngineAdapter]:
815        return engine_adapter.PostgresEngineAdapter
816
817    @property
818    def _connection_factory(self) -> t.Callable:
819        from google.cloud.sql.connector import Connector
820
821        return Connector().connect

Postgres Connection Configuration for GCP.

Arguments:
  • instance_connection_string: Connection name for the postgres instance.
  • user: Postgres or IAM user's name
  • password: The postgres user's password. Only needed when the user is a postgres user.
  • enable_iam_auth: Set to True when user is an IAM user.
  • db: Name of the db to connect to.
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class RedshiftConnectionConfig(ConnectionConfig):
824class RedshiftConnectionConfig(ConnectionConfig):
825    """
826    Redshift Connection Configuration.
827
828    Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146
829    Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported.
830
831    Args:
832        user: The username to use for authentication with the Amazon Redshift cluster.
833        password: The password to use for authentication with the Amazon Redshift cluster.
834        database: The name of the database instance to connect to.
835        host: The hostname of the Amazon Redshift cluster.
836        port: The port number of the Amazon Redshift cluster. Default value is 5439.
837        source_address: No description provided
838        unix_sock: No description provided
839        ssl: Is SSL enabled. Default value is ``True``. SSL must be enabled when authenticating using IAM.
840        sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported.
841        timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout.
842        tcp_keepalive: Is `TCP keepalive <https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive>`_ used. The default value is ``True``.
843        application_name: Sets the application name. The default value is None.
844        preferred_role: The IAM role preferred for the current connection.
845        principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy.
846        credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster.
847        region: The AWS region where the Amazon Redshift cluster is located.
848        cluster_identifier: The cluster identifier of the Amazon Redshift cluster.
849        iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP.
850        is_serverless: Redshift end-point is serverless or provisional. Default value false.
851        serverless_acct_id: The account ID of the serverless. Default value None
852        serverless_work_group: The name of work group for serverless end point. Default value None.
853    """
854
855    user: t.Optional[str] = None
856    password: t.Optional[str] = None
857    database: t.Optional[str] = None
858    host: t.Optional[str] = None
859    port: t.Optional[int] = None
860    source_address: t.Optional[str] = None
861    unix_sock: t.Optional[str] = None
862    ssl: t.Optional[bool] = None
863    sslmode: t.Optional[str] = None
864    timeout: t.Optional[int] = None
865    tcp_keepalive: t.Optional[bool] = None
866    application_name: t.Optional[str] = None
867    preferred_role: t.Optional[str] = None
868    principal_arn: t.Optional[str] = None
869    credentials_provider: t.Optional[str] = None
870    region: t.Optional[str] = None
871    cluster_identifier: t.Optional[str] = None
872    iam: t.Optional[bool] = None
873    is_serverless: t.Optional[bool] = None
874    serverless_acct_id: t.Optional[str] = None
875    serverless_work_group: t.Optional[str] = None
876
877    concurrent_tasks: int = 4
878    register_comments: bool = True
879
880    type_: Literal["redshift"] = Field(alias="type", default="redshift")
881
882    @property
883    def _connection_kwargs_keys(self) -> t.Set[str]:
884        return {
885            "user",
886            "password",
887            "database",
888            "host",
889            "port",
890            "source_address",
891            "unix_sock",
892            "ssl",
893            "sslmode",
894            "timeout",
895            "tcp_keepalive",
896            "application_name",
897            "preferred_role",
898            "principal_arn",
899            "credentials_provider",
900            "region",
901            "cluster_identifier",
902            "iam",
903            "is_serverless",
904            "serverless_acct_id",
905            "serverless_work_group",
906        }
907
908    @property
909    def _engine_adapter(self) -> t.Type[EngineAdapter]:
910        return engine_adapter.RedshiftEngineAdapter
911
912    @property
913    def _connection_factory(self) -> t.Callable:
914        from redshift_connector import connect
915
916        return connect

Redshift Connection Configuration.

Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146 Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported.

Arguments:
  • user: The username to use for authentication with the Amazon Redshift cluster.
  • password: The password to use for authentication with the Amazon Redshift cluster.
  • database: The name of the database instance to connect to.
  • host: The hostname of the Amazon Redshift cluster.
  • port: The port number of the Amazon Redshift cluster. Default value is 5439.
  • source_address: No description provided
  • unix_sock: No description provided
  • ssl: Is SSL enabled. Default value is True. SSL must be enabled when authenticating using IAM.
  • sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported.
  • timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout.
  • tcp_keepalive: Is TCP keepalive used. The default value is True.
  • application_name: Sets the application name. The default value is None.
  • preferred_role: The IAM role preferred for the current connection.
  • principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy.
  • credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster.
  • region: The AWS region where the Amazon Redshift cluster is located.
  • cluster_identifier: The cluster identifier of the Amazon Redshift cluster.
  • iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP.
  • is_serverless: Redshift end-point is serverless or provisional. Default value false.
  • serverless_acct_id: The account ID of the serverless. Default value None
  • serverless_work_group: The name of work group for serverless end point. Default value None.
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class PostgresConnectionConfig(ConnectionConfig):
919class PostgresConnectionConfig(ConnectionConfig):
920    host: str
921    user: str
922    password: str
923    port: int
924    database: str
925    keepalives_idle: t.Optional[int] = None
926    connect_timeout: int = 10
927    role: t.Optional[str] = None
928    sslmode: t.Optional[str] = None
929
930    concurrent_tasks: int = 4
931    register_comments: bool = True
932
933    type_: Literal["postgres"] = Field(alias="type", default="postgres")
934
935    @property
936    def _connection_kwargs_keys(self) -> t.Set[str]:
937        return {
938            "host",
939            "user",
940            "password",
941            "port",
942            "database",
943            "keepalives_idle",
944            "connect_timeout",
945            "role",
946            "sslmode",
947        }
948
949    @property
950    def _engine_adapter(self) -> t.Type[EngineAdapter]:
951        return engine_adapter.PostgresEngineAdapter
952
953    @property
954    def _connection_factory(self) -> t.Callable:
955        from psycopg2 import connect
956
957        return connect

Helper class that provides a standard way to create an ABC using inheritance.

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class MySQLConnectionConfig(ConnectionConfig):
 960class MySQLConnectionConfig(ConnectionConfig):
 961    host: str
 962    user: str
 963    password: str
 964    port: t.Optional[int] = None
 965    charset: t.Optional[str] = None
 966    ssl_disabled: t.Optional[bool] = None
 967
 968    concurrent_tasks: int = 4
 969    register_comments: bool = True
 970
 971    type_: Literal["mysql"] = Field(alias="type", default="mysql")
 972
 973    @property
 974    def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]:
 975        """Key-value arguments that will be passed during cursor construction."""
 976        return {"buffered": True}
 977
 978    @property
 979    def _connection_kwargs_keys(self) -> t.Set[str]:
 980        connection_keys = {
 981            "host",
 982            "user",
 983            "password",
 984            "port",
 985            "database",
 986        }
 987        if self.port is not None:
 988            connection_keys.add("port")
 989        if self.charset is not None:
 990            connection_keys.add("charset")
 991        if self.ssl_disabled is not None:
 992            connection_keys.add("ssl_disabled")
 993        return connection_keys
 994
 995    @property
 996    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 997        return engine_adapter.MySQLEngineAdapter
 998
 999    @property
1000    def _connection_factory(self) -> t.Callable:
1001        from mysql.connector import connect
1002
1003        return connect

Helper class that provides a standard way to create an ABC using inheritance.

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class MSSQLConnectionConfig(ConnectionConfig):
1006class MSSQLConnectionConfig(ConnectionConfig):
1007    host: str
1008    user: t.Optional[str] = None
1009    password: t.Optional[str] = None
1010    database: t.Optional[str] = ""
1011    timeout: t.Optional[int] = 0
1012    login_timeout: t.Optional[int] = 60
1013    charset: t.Optional[str] = "UTF-8"
1014    appname: t.Optional[str] = None
1015    port: t.Optional[int] = 1433
1016    conn_properties: t.Optional[t.Union[t.Iterable[str], str]] = None
1017    autocommit: t.Optional[bool] = False
1018    tds_version: t.Optional[str] = None
1019
1020    concurrent_tasks: int = 4
1021    register_comments: bool = True
1022
1023    type_: Literal["mssql"] = Field(alias="type", default="mssql")
1024
1025    @property
1026    def _connection_kwargs_keys(self) -> t.Set[str]:
1027        return {
1028            "host",
1029            "user",
1030            "password",
1031            "database",
1032            "timeout",
1033            "login_timeout",
1034            "charset",
1035            "appname",
1036            "port",
1037            "conn_properties",
1038            "autocommit",
1039            "tds_version",
1040        }
1041
1042    @property
1043    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1044        return engine_adapter.MSSQLEngineAdapter
1045
1046    @property
1047    def _connection_factory(self) -> t.Callable:
1048        import pymssql
1049
1050        return pymssql.connect

Helper class that provides a standard way to create an ABC using inheritance.

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class SparkConnectionConfig(ConnectionConfig):
1053class SparkConnectionConfig(ConnectionConfig):
1054    """
1055    Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks.
1056    """
1057
1058    config_dir: t.Optional[str] = None
1059    catalog: t.Optional[str] = None
1060    config: t.Dict[str, t.Any] = {}
1061
1062    concurrent_tasks: int = 4
1063    register_comments: bool = True
1064
1065    type_: Literal["spark"] = Field(alias="type", default="spark")
1066
1067    @property
1068    def _connection_kwargs_keys(self) -> t.Set[str]:
1069        return {
1070            "catalog",
1071        }
1072
1073    @property
1074    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1075        return engine_adapter.SparkEngineAdapter
1076
1077    @property
1078    def _connection_factory(self) -> t.Callable:
1079        from sqlmesh.engines.spark.db_api.spark_session import connection
1080
1081        return connection
1082
1083    @property
1084    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
1085        from pyspark.conf import SparkConf
1086        from pyspark.sql import SparkSession
1087
1088        spark_config = SparkConf()
1089        if self.config:
1090            for k, v in self.config.items():
1091                spark_config.set(k, v)
1092
1093        if self.config_dir:
1094            os.environ["SPARK_CONF_DIR"] = self.config_dir
1095        return {
1096            "spark": SparkSession.builder.config(conf=spark_config)
1097            .enableHiveSupport()
1098            .getOrCreate(),
1099        }

Vanilla Spark Connection Configuration. Use DatabricksConnectionConfig for Databricks.

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class TrinoAuthenticationMethod(builtins.str, enum.Enum):
1102class TrinoAuthenticationMethod(str, Enum):
1103    NO_AUTH = "no-auth"
1104    BASIC = "basic"
1105    LDAP = "ldap"
1106    KERBEROS = "kerberos"
1107    JWT = "jwt"
1108    CERTIFICATE = "certificate"
1109    OAUTH = "oauth"
1110
1111    @property
1112    def is_no_auth(self) -> bool:
1113        return self == self.NO_AUTH
1114
1115    @property
1116    def is_basic(self) -> bool:
1117        return self == self.BASIC
1118
1119    @property
1120    def is_ldap(self) -> bool:
1121        return self == self.LDAP
1122
1123    @property
1124    def is_kerberos(self) -> bool:
1125        return self == self.KERBEROS
1126
1127    @property
1128    def is_jwt(self) -> bool:
1129        return self == self.JWT
1130
1131    @property
1132    def is_certificate(self) -> bool:
1133        return self == self.CERTIFICATE
1134
1135    @property
1136    def is_oauth(self) -> bool:
1137        return self == self.OAUTH

An enumeration.

NO_AUTH = <TrinoAuthenticationMethod.NO_AUTH: 'no-auth'>
BASIC = <TrinoAuthenticationMethod.BASIC: 'basic'>
KERBEROS = <TrinoAuthenticationMethod.KERBEROS: 'kerberos'>
CERTIFICATE = <TrinoAuthenticationMethod.CERTIFICATE: 'certificate'>
OAUTH = <TrinoAuthenticationMethod.OAUTH: 'oauth'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class TrinoConnectionConfig(ConnectionConfig):
1140class TrinoConnectionConfig(ConnectionConfig):
1141    method: TrinoAuthenticationMethod = TrinoAuthenticationMethod.NO_AUTH
1142    host: str
1143    user: str
1144    catalog: str
1145    port: t.Optional[int] = None
1146    http_scheme: Literal["http", "https"] = "https"
1147    # General Optional
1148    roles: t.Optional[t.Dict[str, str]] = None
1149    http_headers: t.Optional[t.Dict[str, str]] = None
1150    session_properties: t.Optional[t.Dict[str, str]] = None
1151    retries: int = 3
1152    timezone: t.Optional[str] = None
1153    # Basic/LDAP
1154    password: t.Optional[str] = None
1155    # LDAP
1156    impersonation_user: t.Optional[str] = None
1157    # Kerberos
1158    keytab: t.Optional[str] = None
1159    krb5_config: t.Optional[str] = None
1160    principal: t.Optional[str] = None
1161    service_name: str = "trino"
1162    hostname_override: t.Optional[str] = None
1163    mutual_authentication: bool = False
1164    force_preemptive: bool = False
1165    sanitize_mutual_error_response: bool = True
1166    delegate: bool = False
1167    # JWT
1168    jwt_token: t.Optional[str] = None
1169    # Certificate
1170    client_certificate: t.Optional[str] = None
1171    client_private_key: t.Optional[str] = None
1172    cert: t.Optional[str] = None
1173
1174    concurrent_tasks: int = 4
1175    register_comments: bool = True
1176
1177    type_: Literal["trino"] = Field(alias="type", default="trino")
1178
1179    @model_validator(mode="after")
1180    @model_validator_v1_args
1181    def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
1182        port = values.get("port")
1183        if (
1184            values["http_scheme"] == "http"
1185            and not values["method"].is_no_auth
1186            and not values["method"].is_basic
1187        ):
1188            raise ConfigError("HTTP scheme can only be used with no-auth or basic method")
1189        if port is None:
1190            values["port"] = 80 if values["http_scheme"] == "http" else 443
1191        if (values["method"].is_ldap or values["method"].is_basic) and (
1192            not values["password"] or not values["user"]
1193        ):
1194            raise ConfigError(
1195                f"Username and Password must be provided if using {values['method'].value} authentication"
1196            )
1197        if values["method"].is_kerberos and (
1198            not values["principal"] or not values["keytab"] or not values["krb5_config"]
1199        ):
1200            raise ConfigError(
1201                "Kerberos requires the following fields: principal, keytab, and krb5_config"
1202            )
1203        if values["method"].is_jwt and not values["jwt_token"]:
1204            raise ConfigError("JWT requires `jwt_token` to be set")
1205        if values["method"].is_certificate and (
1206            not values["cert"]
1207            or not values["client_certificate"]
1208            or not values["client_private_key"]
1209        ):
1210            raise ConfigError(
1211                "Certificate requires the following fields: cert, client_certificate, and client_private_key"
1212            )
1213        return values
1214
1215    @property
1216    def _connection_kwargs_keys(self) -> t.Set[str]:
1217        kwargs = {
1218            "host",
1219            "port",
1220            "catalog",
1221            "roles",
1222            "http_scheme",
1223            "http_headers",
1224            "session_properties",
1225            "timezone",
1226        }
1227        return kwargs
1228
1229    @property
1230    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1231        return engine_adapter.TrinoEngineAdapter
1232
1233    @property
1234    def _connection_factory(self) -> t.Callable:
1235        from trino.dbapi import connect
1236
1237        return connect
1238
1239    @property
1240    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
1241        from trino.auth import (
1242            BasicAuthentication,
1243            CertificateAuthentication,
1244            JWTAuthentication,
1245            KerberosAuthentication,
1246            OAuth2Authentication,
1247        )
1248
1249        if self.method.is_basic or self.method.is_ldap:
1250            auth = BasicAuthentication(self.user, self.password)
1251        elif self.method.is_kerberos:
1252            if self.keytab:
1253                os.environ["KRB5_CLIENT_KTNAME"] = self.keytab
1254            auth = KerberosAuthentication(
1255                config=self.krb5_config,
1256                service_name=self.service_name,
1257                principal=self.principal,
1258                mutual_authentication=self.mutual_authentication,
1259                ca_bundle=self.cert,
1260                force_preemptive=self.force_preemptive,
1261                hostname_override=self.hostname_override,
1262                sanitize_mutual_error_response=self.sanitize_mutual_error_response,
1263                delegate=self.delegate,
1264            )
1265        elif self.method.is_oauth:
1266            auth = OAuth2Authentication()
1267        elif self.method.is_jwt:
1268            auth = JWTAuthentication(self.jwt_token)
1269        elif self.method.is_certificate:
1270            auth = CertificateAuthentication(self.client_certificate, self.client_private_key)
1271        else:
1272            auth = None
1273
1274        return {
1275            "auth": auth,
1276            "user": self.impersonation_user or self.user,
1277            "max_attempts": self.retries,
1278            "verify": self.cert,
1279            "source": "sqlmesh",
1280        }

Helper class that provides a standard way to create an ABC using inheritance.

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
model_post_init
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
def parse_connection_config(v: Dict[str, Any]) -> sqlmesh.core.config.connection.ConnectionConfig:
1292def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig:
1293    if "type" not in v:
1294        raise ConfigError("Missing connection type.")
1295
1296    connection_type = v["type"]
1297    if connection_type not in CONNECTION_CONFIG_TO_TYPE:
1298        raise ConfigError(f"Unknown connection type '{connection_type}'.")
1299
1300    return CONNECTION_CONFIG_TO_TYPE[connection_type](**v)
def connection_config_validator(unknown):

Wrap a classmethod, staticmethod, property or unbound function and act as a descriptor that allows us to detect decorated items from the class' attributes.

This class' __get__ returns the wrapped item's __get__ result, which makes it transparent for classmethods and staticmethods.

Attributes:
  • wrapped: The decorator that has to be wrapped.
  • decorator_info: The decorator info.
  • shim: A wrapper function to wrap V1 style function.