Edit on GitHub

sqlmesh.core.config.connection

   1from __future__ import annotations
   2
   3import abc
   4import base64
   5import logging
   6import os
   7import importlib
   8import pathlib
   9import re
  10import typing as t
  11from enum import Enum
  12from functools import partial
  13
  14import pydantic
  15from pydantic import Field
  16from pydantic_core import from_json
  17from packaging import version
  18from sqlglot import exp
  19from sqlglot.helper import subclasses
  20from sqlglot.errors import ParseError
  21
  22from sqlmesh.core import engine_adapter
  23from sqlmesh.core.config.base import BaseConfig
  24from sqlmesh.core.config.common import (
  25    concurrent_tasks_validator,
  26    http_headers_validator,
  27    compile_regex_mapping,
  28)
  29from sqlmesh.core.engine_adapter.shared import CatalogSupport
  30from sqlmesh.core.engine_adapter import EngineAdapter
  31from sqlmesh.utils import debug_mode_enabled, str_to_bool
  32from sqlmesh.utils.errors import ConfigError
  33from sqlmesh.utils.pydantic import (
  34    ValidationInfo,
  35    field_validator,
  36    model_validator,
  37    validation_error_message,
  38    get_concrete_types_from_typehint,
  39)
  40from sqlmesh.utils.aws import validate_s3_uri
  41
  42if t.TYPE_CHECKING:
  43    from sqlmesh.core._typing import Self
  44
  45logger = logging.getLogger(__name__)
  46
  47RECOMMENDED_STATE_SYNC_ENGINES = {
  48    "postgres",
  49    "gcp_postgres",
  50    "mysql",
  51    "mssql",
  52    "azuresql",
  53}
  54FORBIDDEN_STATE_SYNC_ENGINES = {
  55    # Do not support row-level operations
  56    "spark",
  57    "trino",
  58    # Nullable types are problematic
  59    "clickhouse",
  60}
  61MOTHERDUCK_TOKEN_REGEX = re.compile(r"(\?|\&)(motherduck_token=)(\S*)")
  62PASSWORD_REGEX = re.compile(r"(password=)(\S+)")
  63
  64
  65def _get_engine_import_validator(
  66    import_name: str, engine_type: str, extra_name: t.Optional[str] = None, decorate: bool = True
  67) -> t.Callable:
  68    extra_name = extra_name or engine_type
  69
  70    def validate(cls: t.Any, data: t.Any) -> t.Any:
  71        check_import = (
  72            str_to_bool(str(data.pop("check_import", True))) if isinstance(data, dict) else True
  73        )
  74        if not check_import:
  75            return data
  76        try:
  77            importlib.import_module(import_name)
  78        except ImportError:
  79            if debug_mode_enabled():
  80                raise
  81
  82            logger.exception("Failed to import the engine library")
  83
  84            raise ConfigError(
  85                f"Failed to import the '{engine_type}' engine library. This may be due to a missing "
  86                "or incompatible installation. Please ensure the required dependency is installed by "
  87                f'running: `pip install "sqlmesh[{extra_name}]"`. For more details, check the logs '
  88                "in the 'logs/' folder, or rerun the command with the '--debug' flag."
  89            )
  90
  91        return data
  92
  93    return model_validator(mode="before")(validate) if decorate else validate
  94
  95
  96class ConnectionConfig(abc.ABC, BaseConfig):
  97    type_: str
  98    DIALECT: t.ClassVar[str]
  99    DISPLAY_NAME: t.ClassVar[str]
 100    DISPLAY_ORDER: t.ClassVar[int]
 101    concurrent_tasks: int
 102    register_comments: bool
 103    pre_ping: bool
 104    pretty_sql: bool = False
 105    schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None
 106    catalog_type_overrides: t.Optional[t.Dict[str, str]] = None
 107
 108    # Whether to share a  single connection across threads or create a new connection per thread.
 109    shared_connection: t.ClassVar[bool] = False
 110
 111    @property
 112    @abc.abstractmethod
 113    def _connection_kwargs_keys(self) -> t.Set[str]:
 114        """keywords that should be passed into the connection"""
 115
 116    @property
 117    @abc.abstractmethod
 118    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 119        """The engine adapter for this connection"""
 120
 121    @property
 122    @abc.abstractmethod
 123    def _connection_factory(self) -> t.Callable:
 124        """A function that is called to return a connection object for the given Engine Adapter"""
 125
 126    @property
 127    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
 128        """The static connection kwargs for this connection"""
 129        return {}
 130
 131    @property
 132    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
 133        """kwargs that are for execution config only"""
 134        return {}
 135
 136    @property
 137    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
 138        """A function that is called to initialize the cursor"""
 139        return None
 140
 141    @property
 142    def is_recommended_for_state_sync(self) -> bool:
 143        """Whether this engine is recommended for being used as a state sync for production state syncs"""
 144        return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES
 145
 146    @property
 147    def is_forbidden_for_state_sync(self) -> bool:
 148        """Whether this engine is forbidden from being used as a state sync"""
 149        return self.type_ in FORBIDDEN_STATE_SYNC_ENGINES
 150
 151    @property
 152    def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]:
 153        """A function that is called to return a connection object for the given Engine Adapter"""
 154        return partial(
 155            self._connection_factory,
 156            **{
 157                **self._static_connection_kwargs,
 158                **{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys},
 159            },
 160        )
 161
 162    def connection_validator(self) -> t.Callable[[], None]:
 163        """A function that validates the connection configuration"""
 164        return self.create_engine_adapter().ping
 165
 166    def create_engine_adapter(
 167        self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None
 168    ) -> EngineAdapter:
 169        """Returns a new instance of the Engine Adapter."""
 170
 171        concurrent_tasks = concurrent_tasks or self.concurrent_tasks
 172        return self._engine_adapter(
 173            self._connection_factory_with_kwargs,
 174            multithreaded=concurrent_tasks > 1,
 175            default_catalog=self.get_catalog(),
 176            cursor_init=self._cursor_init,
 177            register_comments=register_comments_override or self.register_comments,
 178            pre_ping=self.pre_ping,
 179            pretty_sql=self.pretty_sql,
 180            shared_connection=self.shared_connection,
 181            schema_differ_overrides=self.schema_differ_overrides,
 182            catalog_type_overrides=self.catalog_type_overrides,
 183            **self._extra_engine_config,
 184        )
 185
 186    def get_catalog(self) -> t.Optional[str]:
 187        """The catalog for this connection"""
 188        if hasattr(self, "catalog"):
 189            return self.catalog
 190        if hasattr(self, "database"):
 191            return self.database
 192        if hasattr(self, "db"):
 193            return self.db
 194        return None
 195
 196    @model_validator(mode="before")
 197    @classmethod
 198    def _expand_json_strings_to_concrete_types(cls, data: t.Any) -> t.Any:
 199        """
 200        There are situations where a connection config class has a field that is some kind of complex type
 201        (eg a list of strings or a dict) but the value is being supplied from a source such as an environment variable
 202
 203        When this happens, the value is supplied as a string rather than a Python object. We need some way
 204        of turning this string into the corresponding Python list or dict.
 205
 206        Rather than doing this piecemeal on every config subclass, this provides a generic implementatation
 207        to identify fields that may be be supplied as JSON strings and handle them transparently
 208        """
 209        if data and isinstance(data, dict):
 210            for maybe_json_field_name in cls._get_list_and_dict_field_names():
 211                if (value := data.get(maybe_json_field_name)) and isinstance(value, str):
 212                    # crude JSON check as we dont want to try and parse every string we get
 213                    value = value.strip()
 214                    if value.startswith("{") or value.startswith("["):
 215                        data[maybe_json_field_name] = from_json(value)
 216
 217        return data
 218
 219    @classmethod
 220    def _get_list_and_dict_field_names(cls) -> t.Set[str]:
 221        field_names = set()
 222        for name, field in cls.model_fields.items():
 223            if field.annotation:
 224                field_types = get_concrete_types_from_typehint(field.annotation)
 225
 226                # check if the field type is something that could concievably be supplied as a json string
 227                if any(ft is t for t in (list, tuple, set, dict) for ft in field_types):
 228                    field_names.add(name)
 229
 230        return field_names
 231
 232
 233class DuckDBAttachOptions(BaseConfig):
 234    type: str
 235    path: str
 236    read_only: bool = False
 237
 238    # DuckLake specific options
 239    data_path: t.Optional[str] = None
 240    encrypted: bool = False
 241    data_inlining_row_limit: t.Optional[int] = None
 242    metadata_schema: t.Optional[str] = None
 243
 244    def to_sql(self, alias: str) -> str:
 245        options = []
 246        # 'duckdb' is actually not a supported type, but we'd like to allow it for
 247        # fully qualified attach options or integration testing, similar to duckdb-dbt
 248        if self.type not in ("duckdb", "ducklake", "motherduck"):
 249            options.append(f"TYPE {self.type.upper()}")
 250        if self.read_only:
 251            options.append("READ_ONLY")
 252
 253        # DuckLake specific options
 254        path = self.path
 255        if self.type == "ducklake":
 256            if not path.startswith("ducklake:"):
 257                path = f"ducklake:{path}"
 258            if self.data_path is not None:
 259                options.append(f"DATA_PATH '{self.data_path}'")
 260            if self.encrypted:
 261                options.append("ENCRYPTED")
 262            if self.data_inlining_row_limit is not None:
 263                options.append(f"DATA_INLINING_ROW_LIMIT {self.data_inlining_row_limit}")
 264            if self.metadata_schema is not None:
 265                options.append(f"METADATA_SCHEMA '{self.metadata_schema}'")
 266
 267        options_sql = f" ({', '.join(options)})" if options else ""
 268        alias_sql = ""
 269        # TODO: Add support for Postgres schema. Currently adding it blocks access to the information_schema
 270
 271        # MotherDuck does not support aliasing
 272        alias_sql = (
 273            f" AS {alias}" if not (self.type == "motherduck" or self.path.startswith("md:")) else ""
 274        )
 275        return f"ATTACH IF NOT EXISTS '{path}'{alias_sql}{options_sql}"
 276
 277
 278class BaseDuckDBConnectionConfig(ConnectionConfig):
 279    """Common configuration for the DuckDB-based connections.
 280
 281    Args:
 282        database: The optional database name. If not specified, the in-memory database will be used.
 283        catalogs: Key is the name of the catalog and value is the path.
 284        extensions: A list of autoloadable extensions to load.
 285        connector_config: A dictionary of configuration to pass into the duckdb connector.
 286        secrets: A list of dictionaries used to generate DuckDB secrets for authenticating with external services (e.g. S3).
 287        filesystems: A list of dictionaries used to register `fsspec` filesystems to the DuckDB cursor.
 288        concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
 289        register_comments: Whether or not to register model comments with the SQL engine.
 290        pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
 291        token: The optional MotherDuck token. If not specified and a MotherDuck path is in the catalog, the user will be prompted to login with their web browser.
 292    """
 293
 294    database: t.Optional[str] = None
 295    catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None
 296    extensions: t.List[t.Union[str, t.Dict[str, t.Any]]] = []
 297    connector_config: t.Dict[str, t.Any] = {}
 298    secrets: t.Union[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Dict[str, t.Any]]] = []
 299    filesystems: t.List[t.Dict[str, t.Any]] = []
 300
 301    concurrent_tasks: int = 1
 302    register_comments: bool = True
 303    pre_ping: t.Literal[False] = False
 304
 305    token: t.Optional[str] = None
 306
 307    shared_connection: t.ClassVar[bool] = True
 308
 309    _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {}
 310
 311    @model_validator(mode="before")
 312    def _validate_database_catalogs(cls, data: t.Any) -> t.Any:
 313        if not isinstance(data, dict):
 314            return data
 315
 316        db_path = data.get("database")
 317        if db_path and data.get("catalogs"):
 318            raise ConfigError(
 319                "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog"
 320            )
 321        if isinstance(db_path, str) and db_path.startswith("md:"):
 322            raise ConfigError(
 323                "Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`."
 324            )
 325
 326        return data
 327
 328    @property
 329    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 330        return engine_adapter.DuckDBEngineAdapter
 331
 332    @property
 333    def _connection_kwargs_keys(self) -> t.Set[str]:
 334        return {"database"}
 335
 336    @property
 337    def _connection_factory(self) -> t.Callable:
 338        import duckdb
 339
 340        return duckdb.connect
 341
 342    @property
 343    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
 344        """A function that is called to initialize the cursor"""
 345        import duckdb
 346        from duckdb import BinderException
 347
 348        def init(cursor: duckdb.DuckDBPyConnection) -> None:
 349            for extension in self.extensions:
 350                extension = extension if isinstance(extension, dict) else {"name": extension}
 351
 352                install_command = f"INSTALL {extension['name']}"
 353
 354                if extension.get("repository"):
 355                    install_command = f"{install_command} FROM {extension['repository']}"
 356
 357                if extension.get("force_install"):
 358                    install_command = f"FORCE {install_command}"
 359
 360                try:
 361                    cursor.execute(install_command)
 362                    cursor.execute(f"LOAD {extension['name']}")
 363                except Exception as e:
 364                    raise ConfigError(f"Failed to load extension {extension['name']}: {e}")
 365
 366            if self.connector_config:
 367                option_names = list(self.connector_config)
 368                in_part = ",".join("?" for _ in range(len(option_names)))
 369
 370                cursor.execute(
 371                    f"SELECT name, value FROM duckdb_settings() WHERE name IN ({in_part})",
 372                    option_names,
 373                )
 374
 375                existing_values = {field: setting for field, setting in cursor.fetchall()}
 376
 377                # only set connector_config items if the values differ from what is already set
 378                # trying to set options like 'temp_directory' even to the same value can throw errors like:
 379                # Not implemented Error: Cannot switch temporary directory after the current one has been used
 380                for field, setting in self.connector_config.items():
 381                    if existing_values.get(field) != setting:
 382                        try:
 383                            cursor.execute(f"SET {field} = '{setting}'")
 384                        except Exception as e:
 385                            raise ConfigError(
 386                                f"Failed to set connector config {field} to {setting}: {e}"
 387                            )
 388
 389            if self.secrets:
 390                duckdb_version = duckdb.__version__
 391                if version.parse(duckdb_version) < version.parse("0.10.0"):
 392                    from sqlmesh.core.console import get_console
 393
 394                    get_console().log_warning(
 395                        f"DuckDB version {duckdb_version} does not support secrets-based authentication (requires 0.10.0 or later).\n"
 396                        "To use secrets, please upgrade DuckDB. For older versions, configure legacy authentication via `connector_config`.\n"
 397                        "More info: https://duckdb.org/docs/stable/extensions/httpfs/s3api_legacy_authentication.html"
 398                    )
 399                else:
 400                    if isinstance(self.secrets, list):
 401                        secrets_items = [(secret_dict, "") for secret_dict in self.secrets]
 402                    else:
 403                        secrets_items = [
 404                            (secret_dict, secret_name)
 405                            for secret_name, secret_dict in self.secrets.items()
 406                        ]
 407
 408                    for secret_dict, secret_name in secrets_items:
 409                        secret_settings: t.List[str] = []
 410                        for field, setting in secret_dict.items():
 411                            secret_settings.append(f"{field} '{setting}'")
 412                        if secret_settings:
 413                            secret_clause = ", ".join(secret_settings)
 414                            try:
 415                                cursor.execute(
 416                                    f"CREATE OR REPLACE SECRET {secret_name} ({secret_clause});"
 417                                )
 418                            except Exception as e:
 419                                raise ConfigError(f"Failed to create secret: {e}")
 420
 421            if self.filesystems:
 422                from fsspec import filesystem  # type: ignore
 423
 424                for file_system in self.filesystems:
 425                    options = file_system.copy()
 426                    fs = options.pop("fs")
 427                    fs = filesystem(fs, **options)
 428                    cursor.register_filesystem(fs)
 429
 430            for i, (alias, path_options) in enumerate(
 431                (getattr(self, "catalogs", None) or {}).items()
 432            ):
 433                # we parse_identifier and generate to ensure that `alias` has exactly one set of quotes
 434                # regardless of whether it comes in quoted or not
 435                alias = exp.parse_identifier(alias, dialect="duckdb").sql(
 436                    identify=True, dialect="duckdb"
 437                )
 438                try:
 439                    if isinstance(path_options, DuckDBAttachOptions):
 440                        query = path_options.to_sql(alias)
 441                    else:
 442                        query = f"ATTACH IF NOT EXISTS '{path_options}'"
 443                        if not path_options.startswith("md:"):
 444                            query += f" AS {alias}"
 445                    cursor.execute(query)
 446                except BinderException as e:
 447                    # If a user tries to create a catalog pointing at `:memory:` and with the name `memory`
 448                    # then we don't want to raise since this happens by default. They are just doing this to
 449                    # set it as the default catalog.
 450                    # If a user tried to attach a MotherDuck database/share which has already by attached via
 451                    # `ATTACH 'md:'`, then we don't want to raise since this is expected.
 452                    if (
 453                        not (
 454                            'database with name "memory" already exists' in str(e)
 455                            and path_options == ":memory:"
 456                        )
 457                        and f"""database with name "{path_options.path.replace("md:", "")}" already exists"""
 458                        not in str(e)
 459                    ):
 460                        raise e
 461                if i == 0 and not getattr(self, "database", None):
 462                    cursor.execute(f"USE {alias}")
 463
 464        return init
 465
 466    def create_engine_adapter(
 467        self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None
 468    ) -> EngineAdapter:
 469        """Checks if another engine adapter has already been created that shares a catalog that points to the same data
 470        file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration
 471        associated with the new adapter will be ignored."""
 472        data_files = set((self.catalogs or {}).values())
 473        if self.database:
 474            if isinstance(self, MotherDuckConnectionConfig):
 475                data_files.add(
 476                    f"md:{self.database}"
 477                    + (f"?motherduck_token={self.token}" if self.token else "")
 478                )
 479            else:
 480                data_files.add(self.database)
 481        data_files.discard(":memory:")
 482        for data_file in data_files:
 483            key = data_file if isinstance(data_file, str) else data_file.path
 484            adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key)
 485            if adapter is not None:
 486                logger.info(
 487                    f"Using existing DuckDB adapter due to overlapping data file: {self._mask_sensitive_data(key)}"
 488                )
 489                return adapter
 490
 491        if data_files:
 492            masked_files = {
 493                self._mask_sensitive_data(file if isinstance(file, str) else file.path)
 494                for file in data_files
 495            }
 496            logger.info(f"Creating new DuckDB adapter for data files: {masked_files}")
 497        else:
 498            logger.info("Creating new DuckDB adapter for in-memory database")
 499        adapter = super().create_engine_adapter(
 500            register_comments_override, concurrent_tasks=concurrent_tasks
 501        )
 502        for data_file in data_files:
 503            key = data_file if isinstance(data_file, str) else data_file.path
 504            BaseDuckDBConnectionConfig._data_file_to_adapter[key] = adapter
 505        return adapter
 506
 507    def get_catalog(self) -> t.Optional[str]:
 508        if self.database:
 509            # Remove `:` from the database name in order to handle if `:memory:` is passed in
 510            return pathlib.Path(self.database.replace(":memory:", "memory")).stem
 511        if self.catalogs:
 512            return list(self.catalogs)[0]
 513        return None
 514
 515    def _mask_sensitive_data(self, string: str) -> str:
 516        # Mask MotherDuck tokens with fixed number of asterisks
 517        result = MOTHERDUCK_TOKEN_REGEX.sub(
 518            lambda m: f"{m.group(1)}{m.group(2)}{'*' * 8 if m.group(3) else ''}", string
 519        )
 520        # Mask PostgreSQL/MySQL passwords with fixed number of asterisks
 521        result = PASSWORD_REGEX.sub(lambda m: f"{m.group(1)}{'*' * 8}", result)
 522        return result
 523
 524
 525class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig):
 526    """Configuration for the MotherDuck connection."""
 527
 528    type_: t.Literal["motherduck"] = Field(alias="type", default="motherduck")
 529    DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb"
 530    DISPLAY_NAME: t.ClassVar[t.Literal["MotherDuck"]] = "MotherDuck"
 531    DISPLAY_ORDER: t.ClassVar[t.Literal[5]] = 5
 532
 533    @property
 534    def _connection_kwargs_keys(self) -> t.Set[str]:
 535        return set()
 536
 537    @property
 538    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
 539        """kwargs that are for execution config only"""
 540        from sqlmesh import __version__
 541
 542        custom_user_agent_config = {"custom_user_agent": f"SQLMesh/{__version__}"}
 543        connection_str = "md:"
 544        if self.database:
 545            # Attach single MD database instead of all databases on the account
 546            connection_str += f"{self.database}?attach_mode=single"
 547        if self.token:
 548            connection_str += f"{'&' if self.database else '?'}motherduck_token={self.token}"
 549        return {"database": connection_str, "config": custom_user_agent_config}
 550
 551    @property
 552    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
 553        return {"is_motherduck": True}
 554
 555
 556class DuckDBConnectionConfig(BaseDuckDBConnectionConfig):
 557    """Configuration for the DuckDB connection."""
 558
 559    type_: t.Literal["duckdb"] = Field(alias="type", default="duckdb")
 560    DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb"
 561    DISPLAY_NAME: t.ClassVar[t.Literal["DuckDB"]] = "DuckDB"
 562    DISPLAY_ORDER: t.ClassVar[t.Literal[1]] = 1
 563
 564
 565class SnowflakeConnectionConfig(ConnectionConfig):
 566    """Configuration for the Snowflake connection.
 567
 568    Args:
 569        account: The Snowflake account name.
 570        user: The Snowflake username.
 571        password: The Snowflake password.
 572        warehouse: The optional warehouse name.
 573        database: The optional database name.
 574        role: The optional role name.
 575        concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
 576        authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake").
 577                       Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183
 578        token: The optional oauth access token to use for authentication when authenticator is set to "oauth".
 579        private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation
 580        private_key_path: The optional path to the private key to use for authentication. This would be used instead of `private_key`.
 581        private_key_passphrase: The optional passphrase to use to decrypt `private_key` or `private_key_path`. Keys can be created without encryption so only provide this if needed.
 582        register_comments: Whether or not to register model comments with the SQL engine.
 583        pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
 584        session_parameters: The optional session parameters to set for the connection.
 585        host: Host address for the connection.
 586        port: Port for the connection.
 587    """
 588
 589    account: str
 590    user: t.Optional[str] = None
 591    password: t.Optional[str] = None
 592    warehouse: t.Optional[str] = None
 593    database: t.Optional[str] = None
 594    role: t.Optional[str] = None
 595    authenticator: t.Optional[str] = None
 596    token: t.Optional[str] = None
 597    host: t.Optional[str] = None
 598    port: t.Optional[int] = None
 599    application: t.Literal["Tobiko_SQLMesh"] = "Tobiko_SQLMesh"
 600
 601    # Private Key Auth
 602    private_key: t.Optional[t.Union[str, bytes]] = None
 603    private_key_path: t.Optional[str] = None
 604    private_key_passphrase: t.Optional[str] = None
 605
 606    concurrent_tasks: int = 4
 607    register_comments: bool = True
 608    pre_ping: bool = False
 609
 610    session_parameters: t.Optional[dict] = None
 611
 612    type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake")
 613    DIALECT: t.ClassVar[t.Literal["snowflake"]] = "snowflake"
 614    DISPLAY_NAME: t.ClassVar[t.Literal["Snowflake"]] = "Snowflake"
 615    DISPLAY_ORDER: t.ClassVar[t.Literal[2]] = 2
 616
 617    _concurrent_tasks_validator = concurrent_tasks_validator
 618
 619    @model_validator(mode="before")
 620    def _validate_authenticator(cls, data: t.Any) -> t.Any:
 621        if not isinstance(data, dict):
 622            return data
 623
 624        from snowflake.connector.network import DEFAULT_AUTHENTICATOR, OAUTH_AUTHENTICATOR
 625
 626        auth = data.get("authenticator")
 627        auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR
 628        user = data.get("user")
 629        password = data.get("password")
 630        data["private_key"] = cls._get_private_key(data, auth)  # type: ignore
 631
 632        if (
 633            auth == DEFAULT_AUTHENTICATOR
 634            and not data.get("private_key")
 635            and (not user or not password)
 636        ):
 637            raise ConfigError("User and password must be provided if using default authentication")
 638
 639        if auth == OAUTH_AUTHENTICATOR and not data.get("token"):
 640            raise ConfigError("Token must be provided if using oauth authentication")
 641
 642        return data
 643
 644    _engine_import_validator = _get_engine_import_validator(
 645        "snowflake.connector.network", "snowflake"
 646    )
 647
 648    @classmethod
 649    def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]:
 650        """
 651        source: https://github.com/dbt-labs/dbt-snowflake/blob/0374b4ec948982f2ac8ec0c95d53d672ad19e09c/dbt/adapters/snowflake/connections.py#L247C5-L285C1
 652
 653        Overall code change: Use local variables instead of class attributes + Validation
 654        """
 655        # Start custom code
 656        from cryptography.hazmat.backends import default_backend
 657        from cryptography.hazmat.primitives import serialization
 658        from snowflake.connector.network import (
 659            DEFAULT_AUTHENTICATOR,
 660            KEY_PAIR_AUTHENTICATOR,
 661        )
 662
 663        private_key = values.get("private_key")
 664        private_key_path = values.get("private_key_path")
 665        private_key_passphrase = values.get("private_key_passphrase")
 666        user = values.get("user")
 667        password = values.get("password")
 668        auth = auth if auth and auth != DEFAULT_AUTHENTICATOR else KEY_PAIR_AUTHENTICATOR
 669
 670        if not private_key and not private_key_path:
 671            return None
 672        if private_key and private_key_path:
 673            raise ConfigError("Cannot specify both `private_key` and `private_key_path`")
 674        if auth != KEY_PAIR_AUTHENTICATOR:
 675            raise ConfigError(
 676                f"Private key or private key path can only be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
 677            )
 678        if not user:
 679            raise ConfigError(
 680                f"User must be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
 681            )
 682        if password:
 683            raise ConfigError(
 684                f"Password cannot be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
 685            )
 686
 687        if isinstance(private_key, bytes):
 688            return private_key
 689        # End Custom Code
 690
 691        if private_key_passphrase:
 692            encoded_passphrase = private_key_passphrase.encode()
 693        else:
 694            encoded_passphrase = None
 695
 696        if private_key:
 697            if private_key.startswith("-"):
 698                p_key = serialization.load_pem_private_key(
 699                    data=bytes(private_key, "utf-8"),
 700                    password=encoded_passphrase,
 701                    backend=default_backend(),
 702                )
 703
 704            else:
 705                p_key = serialization.load_der_private_key(
 706                    data=base64.b64decode(private_key),
 707                    password=encoded_passphrase,
 708                    backend=default_backend(),
 709                )
 710
 711        elif private_key_path:
 712            with open(private_key_path, "rb") as key:
 713                p_key = serialization.load_pem_private_key(
 714                    key.read(), password=encoded_passphrase, backend=default_backend()
 715                )
 716        else:
 717            return None
 718
 719        return p_key.private_bytes(
 720            encoding=serialization.Encoding.DER,
 721            format=serialization.PrivateFormat.PKCS8,
 722            encryption_algorithm=serialization.NoEncryption(),
 723        )
 724
 725    @property
 726    def _connection_kwargs_keys(self) -> t.Set[str]:
 727        return {
 728            "user",
 729            "password",
 730            "account",
 731            "warehouse",
 732            "database",
 733            "role",
 734            "authenticator",
 735            "token",
 736            "private_key",
 737            "session_parameters",
 738            "application",
 739            "host",
 740            "port",
 741        }
 742
 743    @property
 744    def _engine_adapter(self) -> t.Type[EngineAdapter]:
 745        return engine_adapter.SnowflakeEngineAdapter
 746
 747    @property
 748    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
 749        return {"autocommit": False}
 750
 751    @property
 752    def _connection_factory(self) -> t.Callable:
 753        from snowflake import connector
 754
 755        return connector.connect
 756
 757
 758class DatabricksConnectionConfig(ConnectionConfig):
 759    """
 760    Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations
 761
 762    Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39
 763    OAuth ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication
 764
 765    Args:
 766        server_hostname: Databricks instance host name.
 767        http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef)
 768            or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
 769        access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
 770        auth_type: Set to 'databricks-oauth' or 'azure-oauth' to trigger OAuth (or dont set at all to use `access_token`)
 771        oauth_client_id: Client ID to use when auth_type is set to one of the 'oauth' types
 772        oauth_client_secret: Client Secret to use when auth_type is set to one of the 'oauth' types
 773        catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in
 774            the Databricks cluster (most likely `hive_metastore`).
 775        http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request
 776        session_configuration: An optional dictionary of Spark session parameters.
 777            Execute the SQL command `SET -v` to get a full list of available commands.
 778        databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect.
 779            Defaults to the `server_hostname` value.
 780        databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect.
 781            Defaults to the `access_token` value.
 782        databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect.
 783            Defaults to deriving the cluster id from the `http_path` value.
 784        force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector.
 785        disable_databricks_connect: Even if databricks connect is installed, do not use it.
 786        disable_spark_session: Do not use SparkSession if it is available (like when running in a notebook).
 787        pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
 788    """
 789
 790    server_hostname: t.Optional[str] = None
 791    http_path: t.Optional[str] = None
 792    access_token: t.Optional[str] = None
 793    auth_type: t.Optional[str] = None
 794    oauth_client_id: t.Optional[str] = None
 795    oauth_client_secret: t.Optional[str] = None
 796    catalog: t.Optional[str] = None
 797    http_headers: t.Optional[t.List[t.Tuple[str, str]]] = None
 798    session_configuration: t.Optional[t.Dict[str, t.Any]] = None
 799    databricks_connect_server_hostname: t.Optional[str] = None
 800    databricks_connect_access_token: t.Optional[str] = None
 801    databricks_connect_cluster_id: t.Optional[str] = None
 802    databricks_connect_use_serverless: bool = False
 803    force_databricks_connect: bool = False
 804    disable_databricks_connect: bool = False
 805    disable_spark_session: bool = False
 806
 807    concurrent_tasks: int = 1
 808    register_comments: bool = True
 809    pre_ping: t.Literal[False] = False
 810
 811    type_: t.Literal["databricks"] = Field(alias="type", default="databricks")
 812    DIALECT: t.ClassVar[t.Literal["databricks"]] = "databricks"
 813    DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks"
 814    DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3
 815
 816    _concurrent_tasks_validator = concurrent_tasks_validator
 817    _http_headers_validator = http_headers_validator
 818
 819    @model_validator(mode="before")
 820    def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
 821        # SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block.
 822        # Disabling this allows SQLMesh to determine what should be shown to the user.
 823        # Ex: We describe a table to see if it exists and therefore that execution can fail but we don't need to show
 824        # the user since it is expected if the table doesn't exist. Without this change the user would see the error.
 825        logging.getLogger("SQLQueryContextLogger").setLevel(logging.CRITICAL)
 826
 827        if not isinstance(data, dict):
 828            return data
 829
 830        from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
 831
 832        if DatabricksEngineAdapter.can_access_spark_session(
 833            bool(data.get("disable_spark_session"))
 834        ):
 835            return data
 836
 837        databricks_connect_use_serverless = data.get("databricks_connect_use_serverless")
 838        server_hostname, http_path, access_token, auth_type = (
 839            data.get("server_hostname"),
 840            data.get("http_path"),
 841            data.get("access_token"),
 842            data.get("auth_type"),
 843        )
 844
 845        if (not server_hostname or not http_path or not access_token) and (
 846            not databricks_connect_use_serverless and not auth_type
 847        ):
 848            raise ValueError(
 849                "`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook"
 850            )
 851        if (
 852            databricks_connect_use_serverless
 853            and not server_hostname
 854            and not data.get("databricks_connect_server_hostname")
 855        ):
 856            raise ValueError(
 857                "`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set"
 858            )
 859        if DatabricksEngineAdapter.can_access_databricks_connect(
 860            bool(data.get("disable_databricks_connect"))
 861        ):
 862            if not data.get("databricks_connect_access_token"):
 863                data["databricks_connect_access_token"] = access_token
 864            if not data.get("databricks_connect_server_hostname"):
 865                data["databricks_connect_server_hostname"] = f"https://{server_hostname}"
 866            if not databricks_connect_use_serverless and not data.get(
 867                "databricks_connect_cluster_id"
 868            ):
 869                if t.TYPE_CHECKING:
 870                    assert http_path is not None
 871                data["databricks_connect_cluster_id"] = http_path.split("/")[-1]
 872
 873        if auth_type:
 874            from databricks.sql.auth.auth import AuthType
 875
 876            all_data = [m.value for m in AuthType]
 877            if auth_type not in all_data:
 878                raise ValueError(
 879                    f"`auth_type` {auth_type} does not match a valid option: {all_data}"
 880                )
 881
 882            client_id = data.get("oauth_client_id")
 883            client_secret = data.get("oauth_client_secret")
 884
 885            if client_secret and not client_id:
 886                raise ValueError(
 887                    "`oauth_client_id` is required when `oauth_client_secret` is specified"
 888                )
 889
 890            if not http_path:
 891                raise ValueError("`http_path` is still required when using `auth_type`")
 892
 893        return data
 894
 895    _engine_import_validator = _get_engine_import_validator("databricks", "databricks")
 896
 897    @property
 898    def _connection_kwargs_keys(self) -> t.Set[str]:
 899        if self.use_spark_session_only:
 900            return set()
 901        return {
 902            "server_hostname",
 903            "http_path",
 904            "access_token",
 905            "http_headers",
 906            "session_configuration",
 907            "catalog",
 908        }
 909
 910    @property
 911    def _engine_adapter(self) -> t.Type[engine_adapter.DatabricksEngineAdapter]:
 912        return engine_adapter.DatabricksEngineAdapter
 913
 914    @property
 915    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
 916        return {
 917            k: v
 918            for k, v in self.dict().items()
 919            if k.startswith("databricks_connect_")
 920            or k in ("catalog", "disable_databricks_connect", "disable_spark_session")
 921        }
 922
 923    @property
 924    def use_spark_session_only(self) -> bool:
 925        from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
 926
 927        return (
 928            DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session)
 929            or self.force_databricks_connect
 930        )
 931
 932    @property
 933    def _connection_factory(self) -> t.Callable:
 934        if self.use_spark_session_only:
 935            from sqlmesh.engines.spark.db_api.spark_session import connection
 936
 937            return connection
 938
 939        from databricks import sql  # type: ignore
 940
 941        return sql.connect
 942
 943    @property
 944    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
 945        from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
 946
 947        if not self.use_spark_session_only:
 948            conn_kwargs: t.Dict[str, t.Any] = {
 949                "_user_agent_entry": "sqlmesh",
 950            }
 951
 952            if self.auth_type and "oauth" in self.auth_type:
 953                # there are two types of oauth: User-to-Machine (U2M) and Machine-to-Machine (M2M)
 954                if self.oauth_client_secret:
 955                    # if a client_secret exists, then a client_id also exists and we are using M2M
 956                    # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication
 957                    # ref: https://github.com/databricks/databricks-sql-python/blob/main/examples/m2m_oauth.py
 958                    from databricks.sdk.core import oauth_service_principal, Config
 959
 960                    config = Config(
 961                        host=f"https://{self.server_hostname}",
 962                        client_id=self.oauth_client_id,
 963                        client_secret=self.oauth_client_secret,
 964                    )
 965                    conn_kwargs["credentials_provider"] = lambda: oauth_service_principal(config)
 966                else:
 967                    # if auth_type is set to an 'oauth' type but no client_id/secret are set, then we are using U2M
 968                    # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-user-to-machine-u2m-authentication
 969                    conn_kwargs["auth_type"] = self.auth_type
 970
 971            return conn_kwargs
 972
 973        if DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session):
 974            from pyspark.sql import SparkSession
 975
 976            return dict(
 977                spark=SparkSession.getActiveSession(),
 978                catalog=self.catalog,
 979            )
 980
 981        from databricks.connect import DatabricksSession
 982
 983        if t.TYPE_CHECKING:
 984            assert self.databricks_connect_server_hostname is not None
 985            assert self.databricks_connect_access_token is not None
 986
 987        if self.databricks_connect_use_serverless:
 988            builder = DatabricksSession.builder.remote(
 989                host=self.databricks_connect_server_hostname,
 990                token=self.databricks_connect_access_token,
 991                serverless=True,
 992            )
 993        else:
 994            if t.TYPE_CHECKING:
 995                assert self.databricks_connect_cluster_id is not None
 996            builder = DatabricksSession.builder.remote(
 997                host=self.databricks_connect_server_hostname,
 998                token=self.databricks_connect_access_token,
 999                cluster_id=self.databricks_connect_cluster_id,
1000            )
1001
1002        return dict(
1003            spark=builder.userAgent("sqlmesh").getOrCreate(),
1004            catalog=self.catalog,
1005        )
1006
1007
1008class BigQueryConnectionMethod(str, Enum):
1009    OAUTH = "oauth"
1010    OAUTH_SECRETS = "oauth-secrets"
1011    SERVICE_ACCOUNT = "service-account"
1012    SERVICE_ACCOUNT_JSON = "service-account-json"
1013
1014
1015class BigQueryPriority(str, Enum):
1016    BATCH = "batch"
1017    INTERACTIVE = "interactive"
1018
1019    @property
1020    def is_batch(self) -> bool:
1021        return self == self.BATCH
1022
1023    @property
1024    def is_interactive(self) -> bool:
1025        return self == self.INTERACTIVE
1026
1027    @property
1028    def bigquery_constant(self) -> str:
1029        from google.cloud.bigquery import QueryPriority
1030
1031        if self.is_batch:
1032            return QueryPriority.BATCH
1033        return QueryPriority.INTERACTIVE
1034
1035
1036class BigQueryConnectionConfig(ConnectionConfig):
1037    """
1038    BigQuery Connection Configuration.
1039    """
1040
1041    method: BigQueryConnectionMethod = BigQueryConnectionMethod.OAUTH
1042
1043    project: t.Optional[str] = None
1044    execution_project: t.Optional[str] = None
1045    quota_project: t.Optional[str] = None
1046    location: t.Optional[str] = None
1047    # Keyfile Auth
1048    keyfile: t.Optional[str] = None
1049    keyfile_json: t.Optional[t.Dict[str, t.Any]] = None
1050    # Oath Secret Auth
1051    token: t.Optional[str] = None
1052    refresh_token: t.Optional[str] = None
1053    client_id: t.Optional[str] = None
1054    client_secret: t.Optional[str] = None
1055    token_uri: t.Optional[str] = None
1056    scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/bigquery",)
1057    impersonated_service_account: t.Optional[str] = None
1058    # Extra Engine Config
1059    job_creation_timeout_seconds: t.Optional[int] = None
1060    job_execution_timeout_seconds: t.Optional[int] = None
1061    job_retries: t.Optional[int] = 1
1062    job_retry_deadline_seconds: t.Optional[int] = None
1063    priority: t.Optional[BigQueryPriority] = None
1064    maximum_bytes_billed: t.Optional[int] = None
1065    reservation: t.Optional[str] = None
1066
1067    concurrent_tasks: int = 1
1068    register_comments: bool = True
1069    pre_ping: t.Literal[False] = False
1070
1071    type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery")
1072    DIALECT: t.ClassVar[t.Literal["bigquery"]] = "bigquery"
1073    DISPLAY_NAME: t.ClassVar[t.Literal["BigQuery"]] = "BigQuery"
1074    DISPLAY_ORDER: t.ClassVar[t.Literal[4]] = 4
1075
1076    _engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery")
1077
1078    @field_validator("execution_project")
1079    def validate_execution_project(
1080        cls,
1081        v: t.Optional[str],
1082        info: ValidationInfo,
1083    ) -> t.Optional[str]:
1084        if v and not info.data.get("project"):
1085            raise ConfigError(
1086                "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location."
1087            )
1088        return v
1089
1090    @field_validator("quota_project")
1091    def validate_quota_project(
1092        cls,
1093        v: t.Optional[str],
1094        info: ValidationInfo,
1095    ) -> t.Optional[str]:
1096        if v and not info.data.get("project"):
1097            raise ConfigError(
1098                "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location."
1099            )
1100        return v
1101
1102    @property
1103    def _connection_kwargs_keys(self) -> t.Set[str]:
1104        return set()
1105
1106    @property
1107    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1108        return engine_adapter.BigQueryEngineAdapter
1109
1110    @property
1111    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
1112        """The static connection kwargs for this connection"""
1113        import google.auth
1114        from google.auth import impersonated_credentials
1115        from google.api_core import client_info, client_options
1116        from google.oauth2 import credentials, service_account
1117
1118        if self.method == BigQueryConnectionMethod.OAUTH:
1119            creds, _ = google.auth.default(scopes=self.scopes)
1120        elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT:
1121            creds = service_account.Credentials.from_service_account_file(
1122                self.keyfile, scopes=self.scopes
1123            )
1124        elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON:
1125            creds = service_account.Credentials.from_service_account_info(
1126                self.keyfile_json, scopes=self.scopes
1127            )
1128        elif self.method == BigQueryConnectionMethod.OAUTH_SECRETS:
1129            creds = credentials.Credentials(
1130                token=self.token,
1131                refresh_token=self.refresh_token,
1132                client_id=self.client_id,
1133                client_secret=self.client_secret,
1134                token_uri=self.token_uri,
1135                scopes=self.scopes,
1136            )
1137        else:
1138            raise ConfigError("Invalid BigQuery Connection Method")
1139
1140        if self.impersonated_service_account:
1141            creds = impersonated_credentials.Credentials(
1142                source_credentials=creds,
1143                target_principal=self.impersonated_service_account,
1144                target_scopes=self.scopes,
1145            )
1146
1147        options = client_options.ClientOptions(quota_project_id=self.quota_project)
1148        project = self.execution_project or self.project or None
1149
1150        client = google.cloud.bigquery.Client(
1151            project=project and exp.parse_identifier(project, dialect="bigquery").name,
1152            credentials=creds,
1153            location=self.location,
1154            client_info=client_info.ClientInfo(user_agent="sqlmesh"),
1155            client_options=options,
1156        )
1157
1158        return {
1159            "client": client,
1160        }
1161
1162    @property
1163    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1164        return {
1165            k: v
1166            for k, v in self.dict().items()
1167            if k
1168            in {
1169                "job_creation_timeout_seconds",
1170                "job_execution_timeout_seconds",
1171                "job_retries",
1172                "job_retry_deadline_seconds",
1173                "priority",
1174                "maximum_bytes_billed",
1175                "reservation",
1176            }
1177        }
1178
1179    @property
1180    def _connection_factory(self) -> t.Callable:
1181        from google.cloud.bigquery.dbapi import connect
1182
1183        return connect
1184
1185    def get_catalog(self) -> t.Optional[str]:
1186        return self.project
1187
1188
1189class GCPPostgresConnectionConfig(ConnectionConfig):
1190    """
1191    Postgres Connection Configuration for GCP.
1192
1193    Args:
1194        instance_connection_string: Connection name for the postgres instance.
1195        user: Postgres or IAM user's name
1196        password: The postgres user's password. Only needed when the user is a postgres user.
1197        enable_iam_auth: Set to True when user is an IAM user.
1198        db: Name of the db to connect to.
1199        keyfile: string path to json service account credentials file
1200        keyfile_json: dict service account credentials info
1201        pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
1202    """
1203
1204    instance_connection_string: str
1205    user: str
1206    password: t.Optional[str] = None
1207    enable_iam_auth: t.Optional[bool] = None
1208    db: str
1209    ip_type: t.Union[t.Literal["public"], t.Literal["private"], t.Literal["psc"]] = "public"
1210    # Keyfile Auth
1211    keyfile: t.Optional[str] = None
1212    keyfile_json: t.Optional[t.Dict[str, t.Any]] = None
1213    timeout: t.Optional[int] = None
1214    scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/sqlservice.admin",)
1215    driver: str = "pg8000"
1216
1217    type_: t.Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres")
1218    DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres"
1219    DISPLAY_NAME: t.ClassVar[t.Literal["GCP Postgres"]] = "GCP Postgres"
1220    DISPLAY_ORDER: t.ClassVar[t.Literal[13]] = 13
1221
1222    concurrent_tasks: int = 4
1223    register_comments: bool = True
1224    pre_ping: bool = True
1225
1226    _engine_import_validator = _get_engine_import_validator(
1227        "google.cloud.sql", "gcp_postgres", "gcppostgres"
1228    )
1229
1230    @model_validator(mode="before")
1231    def _validate_auth_method(cls, data: t.Any) -> t.Any:
1232        if not isinstance(data, dict):
1233            return data
1234
1235        password = data.get("password")
1236        enable_iam_auth = data.get("enable_iam_auth")
1237
1238        if not password and not enable_iam_auth:
1239            raise ConfigError(
1240                "GCP Postgres connection configuration requires either password set"
1241                " for a postgres user account or enable_iam_auth set to 'True'"
1242                " for an IAM user account."
1243            )
1244
1245        return data
1246
1247    @property
1248    def _connection_kwargs_keys(self) -> t.Set[str]:
1249        return {
1250            "instance_connection_string",
1251            "driver",
1252            "user",
1253            "password",
1254            "db",
1255            "enable_iam_auth",
1256            "timeout",
1257        }
1258
1259    @property
1260    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1261        return engine_adapter.PostgresEngineAdapter
1262
1263    @property
1264    def _connection_factory(self) -> t.Callable:
1265        from google.cloud.sql.connector import Connector
1266        from google.oauth2 import service_account
1267
1268        creds = None
1269        if self.keyfile:
1270            creds = service_account.Credentials.from_service_account_file(
1271                self.keyfile, scopes=self.scopes
1272            )
1273        elif self.keyfile_json:
1274            creds = service_account.Credentials.from_service_account_info(
1275                self.keyfile_json, scopes=self.scopes
1276            )
1277
1278        kwargs = {
1279            "credentials": creds,
1280            "ip_type": self.ip_type,
1281        }
1282
1283        if self.timeout:
1284            kwargs["timeout"] = self.timeout
1285
1286        return Connector(**kwargs).connect  # type: ignore
1287
1288
1289class RedshiftConnectionConfig(ConnectionConfig):
1290    """
1291    Redshift Connection Configuration.
1292
1293    Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146
1294    Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported.
1295
1296    Args:
1297        user: The username to use for authentication with the Amazon Redshift cluster.
1298        password: The password to use for authentication with the Amazon Redshift cluster.
1299        database: The name of the database instance to connect to.
1300        host: The hostname of the Amazon Redshift cluster.
1301        port: The port number of the Amazon Redshift cluster. Default value is 5439.
1302        source_address: No description provided
1303        unix_sock: No description provided
1304        ssl: Is SSL enabled. Default value is ``True``. SSL must be enabled when authenticating using IAM.
1305        sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported.
1306        timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout.
1307        tcp_keepalive: Is `TCP keepalive <https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive>`_ used. The default value is ``True``.
1308        application_name: Sets the application name. The default value is None.
1309        preferred_role: The IAM role preferred for the current connection.
1310        principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy.
1311        credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster.
1312        region: The AWS region where the Amazon Redshift cluster is located.
1313        cluster_identifier: The cluster identifier of the Amazon Redshift cluster.
1314        iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP.
1315        is_serverless: Redshift end-point is serverless or provisional. Default value false.
1316        serverless_acct_id: The account ID of the serverless. Default value None
1317        serverless_work_group: The name of work group for serverless end point. Default value None.
1318        pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
1319        enable_merge: Whether to use the Redshift merge operation instead of the SQLMesh logical merge.
1320    """
1321
1322    user: t.Optional[str] = None
1323    password: t.Optional[str] = None
1324    database: t.Optional[str] = None
1325    host: t.Optional[str] = None
1326    port: t.Optional[int] = None
1327    source_address: t.Optional[str] = None
1328    unix_sock: t.Optional[str] = None
1329    ssl: t.Optional[bool] = None
1330    sslmode: t.Optional[str] = None
1331    timeout: t.Optional[int] = None
1332    tcp_keepalive: t.Optional[bool] = None
1333    application_name: t.Optional[str] = None
1334    preferred_role: t.Optional[str] = None
1335    principal_arn: t.Optional[str] = None
1336    credentials_provider: t.Optional[str] = None
1337    region: t.Optional[str] = None
1338    cluster_identifier: t.Optional[str] = None
1339    iam: t.Optional[bool] = None
1340    is_serverless: t.Optional[bool] = None
1341    serverless_acct_id: t.Optional[str] = None
1342    serverless_work_group: t.Optional[str] = None
1343    enable_merge: t.Optional[bool] = None
1344
1345    concurrent_tasks: int = 4
1346    register_comments: bool = True
1347    pre_ping: bool = False
1348
1349    type_: t.Literal["redshift"] = Field(alias="type", default="redshift")
1350    DIALECT: t.ClassVar[t.Literal["redshift"]] = "redshift"
1351    DISPLAY_NAME: t.ClassVar[t.Literal["Redshift"]] = "Redshift"
1352    DISPLAY_ORDER: t.ClassVar[t.Literal[7]] = 7
1353
1354    _engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift")
1355
1356    @property
1357    def _connection_kwargs_keys(self) -> t.Set[str]:
1358        return {
1359            "user",
1360            "password",
1361            "database",
1362            "host",
1363            "port",
1364            "source_address",
1365            "unix_sock",
1366            "ssl",
1367            "sslmode",
1368            "timeout",
1369            "tcp_keepalive",
1370            "application_name",
1371            "preferred_role",
1372            "principal_arn",
1373            "credentials_provider",
1374            "region",
1375            "cluster_identifier",
1376            "iam",
1377            "is_serverless",
1378            "serverless_acct_id",
1379            "serverless_work_group",
1380        }
1381
1382    @property
1383    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1384        return engine_adapter.RedshiftEngineAdapter
1385
1386    @property
1387    def _connection_factory(self) -> t.Callable:
1388        from redshift_connector import connect
1389
1390        return connect
1391
1392    @property
1393    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1394        return {"enable_merge": self.enable_merge}
1395
1396
1397class PostgresConnectionConfig(ConnectionConfig):
1398    host: str
1399    user: str
1400    password: str
1401    port: int
1402    database: str
1403    keepalives_idle: t.Optional[int] = None
1404    connect_timeout: int = 10
1405    role: t.Optional[str] = None
1406    sslmode: t.Optional[str] = None
1407    application_name: t.Optional[str] = None
1408
1409    concurrent_tasks: int = 4
1410    register_comments: bool = True
1411    pre_ping: bool = True
1412
1413    type_: t.Literal["postgres"] = Field(alias="type", default="postgres")
1414    DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres"
1415    DISPLAY_NAME: t.ClassVar[t.Literal["Postgres"]] = "Postgres"
1416    DISPLAY_ORDER: t.ClassVar[t.Literal[12]] = 12
1417
1418    _engine_import_validator = _get_engine_import_validator("psycopg2", "postgres")
1419
1420    @property
1421    def _connection_kwargs_keys(self) -> t.Set[str]:
1422        return {
1423            "host",
1424            "user",
1425            "password",
1426            "port",
1427            "database",
1428            "keepalives_idle",
1429            "connect_timeout",
1430            "sslmode",
1431            "application_name",
1432        }
1433
1434    @property
1435    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1436        return engine_adapter.PostgresEngineAdapter
1437
1438    @property
1439    def _connection_factory(self) -> t.Callable:
1440        from psycopg2 import connect
1441
1442        return connect
1443
1444    @property
1445    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
1446        if not self.role:
1447            return None
1448
1449        def init(cursor: t.Any) -> None:
1450            cursor.execute(f"SET ROLE {self.role}")
1451
1452        return init
1453
1454
1455class MySQLConnectionConfig(ConnectionConfig):
1456    host: str
1457    user: str
1458    password: str
1459    port: t.Optional[int] = None
1460    database: t.Optional[str] = None
1461    charset: t.Optional[str] = None
1462    collation: t.Optional[str] = None
1463    ssl_disabled: t.Optional[bool] = None
1464
1465    concurrent_tasks: int = 4
1466    register_comments: bool = True
1467    pre_ping: bool = True
1468
1469    type_: t.Literal["mysql"] = Field(alias="type", default="mysql")
1470    DIALECT: t.ClassVar[t.Literal["mysql"]] = "mysql"
1471    DISPLAY_NAME: t.ClassVar[t.Literal["MySQL"]] = "MySQL"
1472    DISPLAY_ORDER: t.ClassVar[t.Literal[14]] = 14
1473
1474    _engine_import_validator = _get_engine_import_validator("pymysql", "mysql")
1475
1476    @property
1477    def _connection_kwargs_keys(self) -> t.Set[str]:
1478        connection_keys = {
1479            "host",
1480            "user",
1481            "password",
1482        }
1483        if self.port is not None:
1484            connection_keys.add("port")
1485        if self.database is not None:
1486            connection_keys.add("database")
1487        if self.charset is not None:
1488            connection_keys.add("charset")
1489        if self.collation is not None:
1490            connection_keys.add("collation")
1491        if self.ssl_disabled is not None:
1492            connection_keys.add("ssl_disabled")
1493        return connection_keys
1494
1495    @property
1496    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1497        return engine_adapter.MySQLEngineAdapter
1498
1499    @property
1500    def _connection_factory(self) -> t.Callable:
1501        from pymysql import connect
1502
1503        return connect
1504
1505
1506class MSSQLConnectionConfig(ConnectionConfig):
1507    host: str
1508    user: t.Optional[str] = None
1509    password: t.Optional[str] = None
1510    database: t.Optional[str] = ""
1511    timeout: t.Optional[int] = 0
1512    login_timeout: t.Optional[int] = 60
1513    charset: t.Optional[str] = "UTF-8"
1514    appname: t.Optional[str] = None
1515    port: t.Optional[int] = 1433
1516    conn_properties: t.Optional[t.Union[t.List[str], str]] = None
1517    autocommit: t.Optional[bool] = False
1518    tds_version: t.Optional[str] = None
1519
1520    # Driver options
1521    driver: t.Literal["pymssql", "pyodbc"] = "pymssql"
1522    # PyODBC specific options
1523    driver_name: t.Optional[str] = None  # e.g. "ODBC Driver 18 for SQL Server"
1524    trust_server_certificate: t.Optional[bool] = None
1525    encrypt: t.Optional[bool] = None
1526    # Dictionary of arbitrary ODBC connection properties
1527    # See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute
1528    odbc_properties: t.Optional[t.Dict[str, t.Any]] = None
1529
1530    concurrent_tasks: int = 4
1531    register_comments: bool = True
1532    pre_ping: bool = True
1533
1534    type_: t.Literal["mssql"] = Field(alias="type", default="mssql")
1535    DIALECT: t.ClassVar[t.Literal["tsql"]] = "tsql"
1536    DISPLAY_NAME: t.ClassVar[t.Literal["MSSQL"]] = "MSSQL"
1537    DISPLAY_ORDER: t.ClassVar[t.Literal[11]] = 11
1538
1539    @model_validator(mode="before")
1540    @classmethod
1541    def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any:
1542        if not isinstance(data, dict):
1543            return data
1544
1545        driver = data.get("driver", "pymssql")
1546
1547        # Define the mapping of driver to import module and extra name
1548        driver_configs = {"pymssql": ("pymssql", "mssql"), "pyodbc": ("pyodbc", "mssql-odbc")}
1549
1550        if driver not in driver_configs:
1551            raise ValueError(f"Unsupported driver: {driver}")
1552
1553        import_module, extra_name = driver_configs[driver]
1554
1555        # Use _get_engine_import_validator with decorate=False to get the raw validation function
1556        # This avoids the __wrapped__ issue in Python 3.9
1557        validator_func = _get_engine_import_validator(
1558            import_module, driver, extra_name, decorate=False
1559        )
1560
1561        # Call the raw validation function directly
1562        return validator_func(cls, data)
1563
1564    @property
1565    def _connection_kwargs_keys(self) -> t.Set[str]:
1566        base_keys = {
1567            "host",
1568            "user",
1569            "password",
1570            "database",
1571            "timeout",
1572            "login_timeout",
1573            "charset",
1574            "appname",
1575            "port",
1576            "conn_properties",
1577            "autocommit",
1578            "tds_version",
1579        }
1580
1581        if self.driver == "pyodbc":
1582            base_keys.update(
1583                {
1584                    "driver_name",
1585                    "trust_server_certificate",
1586                    "encrypt",
1587                    "odbc_properties",
1588                }
1589            )
1590            # Remove pymssql-specific parameters
1591            base_keys.discard("tds_version")
1592            base_keys.discard("conn_properties")
1593
1594        return base_keys
1595
1596    @property
1597    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1598        return engine_adapter.MSSQLEngineAdapter
1599
1600    @property
1601    def _connection_factory(self) -> t.Callable:
1602        if self.driver == "pymssql":
1603            import pymssql
1604
1605            return pymssql.connect
1606
1607        import pyodbc
1608
1609        def connect(**kwargs: t.Any) -> t.Callable:
1610            # Extract parameters for connection string
1611            host = kwargs.pop("host")
1612            port = kwargs.pop("port", 1433)
1613            database = kwargs.pop("database", "")
1614            user = kwargs.pop("user", None)
1615            password = kwargs.pop("password", None)
1616            driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server")
1617            trust_server_certificate = kwargs.pop("trust_server_certificate", False)
1618            encrypt = kwargs.pop("encrypt", True)
1619            login_timeout = kwargs.pop("login_timeout", 60)
1620
1621            # Build connection string
1622            conn_str_parts = [
1623                f"DRIVER={{{driver_name}}}",
1624                f"SERVER={host},{port}",
1625            ]
1626
1627            if database:
1628                conn_str_parts.append(f"DATABASE={database}")
1629
1630            # Add security options
1631            conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}")
1632            if trust_server_certificate:
1633                conn_str_parts.append("TrustServerCertificate=YES")
1634
1635            conn_str_parts.append(f"Connection Timeout={login_timeout}")
1636
1637            # Standard SQL Server authentication
1638            if user:
1639                conn_str_parts.append(f"UID={user}")
1640            if password:
1641                conn_str_parts.append(f"PWD={password}")
1642
1643            # Add any additional ODBC properties from the odbc_properties dictionary
1644            if self.odbc_properties:
1645                for key, value in self.odbc_properties.items():
1646                    # Skip properties that we've already set above
1647                    if key.lower() in (
1648                        "driver",
1649                        "server",
1650                        "database",
1651                        "uid",
1652                        "pwd",
1653                        "encrypt",
1654                        "trustservercertificate",
1655                        "connection timeout",
1656                    ):
1657                        continue
1658
1659                    # Handle boolean values properly
1660                    if isinstance(value, bool):
1661                        conn_str_parts.append(f"{key}={'YES' if value else 'NO'}")
1662                    else:
1663                        conn_str_parts.append(f"{key}={value}")
1664
1665            # Create the connection string
1666            conn_str = ";".join(conn_str_parts)
1667
1668            conn = pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False))
1669
1670            # Set up output converters for MSSQL-specific data types
1671            # Handle SQL type -155 (DATETIMEOFFSET) which is not yet supported by pyodbc
1672            # ref: https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794
1673            def handle_datetimeoffset(dto_value: t.Any) -> t.Any:
1674                from datetime import datetime, timedelta, timezone
1675                import struct
1676
1677                # Unpack the DATETIMEOFFSET binary format:
1678                # Format: <6hI2h = (year, month, day, hour, minute, second, nanoseconds, tz_hour_offset, tz_minute_offset)
1679                tup = struct.unpack("<6hI2h", dto_value)
1680                return datetime(
1681                    tup[0],
1682                    tup[1],
1683                    tup[2],
1684                    tup[3],
1685                    tup[4],
1686                    tup[5],
1687                    tup[6] // 1000,
1688                    timezone(timedelta(hours=tup[7], minutes=tup[8])),
1689                )
1690
1691            conn.add_output_converter(-155, handle_datetimeoffset)
1692
1693            return conn
1694
1695        return connect
1696
1697    @property
1698    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1699        return {"catalog_support": CatalogSupport.REQUIRES_SET_CATALOG}
1700
1701
1702class AzureSQLConnectionConfig(MSSQLConnectionConfig):
1703    type_: t.Literal["azuresql"] = Field(alias="type", default="azuresql")  # type: ignore
1704    DISPLAY_NAME: t.ClassVar[t.Literal["Azure SQL"]] = "Azure SQL"  # type: ignore
1705    DISPLAY_ORDER: t.ClassVar[t.Literal[10]] = 10  # type: ignore
1706
1707    @property
1708    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1709        return {"catalog_support": CatalogSupport.SINGLE_CATALOG_ONLY}
1710
1711
1712class FabricConnectionConfig(MSSQLConnectionConfig):
1713    """
1714    Fabric Connection Configuration.
1715    Inherits most settings from MSSQLConnectionConfig and sets the type to 'fabric'.
1716    It is recommended to use the 'pyodbc' driver for Fabric.
1717    """
1718
1719    type_: t.Literal["fabric"] = Field(alias="type", default="fabric")  # type: ignore
1720    DIALECT: t.ClassVar[t.Literal["fabric"]] = "fabric"  # type: ignore
1721    DISPLAY_NAME: t.ClassVar[t.Literal["Fabric"]] = "Fabric"  # type: ignore
1722    DISPLAY_ORDER: t.ClassVar[t.Literal[17]] = 17  # type: ignore
1723    driver: t.Literal["pyodbc"] = "pyodbc"
1724    workspace_id: str
1725    tenant_id: str
1726    autocommit: t.Optional[bool] = True
1727
1728    @property
1729    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1730        from sqlmesh.core.engine_adapter.fabric import FabricEngineAdapter
1731
1732        return FabricEngineAdapter
1733
1734    @property
1735    def _connection_factory(self) -> t.Callable:
1736        # Override to support catalog switching for Fabric
1737        base_factory = super()._connection_factory
1738
1739        def create_fabric_connection(
1740            target_catalog: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any
1741        ) -> t.Callable:
1742            kwargs["database"] = target_catalog or self.database
1743            return base_factory(*args, **kwargs)
1744
1745        return create_fabric_connection
1746
1747    @property
1748    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1749        return {
1750            "database": self.database,
1751            # more operations than not require a specific catalog to be already active
1752            # in particular, create/drop view, create/drop schema and querying information_schema
1753            "catalog_support": CatalogSupport.REQUIRES_SET_CATALOG,
1754            "workspace_id": self.workspace_id,
1755            "tenant_id": self.tenant_id,
1756            "user": self.user,
1757            "password": self.password,
1758        }
1759
1760
1761class SparkConnectionConfig(ConnectionConfig):
1762    """
1763    Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks.
1764    """
1765
1766    config_dir: t.Optional[str] = None
1767    catalog: t.Optional[str] = None
1768    config: t.Dict[str, t.Any] = {}
1769    wap_enabled: bool = False
1770
1771    concurrent_tasks: int = 4
1772    register_comments: bool = True
1773    pre_ping: t.Literal[False] = False
1774
1775    type_: t.Literal["spark"] = Field(alias="type", default="spark")
1776    DIALECT: t.ClassVar[t.Literal["spark"]] = "spark"
1777    DISPLAY_NAME: t.ClassVar[t.Literal["Spark"]] = "Spark"
1778    DISPLAY_ORDER: t.ClassVar[t.Literal[8]] = 8
1779
1780    _engine_import_validator = _get_engine_import_validator("pyspark", "spark")
1781
1782    @property
1783    def _connection_kwargs_keys(self) -> t.Set[str]:
1784        return {
1785            "catalog",
1786        }
1787
1788    @property
1789    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1790        return engine_adapter.SparkEngineAdapter
1791
1792    @property
1793    def _connection_factory(self) -> t.Callable:
1794        from sqlmesh.engines.spark.db_api.spark_session import connection
1795
1796        return connection
1797
1798    @property
1799    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
1800        from pyspark.conf import SparkConf
1801        from pyspark.sql import SparkSession
1802
1803        spark_config = SparkConf()
1804        if self.config:
1805            for k, v in self.config.items():
1806                spark_config.set(k, v)
1807
1808        if self.config_dir:
1809            os.environ["SPARK_CONF_DIR"] = self.config_dir
1810        return {
1811            "spark": SparkSession.builder.config(conf=spark_config)
1812            .enableHiveSupport()
1813            .getOrCreate(),
1814        }
1815
1816    @property
1817    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1818        return {"wap_enabled": self.wap_enabled}
1819
1820
1821class TrinoAuthenticationMethod(str, Enum):
1822    NO_AUTH = "no-auth"
1823    BASIC = "basic"
1824    LDAP = "ldap"
1825    KERBEROS = "kerberos"
1826    JWT = "jwt"
1827    CERTIFICATE = "certificate"
1828    OAUTH = "oauth"
1829
1830    @property
1831    def is_no_auth(self) -> bool:
1832        return self == self.NO_AUTH
1833
1834    @property
1835    def is_basic(self) -> bool:
1836        return self == self.BASIC
1837
1838    @property
1839    def is_ldap(self) -> bool:
1840        return self == self.LDAP
1841
1842    @property
1843    def is_kerberos(self) -> bool:
1844        return self == self.KERBEROS
1845
1846    @property
1847    def is_jwt(self) -> bool:
1848        return self == self.JWT
1849
1850    @property
1851    def is_certificate(self) -> bool:
1852        return self == self.CERTIFICATE
1853
1854    @property
1855    def is_oauth(self) -> bool:
1856        return self == self.OAUTH
1857
1858
1859class TrinoConnectionConfig(ConnectionConfig):
1860    method: TrinoAuthenticationMethod = TrinoAuthenticationMethod.NO_AUTH
1861    host: str
1862    user: str
1863    catalog: str
1864    port: t.Optional[int] = None
1865    http_scheme: t.Literal["http", "https"] = "https"
1866    # General Optional
1867    roles: t.Optional[t.Dict[str, str]] = None
1868    http_headers: t.Optional[t.Dict[str, str]] = None
1869    session_properties: t.Optional[t.Dict[str, str]] = None
1870    retries: int = 3
1871    timezone: t.Optional[str] = None
1872    # Basic/LDAP
1873    password: t.Optional[str] = None
1874    verify: t.Optional[bool] = None  # disable SSL verification (ignored if `cert` is provided)
1875    # LDAP
1876    impersonation_user: t.Optional[str] = None
1877    # Kerberos
1878    keytab: t.Optional[str] = None
1879    krb5_config: t.Optional[str] = None
1880    principal: t.Optional[str] = None
1881    service_name: str = "trino"
1882    hostname_override: t.Optional[str] = None
1883    mutual_authentication: bool = False
1884    force_preemptive: bool = False
1885    sanitize_mutual_error_response: bool = True
1886    delegate: bool = False
1887    # JWT
1888    jwt_token: t.Optional[str] = None
1889    # Certificate
1890    client_certificate: t.Optional[str] = None
1891    client_private_key: t.Optional[str] = None
1892    cert: t.Optional[str] = None
1893    source: str = "sqlmesh"
1894
1895    # SQLMesh options
1896    schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None
1897    timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None
1898    concurrent_tasks: int = 4
1899    register_comments: bool = True
1900    pre_ping: t.Literal[False] = False
1901
1902    type_: t.Literal["trino"] = Field(alias="type", default="trino")
1903    DIALECT: t.ClassVar[t.Literal["trino"]] = "trino"
1904    DISPLAY_NAME: t.ClassVar[t.Literal["Trino"]] = "Trino"
1905    DISPLAY_ORDER: t.ClassVar[t.Literal[9]] = 9
1906
1907    _engine_import_validator = _get_engine_import_validator("trino", "trino")
1908
1909    @field_validator("schema_location_mapping", mode="before")
1910    @classmethod
1911    def _validate_regex_keys(
1912        cls, value: t.Dict[str | re.Pattern, str]
1913    ) -> t.Dict[re.Pattern, t.Any]:
1914        compiled = compile_regex_mapping(value)
1915        for replacement in compiled.values():
1916            if "@{schema_name}" not in replacement:
1917                raise ConfigError(
1918                    "schema_location_mapping needs to include the '@{schema_name}' placeholder in the value so SQLMesh knows where to substitute the schema name"
1919                )
1920        return compiled
1921
1922    @field_validator("timestamp_mapping", mode="before")
1923    @classmethod
1924    def _validate_timestamp_mapping(
1925        cls, value: t.Optional[dict[str, str]]
1926    ) -> t.Optional[dict[exp.DataType, exp.DataType]]:
1927        if value is None:
1928            return value
1929
1930        result: dict[exp.DataType, exp.DataType] = {}
1931        for source_type, target_type in value.items():
1932            try:
1933                source_datatype = exp.DataType.build(source_type)
1934            except ParseError:
1935                raise ConfigError(
1936                    f"Invalid SQL type string in timestamp_mapping: "
1937                    f"'{source_type}' is not a valid SQL data type."
1938                )
1939            try:
1940                target_datatype = exp.DataType.build(target_type)
1941            except ParseError:
1942                raise ConfigError(
1943                    f"Invalid SQL type string in timestamp_mapping: "
1944                    f"'{target_type}' is not a valid SQL data type."
1945                )
1946            result[source_datatype] = target_datatype
1947
1948        return result
1949
1950    @model_validator(mode="after")
1951    def _root_validator(self) -> Self:
1952        port = self.port
1953        if self.http_scheme == "http" and not self.method.is_no_auth and not self.method.is_basic:
1954            raise ConfigError("HTTP scheme can only be used with no-auth or basic method")
1955
1956        if port is None:
1957            self.port = 80 if self.http_scheme == "http" else 443
1958
1959        if (self.method.is_ldap or self.method.is_basic) and (not self.password or not self.user):
1960            raise ConfigError(
1961                f"Username and Password must be provided if using {self.method.value} authentication"
1962            )
1963
1964        if self.method.is_kerberos and (
1965            not self.principal or not self.keytab or not self.krb5_config
1966        ):
1967            raise ConfigError(
1968                "Kerberos requires the following fields: principal, keytab, and krb5_config"
1969            )
1970
1971        if self.method.is_jwt and not self.jwt_token:
1972            raise ConfigError("JWT requires `jwt_token` to be set")
1973
1974        if self.method.is_certificate and (
1975            not self.cert or not self.client_certificate or not self.client_private_key
1976        ):
1977            raise ConfigError(
1978                "Certificate requires the following fields: cert, client_certificate, and client_private_key"
1979            )
1980
1981        return self
1982
1983    @property
1984    def _connection_kwargs_keys(self) -> t.Set[str]:
1985        kwargs = {
1986            "host",
1987            "port",
1988            "catalog",
1989            "roles",
1990            "source",
1991            "http_scheme",
1992            "http_headers",
1993            "session_properties",
1994            "timezone",
1995        }
1996        return kwargs
1997
1998    @property
1999    def _engine_adapter(self) -> t.Type[EngineAdapter]:
2000        return engine_adapter.TrinoEngineAdapter
2001
2002    @property
2003    def _connection_factory(self) -> t.Callable:
2004        from trino.dbapi import connect
2005
2006        return connect
2007
2008    @property
2009    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
2010        from trino.auth import (
2011            BasicAuthentication,
2012            CertificateAuthentication,
2013            JWTAuthentication,
2014            KerberosAuthentication,
2015            OAuth2Authentication,
2016        )
2017
2018        auth: t.Optional[
2019            t.Union[
2020                BasicAuthentication,
2021                KerberosAuthentication,
2022                OAuth2Authentication,
2023                JWTAuthentication,
2024                CertificateAuthentication,
2025            ]
2026        ] = None
2027        if self.method.is_basic or self.method.is_ldap:
2028            assert self.password is not None  # for mypy since validator already checks this
2029            auth = BasicAuthentication(self.user, self.password)
2030        elif self.method.is_kerberos:
2031            if self.keytab:
2032                os.environ["KRB5_CLIENT_KTNAME"] = self.keytab
2033            auth = KerberosAuthentication(
2034                config=self.krb5_config,
2035                service_name=self.service_name,
2036                principal=self.principal,
2037                mutual_authentication=self.mutual_authentication,
2038                ca_bundle=self.cert,
2039                force_preemptive=self.force_preemptive,
2040                hostname_override=self.hostname_override,
2041                sanitize_mutual_error_response=self.sanitize_mutual_error_response,
2042                delegate=self.delegate,
2043            )
2044        elif self.method.is_oauth:
2045            auth = OAuth2Authentication()
2046        elif self.method.is_jwt:
2047            assert self.jwt_token is not None
2048            auth = JWTAuthentication(self.jwt_token)
2049        elif self.method.is_certificate:
2050            assert self.client_certificate is not None
2051            assert self.client_private_key is not None
2052            auth = CertificateAuthentication(self.client_certificate, self.client_private_key)
2053
2054        return {
2055            "auth": auth,
2056            "user": self.impersonation_user or self.user,
2057            "max_attempts": self.retries,
2058            "verify": self.cert if self.cert is not None else self.verify,
2059            "source": self.source,
2060        }
2061
2062    @property
2063    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
2064        return {
2065            "schema_location_mapping": self.schema_location_mapping,
2066            "timestamp_mapping": self.timestamp_mapping,
2067        }
2068
2069
2070class ClickhouseConnectionConfig(ConnectionConfig):
2071    """
2072    Clickhouse Connection Configuration.
2073
2074    Property reference: https://clickhouse.com/docs/en/integrations/python#client-initialization
2075    """
2076
2077    host: str
2078    username: str
2079    password: t.Optional[str] = None
2080    port: t.Optional[int] = None
2081    cluster: t.Optional[str] = None
2082    connect_timeout: int = 10
2083    send_receive_timeout: int = 300
2084    query_limit: int = 0
2085    use_compression: bool = True
2086    compression_method: t.Optional[str] = None
2087    connection_settings: t.Optional[t.Dict[str, t.Any]] = None
2088    http_proxy: t.Optional[str] = None
2089    # HTTPS/TLS settings
2090    verify: bool = True
2091    ca_cert: t.Optional[str] = None
2092    client_cert: t.Optional[str] = None
2093    client_cert_key: t.Optional[str] = None
2094    https_proxy: t.Optional[str] = None
2095    server_host_name: t.Optional[str] = None
2096    tls_mode: t.Optional[str] = None
2097
2098    concurrent_tasks: int = 1
2099    register_comments: bool = True
2100    pre_ping: bool = False
2101
2102    # This object expects options from urllib3 and also from clickhouse-connect
2103    # See:
2104    # * https://urllib3.readthedocs.io/en/stable/advanced-usage.html
2105    # * https://clickhouse.com/docs/en/integrations/python#customizing-the-http-connection-pool
2106    connection_pool_options: t.Optional[t.Dict[str, t.Any]] = None
2107
2108    type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse")
2109    DIALECT: t.ClassVar[t.Literal["clickhouse"]] = "clickhouse"
2110    DISPLAY_NAME: t.ClassVar[t.Literal["ClickHouse"]] = "ClickHouse"
2111    DISPLAY_ORDER: t.ClassVar[t.Literal[6]] = 6
2112
2113    _engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse")
2114
2115    @property
2116    def _connection_kwargs_keys(self) -> t.Set[str]:
2117        kwargs = {
2118            "host",
2119            "username",
2120            "port",
2121            "password",
2122            "connect_timeout",
2123            "send_receive_timeout",
2124            "query_limit",
2125            "http_proxy",
2126            "verify",
2127            "ca_cert",
2128            "client_cert",
2129            "client_cert_key",
2130            "https_proxy",
2131            "server_host_name",
2132            "tls_mode",
2133        }
2134        return kwargs
2135
2136    @property
2137    def _engine_adapter(self) -> t.Type[EngineAdapter]:
2138        return engine_adapter.ClickhouseEngineAdapter
2139
2140    @property
2141    def _connection_factory(self) -> t.Callable:
2142        from clickhouse_connect.dbapi import connect  # type: ignore
2143        from clickhouse_connect.driver import httputil  # type: ignore
2144        from functools import partial
2145
2146        pool_manager_options: t.Dict[str, t.Any] = dict(
2147            # Match the maxsize to the number of concurrent tasks
2148            maxsize=self.concurrent_tasks,
2149            # Block if there are no free connections
2150            block=True,
2151            verify=self.verify,
2152            ca_cert=self.ca_cert,
2153            client_cert=self.client_cert,
2154            client_cert_key=self.client_cert_key,
2155            https_proxy=self.https_proxy,
2156        )
2157        # this doesn't happen automatically because we always supply our own pool manager to the connection
2158        # https://github.com/ClickHouse/clickhouse-connect/blob/3a7f4b04cad29c7c2536661b831fb744248e2ec0/clickhouse_connect/driver/httpclient.py#L109
2159        if self.server_host_name:
2160            pool_manager_options["server_hostname"] = self.server_host_name
2161            if self.verify:
2162                pool_manager_options["assert_hostname"] = self.server_host_name
2163        if self.connection_pool_options:
2164            pool_manager_options.update(self.connection_pool_options)
2165        pool_mgr = httputil.get_pool_manager(**pool_manager_options)
2166
2167        return partial(connect, pool_mgr=pool_mgr)
2168
2169    @property
2170    def cloud_mode(self) -> bool:
2171        return "clickhouse.cloud" in self.host
2172
2173    @property
2174    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
2175        return {"cluster": self.cluster, "cloud_mode": self.cloud_mode}
2176
2177    @property
2178    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
2179        from sqlmesh import __version__
2180
2181        # False = no compression
2182        # True = Clickhouse default compression method
2183        # string = specific compression method
2184        compress: bool | str = self.use_compression
2185        if compress and self.compression_method:
2186            compress = self.compression_method
2187
2188        # Clickhouse system settings passed to connection
2189        # https://clickhouse.com/docs/en/operations/settings/settings
2190        # - below are set to align with dbt-clickhouse
2191        # - https://github.com/ClickHouse/dbt-clickhouse/blob/44d26308ea6a3c8ead25c280164aa88191f05f47/dbt/adapters/clickhouse/dbclient.py#L77
2192        settings = self.connection_settings or {}
2193        #  mutations_sync = 2: "The query waits for all mutations [ALTER statements] to complete on all replicas (if they exist)"
2194        settings["mutations_sync"] = "2"
2195        #  insert_distributed_sync = 1: "INSERT operation succeeds only after all the data is saved on all shards"
2196        settings["insert_distributed_sync"] = "1"
2197        if self.cluster or self.cloud_mode:
2198            # database_replicated_enforce_synchronous_settings = 1:
2199            #   - "Enforces synchronous waiting for some queries"
2200            #   - https://github.com/ClickHouse/ClickHouse/blob/ccaa8d03a9351efc16625340268b9caffa8a22ba/src/Core/Settings.h#L709
2201            settings["database_replicated_enforce_synchronous_settings"] = "1"
2202            # insert_quorum = auto:
2203            #   - "INSERT succeeds only when ClickHouse manages to correctly write data to the insert_quorum of replicas during
2204            #       the insert_quorum_timeout"
2205            #   - "use majority number (number_of_replicas / 2 + 1) as quorum number"
2206            settings["insert_quorum"] = "auto"
2207
2208        return {
2209            "compress": compress,
2210            "client_name": f"SQLMesh/{__version__}",
2211            **settings,
2212        }
2213
2214
2215class AthenaConnectionConfig(ConnectionConfig):
2216    # PyAthena connection options
2217    aws_access_key_id: t.Optional[str] = None
2218    aws_secret_access_key: t.Optional[str] = None
2219    role_arn: t.Optional[str] = None
2220    role_session_name: t.Optional[str] = None
2221    region_name: t.Optional[str] = None
2222    work_group: t.Optional[str] = None
2223    s3_staging_dir: t.Optional[str] = None
2224    schema_name: t.Optional[str] = None
2225    catalog_name: t.Optional[str] = None
2226
2227    # SQLMesh options
2228    s3_warehouse_location: t.Optional[str] = None
2229    concurrent_tasks: int = 4
2230    register_comments: t.Literal[False] = (
2231        False  # because Athena doesnt support comments in most cases
2232    )
2233    pre_ping: t.Literal[False] = False
2234
2235    type_: t.Literal["athena"] = Field(alias="type", default="athena")
2236    DIALECT: t.ClassVar[t.Literal["athena"]] = "athena"
2237    DISPLAY_NAME: t.ClassVar[t.Literal["Athena"]] = "Athena"
2238    DISPLAY_ORDER: t.ClassVar[t.Literal[15]] = 15
2239
2240    _engine_import_validator = _get_engine_import_validator("pyathena", "athena")
2241
2242    @model_validator(mode="after")
2243    def _root_validator(self) -> Self:
2244        work_group = self.work_group
2245        s3_staging_dir = self.s3_staging_dir
2246        s3_warehouse_location = self.s3_warehouse_location
2247
2248        if not work_group and not s3_staging_dir:
2249            raise ConfigError("At least one of work_group or s3_staging_dir must be set")
2250
2251        if s3_staging_dir:
2252            self.s3_staging_dir = validate_s3_uri(s3_staging_dir, base=True, error_type=ConfigError)
2253
2254        if s3_warehouse_location:
2255            self.s3_warehouse_location = validate_s3_uri(
2256                s3_warehouse_location, base=True, error_type=ConfigError
2257            )
2258
2259        return self
2260
2261    @property
2262    def _connection_kwargs_keys(self) -> t.Set[str]:
2263        return {
2264            "aws_access_key_id",
2265            "aws_secret_access_key",
2266            "role_arn",
2267            "role_session_name",
2268            "region_name",
2269            "work_group",
2270            "s3_staging_dir",
2271            "schema_name",
2272            "catalog_name",
2273        }
2274
2275    @property
2276    def _engine_adapter(self) -> t.Type[EngineAdapter]:
2277        return engine_adapter.AthenaEngineAdapter
2278
2279    @property
2280    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
2281        return {"s3_warehouse_location": self.s3_warehouse_location}
2282
2283    @property
2284    def _connection_factory(self) -> t.Callable:
2285        from pyathena import connect  # type: ignore
2286
2287        return connect
2288
2289    def get_catalog(self) -> t.Optional[str]:
2290        return self.catalog_name
2291
2292
2293class RisingwaveConnectionConfig(ConnectionConfig):
2294    host: str
2295    user: str
2296    password: t.Optional[str] = None
2297    port: int
2298    database: str
2299    role: t.Optional[str] = None
2300    sslmode: t.Optional[str] = None
2301
2302    concurrent_tasks: int = 4
2303    register_comments: bool = True
2304    pre_ping: bool = True
2305
2306    type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave")
2307    DIALECT: t.ClassVar[t.Literal["risingwave"]] = "risingwave"
2308    DISPLAY_NAME: t.ClassVar[t.Literal["RisingWave"]] = "RisingWave"
2309    DISPLAY_ORDER: t.ClassVar[t.Literal[16]] = 16
2310
2311    _engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave")
2312
2313    @property
2314    def _connection_kwargs_keys(self) -> t.Set[str]:
2315        return {
2316            "host",
2317            "user",
2318            "password",
2319            "port",
2320            "database",
2321            "role",
2322            "sslmode",
2323        }
2324
2325    @property
2326    def _engine_adapter(self) -> t.Type[EngineAdapter]:
2327        return engine_adapter.RisingwaveEngineAdapter
2328
2329    @property
2330    def _connection_factory(self) -> t.Callable:
2331        from psycopg2 import connect
2332
2333        return connect
2334
2335    @property
2336    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
2337        def init(cursor: t.Any) -> None:
2338            sql = "SET RW_IMPLICIT_FLUSH TO true;"
2339            cursor.execute(sql)
2340
2341        return init
2342
2343
2344CONNECTION_CONFIG_TO_TYPE = {
2345    # Map all subclasses of ConnectionConfig to the value of their `type_` field.
2346    tpe.all_field_infos()["type_"].default: tpe
2347    for tpe in subclasses(
2348        __name__,
2349        ConnectionConfig,
2350        exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
2351    )
2352}
2353
2354DIALECT_TO_TYPE = {
2355    tpe.all_field_infos()["type_"].default: tpe.DIALECT
2356    for tpe in subclasses(
2357        __name__,
2358        ConnectionConfig,
2359        exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
2360    )
2361}
2362
2363INIT_DISPLAY_INFO_TO_TYPE = {
2364    tpe.all_field_infos()["type_"].default: (
2365        tpe.DISPLAY_ORDER,
2366        tpe.DISPLAY_NAME,
2367    )
2368    for tpe in subclasses(
2369        __name__,
2370        ConnectionConfig,
2371        exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
2372    )
2373}
2374
2375
2376def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig:
2377    if "type" not in v:
2378        raise ConfigError("Missing connection type.")
2379
2380    connection_type = v["type"]
2381    if connection_type not in CONNECTION_CONFIG_TO_TYPE:
2382        raise ConfigError(f"Unknown connection type '{connection_type}'.")
2383
2384    return CONNECTION_CONFIG_TO_TYPE[connection_type](**v)
2385
2386
2387def _connection_config_validator(
2388    cls: t.Type, v: ConnectionConfig | t.Dict[str, t.Any] | None
2389) -> ConnectionConfig | None:
2390    if v is None or isinstance(v, ConnectionConfig):
2391        return v
2392
2393    check_config_and_vars_msg = "\n\nVerify your config.yaml and environment variables."
2394
2395    try:
2396        return parse_connection_config(v)
2397    except pydantic.ValidationError as e:
2398        raise ConfigError(
2399            validation_error_message(e, f"Invalid '{v['type']}' connection config:")
2400            + check_config_and_vars_msg
2401        )
2402    except ConfigError as e:
2403        raise ConfigError(str(e) + check_config_and_vars_msg)
2404
2405
2406connection_config_validator: t.Callable = field_validator(
2407    "connection",
2408    "state_connection",
2409    "test_connection",
2410    "default_connection",
2411    "default_test_connection",
2412    mode="before",
2413    check_fields=False,
2414)(_connection_config_validator)
2415
2416
2417if t.TYPE_CHECKING:
2418    # TypeAlias hasn't been introduced until Python 3.10 which means that we can't use it
2419    # outside the TYPE_CHECKING guard.
2420    SerializableConnectionConfig: t.TypeAlias = ConnectionConfig  # type: ignore
2421else:
2422    import pydantic
2423
2424    # Workaround for https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
2425    SerializableConnectionConfig = pydantic.SerializeAsAny[ConnectionConfig]  # type: ignore
logger = <Logger sqlmesh.core.config.connection (WARNING)>
FORBIDDEN_STATE_SYNC_ENGINES = {'clickhouse', 'spark', 'trino'}
MOTHERDUCK_TOKEN_REGEX = re.compile('(\\?|\\&)(motherduck_token=)(\\S*)')
PASSWORD_REGEX = re.compile('(password=)(\\S+)')
class ConnectionConfig(abc.ABC, sqlmesh.core.config.base.BaseConfig):
 97class ConnectionConfig(abc.ABC, BaseConfig):
 98    type_: str
 99    DIALECT: t.ClassVar[str]
100    DISPLAY_NAME: t.ClassVar[str]
101    DISPLAY_ORDER: t.ClassVar[int]
102    concurrent_tasks: int
103    register_comments: bool
104    pre_ping: bool
105    pretty_sql: bool = False
106    schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None
107    catalog_type_overrides: t.Optional[t.Dict[str, str]] = None
108
109    # Whether to share a  single connection across threads or create a new connection per thread.
110    shared_connection: t.ClassVar[bool] = False
111
112    @property
113    @abc.abstractmethod
114    def _connection_kwargs_keys(self) -> t.Set[str]:
115        """keywords that should be passed into the connection"""
116
117    @property
118    @abc.abstractmethod
119    def _engine_adapter(self) -> t.Type[EngineAdapter]:
120        """The engine adapter for this connection"""
121
122    @property
123    @abc.abstractmethod
124    def _connection_factory(self) -> t.Callable:
125        """A function that is called to return a connection object for the given Engine Adapter"""
126
127    @property
128    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
129        """The static connection kwargs for this connection"""
130        return {}
131
132    @property
133    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
134        """kwargs that are for execution config only"""
135        return {}
136
137    @property
138    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
139        """A function that is called to initialize the cursor"""
140        return None
141
142    @property
143    def is_recommended_for_state_sync(self) -> bool:
144        """Whether this engine is recommended for being used as a state sync for production state syncs"""
145        return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES
146
147    @property
148    def is_forbidden_for_state_sync(self) -> bool:
149        """Whether this engine is forbidden from being used as a state sync"""
150        return self.type_ in FORBIDDEN_STATE_SYNC_ENGINES
151
152    @property
153    def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]:
154        """A function that is called to return a connection object for the given Engine Adapter"""
155        return partial(
156            self._connection_factory,
157            **{
158                **self._static_connection_kwargs,
159                **{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys},
160            },
161        )
162
163    def connection_validator(self) -> t.Callable[[], None]:
164        """A function that validates the connection configuration"""
165        return self.create_engine_adapter().ping
166
167    def create_engine_adapter(
168        self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None
169    ) -> EngineAdapter:
170        """Returns a new instance of the Engine Adapter."""
171
172        concurrent_tasks = concurrent_tasks or self.concurrent_tasks
173        return self._engine_adapter(
174            self._connection_factory_with_kwargs,
175            multithreaded=concurrent_tasks > 1,
176            default_catalog=self.get_catalog(),
177            cursor_init=self._cursor_init,
178            register_comments=register_comments_override or self.register_comments,
179            pre_ping=self.pre_ping,
180            pretty_sql=self.pretty_sql,
181            shared_connection=self.shared_connection,
182            schema_differ_overrides=self.schema_differ_overrides,
183            catalog_type_overrides=self.catalog_type_overrides,
184            **self._extra_engine_config,
185        )
186
187    def get_catalog(self) -> t.Optional[str]:
188        """The catalog for this connection"""
189        if hasattr(self, "catalog"):
190            return self.catalog
191        if hasattr(self, "database"):
192            return self.database
193        if hasattr(self, "db"):
194            return self.db
195        return None
196
197    @model_validator(mode="before")
198    @classmethod
199    def _expand_json_strings_to_concrete_types(cls, data: t.Any) -> t.Any:
200        """
201        There are situations where a connection config class has a field that is some kind of complex type
202        (eg a list of strings or a dict) but the value is being supplied from a source such as an environment variable
203
204        When this happens, the value is supplied as a string rather than a Python object. We need some way
205        of turning this string into the corresponding Python list or dict.
206
207        Rather than doing this piecemeal on every config subclass, this provides a generic implementatation
208        to identify fields that may be be supplied as JSON strings and handle them transparently
209        """
210        if data and isinstance(data, dict):
211            for maybe_json_field_name in cls._get_list_and_dict_field_names():
212                if (value := data.get(maybe_json_field_name)) and isinstance(value, str):
213                    # crude JSON check as we dont want to try and parse every string we get
214                    value = value.strip()
215                    if value.startswith("{") or value.startswith("["):
216                        data[maybe_json_field_name] = from_json(value)
217
218        return data
219
220    @classmethod
221    def _get_list_and_dict_field_names(cls) -> t.Set[str]:
222        field_names = set()
223        for name, field in cls.model_fields.items():
224            if field.annotation:
225                field_types = get_concrete_types_from_typehint(field.annotation)
226
227                # check if the field type is something that could concievably be supplied as a json string
228                if any(ft is t for t in (list, tuple, set, dict) for ft in field_types):
229                    field_names.add(name)
230
231        return field_names

Helper class that provides a standard way to create an ABC using inheritance.

type_: str
DIALECT: ClassVar[str]
DISPLAY_NAME: ClassVar[str]
DISPLAY_ORDER: ClassVar[int]
concurrent_tasks: int
register_comments: bool
pre_ping: bool
pretty_sql: bool
schema_differ_overrides: Optional[Dict[str, Any]]
catalog_type_overrides: Optional[Dict[str, str]]
shared_connection: ClassVar[bool] = False
is_forbidden_for_state_sync: bool
147    @property
148    def is_forbidden_for_state_sync(self) -> bool:
149        """Whether this engine is forbidden from being used as a state sync"""
150        return self.type_ in FORBIDDEN_STATE_SYNC_ENGINES

Whether this engine is forbidden from being used as a state sync

def connection_validator(self) -> Callable[[], NoneType]:
163    def connection_validator(self) -> t.Callable[[], None]:
164        """A function that validates the connection configuration"""
165        return self.create_engine_adapter().ping

A function that validates the connection configuration

def create_engine_adapter( self, register_comments_override: bool = False, concurrent_tasks: Optional[int] = None) -> sqlmesh.core.engine_adapter.base.EngineAdapter:
167    def create_engine_adapter(
168        self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None
169    ) -> EngineAdapter:
170        """Returns a new instance of the Engine Adapter."""
171
172        concurrent_tasks = concurrent_tasks or self.concurrent_tasks
173        return self._engine_adapter(
174            self._connection_factory_with_kwargs,
175            multithreaded=concurrent_tasks > 1,
176            default_catalog=self.get_catalog(),
177            cursor_init=self._cursor_init,
178            register_comments=register_comments_override or self.register_comments,
179            pre_ping=self.pre_ping,
180            pretty_sql=self.pretty_sql,
181            shared_connection=self.shared_connection,
182            schema_differ_overrides=self.schema_differ_overrides,
183            catalog_type_overrides=self.catalog_type_overrides,
184            **self._extra_engine_config,
185        )

Returns a new instance of the Engine Adapter.

def get_catalog(self) -> Optional[str]:
187    def get_catalog(self) -> t.Optional[str]:
188        """The catalog for this connection"""
189        if hasattr(self, "catalog"):
190            return self.catalog
191        if hasattr(self, "database"):
192            return self.database
193        if hasattr(self, "db"):
194            return self.db
195        return None

The catalog for this connection

model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class DuckDBAttachOptions(sqlmesh.core.config.base.BaseConfig):
234class DuckDBAttachOptions(BaseConfig):
235    type: str
236    path: str
237    read_only: bool = False
238
239    # DuckLake specific options
240    data_path: t.Optional[str] = None
241    encrypted: bool = False
242    data_inlining_row_limit: t.Optional[int] = None
243    metadata_schema: t.Optional[str] = None
244
245    def to_sql(self, alias: str) -> str:
246        options = []
247        # 'duckdb' is actually not a supported type, but we'd like to allow it for
248        # fully qualified attach options or integration testing, similar to duckdb-dbt
249        if self.type not in ("duckdb", "ducklake", "motherduck"):
250            options.append(f"TYPE {self.type.upper()}")
251        if self.read_only:
252            options.append("READ_ONLY")
253
254        # DuckLake specific options
255        path = self.path
256        if self.type == "ducklake":
257            if not path.startswith("ducklake:"):
258                path = f"ducklake:{path}"
259            if self.data_path is not None:
260                options.append(f"DATA_PATH '{self.data_path}'")
261            if self.encrypted:
262                options.append("ENCRYPTED")
263            if self.data_inlining_row_limit is not None:
264                options.append(f"DATA_INLINING_ROW_LIMIT {self.data_inlining_row_limit}")
265            if self.metadata_schema is not None:
266                options.append(f"METADATA_SCHEMA '{self.metadata_schema}'")
267
268        options_sql = f" ({', '.join(options)})" if options else ""
269        alias_sql = ""
270        # TODO: Add support for Postgres schema. Currently adding it blocks access to the information_schema
271
272        # MotherDuck does not support aliasing
273        alias_sql = (
274            f" AS {alias}" if not (self.type == "motherduck" or self.path.startswith("md:")) else ""
275        )
276        return f"ATTACH IF NOT EXISTS '{path}'{alias_sql}{options_sql}"

Base configuration functionality for configuration classes.

type: str
path: str
read_only: bool
data_path: Optional[str]
encrypted: bool
data_inlining_row_limit: Optional[int]
metadata_schema: Optional[str]
def to_sql(self, alias: str) -> str:
245    def to_sql(self, alias: str) -> str:
246        options = []
247        # 'duckdb' is actually not a supported type, but we'd like to allow it for
248        # fully qualified attach options or integration testing, similar to duckdb-dbt
249        if self.type not in ("duckdb", "ducklake", "motherduck"):
250            options.append(f"TYPE {self.type.upper()}")
251        if self.read_only:
252            options.append("READ_ONLY")
253
254        # DuckLake specific options
255        path = self.path
256        if self.type == "ducklake":
257            if not path.startswith("ducklake:"):
258                path = f"ducklake:{path}"
259            if self.data_path is not None:
260                options.append(f"DATA_PATH '{self.data_path}'")
261            if self.encrypted:
262                options.append("ENCRYPTED")
263            if self.data_inlining_row_limit is not None:
264                options.append(f"DATA_INLINING_ROW_LIMIT {self.data_inlining_row_limit}")
265            if self.metadata_schema is not None:
266                options.append(f"METADATA_SCHEMA '{self.metadata_schema}'")
267
268        options_sql = f" ({', '.join(options)})" if options else ""
269        alias_sql = ""
270        # TODO: Add support for Postgres schema. Currently adding it blocks access to the information_schema
271
272        # MotherDuck does not support aliasing
273        alias_sql = (
274            f" AS {alias}" if not (self.type == "motherduck" or self.path.startswith("md:")) else ""
275        )
276        return f"ATTACH IF NOT EXISTS '{path}'{alias_sql}{options_sql}"
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class BaseDuckDBConnectionConfig(ConnectionConfig):
279class BaseDuckDBConnectionConfig(ConnectionConfig):
280    """Common configuration for the DuckDB-based connections.
281
282    Args:
283        database: The optional database name. If not specified, the in-memory database will be used.
284        catalogs: Key is the name of the catalog and value is the path.
285        extensions: A list of autoloadable extensions to load.
286        connector_config: A dictionary of configuration to pass into the duckdb connector.
287        secrets: A list of dictionaries used to generate DuckDB secrets for authenticating with external services (e.g. S3).
288        filesystems: A list of dictionaries used to register `fsspec` filesystems to the DuckDB cursor.
289        concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
290        register_comments: Whether or not to register model comments with the SQL engine.
291        pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
292        token: The optional MotherDuck token. If not specified and a MotherDuck path is in the catalog, the user will be prompted to login with their web browser.
293    """
294
295    database: t.Optional[str] = None
296    catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None
297    extensions: t.List[t.Union[str, t.Dict[str, t.Any]]] = []
298    connector_config: t.Dict[str, t.Any] = {}
299    secrets: t.Union[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Dict[str, t.Any]]] = []
300    filesystems: t.List[t.Dict[str, t.Any]] = []
301
302    concurrent_tasks: int = 1
303    register_comments: bool = True
304    pre_ping: t.Literal[False] = False
305
306    token: t.Optional[str] = None
307
308    shared_connection: t.ClassVar[bool] = True
309
310    _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {}
311
312    @model_validator(mode="before")
313    def _validate_database_catalogs(cls, data: t.Any) -> t.Any:
314        if not isinstance(data, dict):
315            return data
316
317        db_path = data.get("database")
318        if db_path and data.get("catalogs"):
319            raise ConfigError(
320                "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog"
321            )
322        if isinstance(db_path, str) and db_path.startswith("md:"):
323            raise ConfigError(
324                "Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`."
325            )
326
327        return data
328
329    @property
330    def _engine_adapter(self) -> t.Type[EngineAdapter]:
331        return engine_adapter.DuckDBEngineAdapter
332
333    @property
334    def _connection_kwargs_keys(self) -> t.Set[str]:
335        return {"database"}
336
337    @property
338    def _connection_factory(self) -> t.Callable:
339        import duckdb
340
341        return duckdb.connect
342
343    @property
344    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
345        """A function that is called to initialize the cursor"""
346        import duckdb
347        from duckdb import BinderException
348
349        def init(cursor: duckdb.DuckDBPyConnection) -> None:
350            for extension in self.extensions:
351                extension = extension if isinstance(extension, dict) else {"name": extension}
352
353                install_command = f"INSTALL {extension['name']}"
354
355                if extension.get("repository"):
356                    install_command = f"{install_command} FROM {extension['repository']}"
357
358                if extension.get("force_install"):
359                    install_command = f"FORCE {install_command}"
360
361                try:
362                    cursor.execute(install_command)
363                    cursor.execute(f"LOAD {extension['name']}")
364                except Exception as e:
365                    raise ConfigError(f"Failed to load extension {extension['name']}: {e}")
366
367            if self.connector_config:
368                option_names = list(self.connector_config)
369                in_part = ",".join("?" for _ in range(len(option_names)))
370
371                cursor.execute(
372                    f"SELECT name, value FROM duckdb_settings() WHERE name IN ({in_part})",
373                    option_names,
374                )
375
376                existing_values = {field: setting for field, setting in cursor.fetchall()}
377
378                # only set connector_config items if the values differ from what is already set
379                # trying to set options like 'temp_directory' even to the same value can throw errors like:
380                # Not implemented Error: Cannot switch temporary directory after the current one has been used
381                for field, setting in self.connector_config.items():
382                    if existing_values.get(field) != setting:
383                        try:
384                            cursor.execute(f"SET {field} = '{setting}'")
385                        except Exception as e:
386                            raise ConfigError(
387                                f"Failed to set connector config {field} to {setting}: {e}"
388                            )
389
390            if self.secrets:
391                duckdb_version = duckdb.__version__
392                if version.parse(duckdb_version) < version.parse("0.10.0"):
393                    from sqlmesh.core.console import get_console
394
395                    get_console().log_warning(
396                        f"DuckDB version {duckdb_version} does not support secrets-based authentication (requires 0.10.0 or later).\n"
397                        "To use secrets, please upgrade DuckDB. For older versions, configure legacy authentication via `connector_config`.\n"
398                        "More info: https://duckdb.org/docs/stable/extensions/httpfs/s3api_legacy_authentication.html"
399                    )
400                else:
401                    if isinstance(self.secrets, list):
402                        secrets_items = [(secret_dict, "") for secret_dict in self.secrets]
403                    else:
404                        secrets_items = [
405                            (secret_dict, secret_name)
406                            for secret_name, secret_dict in self.secrets.items()
407                        ]
408
409                    for secret_dict, secret_name in secrets_items:
410                        secret_settings: t.List[str] = []
411                        for field, setting in secret_dict.items():
412                            secret_settings.append(f"{field} '{setting}'")
413                        if secret_settings:
414                            secret_clause = ", ".join(secret_settings)
415                            try:
416                                cursor.execute(
417                                    f"CREATE OR REPLACE SECRET {secret_name} ({secret_clause});"
418                                )
419                            except Exception as e:
420                                raise ConfigError(f"Failed to create secret: {e}")
421
422            if self.filesystems:
423                from fsspec import filesystem  # type: ignore
424
425                for file_system in self.filesystems:
426                    options = file_system.copy()
427                    fs = options.pop("fs")
428                    fs = filesystem(fs, **options)
429                    cursor.register_filesystem(fs)
430
431            for i, (alias, path_options) in enumerate(
432                (getattr(self, "catalogs", None) or {}).items()
433            ):
434                # we parse_identifier and generate to ensure that `alias` has exactly one set of quotes
435                # regardless of whether it comes in quoted or not
436                alias = exp.parse_identifier(alias, dialect="duckdb").sql(
437                    identify=True, dialect="duckdb"
438                )
439                try:
440                    if isinstance(path_options, DuckDBAttachOptions):
441                        query = path_options.to_sql(alias)
442                    else:
443                        query = f"ATTACH IF NOT EXISTS '{path_options}'"
444                        if not path_options.startswith("md:"):
445                            query += f" AS {alias}"
446                    cursor.execute(query)
447                except BinderException as e:
448                    # If a user tries to create a catalog pointing at `:memory:` and with the name `memory`
449                    # then we don't want to raise since this happens by default. They are just doing this to
450                    # set it as the default catalog.
451                    # If a user tried to attach a MotherDuck database/share which has already by attached via
452                    # `ATTACH 'md:'`, then we don't want to raise since this is expected.
453                    if (
454                        not (
455                            'database with name "memory" already exists' in str(e)
456                            and path_options == ":memory:"
457                        )
458                        and f"""database with name "{path_options.path.replace("md:", "")}" already exists"""
459                        not in str(e)
460                    ):
461                        raise e
462                if i == 0 and not getattr(self, "database", None):
463                    cursor.execute(f"USE {alias}")
464
465        return init
466
467    def create_engine_adapter(
468        self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None
469    ) -> EngineAdapter:
470        """Checks if another engine adapter has already been created that shares a catalog that points to the same data
471        file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration
472        associated with the new adapter will be ignored."""
473        data_files = set((self.catalogs or {}).values())
474        if self.database:
475            if isinstance(self, MotherDuckConnectionConfig):
476                data_files.add(
477                    f"md:{self.database}"
478                    + (f"?motherduck_token={self.token}" if self.token else "")
479                )
480            else:
481                data_files.add(self.database)
482        data_files.discard(":memory:")
483        for data_file in data_files:
484            key = data_file if isinstance(data_file, str) else data_file.path
485            adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key)
486            if adapter is not None:
487                logger.info(
488                    f"Using existing DuckDB adapter due to overlapping data file: {self._mask_sensitive_data(key)}"
489                )
490                return adapter
491
492        if data_files:
493            masked_files = {
494                self._mask_sensitive_data(file if isinstance(file, str) else file.path)
495                for file in data_files
496            }
497            logger.info(f"Creating new DuckDB adapter for data files: {masked_files}")
498        else:
499            logger.info("Creating new DuckDB adapter for in-memory database")
500        adapter = super().create_engine_adapter(
501            register_comments_override, concurrent_tasks=concurrent_tasks
502        )
503        for data_file in data_files:
504            key = data_file if isinstance(data_file, str) else data_file.path
505            BaseDuckDBConnectionConfig._data_file_to_adapter[key] = adapter
506        return adapter
507
508    def get_catalog(self) -> t.Optional[str]:
509        if self.database:
510            # Remove `:` from the database name in order to handle if `:memory:` is passed in
511            return pathlib.Path(self.database.replace(":memory:", "memory")).stem
512        if self.catalogs:
513            return list(self.catalogs)[0]
514        return None
515
516    def _mask_sensitive_data(self, string: str) -> str:
517        # Mask MotherDuck tokens with fixed number of asterisks
518        result = MOTHERDUCK_TOKEN_REGEX.sub(
519            lambda m: f"{m.group(1)}{m.group(2)}{'*' * 8 if m.group(3) else ''}", string
520        )
521        # Mask PostgreSQL/MySQL passwords with fixed number of asterisks
522        result = PASSWORD_REGEX.sub(lambda m: f"{m.group(1)}{'*' * 8}", result)
523        return result

Common configuration for the DuckDB-based connections.

Arguments:
  • database: The optional database name. If not specified, the in-memory database will be used.
  • catalogs: Key is the name of the catalog and value is the path.
  • extensions: A list of autoloadable extensions to load.
  • connector_config: A dictionary of configuration to pass into the duckdb connector.
  • secrets: A list of dictionaries used to generate DuckDB secrets for authenticating with external services (e.g. S3).
  • filesystems: A list of dictionaries used to register fsspec filesystems to the DuckDB cursor.
  • concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
  • register_comments: Whether or not to register model comments with the SQL engine.
  • pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
  • token: The optional MotherDuck token. If not specified and a MotherDuck path is in the catalog, the user will be prompted to login with their web browser.
database: Optional[str]
catalogs: Optional[Dict[str, Union[str, DuckDBAttachOptions]]]
extensions: List[Union[str, Dict[str, Any]]]
connector_config: Dict[str, Any]
secrets: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]]
filesystems: List[Dict[str, Any]]
concurrent_tasks: int
register_comments: bool
pre_ping: Literal[False]
token: Optional[str]
shared_connection: ClassVar[bool] = True
def create_engine_adapter( self, register_comments_override: bool = False, concurrent_tasks: Optional[int] = None) -> sqlmesh.core.engine_adapter.base.EngineAdapter:
467    def create_engine_adapter(
468        self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None
469    ) -> EngineAdapter:
470        """Checks if another engine adapter has already been created that shares a catalog that points to the same data
471        file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration
472        associated with the new adapter will be ignored."""
473        data_files = set((self.catalogs or {}).values())
474        if self.database:
475            if isinstance(self, MotherDuckConnectionConfig):
476                data_files.add(
477                    f"md:{self.database}"
478                    + (f"?motherduck_token={self.token}" if self.token else "")
479                )
480            else:
481                data_files.add(self.database)
482        data_files.discard(":memory:")
483        for data_file in data_files:
484            key = data_file if isinstance(data_file, str) else data_file.path
485            adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key)
486            if adapter is not None:
487                logger.info(
488                    f"Using existing DuckDB adapter due to overlapping data file: {self._mask_sensitive_data(key)}"
489                )
490                return adapter
491
492        if data_files:
493            masked_files = {
494                self._mask_sensitive_data(file if isinstance(file, str) else file.path)
495                for file in data_files
496            }
497            logger.info(f"Creating new DuckDB adapter for data files: {masked_files}")
498        else:
499            logger.info("Creating new DuckDB adapter for in-memory database")
500        adapter = super().create_engine_adapter(
501            register_comments_override, concurrent_tasks=concurrent_tasks
502        )
503        for data_file in data_files:
504            key = data_file if isinstance(data_file, str) else data_file.path
505            BaseDuckDBConnectionConfig._data_file_to_adapter[key] = adapter
506        return adapter

Checks if another engine adapter has already been created that shares a catalog that points to the same data file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration associated with the new adapter will be ignored.

def get_catalog(self) -> Optional[str]:
508    def get_catalog(self) -> t.Optional[str]:
509        if self.database:
510            # Remove `:` from the database name in order to handle if `:memory:` is passed in
511            return pathlib.Path(self.database.replace(":memory:", "memory")).stem
512        if self.catalogs:
513            return list(self.catalogs)[0]
514        return None

The catalog for this connection

model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
type_
DIALECT
DISPLAY_NAME
DISPLAY_ORDER
pretty_sql
schema_differ_overrides
catalog_type_overrides
is_forbidden_for_state_sync
connection_validator
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig):
526class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig):
527    """Configuration for the MotherDuck connection."""
528
529    type_: t.Literal["motherduck"] = Field(alias="type", default="motherduck")
530    DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb"
531    DISPLAY_NAME: t.ClassVar[t.Literal["MotherDuck"]] = "MotherDuck"
532    DISPLAY_ORDER: t.ClassVar[t.Literal[5]] = 5
533
534    @property
535    def _connection_kwargs_keys(self) -> t.Set[str]:
536        return set()
537
538    @property
539    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
540        """kwargs that are for execution config only"""
541        from sqlmesh import __version__
542
543        custom_user_agent_config = {"custom_user_agent": f"SQLMesh/{__version__}"}
544        connection_str = "md:"
545        if self.database:
546            # Attach single MD database instead of all databases on the account
547            connection_str += f"{self.database}?attach_mode=single"
548        if self.token:
549            connection_str += f"{'&' if self.database else '?'}motherduck_token={self.token}"
550        return {"database": connection_str, "config": custom_user_agent_config}
551
552    @property
553    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
554        return {"is_motherduck": True}

Configuration for the MotherDuck connection.

type_: Literal['motherduck']
DIALECT: ClassVar[Literal['duckdb']] = 'duckdb'
DISPLAY_NAME: ClassVar[Literal['MotherDuck']] = 'MotherDuck'
DISPLAY_ORDER: ClassVar[Literal[5]] = 5
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
BaseDuckDBConnectionConfig
database
catalogs
extensions
connector_config
secrets
filesystems
concurrent_tasks
register_comments
pre_ping
token
shared_connection
create_engine_adapter
get_catalog
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
is_forbidden_for_state_sync
connection_validator
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class DuckDBConnectionConfig(BaseDuckDBConnectionConfig):
557class DuckDBConnectionConfig(BaseDuckDBConnectionConfig):
558    """Configuration for the DuckDB connection."""
559
560    type_: t.Literal["duckdb"] = Field(alias="type", default="duckdb")
561    DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb"
562    DISPLAY_NAME: t.ClassVar[t.Literal["DuckDB"]] = "DuckDB"
563    DISPLAY_ORDER: t.ClassVar[t.Literal[1]] = 1

Configuration for the DuckDB connection.

type_: Literal['duckdb']
DIALECT: ClassVar[Literal['duckdb']] = 'duckdb'
DISPLAY_NAME: ClassVar[Literal['DuckDB']] = 'DuckDB'
DISPLAY_ORDER: ClassVar[Literal[1]] = 1
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
BaseDuckDBConnectionConfig
database
catalogs
extensions
connector_config
secrets
filesystems
concurrent_tasks
register_comments
pre_ping
token
shared_connection
create_engine_adapter
get_catalog
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
is_forbidden_for_state_sync
connection_validator
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class SnowflakeConnectionConfig(ConnectionConfig):
566class SnowflakeConnectionConfig(ConnectionConfig):
567    """Configuration for the Snowflake connection.
568
569    Args:
570        account: The Snowflake account name.
571        user: The Snowflake username.
572        password: The Snowflake password.
573        warehouse: The optional warehouse name.
574        database: The optional database name.
575        role: The optional role name.
576        concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
577        authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake").
578                       Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183
579        token: The optional oauth access token to use for authentication when authenticator is set to "oauth".
580        private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation
581        private_key_path: The optional path to the private key to use for authentication. This would be used instead of `private_key`.
582        private_key_passphrase: The optional passphrase to use to decrypt `private_key` or `private_key_path`. Keys can be created without encryption so only provide this if needed.
583        register_comments: Whether or not to register model comments with the SQL engine.
584        pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
585        session_parameters: The optional session parameters to set for the connection.
586        host: Host address for the connection.
587        port: Port for the connection.
588    """
589
590    account: str
591    user: t.Optional[str] = None
592    password: t.Optional[str] = None
593    warehouse: t.Optional[str] = None
594    database: t.Optional[str] = None
595    role: t.Optional[str] = None
596    authenticator: t.Optional[str] = None
597    token: t.Optional[str] = None
598    host: t.Optional[str] = None
599    port: t.Optional[int] = None
600    application: t.Literal["Tobiko_SQLMesh"] = "Tobiko_SQLMesh"
601
602    # Private Key Auth
603    private_key: t.Optional[t.Union[str, bytes]] = None
604    private_key_path: t.Optional[str] = None
605    private_key_passphrase: t.Optional[str] = None
606
607    concurrent_tasks: int = 4
608    register_comments: bool = True
609    pre_ping: bool = False
610
611    session_parameters: t.Optional[dict] = None
612
613    type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake")
614    DIALECT: t.ClassVar[t.Literal["snowflake"]] = "snowflake"
615    DISPLAY_NAME: t.ClassVar[t.Literal["Snowflake"]] = "Snowflake"
616    DISPLAY_ORDER: t.ClassVar[t.Literal[2]] = 2
617
618    _concurrent_tasks_validator = concurrent_tasks_validator
619
620    @model_validator(mode="before")
621    def _validate_authenticator(cls, data: t.Any) -> t.Any:
622        if not isinstance(data, dict):
623            return data
624
625        from snowflake.connector.network import DEFAULT_AUTHENTICATOR, OAUTH_AUTHENTICATOR
626
627        auth = data.get("authenticator")
628        auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR
629        user = data.get("user")
630        password = data.get("password")
631        data["private_key"] = cls._get_private_key(data, auth)  # type: ignore
632
633        if (
634            auth == DEFAULT_AUTHENTICATOR
635            and not data.get("private_key")
636            and (not user or not password)
637        ):
638            raise ConfigError("User and password must be provided if using default authentication")
639
640        if auth == OAUTH_AUTHENTICATOR and not data.get("token"):
641            raise ConfigError("Token must be provided if using oauth authentication")
642
643        return data
644
645    _engine_import_validator = _get_engine_import_validator(
646        "snowflake.connector.network", "snowflake"
647    )
648
649    @classmethod
650    def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]:
651        """
652        source: https://github.com/dbt-labs/dbt-snowflake/blob/0374b4ec948982f2ac8ec0c95d53d672ad19e09c/dbt/adapters/snowflake/connections.py#L247C5-L285C1
653
654        Overall code change: Use local variables instead of class attributes + Validation
655        """
656        # Start custom code
657        from cryptography.hazmat.backends import default_backend
658        from cryptography.hazmat.primitives import serialization
659        from snowflake.connector.network import (
660            DEFAULT_AUTHENTICATOR,
661            KEY_PAIR_AUTHENTICATOR,
662        )
663
664        private_key = values.get("private_key")
665        private_key_path = values.get("private_key_path")
666        private_key_passphrase = values.get("private_key_passphrase")
667        user = values.get("user")
668        password = values.get("password")
669        auth = auth if auth and auth != DEFAULT_AUTHENTICATOR else KEY_PAIR_AUTHENTICATOR
670
671        if not private_key and not private_key_path:
672            return None
673        if private_key and private_key_path:
674            raise ConfigError("Cannot specify both `private_key` and `private_key_path`")
675        if auth != KEY_PAIR_AUTHENTICATOR:
676            raise ConfigError(
677                f"Private key or private key path can only be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
678            )
679        if not user:
680            raise ConfigError(
681                f"User must be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
682            )
683        if password:
684            raise ConfigError(
685                f"Password cannot be provided when using {KEY_PAIR_AUTHENTICATOR} authentication"
686            )
687
688        if isinstance(private_key, bytes):
689            return private_key
690        # End Custom Code
691
692        if private_key_passphrase:
693            encoded_passphrase = private_key_passphrase.encode()
694        else:
695            encoded_passphrase = None
696
697        if private_key:
698            if private_key.startswith("-"):
699                p_key = serialization.load_pem_private_key(
700                    data=bytes(private_key, "utf-8"),
701                    password=encoded_passphrase,
702                    backend=default_backend(),
703                )
704
705            else:
706                p_key = serialization.load_der_private_key(
707                    data=base64.b64decode(private_key),
708                    password=encoded_passphrase,
709                    backend=default_backend(),
710                )
711
712        elif private_key_path:
713            with open(private_key_path, "rb") as key:
714                p_key = serialization.load_pem_private_key(
715                    key.read(), password=encoded_passphrase, backend=default_backend()
716                )
717        else:
718            return None
719
720        return p_key.private_bytes(
721            encoding=serialization.Encoding.DER,
722            format=serialization.PrivateFormat.PKCS8,
723            encryption_algorithm=serialization.NoEncryption(),
724        )
725
726    @property
727    def _connection_kwargs_keys(self) -> t.Set[str]:
728        return {
729            "user",
730            "password",
731            "account",
732            "warehouse",
733            "database",
734            "role",
735            "authenticator",
736            "token",
737            "private_key",
738            "session_parameters",
739            "application",
740            "host",
741            "port",
742        }
743
744    @property
745    def _engine_adapter(self) -> t.Type[EngineAdapter]:
746        return engine_adapter.SnowflakeEngineAdapter
747
748    @property
749    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
750        return {"autocommit": False}
751
752    @property
753    def _connection_factory(self) -> t.Callable:
754        from snowflake import connector
755
756        return connector.connect

Configuration for the Snowflake connection.

Arguments:
  • account: The Snowflake account name.
  • user: The Snowflake username.
  • password: The Snowflake password.
  • warehouse: The optional warehouse name.
  • database: The optional database name.
  • role: The optional role name.
  • concurrent_tasks: The maximum number of tasks that can use this connection concurrently.
  • authenticator: The optional authenticator name. Defaults to username/password authentication ("snowflake"). Options: https://github.com/snowflakedb/snowflake-connector-python/blob/e937591356c067a77f34a0a42328907fda792c23/src/snowflake/connector/network.py#L178-L183
  • token: The optional oauth access token to use for authentication when authenticator is set to "oauth".
  • private_key: The optional private key to use for authentication. Key can be Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or bytes (Python config only). https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#using-key-pair-authentication-key-pair-rotation
  • private_key_path: The optional path to the private key to use for authentication. This would be used instead of private_key.
  • private_key_passphrase: The optional passphrase to use to decrypt private_key or private_key_path. Keys can be created without encryption so only provide this if needed.
  • register_comments: Whether or not to register model comments with the SQL engine.
  • pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
  • session_parameters: The optional session parameters to set for the connection.
  • host: Host address for the connection.
  • port: Port for the connection.
account: str
user: Optional[str]
password: Optional[str]
warehouse: Optional[str]
database: Optional[str]
role: Optional[str]
authenticator: Optional[str]
token: Optional[str]
host: Optional[str]
port: Optional[int]
application: Literal['Tobiko_SQLMesh']
private_key: Union[str, bytes, NoneType]
private_key_path: Optional[str]
private_key_passphrase: Optional[str]
concurrent_tasks: int
register_comments: bool
pre_ping: bool
session_parameters: Optional[dict]
type_: Literal['snowflake']
DIALECT: ClassVar[Literal['snowflake']] = 'snowflake'
DISPLAY_NAME: ClassVar[Literal['Snowflake']] = 'Snowflake'
DISPLAY_ORDER: ClassVar[Literal[2]] = 2
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class DatabricksConnectionConfig(ConnectionConfig):
 759class DatabricksConnectionConfig(ConnectionConfig):
 760    """
 761    Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations
 762
 763    Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39
 764    OAuth ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication
 765
 766    Args:
 767        server_hostname: Databricks instance host name.
 768        http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef)
 769            or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
 770        access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
 771        auth_type: Set to 'databricks-oauth' or 'azure-oauth' to trigger OAuth (or dont set at all to use `access_token`)
 772        oauth_client_id: Client ID to use when auth_type is set to one of the 'oauth' types
 773        oauth_client_secret: Client Secret to use when auth_type is set to one of the 'oauth' types
 774        catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in
 775            the Databricks cluster (most likely `hive_metastore`).
 776        http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request
 777        session_configuration: An optional dictionary of Spark session parameters.
 778            Execute the SQL command `SET -v` to get a full list of available commands.
 779        databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect.
 780            Defaults to the `server_hostname` value.
 781        databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect.
 782            Defaults to the `access_token` value.
 783        databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect.
 784            Defaults to deriving the cluster id from the `http_path` value.
 785        force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector.
 786        disable_databricks_connect: Even if databricks connect is installed, do not use it.
 787        disable_spark_session: Do not use SparkSession if it is available (like when running in a notebook).
 788        pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
 789    """
 790
 791    server_hostname: t.Optional[str] = None
 792    http_path: t.Optional[str] = None
 793    access_token: t.Optional[str] = None
 794    auth_type: t.Optional[str] = None
 795    oauth_client_id: t.Optional[str] = None
 796    oauth_client_secret: t.Optional[str] = None
 797    catalog: t.Optional[str] = None
 798    http_headers: t.Optional[t.List[t.Tuple[str, str]]] = None
 799    session_configuration: t.Optional[t.Dict[str, t.Any]] = None
 800    databricks_connect_server_hostname: t.Optional[str] = None
 801    databricks_connect_access_token: t.Optional[str] = None
 802    databricks_connect_cluster_id: t.Optional[str] = None
 803    databricks_connect_use_serverless: bool = False
 804    force_databricks_connect: bool = False
 805    disable_databricks_connect: bool = False
 806    disable_spark_session: bool = False
 807
 808    concurrent_tasks: int = 1
 809    register_comments: bool = True
 810    pre_ping: t.Literal[False] = False
 811
 812    type_: t.Literal["databricks"] = Field(alias="type", default="databricks")
 813    DIALECT: t.ClassVar[t.Literal["databricks"]] = "databricks"
 814    DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks"
 815    DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3
 816
 817    _concurrent_tasks_validator = concurrent_tasks_validator
 818    _http_headers_validator = http_headers_validator
 819
 820    @model_validator(mode="before")
 821    def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
 822        # SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block.
 823        # Disabling this allows SQLMesh to determine what should be shown to the user.
 824        # Ex: We describe a table to see if it exists and therefore that execution can fail but we don't need to show
 825        # the user since it is expected if the table doesn't exist. Without this change the user would see the error.
 826        logging.getLogger("SQLQueryContextLogger").setLevel(logging.CRITICAL)
 827
 828        if not isinstance(data, dict):
 829            return data
 830
 831        from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
 832
 833        if DatabricksEngineAdapter.can_access_spark_session(
 834            bool(data.get("disable_spark_session"))
 835        ):
 836            return data
 837
 838        databricks_connect_use_serverless = data.get("databricks_connect_use_serverless")
 839        server_hostname, http_path, access_token, auth_type = (
 840            data.get("server_hostname"),
 841            data.get("http_path"),
 842            data.get("access_token"),
 843            data.get("auth_type"),
 844        )
 845
 846        if (not server_hostname or not http_path or not access_token) and (
 847            not databricks_connect_use_serverless and not auth_type
 848        ):
 849            raise ValueError(
 850                "`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook"
 851            )
 852        if (
 853            databricks_connect_use_serverless
 854            and not server_hostname
 855            and not data.get("databricks_connect_server_hostname")
 856        ):
 857            raise ValueError(
 858                "`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set"
 859            )
 860        if DatabricksEngineAdapter.can_access_databricks_connect(
 861            bool(data.get("disable_databricks_connect"))
 862        ):
 863            if not data.get("databricks_connect_access_token"):
 864                data["databricks_connect_access_token"] = access_token
 865            if not data.get("databricks_connect_server_hostname"):
 866                data["databricks_connect_server_hostname"] = f"https://{server_hostname}"
 867            if not databricks_connect_use_serverless and not data.get(
 868                "databricks_connect_cluster_id"
 869            ):
 870                if t.TYPE_CHECKING:
 871                    assert http_path is not None
 872                data["databricks_connect_cluster_id"] = http_path.split("/")[-1]
 873
 874        if auth_type:
 875            from databricks.sql.auth.auth import AuthType
 876
 877            all_data = [m.value for m in AuthType]
 878            if auth_type not in all_data:
 879                raise ValueError(
 880                    f"`auth_type` {auth_type} does not match a valid option: {all_data}"
 881                )
 882
 883            client_id = data.get("oauth_client_id")
 884            client_secret = data.get("oauth_client_secret")
 885
 886            if client_secret and not client_id:
 887                raise ValueError(
 888                    "`oauth_client_id` is required when `oauth_client_secret` is specified"
 889                )
 890
 891            if not http_path:
 892                raise ValueError("`http_path` is still required when using `auth_type`")
 893
 894        return data
 895
 896    _engine_import_validator = _get_engine_import_validator("databricks", "databricks")
 897
 898    @property
 899    def _connection_kwargs_keys(self) -> t.Set[str]:
 900        if self.use_spark_session_only:
 901            return set()
 902        return {
 903            "server_hostname",
 904            "http_path",
 905            "access_token",
 906            "http_headers",
 907            "session_configuration",
 908            "catalog",
 909        }
 910
 911    @property
 912    def _engine_adapter(self) -> t.Type[engine_adapter.DatabricksEngineAdapter]:
 913        return engine_adapter.DatabricksEngineAdapter
 914
 915    @property
 916    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
 917        return {
 918            k: v
 919            for k, v in self.dict().items()
 920            if k.startswith("databricks_connect_")
 921            or k in ("catalog", "disable_databricks_connect", "disable_spark_session")
 922        }
 923
 924    @property
 925    def use_spark_session_only(self) -> bool:
 926        from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
 927
 928        return (
 929            DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session)
 930            or self.force_databricks_connect
 931        )
 932
 933    @property
 934    def _connection_factory(self) -> t.Callable:
 935        if self.use_spark_session_only:
 936            from sqlmesh.engines.spark.db_api.spark_session import connection
 937
 938            return connection
 939
 940        from databricks import sql  # type: ignore
 941
 942        return sql.connect
 943
 944    @property
 945    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
 946        from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
 947
 948        if not self.use_spark_session_only:
 949            conn_kwargs: t.Dict[str, t.Any] = {
 950                "_user_agent_entry": "sqlmesh",
 951            }
 952
 953            if self.auth_type and "oauth" in self.auth_type:
 954                # there are two types of oauth: User-to-Machine (U2M) and Machine-to-Machine (M2M)
 955                if self.oauth_client_secret:
 956                    # if a client_secret exists, then a client_id also exists and we are using M2M
 957                    # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication
 958                    # ref: https://github.com/databricks/databricks-sql-python/blob/main/examples/m2m_oauth.py
 959                    from databricks.sdk.core import oauth_service_principal, Config
 960
 961                    config = Config(
 962                        host=f"https://{self.server_hostname}",
 963                        client_id=self.oauth_client_id,
 964                        client_secret=self.oauth_client_secret,
 965                    )
 966                    conn_kwargs["credentials_provider"] = lambda: oauth_service_principal(config)
 967                else:
 968                    # if auth_type is set to an 'oauth' type but no client_id/secret are set, then we are using U2M
 969                    # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-user-to-machine-u2m-authentication
 970                    conn_kwargs["auth_type"] = self.auth_type
 971
 972            return conn_kwargs
 973
 974        if DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session):
 975            from pyspark.sql import SparkSession
 976
 977            return dict(
 978                spark=SparkSession.getActiveSession(),
 979                catalog=self.catalog,
 980            )
 981
 982        from databricks.connect import DatabricksSession
 983
 984        if t.TYPE_CHECKING:
 985            assert self.databricks_connect_server_hostname is not None
 986            assert self.databricks_connect_access_token is not None
 987
 988        if self.databricks_connect_use_serverless:
 989            builder = DatabricksSession.builder.remote(
 990                host=self.databricks_connect_server_hostname,
 991                token=self.databricks_connect_access_token,
 992                serverless=True,
 993            )
 994        else:
 995            if t.TYPE_CHECKING:
 996                assert self.databricks_connect_cluster_id is not None
 997            builder = DatabricksSession.builder.remote(
 998                host=self.databricks_connect_server_hostname,
 999                token=self.databricks_connect_access_token,
1000                cluster_id=self.databricks_connect_cluster_id,
1001            )
1002
1003        return dict(
1004            spark=builder.userAgent("sqlmesh").getOrCreate(),
1005            catalog=self.catalog,
1006        )

Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations

Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39 OAuth ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication

Arguments:
  • server_hostname: Databricks instance host name.
  • http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
  • access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
  • auth_type: Set to 'databricks-oauth' or 'azure-oauth' to trigger OAuth (or dont set at all to use access_token)
  • oauth_client_id: Client ID to use when auth_type is set to one of the 'oauth' types
  • oauth_client_secret: Client Secret to use when auth_type is set to one of the 'oauth' types
  • catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in the Databricks cluster (most likely hive_metastore).
  • http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request
  • session_configuration: An optional dictionary of Spark session parameters. Execute the SQL command SET -v to get a full list of available commands.
  • databricks_connect_server_hostname: The hostname to use when establishing a connecting using Databricks Connect. Defaults to the server_hostname value.
  • databricks_connect_access_token: The access token to use when establishing a connecting using Databricks Connect. Defaults to the access_token value.
  • databricks_connect_cluster_id: The cluster id to use when establishing a connecting using Databricks Connect. Defaults to deriving the cluster id from the http_path value.
  • force_databricks_connect: Force all queries to run using Databricks Connect instead of the SQL connector.
  • disable_databricks_connect: Even if databricks connect is installed, do not use it.
  • disable_spark_session: Do not use SparkSession if it is available (like when running in a notebook).
  • pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
server_hostname: Optional[str]
http_path: Optional[str]
access_token: Optional[str]
auth_type: Optional[str]
oauth_client_id: Optional[str]
oauth_client_secret: Optional[str]
catalog: Optional[str]
http_headers: Optional[List[Tuple[str, str]]]
session_configuration: Optional[Dict[str, Any]]
databricks_connect_server_hostname: Optional[str]
databricks_connect_access_token: Optional[str]
databricks_connect_cluster_id: Optional[str]
databricks_connect_use_serverless: bool
force_databricks_connect: bool
disable_databricks_connect: bool
disable_spark_session: bool
concurrent_tasks: int
register_comments: bool
pre_ping: Literal[False]
type_: Literal['databricks']
DIALECT: ClassVar[Literal['databricks']] = 'databricks'
DISPLAY_NAME: ClassVar[Literal['Databricks']] = 'Databricks'
DISPLAY_ORDER: ClassVar[Literal[3]] = 3
use_spark_session_only: bool
924    @property
925    def use_spark_session_only(self) -> bool:
926        from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
927
928        return (
929            DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session)
930            or self.force_databricks_connect
931        )
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class BigQueryConnectionMethod(builtins.str, enum.Enum):
1009class BigQueryConnectionMethod(str, Enum):
1010    OAUTH = "oauth"
1011    OAUTH_SECRETS = "oauth-secrets"
1012    SERVICE_ACCOUNT = "service-account"
1013    SERVICE_ACCOUNT_JSON = "service-account-json"

An enumeration.

OAUTH = <BigQueryConnectionMethod.OAUTH: 'oauth'>
OAUTH_SECRETS = <BigQueryConnectionMethod.OAUTH_SECRETS: 'oauth-secrets'>
SERVICE_ACCOUNT = <BigQueryConnectionMethod.SERVICE_ACCOUNT: 'service-account'>
SERVICE_ACCOUNT_JSON = <BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: 'service-account-json'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class BigQueryPriority(builtins.str, enum.Enum):
1016class BigQueryPriority(str, Enum):
1017    BATCH = "batch"
1018    INTERACTIVE = "interactive"
1019
1020    @property
1021    def is_batch(self) -> bool:
1022        return self == self.BATCH
1023
1024    @property
1025    def is_interactive(self) -> bool:
1026        return self == self.INTERACTIVE
1027
1028    @property
1029    def bigquery_constant(self) -> str:
1030        from google.cloud.bigquery import QueryPriority
1031
1032        if self.is_batch:
1033            return QueryPriority.BATCH
1034        return QueryPriority.INTERACTIVE

An enumeration.

BATCH = <BigQueryPriority.BATCH: 'batch'>
INTERACTIVE = <BigQueryPriority.INTERACTIVE: 'interactive'>
is_batch: bool
1020    @property
1021    def is_batch(self) -> bool:
1022        return self == self.BATCH
is_interactive: bool
1024    @property
1025    def is_interactive(self) -> bool:
1026        return self == self.INTERACTIVE
bigquery_constant: str
1028    @property
1029    def bigquery_constant(self) -> str:
1030        from google.cloud.bigquery import QueryPriority
1031
1032        if self.is_batch:
1033            return QueryPriority.BATCH
1034        return QueryPriority.INTERACTIVE
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class BigQueryConnectionConfig(ConnectionConfig):
1037class BigQueryConnectionConfig(ConnectionConfig):
1038    """
1039    BigQuery Connection Configuration.
1040    """
1041
1042    method: BigQueryConnectionMethod = BigQueryConnectionMethod.OAUTH
1043
1044    project: t.Optional[str] = None
1045    execution_project: t.Optional[str] = None
1046    quota_project: t.Optional[str] = None
1047    location: t.Optional[str] = None
1048    # Keyfile Auth
1049    keyfile: t.Optional[str] = None
1050    keyfile_json: t.Optional[t.Dict[str, t.Any]] = None
1051    # Oath Secret Auth
1052    token: t.Optional[str] = None
1053    refresh_token: t.Optional[str] = None
1054    client_id: t.Optional[str] = None
1055    client_secret: t.Optional[str] = None
1056    token_uri: t.Optional[str] = None
1057    scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/bigquery",)
1058    impersonated_service_account: t.Optional[str] = None
1059    # Extra Engine Config
1060    job_creation_timeout_seconds: t.Optional[int] = None
1061    job_execution_timeout_seconds: t.Optional[int] = None
1062    job_retries: t.Optional[int] = 1
1063    job_retry_deadline_seconds: t.Optional[int] = None
1064    priority: t.Optional[BigQueryPriority] = None
1065    maximum_bytes_billed: t.Optional[int] = None
1066    reservation: t.Optional[str] = None
1067
1068    concurrent_tasks: int = 1
1069    register_comments: bool = True
1070    pre_ping: t.Literal[False] = False
1071
1072    type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery")
1073    DIALECT: t.ClassVar[t.Literal["bigquery"]] = "bigquery"
1074    DISPLAY_NAME: t.ClassVar[t.Literal["BigQuery"]] = "BigQuery"
1075    DISPLAY_ORDER: t.ClassVar[t.Literal[4]] = 4
1076
1077    _engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery")
1078
1079    @field_validator("execution_project")
1080    def validate_execution_project(
1081        cls,
1082        v: t.Optional[str],
1083        info: ValidationInfo,
1084    ) -> t.Optional[str]:
1085        if v and not info.data.get("project"):
1086            raise ConfigError(
1087                "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location."
1088            )
1089        return v
1090
1091    @field_validator("quota_project")
1092    def validate_quota_project(
1093        cls,
1094        v: t.Optional[str],
1095        info: ValidationInfo,
1096    ) -> t.Optional[str]:
1097        if v and not info.data.get("project"):
1098            raise ConfigError(
1099                "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location."
1100            )
1101        return v
1102
1103    @property
1104    def _connection_kwargs_keys(self) -> t.Set[str]:
1105        return set()
1106
1107    @property
1108    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1109        return engine_adapter.BigQueryEngineAdapter
1110
1111    @property
1112    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
1113        """The static connection kwargs for this connection"""
1114        import google.auth
1115        from google.auth import impersonated_credentials
1116        from google.api_core import client_info, client_options
1117        from google.oauth2 import credentials, service_account
1118
1119        if self.method == BigQueryConnectionMethod.OAUTH:
1120            creds, _ = google.auth.default(scopes=self.scopes)
1121        elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT:
1122            creds = service_account.Credentials.from_service_account_file(
1123                self.keyfile, scopes=self.scopes
1124            )
1125        elif self.method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON:
1126            creds = service_account.Credentials.from_service_account_info(
1127                self.keyfile_json, scopes=self.scopes
1128            )
1129        elif self.method == BigQueryConnectionMethod.OAUTH_SECRETS:
1130            creds = credentials.Credentials(
1131                token=self.token,
1132                refresh_token=self.refresh_token,
1133                client_id=self.client_id,
1134                client_secret=self.client_secret,
1135                token_uri=self.token_uri,
1136                scopes=self.scopes,
1137            )
1138        else:
1139            raise ConfigError("Invalid BigQuery Connection Method")
1140
1141        if self.impersonated_service_account:
1142            creds = impersonated_credentials.Credentials(
1143                source_credentials=creds,
1144                target_principal=self.impersonated_service_account,
1145                target_scopes=self.scopes,
1146            )
1147
1148        options = client_options.ClientOptions(quota_project_id=self.quota_project)
1149        project = self.execution_project or self.project or None
1150
1151        client = google.cloud.bigquery.Client(
1152            project=project and exp.parse_identifier(project, dialect="bigquery").name,
1153            credentials=creds,
1154            location=self.location,
1155            client_info=client_info.ClientInfo(user_agent="sqlmesh"),
1156            client_options=options,
1157        )
1158
1159        return {
1160            "client": client,
1161        }
1162
1163    @property
1164    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1165        return {
1166            k: v
1167            for k, v in self.dict().items()
1168            if k
1169            in {
1170                "job_creation_timeout_seconds",
1171                "job_execution_timeout_seconds",
1172                "job_retries",
1173                "job_retry_deadline_seconds",
1174                "priority",
1175                "maximum_bytes_billed",
1176                "reservation",
1177            }
1178        }
1179
1180    @property
1181    def _connection_factory(self) -> t.Callable:
1182        from google.cloud.bigquery.dbapi import connect
1183
1184        return connect
1185
1186    def get_catalog(self) -> t.Optional[str]:
1187        return self.project

BigQuery Connection Configuration.

project: Optional[str]
execution_project: Optional[str]
quota_project: Optional[str]
location: Optional[str]
keyfile: Optional[str]
keyfile_json: Optional[Dict[str, Any]]
token: Optional[str]
refresh_token: Optional[str]
client_id: Optional[str]
client_secret: Optional[str]
token_uri: Optional[str]
scopes: Tuple[str, ...]
impersonated_service_account: Optional[str]
job_creation_timeout_seconds: Optional[int]
job_execution_timeout_seconds: Optional[int]
job_retries: Optional[int]
job_retry_deadline_seconds: Optional[int]
priority: Optional[BigQueryPriority]
maximum_bytes_billed: Optional[int]
reservation: Optional[str]
concurrent_tasks: int
register_comments: bool
pre_ping: Literal[False]
type_: Literal['bigquery']
DIALECT: ClassVar[Literal['bigquery']] = 'bigquery'
DISPLAY_NAME: ClassVar[Literal['BigQuery']] = 'BigQuery'
DISPLAY_ORDER: ClassVar[Literal[4]] = 4
@field_validator('execution_project')
def validate_execution_project( cls, v: Optional[str], info: pydantic_core.core_schema.ValidationInfo) -> Optional[str]:
1079    @field_validator("execution_project")
1080    def validate_execution_project(
1081        cls,
1082        v: t.Optional[str],
1083        info: ValidationInfo,
1084    ) -> t.Optional[str]:
1085        if v and not info.data.get("project"):
1086            raise ConfigError(
1087                "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location."
1088            )
1089        return v
@field_validator('quota_project')
def validate_quota_project( cls, v: Optional[str], info: pydantic_core.core_schema.ValidationInfo) -> Optional[str]:
1091    @field_validator("quota_project")
1092    def validate_quota_project(
1093        cls,
1094        v: t.Optional[str],
1095        info: ValidationInfo,
1096    ) -> t.Optional[str]:
1097        if v and not info.data.get("project"):
1098            raise ConfigError(
1099                "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location."
1100            )
1101        return v
def get_catalog(self) -> Optional[str]:
1186    def get_catalog(self) -> t.Optional[str]:
1187        return self.project

The catalog for this connection

model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class GCPPostgresConnectionConfig(ConnectionConfig):
1190class GCPPostgresConnectionConfig(ConnectionConfig):
1191    """
1192    Postgres Connection Configuration for GCP.
1193
1194    Args:
1195        instance_connection_string: Connection name for the postgres instance.
1196        user: Postgres or IAM user's name
1197        password: The postgres user's password. Only needed when the user is a postgres user.
1198        enable_iam_auth: Set to True when user is an IAM user.
1199        db: Name of the db to connect to.
1200        keyfile: string path to json service account credentials file
1201        keyfile_json: dict service account credentials info
1202        pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
1203    """
1204
1205    instance_connection_string: str
1206    user: str
1207    password: t.Optional[str] = None
1208    enable_iam_auth: t.Optional[bool] = None
1209    db: str
1210    ip_type: t.Union[t.Literal["public"], t.Literal["private"], t.Literal["psc"]] = "public"
1211    # Keyfile Auth
1212    keyfile: t.Optional[str] = None
1213    keyfile_json: t.Optional[t.Dict[str, t.Any]] = None
1214    timeout: t.Optional[int] = None
1215    scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/sqlservice.admin",)
1216    driver: str = "pg8000"
1217
1218    type_: t.Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres")
1219    DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres"
1220    DISPLAY_NAME: t.ClassVar[t.Literal["GCP Postgres"]] = "GCP Postgres"
1221    DISPLAY_ORDER: t.ClassVar[t.Literal[13]] = 13
1222
1223    concurrent_tasks: int = 4
1224    register_comments: bool = True
1225    pre_ping: bool = True
1226
1227    _engine_import_validator = _get_engine_import_validator(
1228        "google.cloud.sql", "gcp_postgres", "gcppostgres"
1229    )
1230
1231    @model_validator(mode="before")
1232    def _validate_auth_method(cls, data: t.Any) -> t.Any:
1233        if not isinstance(data, dict):
1234            return data
1235
1236        password = data.get("password")
1237        enable_iam_auth = data.get("enable_iam_auth")
1238
1239        if not password and not enable_iam_auth:
1240            raise ConfigError(
1241                "GCP Postgres connection configuration requires either password set"
1242                " for a postgres user account or enable_iam_auth set to 'True'"
1243                " for an IAM user account."
1244            )
1245
1246        return data
1247
1248    @property
1249    def _connection_kwargs_keys(self) -> t.Set[str]:
1250        return {
1251            "instance_connection_string",
1252            "driver",
1253            "user",
1254            "password",
1255            "db",
1256            "enable_iam_auth",
1257            "timeout",
1258        }
1259
1260    @property
1261    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1262        return engine_adapter.PostgresEngineAdapter
1263
1264    @property
1265    def _connection_factory(self) -> t.Callable:
1266        from google.cloud.sql.connector import Connector
1267        from google.oauth2 import service_account
1268
1269        creds = None
1270        if self.keyfile:
1271            creds = service_account.Credentials.from_service_account_file(
1272                self.keyfile, scopes=self.scopes
1273            )
1274        elif self.keyfile_json:
1275            creds = service_account.Credentials.from_service_account_info(
1276                self.keyfile_json, scopes=self.scopes
1277            )
1278
1279        kwargs = {
1280            "credentials": creds,
1281            "ip_type": self.ip_type,
1282        }
1283
1284        if self.timeout:
1285            kwargs["timeout"] = self.timeout
1286
1287        return Connector(**kwargs).connect  # type: ignore

Postgres Connection Configuration for GCP.

Arguments:
  • instance_connection_string: Connection name for the postgres instance.
  • user: Postgres or IAM user's name
  • password: The postgres user's password. Only needed when the user is a postgres user.
  • enable_iam_auth: Set to True when user is an IAM user.
  • db: Name of the db to connect to.
  • keyfile: string path to json service account credentials file
  • keyfile_json: dict service account credentials info
  • pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
instance_connection_string: str
user: str
password: Optional[str]
enable_iam_auth: Optional[bool]
db: str
ip_type: Union[Literal['public'], Literal['private'], Literal['psc']]
keyfile: Optional[str]
keyfile_json: Optional[Dict[str, Any]]
timeout: Optional[int]
scopes: Tuple[str, ...]
driver: str
type_: Literal['gcp_postgres']
DIALECT: ClassVar[Literal['postgres']] = 'postgres'
DISPLAY_NAME: ClassVar[Literal['GCP Postgres']] = 'GCP Postgres'
DISPLAY_ORDER: ClassVar[Literal[13]] = 13
concurrent_tasks: int
register_comments: bool
pre_ping: bool
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class RedshiftConnectionConfig(ConnectionConfig):
1290class RedshiftConnectionConfig(ConnectionConfig):
1291    """
1292    Redshift Connection Configuration.
1293
1294    Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146
1295    Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported.
1296
1297    Args:
1298        user: The username to use for authentication with the Amazon Redshift cluster.
1299        password: The password to use for authentication with the Amazon Redshift cluster.
1300        database: The name of the database instance to connect to.
1301        host: The hostname of the Amazon Redshift cluster.
1302        port: The port number of the Amazon Redshift cluster. Default value is 5439.
1303        source_address: No description provided
1304        unix_sock: No description provided
1305        ssl: Is SSL enabled. Default value is ``True``. SSL must be enabled when authenticating using IAM.
1306        sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported.
1307        timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout.
1308        tcp_keepalive: Is `TCP keepalive <https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive>`_ used. The default value is ``True``.
1309        application_name: Sets the application name. The default value is None.
1310        preferred_role: The IAM role preferred for the current connection.
1311        principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy.
1312        credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster.
1313        region: The AWS region where the Amazon Redshift cluster is located.
1314        cluster_identifier: The cluster identifier of the Amazon Redshift cluster.
1315        iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP.
1316        is_serverless: Redshift end-point is serverless or provisional. Default value false.
1317        serverless_acct_id: The account ID of the serverless. Default value None
1318        serverless_work_group: The name of work group for serverless end point. Default value None.
1319        pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
1320        enable_merge: Whether to use the Redshift merge operation instead of the SQLMesh logical merge.
1321    """
1322
1323    user: t.Optional[str] = None
1324    password: t.Optional[str] = None
1325    database: t.Optional[str] = None
1326    host: t.Optional[str] = None
1327    port: t.Optional[int] = None
1328    source_address: t.Optional[str] = None
1329    unix_sock: t.Optional[str] = None
1330    ssl: t.Optional[bool] = None
1331    sslmode: t.Optional[str] = None
1332    timeout: t.Optional[int] = None
1333    tcp_keepalive: t.Optional[bool] = None
1334    application_name: t.Optional[str] = None
1335    preferred_role: t.Optional[str] = None
1336    principal_arn: t.Optional[str] = None
1337    credentials_provider: t.Optional[str] = None
1338    region: t.Optional[str] = None
1339    cluster_identifier: t.Optional[str] = None
1340    iam: t.Optional[bool] = None
1341    is_serverless: t.Optional[bool] = None
1342    serverless_acct_id: t.Optional[str] = None
1343    serverless_work_group: t.Optional[str] = None
1344    enable_merge: t.Optional[bool] = None
1345
1346    concurrent_tasks: int = 4
1347    register_comments: bool = True
1348    pre_ping: bool = False
1349
1350    type_: t.Literal["redshift"] = Field(alias="type", default="redshift")
1351    DIALECT: t.ClassVar[t.Literal["redshift"]] = "redshift"
1352    DISPLAY_NAME: t.ClassVar[t.Literal["Redshift"]] = "Redshift"
1353    DISPLAY_ORDER: t.ClassVar[t.Literal[7]] = 7
1354
1355    _engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift")
1356
1357    @property
1358    def _connection_kwargs_keys(self) -> t.Set[str]:
1359        return {
1360            "user",
1361            "password",
1362            "database",
1363            "host",
1364            "port",
1365            "source_address",
1366            "unix_sock",
1367            "ssl",
1368            "sslmode",
1369            "timeout",
1370            "tcp_keepalive",
1371            "application_name",
1372            "preferred_role",
1373            "principal_arn",
1374            "credentials_provider",
1375            "region",
1376            "cluster_identifier",
1377            "iam",
1378            "is_serverless",
1379            "serverless_acct_id",
1380            "serverless_work_group",
1381        }
1382
1383    @property
1384    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1385        return engine_adapter.RedshiftEngineAdapter
1386
1387    @property
1388    def _connection_factory(self) -> t.Callable:
1389        from redshift_connector import connect
1390
1391        return connect
1392
1393    @property
1394    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1395        return {"enable_merge": self.enable_merge}

Redshift Connection Configuration.

Arg Source: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/__init__.py#L146 Note: A subset of properties were selected. Please open an issue/PR if you want to see more supported.

Arguments:
  • user: The username to use for authentication with the Amazon Redshift cluster.
  • password: The password to use for authentication with the Amazon Redshift cluster.
  • database: The name of the database instance to connect to.
  • host: The hostname of the Amazon Redshift cluster.
  • port: The port number of the Amazon Redshift cluster. Default value is 5439.
  • source_address: No description provided
  • unix_sock: No description provided
  • ssl: Is SSL enabled. Default value is True. SSL must be enabled when authenticating using IAM.
  • sslmode: The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported.
  • timeout: The number of seconds before the connection to the server will timeout. By default there is no timeout.
  • tcp_keepalive: Is TCP keepalive used. The default value is True.
  • application_name: Sets the application name. The default value is None.
  • preferred_role: The IAM role preferred for the current connection.
  • principal_arn: The ARN of the IAM entity (user or role) for which you are generating a policy.
  • credentials_provider: The class name of the IdP that will be used for authenticating with the Amazon Redshift cluster.
  • region: The AWS region where the Amazon Redshift cluster is located.
  • cluster_identifier: The cluster identifier of the Amazon Redshift cluster.
  • iam: If IAM authentication is enabled. Default value is False. IAM must be True when authenticating using an IdP.
  • is_serverless: Redshift end-point is serverless or provisional. Default value false.
  • serverless_acct_id: The account ID of the serverless. Default value None
  • serverless_work_group: The name of work group for serverless end point. Default value None.
  • pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive.
  • enable_merge: Whether to use the Redshift merge operation instead of the SQLMesh logical merge.
user: Optional[str]
password: Optional[str]
database: Optional[str]
host: Optional[str]
port: Optional[int]
source_address: Optional[str]
unix_sock: Optional[str]
ssl: Optional[bool]
sslmode: Optional[str]
timeout: Optional[int]
tcp_keepalive: Optional[bool]
application_name: Optional[str]
preferred_role: Optional[str]
principal_arn: Optional[str]
credentials_provider: Optional[str]
region: Optional[str]
cluster_identifier: Optional[str]
iam: Optional[bool]
is_serverless: Optional[bool]
serverless_acct_id: Optional[str]
serverless_work_group: Optional[str]
enable_merge: Optional[bool]
concurrent_tasks: int
register_comments: bool
pre_ping: bool
type_: Literal['redshift']
DIALECT: ClassVar[Literal['redshift']] = 'redshift'
DISPLAY_NAME: ClassVar[Literal['Redshift']] = 'Redshift'
DISPLAY_ORDER: ClassVar[Literal[7]] = 7
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class PostgresConnectionConfig(ConnectionConfig):
1398class PostgresConnectionConfig(ConnectionConfig):
1399    host: str
1400    user: str
1401    password: str
1402    port: int
1403    database: str
1404    keepalives_idle: t.Optional[int] = None
1405    connect_timeout: int = 10
1406    role: t.Optional[str] = None
1407    sslmode: t.Optional[str] = None
1408    application_name: t.Optional[str] = None
1409
1410    concurrent_tasks: int = 4
1411    register_comments: bool = True
1412    pre_ping: bool = True
1413
1414    type_: t.Literal["postgres"] = Field(alias="type", default="postgres")
1415    DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres"
1416    DISPLAY_NAME: t.ClassVar[t.Literal["Postgres"]] = "Postgres"
1417    DISPLAY_ORDER: t.ClassVar[t.Literal[12]] = 12
1418
1419    _engine_import_validator = _get_engine_import_validator("psycopg2", "postgres")
1420
1421    @property
1422    def _connection_kwargs_keys(self) -> t.Set[str]:
1423        return {
1424            "host",
1425            "user",
1426            "password",
1427            "port",
1428            "database",
1429            "keepalives_idle",
1430            "connect_timeout",
1431            "sslmode",
1432            "application_name",
1433        }
1434
1435    @property
1436    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1437        return engine_adapter.PostgresEngineAdapter
1438
1439    @property
1440    def _connection_factory(self) -> t.Callable:
1441        from psycopg2 import connect
1442
1443        return connect
1444
1445    @property
1446    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
1447        if not self.role:
1448            return None
1449
1450        def init(cursor: t.Any) -> None:
1451            cursor.execute(f"SET ROLE {self.role}")
1452
1453        return init

Helper class that provides a standard way to create an ABC using inheritance.

host: str
user: str
password: str
port: int
database: str
keepalives_idle: Optional[int]
connect_timeout: int
role: Optional[str]
sslmode: Optional[str]
application_name: Optional[str]
concurrent_tasks: int
register_comments: bool
pre_ping: bool
type_: Literal['postgres']
DIALECT: ClassVar[Literal['postgres']] = 'postgres'
DISPLAY_NAME: ClassVar[Literal['Postgres']] = 'Postgres'
DISPLAY_ORDER: ClassVar[Literal[12]] = 12
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class MySQLConnectionConfig(ConnectionConfig):
1456class MySQLConnectionConfig(ConnectionConfig):
1457    host: str
1458    user: str
1459    password: str
1460    port: t.Optional[int] = None
1461    database: t.Optional[str] = None
1462    charset: t.Optional[str] = None
1463    collation: t.Optional[str] = None
1464    ssl_disabled: t.Optional[bool] = None
1465
1466    concurrent_tasks: int = 4
1467    register_comments: bool = True
1468    pre_ping: bool = True
1469
1470    type_: t.Literal["mysql"] = Field(alias="type", default="mysql")
1471    DIALECT: t.ClassVar[t.Literal["mysql"]] = "mysql"
1472    DISPLAY_NAME: t.ClassVar[t.Literal["MySQL"]] = "MySQL"
1473    DISPLAY_ORDER: t.ClassVar[t.Literal[14]] = 14
1474
1475    _engine_import_validator = _get_engine_import_validator("pymysql", "mysql")
1476
1477    @property
1478    def _connection_kwargs_keys(self) -> t.Set[str]:
1479        connection_keys = {
1480            "host",
1481            "user",
1482            "password",
1483        }
1484        if self.port is not None:
1485            connection_keys.add("port")
1486        if self.database is not None:
1487            connection_keys.add("database")
1488        if self.charset is not None:
1489            connection_keys.add("charset")
1490        if self.collation is not None:
1491            connection_keys.add("collation")
1492        if self.ssl_disabled is not None:
1493            connection_keys.add("ssl_disabled")
1494        return connection_keys
1495
1496    @property
1497    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1498        return engine_adapter.MySQLEngineAdapter
1499
1500    @property
1501    def _connection_factory(self) -> t.Callable:
1502        from pymysql import connect
1503
1504        return connect

Helper class that provides a standard way to create an ABC using inheritance.

host: str
user: str
password: str
port: Optional[int]
database: Optional[str]
charset: Optional[str]
collation: Optional[str]
ssl_disabled: Optional[bool]
concurrent_tasks: int
register_comments: bool
pre_ping: bool
type_: Literal['mysql']
DIALECT: ClassVar[Literal['mysql']] = 'mysql'
DISPLAY_NAME: ClassVar[Literal['MySQL']] = 'MySQL'
DISPLAY_ORDER: ClassVar[Literal[14]] = 14
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class MSSQLConnectionConfig(ConnectionConfig):
1507class MSSQLConnectionConfig(ConnectionConfig):
1508    host: str
1509    user: t.Optional[str] = None
1510    password: t.Optional[str] = None
1511    database: t.Optional[str] = ""
1512    timeout: t.Optional[int] = 0
1513    login_timeout: t.Optional[int] = 60
1514    charset: t.Optional[str] = "UTF-8"
1515    appname: t.Optional[str] = None
1516    port: t.Optional[int] = 1433
1517    conn_properties: t.Optional[t.Union[t.List[str], str]] = None
1518    autocommit: t.Optional[bool] = False
1519    tds_version: t.Optional[str] = None
1520
1521    # Driver options
1522    driver: t.Literal["pymssql", "pyodbc"] = "pymssql"
1523    # PyODBC specific options
1524    driver_name: t.Optional[str] = None  # e.g. "ODBC Driver 18 for SQL Server"
1525    trust_server_certificate: t.Optional[bool] = None
1526    encrypt: t.Optional[bool] = None
1527    # Dictionary of arbitrary ODBC connection properties
1528    # See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute
1529    odbc_properties: t.Optional[t.Dict[str, t.Any]] = None
1530
1531    concurrent_tasks: int = 4
1532    register_comments: bool = True
1533    pre_ping: bool = True
1534
1535    type_: t.Literal["mssql"] = Field(alias="type", default="mssql")
1536    DIALECT: t.ClassVar[t.Literal["tsql"]] = "tsql"
1537    DISPLAY_NAME: t.ClassVar[t.Literal["MSSQL"]] = "MSSQL"
1538    DISPLAY_ORDER: t.ClassVar[t.Literal[11]] = 11
1539
1540    @model_validator(mode="before")
1541    @classmethod
1542    def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any:
1543        if not isinstance(data, dict):
1544            return data
1545
1546        driver = data.get("driver", "pymssql")
1547
1548        # Define the mapping of driver to import module and extra name
1549        driver_configs = {"pymssql": ("pymssql", "mssql"), "pyodbc": ("pyodbc", "mssql-odbc")}
1550
1551        if driver not in driver_configs:
1552            raise ValueError(f"Unsupported driver: {driver}")
1553
1554        import_module, extra_name = driver_configs[driver]
1555
1556        # Use _get_engine_import_validator with decorate=False to get the raw validation function
1557        # This avoids the __wrapped__ issue in Python 3.9
1558        validator_func = _get_engine_import_validator(
1559            import_module, driver, extra_name, decorate=False
1560        )
1561
1562        # Call the raw validation function directly
1563        return validator_func(cls, data)
1564
1565    @property
1566    def _connection_kwargs_keys(self) -> t.Set[str]:
1567        base_keys = {
1568            "host",
1569            "user",
1570            "password",
1571            "database",
1572            "timeout",
1573            "login_timeout",
1574            "charset",
1575            "appname",
1576            "port",
1577            "conn_properties",
1578            "autocommit",
1579            "tds_version",
1580        }
1581
1582        if self.driver == "pyodbc":
1583            base_keys.update(
1584                {
1585                    "driver_name",
1586                    "trust_server_certificate",
1587                    "encrypt",
1588                    "odbc_properties",
1589                }
1590            )
1591            # Remove pymssql-specific parameters
1592            base_keys.discard("tds_version")
1593            base_keys.discard("conn_properties")
1594
1595        return base_keys
1596
1597    @property
1598    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1599        return engine_adapter.MSSQLEngineAdapter
1600
1601    @property
1602    def _connection_factory(self) -> t.Callable:
1603        if self.driver == "pymssql":
1604            import pymssql
1605
1606            return pymssql.connect
1607
1608        import pyodbc
1609
1610        def connect(**kwargs: t.Any) -> t.Callable:
1611            # Extract parameters for connection string
1612            host = kwargs.pop("host")
1613            port = kwargs.pop("port", 1433)
1614            database = kwargs.pop("database", "")
1615            user = kwargs.pop("user", None)
1616            password = kwargs.pop("password", None)
1617            driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server")
1618            trust_server_certificate = kwargs.pop("trust_server_certificate", False)
1619            encrypt = kwargs.pop("encrypt", True)
1620            login_timeout = kwargs.pop("login_timeout", 60)
1621
1622            # Build connection string
1623            conn_str_parts = [
1624                f"DRIVER={{{driver_name}}}",
1625                f"SERVER={host},{port}",
1626            ]
1627
1628            if database:
1629                conn_str_parts.append(f"DATABASE={database}")
1630
1631            # Add security options
1632            conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}")
1633            if trust_server_certificate:
1634                conn_str_parts.append("TrustServerCertificate=YES")
1635
1636            conn_str_parts.append(f"Connection Timeout={login_timeout}")
1637
1638            # Standard SQL Server authentication
1639            if user:
1640                conn_str_parts.append(f"UID={user}")
1641            if password:
1642                conn_str_parts.append(f"PWD={password}")
1643
1644            # Add any additional ODBC properties from the odbc_properties dictionary
1645            if self.odbc_properties:
1646                for key, value in self.odbc_properties.items():
1647                    # Skip properties that we've already set above
1648                    if key.lower() in (
1649                        "driver",
1650                        "server",
1651                        "database",
1652                        "uid",
1653                        "pwd",
1654                        "encrypt",
1655                        "trustservercertificate",
1656                        "connection timeout",
1657                    ):
1658                        continue
1659
1660                    # Handle boolean values properly
1661                    if isinstance(value, bool):
1662                        conn_str_parts.append(f"{key}={'YES' if value else 'NO'}")
1663                    else:
1664                        conn_str_parts.append(f"{key}={value}")
1665
1666            # Create the connection string
1667            conn_str = ";".join(conn_str_parts)
1668
1669            conn = pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False))
1670
1671            # Set up output converters for MSSQL-specific data types
1672            # Handle SQL type -155 (DATETIMEOFFSET) which is not yet supported by pyodbc
1673            # ref: https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794
1674            def handle_datetimeoffset(dto_value: t.Any) -> t.Any:
1675                from datetime import datetime, timedelta, timezone
1676                import struct
1677
1678                # Unpack the DATETIMEOFFSET binary format:
1679                # Format: <6hI2h = (year, month, day, hour, minute, second, nanoseconds, tz_hour_offset, tz_minute_offset)
1680                tup = struct.unpack("<6hI2h", dto_value)
1681                return datetime(
1682                    tup[0],
1683                    tup[1],
1684                    tup[2],
1685                    tup[3],
1686                    tup[4],
1687                    tup[5],
1688                    tup[6] // 1000,
1689                    timezone(timedelta(hours=tup[7], minutes=tup[8])),
1690                )
1691
1692            conn.add_output_converter(-155, handle_datetimeoffset)
1693
1694            return conn
1695
1696        return connect
1697
1698    @property
1699    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1700        return {"catalog_support": CatalogSupport.REQUIRES_SET_CATALOG}

Helper class that provides a standard way to create an ABC using inheritance.

host: str
user: Optional[str]
password: Optional[str]
database: Optional[str]
timeout: Optional[int]
login_timeout: Optional[int]
charset: Optional[str]
appname: Optional[str]
port: Optional[int]
conn_properties: Union[List[str], str, NoneType]
autocommit: Optional[bool]
tds_version: Optional[str]
driver: Literal['pymssql', 'pyodbc']
driver_name: Optional[str]
trust_server_certificate: Optional[bool]
encrypt: Optional[bool]
odbc_properties: Optional[Dict[str, Any]]
concurrent_tasks: int
register_comments: bool
pre_ping: bool
type_: Literal['mssql']
DIALECT: ClassVar[Literal['tsql']] = 'tsql'
DISPLAY_NAME: ClassVar[Literal['MSSQL']] = 'MSSQL'
DISPLAY_ORDER: ClassVar[Literal[11]] = 11
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class AzureSQLConnectionConfig(MSSQLConnectionConfig):
1703class AzureSQLConnectionConfig(MSSQLConnectionConfig):
1704    type_: t.Literal["azuresql"] = Field(alias="type", default="azuresql")  # type: ignore
1705    DISPLAY_NAME: t.ClassVar[t.Literal["Azure SQL"]] = "Azure SQL"  # type: ignore
1706    DISPLAY_ORDER: t.ClassVar[t.Literal[10]] = 10  # type: ignore
1707
1708    @property
1709    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1710        return {"catalog_support": CatalogSupport.SINGLE_CATALOG_ONLY}

Helper class that provides a standard way to create an ABC using inheritance.

type_: Literal['azuresql']
DISPLAY_NAME: ClassVar[Literal['Azure SQL']] = 'Azure SQL'
DISPLAY_ORDER: ClassVar[Literal[10]] = 10
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
MSSQLConnectionConfig
host
user
password
database
timeout
login_timeout
charset
appname
port
conn_properties
autocommit
tds_version
driver
driver_name
trust_server_certificate
encrypt
odbc_properties
concurrent_tasks
register_comments
pre_ping
DIALECT
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class FabricConnectionConfig(MSSQLConnectionConfig):
1713class FabricConnectionConfig(MSSQLConnectionConfig):
1714    """
1715    Fabric Connection Configuration.
1716    Inherits most settings from MSSQLConnectionConfig and sets the type to 'fabric'.
1717    It is recommended to use the 'pyodbc' driver for Fabric.
1718    """
1719
1720    type_: t.Literal["fabric"] = Field(alias="type", default="fabric")  # type: ignore
1721    DIALECT: t.ClassVar[t.Literal["fabric"]] = "fabric"  # type: ignore
1722    DISPLAY_NAME: t.ClassVar[t.Literal["Fabric"]] = "Fabric"  # type: ignore
1723    DISPLAY_ORDER: t.ClassVar[t.Literal[17]] = 17  # type: ignore
1724    driver: t.Literal["pyodbc"] = "pyodbc"
1725    workspace_id: str
1726    tenant_id: str
1727    autocommit: t.Optional[bool] = True
1728
1729    @property
1730    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1731        from sqlmesh.core.engine_adapter.fabric import FabricEngineAdapter
1732
1733        return FabricEngineAdapter
1734
1735    @property
1736    def _connection_factory(self) -> t.Callable:
1737        # Override to support catalog switching for Fabric
1738        base_factory = super()._connection_factory
1739
1740        def create_fabric_connection(
1741            target_catalog: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any
1742        ) -> t.Callable:
1743            kwargs["database"] = target_catalog or self.database
1744            return base_factory(*args, **kwargs)
1745
1746        return create_fabric_connection
1747
1748    @property
1749    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1750        return {
1751            "database": self.database,
1752            # more operations than not require a specific catalog to be already active
1753            # in particular, create/drop view, create/drop schema and querying information_schema
1754            "catalog_support": CatalogSupport.REQUIRES_SET_CATALOG,
1755            "workspace_id": self.workspace_id,
1756            "tenant_id": self.tenant_id,
1757            "user": self.user,
1758            "password": self.password,
1759        }

Fabric Connection Configuration. Inherits most settings from MSSQLConnectionConfig and sets the type to 'fabric'. It is recommended to use the 'pyodbc' driver for Fabric.

type_: Literal['fabric']
DIALECT: ClassVar[Literal['fabric']] = 'fabric'
DISPLAY_NAME: ClassVar[Literal['Fabric']] = 'Fabric'
DISPLAY_ORDER: ClassVar[Literal[17]] = 17
driver: Literal['pyodbc']
workspace_id: str
tenant_id: str
autocommit: Optional[bool]
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
MSSQLConnectionConfig
host
user
password
database
timeout
login_timeout
charset
appname
port
conn_properties
tds_version
driver_name
trust_server_certificate
encrypt
odbc_properties
concurrent_tasks
register_comments
pre_ping
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class SparkConnectionConfig(ConnectionConfig):
1762class SparkConnectionConfig(ConnectionConfig):
1763    """
1764    Vanilla Spark Connection Configuration. Use `DatabricksConnectionConfig` for Databricks.
1765    """
1766
1767    config_dir: t.Optional[str] = None
1768    catalog: t.Optional[str] = None
1769    config: t.Dict[str, t.Any] = {}
1770    wap_enabled: bool = False
1771
1772    concurrent_tasks: int = 4
1773    register_comments: bool = True
1774    pre_ping: t.Literal[False] = False
1775
1776    type_: t.Literal["spark"] = Field(alias="type", default="spark")
1777    DIALECT: t.ClassVar[t.Literal["spark"]] = "spark"
1778    DISPLAY_NAME: t.ClassVar[t.Literal["Spark"]] = "Spark"
1779    DISPLAY_ORDER: t.ClassVar[t.Literal[8]] = 8
1780
1781    _engine_import_validator = _get_engine_import_validator("pyspark", "spark")
1782
1783    @property
1784    def _connection_kwargs_keys(self) -> t.Set[str]:
1785        return {
1786            "catalog",
1787        }
1788
1789    @property
1790    def _engine_adapter(self) -> t.Type[EngineAdapter]:
1791        return engine_adapter.SparkEngineAdapter
1792
1793    @property
1794    def _connection_factory(self) -> t.Callable:
1795        from sqlmesh.engines.spark.db_api.spark_session import connection
1796
1797        return connection
1798
1799    @property
1800    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
1801        from pyspark.conf import SparkConf
1802        from pyspark.sql import SparkSession
1803
1804        spark_config = SparkConf()
1805        if self.config:
1806            for k, v in self.config.items():
1807                spark_config.set(k, v)
1808
1809        if self.config_dir:
1810            os.environ["SPARK_CONF_DIR"] = self.config_dir
1811        return {
1812            "spark": SparkSession.builder.config(conf=spark_config)
1813            .enableHiveSupport()
1814            .getOrCreate(),
1815        }
1816
1817    @property
1818    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
1819        return {"wap_enabled": self.wap_enabled}

Vanilla Spark Connection Configuration. Use DatabricksConnectionConfig for Databricks.

config_dir: Optional[str]
catalog: Optional[str]
config: Dict[str, Any]
wap_enabled: bool
concurrent_tasks: int
register_comments: bool
pre_ping: Literal[False]
type_: Literal['spark']
DIALECT: ClassVar[Literal['spark']] = 'spark'
DISPLAY_NAME: ClassVar[Literal['Spark']] = 'Spark'
DISPLAY_ORDER: ClassVar[Literal[8]] = 8
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class TrinoAuthenticationMethod(builtins.str, enum.Enum):
1822class TrinoAuthenticationMethod(str, Enum):
1823    NO_AUTH = "no-auth"
1824    BASIC = "basic"
1825    LDAP = "ldap"
1826    KERBEROS = "kerberos"
1827    JWT = "jwt"
1828    CERTIFICATE = "certificate"
1829    OAUTH = "oauth"
1830
1831    @property
1832    def is_no_auth(self) -> bool:
1833        return self == self.NO_AUTH
1834
1835    @property
1836    def is_basic(self) -> bool:
1837        return self == self.BASIC
1838
1839    @property
1840    def is_ldap(self) -> bool:
1841        return self == self.LDAP
1842
1843    @property
1844    def is_kerberos(self) -> bool:
1845        return self == self.KERBEROS
1846
1847    @property
1848    def is_jwt(self) -> bool:
1849        return self == self.JWT
1850
1851    @property
1852    def is_certificate(self) -> bool:
1853        return self == self.CERTIFICATE
1854
1855    @property
1856    def is_oauth(self) -> bool:
1857        return self == self.OAUTH

An enumeration.

NO_AUTH = <TrinoAuthenticationMethod.NO_AUTH: 'no-auth'>
BASIC = <TrinoAuthenticationMethod.BASIC: 'basic'>
KERBEROS = <TrinoAuthenticationMethod.KERBEROS: 'kerberos'>
CERTIFICATE = <TrinoAuthenticationMethod.CERTIFICATE: 'certificate'>
OAUTH = <TrinoAuthenticationMethod.OAUTH: 'oauth'>
is_no_auth: bool
1831    @property
1832    def is_no_auth(self) -> bool:
1833        return self == self.NO_AUTH
is_basic: bool
1835    @property
1836    def is_basic(self) -> bool:
1837        return self == self.BASIC
is_ldap: bool
1839    @property
1840    def is_ldap(self) -> bool:
1841        return self == self.LDAP
is_kerberos: bool
1843    @property
1844    def is_kerberos(self) -> bool:
1845        return self == self.KERBEROS
is_jwt: bool
1847    @property
1848    def is_jwt(self) -> bool:
1849        return self == self.JWT
is_certificate: bool
1851    @property
1852    def is_certificate(self) -> bool:
1853        return self == self.CERTIFICATE
is_oauth: bool
1855    @property
1856    def is_oauth(self) -> bool:
1857        return self == self.OAUTH
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class TrinoConnectionConfig(ConnectionConfig):
1860class TrinoConnectionConfig(ConnectionConfig):
1861    method: TrinoAuthenticationMethod = TrinoAuthenticationMethod.NO_AUTH
1862    host: str
1863    user: str
1864    catalog: str
1865    port: t.Optional[int] = None
1866    http_scheme: t.Literal["http", "https"] = "https"
1867    # General Optional
1868    roles: t.Optional[t.Dict[str, str]] = None
1869    http_headers: t.Optional[t.Dict[str, str]] = None
1870    session_properties: t.Optional[t.Dict[str, str]] = None
1871    retries: int = 3
1872    timezone: t.Optional[str] = None
1873    # Basic/LDAP
1874    password: t.Optional[str] = None
1875    verify: t.Optional[bool] = None  # disable SSL verification (ignored if `cert` is provided)
1876    # LDAP
1877    impersonation_user: t.Optional[str] = None
1878    # Kerberos
1879    keytab: t.Optional[str] = None
1880    krb5_config: t.Optional[str] = None
1881    principal: t.Optional[str] = None
1882    service_name: str = "trino"
1883    hostname_override: t.Optional[str] = None
1884    mutual_authentication: bool = False
1885    force_preemptive: bool = False
1886    sanitize_mutual_error_response: bool = True
1887    delegate: bool = False
1888    # JWT
1889    jwt_token: t.Optional[str] = None
1890    # Certificate
1891    client_certificate: t.Optional[str] = None
1892    client_private_key: t.Optional[str] = None
1893    cert: t.Optional[str] = None
1894    source: str = "sqlmesh"
1895
1896    # SQLMesh options
1897    schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None
1898    timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None
1899    concurrent_tasks: int = 4
1900    register_comments: bool = True
1901    pre_ping: t.Literal[False] = False
1902
1903    type_: t.Literal["trino"] = Field(alias="type", default="trino")
1904    DIALECT: t.ClassVar[t.Literal["trino"]] = "trino"
1905    DISPLAY_NAME: t.ClassVar[t.Literal["Trino"]] = "Trino"
1906    DISPLAY_ORDER: t.ClassVar[t.Literal[9]] = 9
1907
1908    _engine_import_validator = _get_engine_import_validator("trino", "trino")
1909
1910    @field_validator("schema_location_mapping", mode="before")
1911    @classmethod
1912    def _validate_regex_keys(
1913        cls, value: t.Dict[str | re.Pattern, str]
1914    ) -> t.Dict[re.Pattern, t.Any]:
1915        compiled = compile_regex_mapping(value)
1916        for replacement in compiled.values():
1917            if "@{schema_name}" not in replacement:
1918                raise ConfigError(
1919                    "schema_location_mapping needs to include the '@{schema_name}' placeholder in the value so SQLMesh knows where to substitute the schema name"
1920                )
1921        return compiled
1922
1923    @field_validator("timestamp_mapping", mode="before")
1924    @classmethod
1925    def _validate_timestamp_mapping(
1926        cls, value: t.Optional[dict[str, str]]
1927    ) -> t.Optional[dict[exp.DataType, exp.DataType]]:
1928        if value is None:
1929            return value
1930
1931        result: dict[exp.DataType, exp.DataType] = {}
1932        for source_type, target_type in value.items():
1933            try:
1934                source_datatype = exp.DataType.build(source_type)
1935            except ParseError:
1936                raise ConfigError(
1937                    f"Invalid SQL type string in timestamp_mapping: "
1938                    f"'{source_type}' is not a valid SQL data type."
1939                )
1940            try:
1941                target_datatype = exp.DataType.build(target_type)
1942            except ParseError:
1943                raise ConfigError(
1944                    f"Invalid SQL type string in timestamp_mapping: "
1945                    f"'{target_type}' is not a valid SQL data type."
1946                )
1947            result[source_datatype] = target_datatype
1948
1949        return result
1950
1951    @model_validator(mode="after")
1952    def _root_validator(self) -> Self:
1953        port = self.port
1954        if self.http_scheme == "http" and not self.method.is_no_auth and not self.method.is_basic:
1955            raise ConfigError("HTTP scheme can only be used with no-auth or basic method")
1956
1957        if port is None:
1958            self.port = 80 if self.http_scheme == "http" else 443
1959
1960        if (self.method.is_ldap or self.method.is_basic) and (not self.password or not self.user):
1961            raise ConfigError(
1962                f"Username and Password must be provided if using {self.method.value} authentication"
1963            )
1964
1965        if self.method.is_kerberos and (
1966            not self.principal or not self.keytab or not self.krb5_config
1967        ):
1968            raise ConfigError(
1969                "Kerberos requires the following fields: principal, keytab, and krb5_config"
1970            )
1971
1972        if self.method.is_jwt and not self.jwt_token:
1973            raise ConfigError("JWT requires `jwt_token` to be set")
1974
1975        if self.method.is_certificate and (
1976            not self.cert or not self.client_certificate or not self.client_private_key
1977        ):
1978            raise ConfigError(
1979                "Certificate requires the following fields: cert, client_certificate, and client_private_key"
1980            )
1981
1982        return self
1983
1984    @property
1985    def _connection_kwargs_keys(self) -> t.Set[str]:
1986        kwargs = {
1987            "host",
1988            "port",
1989            "catalog",
1990            "roles",
1991            "source",
1992            "http_scheme",
1993            "http_headers",
1994            "session_properties",
1995            "timezone",
1996        }
1997        return kwargs
1998
1999    @property
2000    def _engine_adapter(self) -> t.Type[EngineAdapter]:
2001        return engine_adapter.TrinoEngineAdapter
2002
2003    @property
2004    def _connection_factory(self) -> t.Callable:
2005        from trino.dbapi import connect
2006
2007        return connect
2008
2009    @property
2010    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
2011        from trino.auth import (
2012            BasicAuthentication,
2013            CertificateAuthentication,
2014            JWTAuthentication,
2015            KerberosAuthentication,
2016            OAuth2Authentication,
2017        )
2018
2019        auth: t.Optional[
2020            t.Union[
2021                BasicAuthentication,
2022                KerberosAuthentication,
2023                OAuth2Authentication,
2024                JWTAuthentication,
2025                CertificateAuthentication,
2026            ]
2027        ] = None
2028        if self.method.is_basic or self.method.is_ldap:
2029            assert self.password is not None  # for mypy since validator already checks this
2030            auth = BasicAuthentication(self.user, self.password)
2031        elif self.method.is_kerberos:
2032            if self.keytab:
2033                os.environ["KRB5_CLIENT_KTNAME"] = self.keytab
2034            auth = KerberosAuthentication(
2035                config=self.krb5_config,
2036                service_name=self.service_name,
2037                principal=self.principal,
2038                mutual_authentication=self.mutual_authentication,
2039                ca_bundle=self.cert,
2040                force_preemptive=self.force_preemptive,
2041                hostname_override=self.hostname_override,
2042                sanitize_mutual_error_response=self.sanitize_mutual_error_response,
2043                delegate=self.delegate,
2044            )
2045        elif self.method.is_oauth:
2046            auth = OAuth2Authentication()
2047        elif self.method.is_jwt:
2048            assert self.jwt_token is not None
2049            auth = JWTAuthentication(self.jwt_token)
2050        elif self.method.is_certificate:
2051            assert self.client_certificate is not None
2052            assert self.client_private_key is not None
2053            auth = CertificateAuthentication(self.client_certificate, self.client_private_key)
2054
2055        return {
2056            "auth": auth,
2057            "user": self.impersonation_user or self.user,
2058            "max_attempts": self.retries,
2059            "verify": self.cert if self.cert is not None else self.verify,
2060            "source": self.source,
2061        }
2062
2063    @property
2064    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
2065        return {
2066            "schema_location_mapping": self.schema_location_mapping,
2067            "timestamp_mapping": self.timestamp_mapping,
2068        }

Helper class that provides a standard way to create an ABC using inheritance.

host: str
user: str
catalog: str
port: Optional[int]
http_scheme: Literal['http', 'https']
roles: Optional[Dict[str, str]]
http_headers: Optional[Dict[str, str]]
session_properties: Optional[Dict[str, str]]
retries: int
timezone: Optional[str]
password: Optional[str]
verify: Optional[bool]
impersonation_user: Optional[str]
keytab: Optional[str]
krb5_config: Optional[str]
principal: Optional[str]
service_name: str
hostname_override: Optional[str]
mutual_authentication: bool
force_preemptive: bool
sanitize_mutual_error_response: bool
delegate: bool
jwt_token: Optional[str]
client_certificate: Optional[str]
client_private_key: Optional[str]
cert: Optional[str]
source: str
schema_location_mapping: Optional[dict[re.Pattern, str]]
timestamp_mapping: Optional[dict[sqlglot.expressions.datatypes.DataType, sqlglot.expressions.datatypes.DataType]]
concurrent_tasks: int
register_comments: bool
pre_ping: Literal[False]
type_: Literal['trino']
DIALECT: ClassVar[Literal['trino']] = 'trino'
DISPLAY_NAME: ClassVar[Literal['Trino']] = 'Trino'
DISPLAY_ORDER: ClassVar[Literal[9]] = 9
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class ClickhouseConnectionConfig(ConnectionConfig):
2071class ClickhouseConnectionConfig(ConnectionConfig):
2072    """
2073    Clickhouse Connection Configuration.
2074
2075    Property reference: https://clickhouse.com/docs/en/integrations/python#client-initialization
2076    """
2077
2078    host: str
2079    username: str
2080    password: t.Optional[str] = None
2081    port: t.Optional[int] = None
2082    cluster: t.Optional[str] = None
2083    connect_timeout: int = 10
2084    send_receive_timeout: int = 300
2085    query_limit: int = 0
2086    use_compression: bool = True
2087    compression_method: t.Optional[str] = None
2088    connection_settings: t.Optional[t.Dict[str, t.Any]] = None
2089    http_proxy: t.Optional[str] = None
2090    # HTTPS/TLS settings
2091    verify: bool = True
2092    ca_cert: t.Optional[str] = None
2093    client_cert: t.Optional[str] = None
2094    client_cert_key: t.Optional[str] = None
2095    https_proxy: t.Optional[str] = None
2096    server_host_name: t.Optional[str] = None
2097    tls_mode: t.Optional[str] = None
2098
2099    concurrent_tasks: int = 1
2100    register_comments: bool = True
2101    pre_ping: bool = False
2102
2103    # This object expects options from urllib3 and also from clickhouse-connect
2104    # See:
2105    # * https://urllib3.readthedocs.io/en/stable/advanced-usage.html
2106    # * https://clickhouse.com/docs/en/integrations/python#customizing-the-http-connection-pool
2107    connection_pool_options: t.Optional[t.Dict[str, t.Any]] = None
2108
2109    type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse")
2110    DIALECT: t.ClassVar[t.Literal["clickhouse"]] = "clickhouse"
2111    DISPLAY_NAME: t.ClassVar[t.Literal["ClickHouse"]] = "ClickHouse"
2112    DISPLAY_ORDER: t.ClassVar[t.Literal[6]] = 6
2113
2114    _engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse")
2115
2116    @property
2117    def _connection_kwargs_keys(self) -> t.Set[str]:
2118        kwargs = {
2119            "host",
2120            "username",
2121            "port",
2122            "password",
2123            "connect_timeout",
2124            "send_receive_timeout",
2125            "query_limit",
2126            "http_proxy",
2127            "verify",
2128            "ca_cert",
2129            "client_cert",
2130            "client_cert_key",
2131            "https_proxy",
2132            "server_host_name",
2133            "tls_mode",
2134        }
2135        return kwargs
2136
2137    @property
2138    def _engine_adapter(self) -> t.Type[EngineAdapter]:
2139        return engine_adapter.ClickhouseEngineAdapter
2140
2141    @property
2142    def _connection_factory(self) -> t.Callable:
2143        from clickhouse_connect.dbapi import connect  # type: ignore
2144        from clickhouse_connect.driver import httputil  # type: ignore
2145        from functools import partial
2146
2147        pool_manager_options: t.Dict[str, t.Any] = dict(
2148            # Match the maxsize to the number of concurrent tasks
2149            maxsize=self.concurrent_tasks,
2150            # Block if there are no free connections
2151            block=True,
2152            verify=self.verify,
2153            ca_cert=self.ca_cert,
2154            client_cert=self.client_cert,
2155            client_cert_key=self.client_cert_key,
2156            https_proxy=self.https_proxy,
2157        )
2158        # this doesn't happen automatically because we always supply our own pool manager to the connection
2159        # https://github.com/ClickHouse/clickhouse-connect/blob/3a7f4b04cad29c7c2536661b831fb744248e2ec0/clickhouse_connect/driver/httpclient.py#L109
2160        if self.server_host_name:
2161            pool_manager_options["server_hostname"] = self.server_host_name
2162            if self.verify:
2163                pool_manager_options["assert_hostname"] = self.server_host_name
2164        if self.connection_pool_options:
2165            pool_manager_options.update(self.connection_pool_options)
2166        pool_mgr = httputil.get_pool_manager(**pool_manager_options)
2167
2168        return partial(connect, pool_mgr=pool_mgr)
2169
2170    @property
2171    def cloud_mode(self) -> bool:
2172        return "clickhouse.cloud" in self.host
2173
2174    @property
2175    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
2176        return {"cluster": self.cluster, "cloud_mode": self.cloud_mode}
2177
2178    @property
2179    def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
2180        from sqlmesh import __version__
2181
2182        # False = no compression
2183        # True = Clickhouse default compression method
2184        # string = specific compression method
2185        compress: bool | str = self.use_compression
2186        if compress and self.compression_method:
2187            compress = self.compression_method
2188
2189        # Clickhouse system settings passed to connection
2190        # https://clickhouse.com/docs/en/operations/settings/settings
2191        # - below are set to align with dbt-clickhouse
2192        # - https://github.com/ClickHouse/dbt-clickhouse/blob/44d26308ea6a3c8ead25c280164aa88191f05f47/dbt/adapters/clickhouse/dbclient.py#L77
2193        settings = self.connection_settings or {}
2194        #  mutations_sync = 2: "The query waits for all mutations [ALTER statements] to complete on all replicas (if they exist)"
2195        settings["mutations_sync"] = "2"
2196        #  insert_distributed_sync = 1: "INSERT operation succeeds only after all the data is saved on all shards"
2197        settings["insert_distributed_sync"] = "1"
2198        if self.cluster or self.cloud_mode:
2199            # database_replicated_enforce_synchronous_settings = 1:
2200            #   - "Enforces synchronous waiting for some queries"
2201            #   - https://github.com/ClickHouse/ClickHouse/blob/ccaa8d03a9351efc16625340268b9caffa8a22ba/src/Core/Settings.h#L709
2202            settings["database_replicated_enforce_synchronous_settings"] = "1"
2203            # insert_quorum = auto:
2204            #   - "INSERT succeeds only when ClickHouse manages to correctly write data to the insert_quorum of replicas during
2205            #       the insert_quorum_timeout"
2206            #   - "use majority number (number_of_replicas / 2 + 1) as quorum number"
2207            settings["insert_quorum"] = "auto"
2208
2209        return {
2210            "compress": compress,
2211            "client_name": f"SQLMesh/{__version__}",
2212            **settings,
2213        }

Clickhouse Connection Configuration.

Property reference: https://clickhouse.com/docs/en/integrations/python#client-initialization

host: str
username: str
password: Optional[str]
port: Optional[int]
cluster: Optional[str]
connect_timeout: int
send_receive_timeout: int
query_limit: int
use_compression: bool
compression_method: Optional[str]
connection_settings: Optional[Dict[str, Any]]
http_proxy: Optional[str]
verify: bool
ca_cert: Optional[str]
client_cert: Optional[str]
client_cert_key: Optional[str]
https_proxy: Optional[str]
server_host_name: Optional[str]
tls_mode: Optional[str]
concurrent_tasks: int
register_comments: bool
pre_ping: bool
connection_pool_options: Optional[Dict[str, Any]]
type_: Literal['clickhouse']
DIALECT: ClassVar[Literal['clickhouse']] = 'clickhouse'
DISPLAY_NAME: ClassVar[Literal['ClickHouse']] = 'ClickHouse'
DISPLAY_ORDER: ClassVar[Literal[6]] = 6
cloud_mode: bool
2170    @property
2171    def cloud_mode(self) -> bool:
2172        return "clickhouse.cloud" in self.host
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class AthenaConnectionConfig(ConnectionConfig):
2216class AthenaConnectionConfig(ConnectionConfig):
2217    # PyAthena connection options
2218    aws_access_key_id: t.Optional[str] = None
2219    aws_secret_access_key: t.Optional[str] = None
2220    role_arn: t.Optional[str] = None
2221    role_session_name: t.Optional[str] = None
2222    region_name: t.Optional[str] = None
2223    work_group: t.Optional[str] = None
2224    s3_staging_dir: t.Optional[str] = None
2225    schema_name: t.Optional[str] = None
2226    catalog_name: t.Optional[str] = None
2227
2228    # SQLMesh options
2229    s3_warehouse_location: t.Optional[str] = None
2230    concurrent_tasks: int = 4
2231    register_comments: t.Literal[False] = (
2232        False  # because Athena doesnt support comments in most cases
2233    )
2234    pre_ping: t.Literal[False] = False
2235
2236    type_: t.Literal["athena"] = Field(alias="type", default="athena")
2237    DIALECT: t.ClassVar[t.Literal["athena"]] = "athena"
2238    DISPLAY_NAME: t.ClassVar[t.Literal["Athena"]] = "Athena"
2239    DISPLAY_ORDER: t.ClassVar[t.Literal[15]] = 15
2240
2241    _engine_import_validator = _get_engine_import_validator("pyathena", "athena")
2242
2243    @model_validator(mode="after")
2244    def _root_validator(self) -> Self:
2245        work_group = self.work_group
2246        s3_staging_dir = self.s3_staging_dir
2247        s3_warehouse_location = self.s3_warehouse_location
2248
2249        if not work_group and not s3_staging_dir:
2250            raise ConfigError("At least one of work_group or s3_staging_dir must be set")
2251
2252        if s3_staging_dir:
2253            self.s3_staging_dir = validate_s3_uri(s3_staging_dir, base=True, error_type=ConfigError)
2254
2255        if s3_warehouse_location:
2256            self.s3_warehouse_location = validate_s3_uri(
2257                s3_warehouse_location, base=True, error_type=ConfigError
2258            )
2259
2260        return self
2261
2262    @property
2263    def _connection_kwargs_keys(self) -> t.Set[str]:
2264        return {
2265            "aws_access_key_id",
2266            "aws_secret_access_key",
2267            "role_arn",
2268            "role_session_name",
2269            "region_name",
2270            "work_group",
2271            "s3_staging_dir",
2272            "schema_name",
2273            "catalog_name",
2274        }
2275
2276    @property
2277    def _engine_adapter(self) -> t.Type[EngineAdapter]:
2278        return engine_adapter.AthenaEngineAdapter
2279
2280    @property
2281    def _extra_engine_config(self) -> t.Dict[str, t.Any]:
2282        return {"s3_warehouse_location": self.s3_warehouse_location}
2283
2284    @property
2285    def _connection_factory(self) -> t.Callable:
2286        from pyathena import connect  # type: ignore
2287
2288        return connect
2289
2290    def get_catalog(self) -> t.Optional[str]:
2291        return self.catalog_name

Helper class that provides a standard way to create an ABC using inheritance.

aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
role_arn: Optional[str]
role_session_name: Optional[str]
region_name: Optional[str]
work_group: Optional[str]
s3_staging_dir: Optional[str]
schema_name: Optional[str]
catalog_name: Optional[str]
s3_warehouse_location: Optional[str]
concurrent_tasks: int
register_comments: Literal[False]
pre_ping: Literal[False]
type_: Literal['athena']
DIALECT: ClassVar[Literal['athena']] = 'athena'
DISPLAY_NAME: ClassVar[Literal['Athena']] = 'Athena'
DISPLAY_ORDER: ClassVar[Literal[15]] = 15
def get_catalog(self) -> Optional[str]:
2290    def get_catalog(self) -> t.Optional[str]:
2291        return self.catalog_name

The catalog for this connection

model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class RisingwaveConnectionConfig(ConnectionConfig):
2294class RisingwaveConnectionConfig(ConnectionConfig):
2295    host: str
2296    user: str
2297    password: t.Optional[str] = None
2298    port: int
2299    database: str
2300    role: t.Optional[str] = None
2301    sslmode: t.Optional[str] = None
2302
2303    concurrent_tasks: int = 4
2304    register_comments: bool = True
2305    pre_ping: bool = True
2306
2307    type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave")
2308    DIALECT: t.ClassVar[t.Literal["risingwave"]] = "risingwave"
2309    DISPLAY_NAME: t.ClassVar[t.Literal["RisingWave"]] = "RisingWave"
2310    DISPLAY_ORDER: t.ClassVar[t.Literal[16]] = 16
2311
2312    _engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave")
2313
2314    @property
2315    def _connection_kwargs_keys(self) -> t.Set[str]:
2316        return {
2317            "host",
2318            "user",
2319            "password",
2320            "port",
2321            "database",
2322            "role",
2323            "sslmode",
2324        }
2325
2326    @property
2327    def _engine_adapter(self) -> t.Type[EngineAdapter]:
2328        return engine_adapter.RisingwaveEngineAdapter
2329
2330    @property
2331    def _connection_factory(self) -> t.Callable:
2332        from psycopg2 import connect
2333
2334        return connect
2335
2336    @property
2337    def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]:
2338        def init(cursor: t.Any) -> None:
2339            sql = "SET RW_IMPLICIT_FLUSH TO true;"
2340            cursor.execute(sql)
2341
2342        return init

Helper class that provides a standard way to create an ABC using inheritance.

host: str
user: str
password: Optional[str]
port: int
database: str
role: Optional[str]
sslmode: Optional[str]
concurrent_tasks: int
register_comments: bool
pre_ping: bool
type_: Literal['risingwave']
DIALECT: ClassVar[Literal['risingwave']] = 'risingwave'
DISPLAY_NAME: ClassVar[Literal['RisingWave']] = 'RisingWave'
DISPLAY_ORDER: ClassVar[Literal[16]] = 16
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
ConnectionConfig
pretty_sql
schema_differ_overrides
catalog_type_overrides
shared_connection
is_forbidden_for_state_sync
connection_validator
create_engine_adapter
get_catalog
sqlmesh.core.config.base.BaseConfig
update_with
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
CONNECTION_CONFIG_TO_TYPE = {'athena': <class 'AthenaConnectionConfig'>, 'azuresql': <class 'AzureSQLConnectionConfig'>, 'bigquery': <class 'BigQueryConnectionConfig'>, 'clickhouse': <class 'ClickhouseConnectionConfig'>, 'databricks': <class 'DatabricksConnectionConfig'>, 'duckdb': <class 'DuckDBConnectionConfig'>, 'fabric': <class 'FabricConnectionConfig'>, 'gcp_postgres': <class 'GCPPostgresConnectionConfig'>, 'mssql': <class 'MSSQLConnectionConfig'>, 'motherduck': <class 'MotherDuckConnectionConfig'>, 'mysql': <class 'MySQLConnectionConfig'>, 'postgres': <class 'PostgresConnectionConfig'>, 'redshift': <class 'RedshiftConnectionConfig'>, 'risingwave': <class 'RisingwaveConnectionConfig'>, 'snowflake': <class 'SnowflakeConnectionConfig'>, 'spark': <class 'SparkConnectionConfig'>, 'trino': <class 'TrinoConnectionConfig'>}
DIALECT_TO_TYPE = {'athena': 'athena', 'azuresql': 'tsql', 'bigquery': 'bigquery', 'clickhouse': 'clickhouse', 'databricks': 'databricks', 'duckdb': 'duckdb', 'fabric': 'fabric', 'gcp_postgres': 'postgres', 'mssql': 'tsql', 'motherduck': 'duckdb', 'mysql': 'mysql', 'postgres': 'postgres', 'redshift': 'redshift', 'risingwave': 'risingwave', 'snowflake': 'snowflake', 'spark': 'spark', 'trino': 'trino'}
INIT_DISPLAY_INFO_TO_TYPE = {'athena': (15, 'Athena'), 'azuresql': (10, 'Azure SQL'), 'bigquery': (4, 'BigQuery'), 'clickhouse': (6, 'ClickHouse'), 'databricks': (3, 'Databricks'), 'duckdb': (1, 'DuckDB'), 'fabric': (17, 'Fabric'), 'gcp_postgres': (13, 'GCP Postgres'), 'mssql': (11, 'MSSQL'), 'motherduck': (5, 'MotherDuck'), 'mysql': (14, 'MySQL'), 'postgres': (12, 'Postgres'), 'redshift': (7, 'Redshift'), 'risingwave': (16, 'RisingWave'), 'snowflake': (2, 'Snowflake'), 'spark': (8, 'Spark'), 'trino': (9, 'Trino')}
def parse_connection_config(v: Dict[str, Any]) -> ConnectionConfig:
2377def parse_connection_config(v: t.Dict[str, t.Any]) -> ConnectionConfig:
2378    if "type" not in v:
2379        raise ConfigError("Missing connection type.")
2380
2381    connection_type = v["type"]
2382    if connection_type not in CONNECTION_CONFIG_TO_TYPE:
2383        raise ConfigError(f"Unknown connection type '{connection_type}'.")
2384
2385    return CONNECTION_CONFIG_TO_TYPE[connection_type](**v)
def connection_config_validator( cls: Type, v: Union[ConnectionConfig, Dict[str, Any], NoneType]) -> ConnectionConfig | None:
2388def _connection_config_validator(
2389    cls: t.Type, v: ConnectionConfig | t.Dict[str, t.Any] | None
2390) -> ConnectionConfig | None:
2391    if v is None or isinstance(v, ConnectionConfig):
2392        return v
2393
2394    check_config_and_vars_msg = "\n\nVerify your config.yaml and environment variables."
2395
2396    try:
2397        return parse_connection_config(v)
2398    except pydantic.ValidationError as e:
2399        raise ConfigError(
2400            validation_error_message(e, f"Invalid '{v['type']}' connection config:")
2401            + check_config_and_vars_msg
2402        )
2403    except ConfigError as e:
2404        raise ConfigError(str(e) + check_config_and_vars_msg)

Wrap a classmethod, staticmethod, property or unbound function and act as a descriptor that allows us to detect decorated items from the class' attributes.

This class' __get__ returns the wrapped item's __get__ result, which makes it transparent for classmethods and staticmethods.

Attributes:
  • wrapped: The decorator that has to be wrapped.
  • decorator_info: The decorator info.
  • shim: A wrapper function to wrap V1 style function.
SerializableConnectionConfig = typing.Annotated[ConnectionConfig, SerializeAsAny()]