Edit on GitHub

sqlmesh.core.engine_adapter.redshift

  1from __future__ import annotations
  2
  3import logging
  4import typing as t
  5
  6from sqlglot import exp
  7
  8from sqlmesh.core.dialect import to_schema
  9from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
 10from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter
 11from sqlmesh.core.engine_adapter.mixins import (
 12    GetCurrentCatalogFromFunctionMixin,
 13    NonTransactionalTruncateMixin,
 14    VarcharSizeWorkaroundMixin,
 15    RowDiffMixin,
 16    logical_merge,
 17    GrantsFromInfoSchemaMixin,
 18)
 19from sqlmesh.core.engine_adapter.shared import (
 20    CommentCreationView,
 21    DataObject,
 22    DataObjectType,
 23    SourceQuery,
 24    set_catalog,
 25)
 26from sqlmesh.utils.errors import SQLMeshError
 27
 28if t.TYPE_CHECKING:
 29    import pandas as pd
 30
 31    from sqlmesh.core._typing import SchemaName, TableName
 32    from sqlmesh.core.engine_adapter.base import QueryOrDF, Query
 33
 34logger = logging.getLogger(__name__)
 35
 36
 37@set_catalog()
 38class RedshiftEngineAdapter(
 39    BasePostgresEngineAdapter,
 40    GetCurrentCatalogFromFunctionMixin,
 41    NonTransactionalTruncateMixin,
 42    VarcharSizeWorkaroundMixin,
 43    RowDiffMixin,
 44    GrantsFromInfoSchemaMixin,
 45):
 46    DIALECT = "redshift"
 47    CURRENT_CATALOG_EXPRESSION = exp.func("current_database")
 48    # Redshift doesn't support comments for VIEWs WITH NO SCHEMA BINDING (which we always use)
 49    COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
 50    SUPPORTS_REPLACE_TABLE = False
 51    SUPPORTS_GRANTS = True
 52    SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True
 53
 54    SCHEMA_DIFFER_KWARGS = {
 55        "parameterized_type_defaults": {
 56            exp.DataType.build("VARBYTE", dialect=DIALECT).this: [(64000,)],
 57            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)],
 58            exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)],
 59            exp.DataType.build("VARCHAR", dialect=DIALECT).this: [(256,)],
 60            exp.DataType.build("NCHAR", dialect=DIALECT).this: [(1,)],
 61            exp.DataType.build("NVARCHAR", dialect=DIALECT).this: [(256,)],
 62        },
 63        "max_parameter_length": {
 64            exp.DataType.build("CHAR", dialect=DIALECT).this: 4096,
 65            exp.DataType.build("VARCHAR", dialect=DIALECT).this: 65535,
 66        },
 67        "precision_increase_allowed_types": {exp.DataType.build("VARCHAR", dialect=DIALECT).this},
 68        "drop_cascade": True,
 69    }
 70    VARIABLE_LENGTH_DATA_TYPES = {
 71        "char",
 72        "character",
 73        "nchar",
 74        "varchar",
 75        "character varying",
 76        "nvarchar",
 77        "varbyte",
 78        "varbinary",
 79        "binary varying",
 80    }
 81
 82    def columns(
 83        self,
 84        table_name: TableName,
 85        include_pseudo_columns: bool = True,
 86    ) -> t.Dict[str, exp.DataType]:
 87        table = exp.to_table(table_name)
 88
 89        sql = (
 90            exp.select(
 91                "column_name",
 92                "data_type",
 93                "character_maximum_length",
 94                "numeric_precision",
 95                "numeric_scale",
 96            )
 97            .from_("svv_columns")  # Includes late-binding views
 98            .where(exp.column("table_name").eq(table.alias_or_name))
 99        )
100        if table.args.get("db"):
101            sql = sql.where(exp.column("table_schema").eq(table.args["db"].name))
102
103        columns_raw = self.fetchall(sql, quote_identifiers=True)
104
105        def build_var_length_col(
106            column_name: str,
107            data_type: str,
108            character_maximum_length: t.Optional[int] = None,
109            numeric_precision: t.Optional[int] = None,
110            numeric_scale: t.Optional[int] = None,
111        ) -> tuple:
112            data_type = data_type.lower()
113            if (
114                data_type in self.VARIABLE_LENGTH_DATA_TYPES
115                and character_maximum_length is not None
116            ):
117                return (column_name, f"{data_type}({character_maximum_length})")
118            if data_type in ("decimal", "numeric"):
119                return (column_name, f"{data_type}({numeric_precision}, {numeric_scale})")
120
121            return (column_name, data_type)
122
123        columns = [build_var_length_col(*row) for row in columns_raw]
124
125        return {
126            column_name: exp.DataType.build(data_type, dialect=self.dialect)
127            for column_name, data_type in columns
128        }
129
130    @property
131    def enable_merge(self) -> bool:
132        # Redshift supports the MERGE operation but we use the logical merge
133        # unless the user has opted in by setting enable_merge in the connection.
134        return bool(self._extra_config.get("enable_merge"))
135
136    @property
137    def cursor(self) -> t.Any:
138        # Redshift by default uses a `format` paramstyle that has issues when we try to write our snapshot
139        # data to snapshot table. There doesn't seem to be a way to disable parameter overriding so we just
140        # set it to `qmark` since that doesn't cause issues.
141        cursor = self._connection_pool.get_cursor()
142        cursor.paramstyle = "qmark"
143        return cursor
144
145    def _fetch_native_df(
146        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
147    ) -> pd.DataFrame:
148        """Fetches a Pandas DataFrame from the cursor"""
149        import pandas as pd
150
151        self.execute(query, quote_identifiers=quote_identifiers)
152
153        # We manually build the `DataFrame` here because the driver's `fetch_dataframe`
154        # method does not respect the active case-sensitivity configuration.
155        #
156        # Context: https://github.com/aws/amazon-redshift-python-driver/issues/238
157        fetcheddata = self.cursor.fetchall()
158
159        try:
160            columns = [column[0] for column in self.cursor.description]
161        except Exception:
162            columns = None
163            logging.warning(
164                "No row description was found, pandas dataframe will be missing column labels."
165            )
166
167        result = [tuple(row) for row in fetcheddata]
168        return pd.DataFrame(result, columns=columns)
169
170    def _create_table_from_source_queries(
171        self,
172        table_name: TableName,
173        source_queries: t.List[SourceQuery],
174        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
175        exists: bool = True,
176        replace: bool = False,
177        table_description: t.Optional[str] = None,
178        column_descriptions: t.Optional[t.Dict[str, str]] = None,
179        table_kind: t.Optional[str] = None,
180        track_rows_processed: bool = True,
181        **kwargs: t.Any,
182    ) -> None:
183        """
184        Redshift doesn't support `CREATE TABLE IF NOT EXISTS AS...` but does support `CREATE TABLE AS...` so
185        we check if the exists check exists and if not then we can use the base implementation. Otherwise we
186        manually check if it exists and if it does then this is a no-op anyways so we return and if it doesn't
187        then we run the query with exists set to False since we just confirmed it doesn't exist.
188        """
189        if not exists:
190            return super()._create_table_from_source_queries(
191                table_name,
192                source_queries,
193                target_columns_to_types,
194                exists,
195                table_description=table_description,
196                column_descriptions=column_descriptions,
197                **kwargs,
198            )
199        if self.table_exists(table_name):
200            return
201        super()._create_table_from_source_queries(
202            table_name,
203            source_queries,
204            exists=False,
205            table_description=table_description,
206            column_descriptions=column_descriptions,
207            **kwargs,
208        )
209
210    def create_view(
211        self,
212        view_name: TableName,
213        query_or_df: QueryOrDF,
214        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
215        replace: bool = True,
216        materialized: bool = False,
217        materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
218        table_description: t.Optional[str] = None,
219        column_descriptions: t.Optional[t.Dict[str, str]] = None,
220        view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
221        source_columns: t.Optional[t.List[str]] = None,
222        **create_kwargs: t.Any,
223    ) -> None:
224        """
225        Redshift views are "binding" by default to their underlying table which means you can't drop that
226        underlying table without dropping the view first. This is a problem for us since we want to be able to
227        swap tables out from under views. Therefore, we create the view as non-binding.
228        """
229        no_schema_binding = True
230        if isinstance(query_or_df, exp.Expr):
231            # We can't include NO SCHEMA BINDING if the query has a recursive CTE
232            has_recursive_cte = any(
233                w.args.get("recursive", False) for w in query_or_df.find_all(exp.With)
234            )
235            no_schema_binding = not has_recursive_cte
236
237        return super().create_view(
238            view_name,
239            query_or_df,
240            target_columns_to_types,
241            replace,
242            materialized,
243            materialized_properties,
244            table_description=table_description,
245            column_descriptions=column_descriptions,
246            no_schema_binding=no_schema_binding,
247            view_properties=view_properties,
248            source_columns=source_columns,
249            **create_kwargs,
250        )
251
252    def replace_query(
253        self,
254        table_name: TableName,
255        query_or_df: QueryOrDF,
256        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
257        table_description: t.Optional[str] = None,
258        column_descriptions: t.Optional[t.Dict[str, str]] = None,
259        source_columns: t.Optional[t.List[str]] = None,
260        supports_replace_table_override: t.Optional[bool] = None,
261        **kwargs: t.Any,
262    ) -> None:
263        """
264        Redshift doesn't support `CREATE OR REPLACE TABLE...` and it also doesn't support `VALUES` expression so we need to specially
265        handle DataFrame replacements.
266
267        If the table doesn't exist then we just create it and load it with insert statements
268        If it does exist then we need to do the:
269            `CREATE TABLE...`, `INSERT INTO...`, `RENAME TABLE...`, `RENAME TABLE...`, DROP TABLE...`  dance.
270        """
271        import pandas as pd
272
273        target_data_object = self.get_data_object(table_name)
274        table_exists = target_data_object is not None
275        if self.drop_data_object_on_type_mismatch(target_data_object, DataObjectType.TABLE):
276            table_exists = False
277
278        if not isinstance(query_or_df, pd.DataFrame) or not table_exists:
279            return super().replace_query(
280                table_name,
281                query_or_df,
282                target_columns_to_types,
283                table_description,
284                column_descriptions,
285                source_columns=source_columns,
286                **kwargs,
287            )
288        source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types(
289            query_or_df,
290            target_columns_to_types,
291            target_table=table_name,
292            source_columns=source_columns,
293        )
294        target_columns_to_types = target_columns_to_types or self.columns(table_name)
295        target_table = exp.to_table(table_name)
296        with self.transaction():
297            temp_table = self._get_temp_table(target_table)
298            old_table = self._get_temp_table(target_table)
299            self.create_table(
300                temp_table,
301                target_columns_to_types,
302                exists=False,
303                table_description=table_description,
304                column_descriptions=column_descriptions,
305                **kwargs,
306            )
307            self._insert_append_source_queries(temp_table, source_queries, target_columns_to_types)
308            self.rename_table(target_table, old_table)
309            self.rename_table(temp_table, target_table)
310            self.drop_table(old_table)
311
312    def _get_data_objects(
313        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
314    ) -> t.List[DataObject]:
315        """
316        Returns all the data objects that exist in the given schema and optionally catalog.
317        """
318        catalog = self.get_current_catalog()
319        table_query = exp.select(
320            exp.column("schemaname").as_("schema_name"),
321            exp.column("tablename").as_("name"),
322            exp.Literal.string("TABLE").as_("type"),
323        ).from_("pg_tables")
324        view_query = (
325            exp.select(
326                exp.column("schemaname").as_("schema_name"),
327                exp.column("viewname").as_("name"),
328                exp.Literal.string("VIEW").as_("type"),
329            )
330            .from_("pg_views")
331            .where(exp.column("definition").ilike("%create materialized view%").not_())
332        )
333        materialized_view_query = (
334            exp.select(
335                exp.column("schemaname").as_("schema_name"),
336                exp.column("viewname").as_("name"),
337                exp.Literal.string("MATERIALIZED_VIEW").as_("type"),
338            )
339            .from_("pg_views")
340            .where(exp.column("definition").ilike("%create materialized view%"))
341        )
342        subquery = exp.union(
343            table_query,
344            exp.union(view_query, materialized_view_query, distinct=False),
345            distinct=False,
346        )
347        query = (
348            exp.select("*")
349            .from_(subquery.subquery(alias="objs"))
350            .where(exp.column("schema_name").eq(to_schema(schema_name).db))
351        )
352        if object_names:
353            query = query.where(exp.column("name").isin(*object_names))
354        df = self.fetchdf(query)
355        return [
356            DataObject(
357                catalog=catalog,
358                schema=row.schema_name,
359                name=row.name,
360                type=DataObjectType.from_str(row.type),  # type: ignore
361            )
362            for row in df.itertuples()
363        ]
364
365    def merge(
366        self,
367        target_table: TableName,
368        source_table: QueryOrDF,
369        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
370        unique_key: t.Sequence[exp.Expr],
371        when_matched: t.Optional[exp.Whens] = None,
372        merge_filter: t.Optional[exp.Expr] = None,
373        source_columns: t.Optional[t.List[str]] = None,
374        **kwargs: t.Any,
375    ) -> None:
376        if self.enable_merge:
377            # By default we use the logical merge unless the user has opted in
378            super().merge(
379                target_table=target_table,
380                source_table=source_table,
381                target_columns_to_types=target_columns_to_types,
382                unique_key=unique_key,
383                when_matched=when_matched,
384                merge_filter=merge_filter,
385                source_columns=source_columns,
386            )
387        else:
388            logical_merge(
389                self,
390                target_table,
391                source_table,
392                target_columns_to_types,
393                unique_key,
394                when_matched=when_matched,
395                merge_filter=merge_filter,
396                source_columns=source_columns,
397            )
398
399    def _merge(
400        self,
401        target_table: TableName,
402        query: Query,
403        on: exp.Expr,
404        whens: exp.Whens,
405    ) -> None:
406        # Redshift does not support table aliases in the target table of a MERGE statement.
407        # So we must use the actual table name instead of an alias, as we do with the source table.
408        def resolve_target_table(expression: exp.Expr) -> exp.Expr:
409            if (
410                isinstance(expression, exp.Column)
411                and expression.table.upper() == MERGE_TARGET_ALIAS
412            ):
413                expression.set("table", exp.to_table(target_table))
414            return expression
415
416        # Ensure that there is exactly one "WHEN MATCHED" and one "WHEN NOT MATCHED" clause.
417        # Since Redshift does not support multiple "WHEN MATCHED" clauses.
418        if (
419            len(whens.expressions) != 2
420            or whens.expressions[0].args["matched"] == whens.expressions[1].args["matched"]
421        ):
422            raise SQLMeshError(
423                "Redshift only supports a single WHEN MATCHED and WHEN NOT MATCHED clause"
424            )
425
426        using = exp.alias_(
427            exp.Subquery(this=query), alias=MERGE_SOURCE_ALIAS, copy=False, table=True
428        )
429        self.execute(
430            exp.Merge(
431                this=target_table,
432                using=using,
433                on=on.transform(resolve_target_table),
434                whens=whens.transform(resolve_target_table),
435            ),
436            track_rows_processed=True,
437        )
438
439    def _normalize_decimal_value(self, expr: exp.Expr, precision: int) -> exp.Expr:
440        # Redshift is finicky. It truncates when the data is already in a table, but rounds when the data is generated as part of a SELECT.
441        #
442        # The following works:
443        #  > select cast(cast(3.14159 as decimal(6, 5)) as decimal(6, 3)); --produces '3.142', the value we want / what every other database produces
444        #
445        # However, if you write that to a table, and then cast it to a less precise decimal, you get _truncation_.
446        #  > create table foo (val decimal(6, 5)); insert into foo(val) values (3.14159);
447        #  > select cast(val as decimal(6, 3)) from foo; --produces '3.141'
448        #
449        # So to make up for this, we force it to round by injecting a round() expression
450        rounded = exp.func("ROUND", expr, precision)
451
452        return super()._normalize_decimal_value(rounded, precision)
logger = <Logger sqlmesh.core.engine_adapter.redshift (WARNING)>
 38@set_catalog()
 39class RedshiftEngineAdapter(
 40    BasePostgresEngineAdapter,
 41    GetCurrentCatalogFromFunctionMixin,
 42    NonTransactionalTruncateMixin,
 43    VarcharSizeWorkaroundMixin,
 44    RowDiffMixin,
 45    GrantsFromInfoSchemaMixin,
 46):
 47    DIALECT = "redshift"
 48    CURRENT_CATALOG_EXPRESSION = exp.func("current_database")
 49    # Redshift doesn't support comments for VIEWs WITH NO SCHEMA BINDING (which we always use)
 50    COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
 51    SUPPORTS_REPLACE_TABLE = False
 52    SUPPORTS_GRANTS = True
 53    SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True
 54
 55    SCHEMA_DIFFER_KWARGS = {
 56        "parameterized_type_defaults": {
 57            exp.DataType.build("VARBYTE", dialect=DIALECT).this: [(64000,)],
 58            exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)],
 59            exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)],
 60            exp.DataType.build("VARCHAR", dialect=DIALECT).this: [(256,)],
 61            exp.DataType.build("NCHAR", dialect=DIALECT).this: [(1,)],
 62            exp.DataType.build("NVARCHAR", dialect=DIALECT).this: [(256,)],
 63        },
 64        "max_parameter_length": {
 65            exp.DataType.build("CHAR", dialect=DIALECT).this: 4096,
 66            exp.DataType.build("VARCHAR", dialect=DIALECT).this: 65535,
 67        },
 68        "precision_increase_allowed_types": {exp.DataType.build("VARCHAR", dialect=DIALECT).this},
 69        "drop_cascade": True,
 70    }
 71    VARIABLE_LENGTH_DATA_TYPES = {
 72        "char",
 73        "character",
 74        "nchar",
 75        "varchar",
 76        "character varying",
 77        "nvarchar",
 78        "varbyte",
 79        "varbinary",
 80        "binary varying",
 81    }
 82
 83    def columns(
 84        self,
 85        table_name: TableName,
 86        include_pseudo_columns: bool = True,
 87    ) -> t.Dict[str, exp.DataType]:
 88        table = exp.to_table(table_name)
 89
 90        sql = (
 91            exp.select(
 92                "column_name",
 93                "data_type",
 94                "character_maximum_length",
 95                "numeric_precision",
 96                "numeric_scale",
 97            )
 98            .from_("svv_columns")  # Includes late-binding views
 99            .where(exp.column("table_name").eq(table.alias_or_name))
100        )
101        if table.args.get("db"):
102            sql = sql.where(exp.column("table_schema").eq(table.args["db"].name))
103
104        columns_raw = self.fetchall(sql, quote_identifiers=True)
105
106        def build_var_length_col(
107            column_name: str,
108            data_type: str,
109            character_maximum_length: t.Optional[int] = None,
110            numeric_precision: t.Optional[int] = None,
111            numeric_scale: t.Optional[int] = None,
112        ) -> tuple:
113            data_type = data_type.lower()
114            if (
115                data_type in self.VARIABLE_LENGTH_DATA_TYPES
116                and character_maximum_length is not None
117            ):
118                return (column_name, f"{data_type}({character_maximum_length})")
119            if data_type in ("decimal", "numeric"):
120                return (column_name, f"{data_type}({numeric_precision}, {numeric_scale})")
121
122            return (column_name, data_type)
123
124        columns = [build_var_length_col(*row) for row in columns_raw]
125
126        return {
127            column_name: exp.DataType.build(data_type, dialect=self.dialect)
128            for column_name, data_type in columns
129        }
130
131    @property
132    def enable_merge(self) -> bool:
133        # Redshift supports the MERGE operation but we use the logical merge
134        # unless the user has opted in by setting enable_merge in the connection.
135        return bool(self._extra_config.get("enable_merge"))
136
137    @property
138    def cursor(self) -> t.Any:
139        # Redshift by default uses a `format` paramstyle that has issues when we try to write our snapshot
140        # data to snapshot table. There doesn't seem to be a way to disable parameter overriding so we just
141        # set it to `qmark` since that doesn't cause issues.
142        cursor = self._connection_pool.get_cursor()
143        cursor.paramstyle = "qmark"
144        return cursor
145
146    def _fetch_native_df(
147        self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
148    ) -> pd.DataFrame:
149        """Fetches a Pandas DataFrame from the cursor"""
150        import pandas as pd
151
152        self.execute(query, quote_identifiers=quote_identifiers)
153
154        # We manually build the `DataFrame` here because the driver's `fetch_dataframe`
155        # method does not respect the active case-sensitivity configuration.
156        #
157        # Context: https://github.com/aws/amazon-redshift-python-driver/issues/238
158        fetcheddata = self.cursor.fetchall()
159
160        try:
161            columns = [column[0] for column in self.cursor.description]
162        except Exception:
163            columns = None
164            logging.warning(
165                "No row description was found, pandas dataframe will be missing column labels."
166            )
167
168        result = [tuple(row) for row in fetcheddata]
169        return pd.DataFrame(result, columns=columns)
170
171    def _create_table_from_source_queries(
172        self,
173        table_name: TableName,
174        source_queries: t.List[SourceQuery],
175        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
176        exists: bool = True,
177        replace: bool = False,
178        table_description: t.Optional[str] = None,
179        column_descriptions: t.Optional[t.Dict[str, str]] = None,
180        table_kind: t.Optional[str] = None,
181        track_rows_processed: bool = True,
182        **kwargs: t.Any,
183    ) -> None:
184        """
185        Redshift doesn't support `CREATE TABLE IF NOT EXISTS AS...` but does support `CREATE TABLE AS...` so
186        we check if the exists check exists and if not then we can use the base implementation. Otherwise we
187        manually check if it exists and if it does then this is a no-op anyways so we return and if it doesn't
188        then we run the query with exists set to False since we just confirmed it doesn't exist.
189        """
190        if not exists:
191            return super()._create_table_from_source_queries(
192                table_name,
193                source_queries,
194                target_columns_to_types,
195                exists,
196                table_description=table_description,
197                column_descriptions=column_descriptions,
198                **kwargs,
199            )
200        if self.table_exists(table_name):
201            return
202        super()._create_table_from_source_queries(
203            table_name,
204            source_queries,
205            exists=False,
206            table_description=table_description,
207            column_descriptions=column_descriptions,
208            **kwargs,
209        )
210
211    def create_view(
212        self,
213        view_name: TableName,
214        query_or_df: QueryOrDF,
215        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
216        replace: bool = True,
217        materialized: bool = False,
218        materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
219        table_description: t.Optional[str] = None,
220        column_descriptions: t.Optional[t.Dict[str, str]] = None,
221        view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
222        source_columns: t.Optional[t.List[str]] = None,
223        **create_kwargs: t.Any,
224    ) -> None:
225        """
226        Redshift views are "binding" by default to their underlying table which means you can't drop that
227        underlying table without dropping the view first. This is a problem for us since we want to be able to
228        swap tables out from under views. Therefore, we create the view as non-binding.
229        """
230        no_schema_binding = True
231        if isinstance(query_or_df, exp.Expr):
232            # We can't include NO SCHEMA BINDING if the query has a recursive CTE
233            has_recursive_cte = any(
234                w.args.get("recursive", False) for w in query_or_df.find_all(exp.With)
235            )
236            no_schema_binding = not has_recursive_cte
237
238        return super().create_view(
239            view_name,
240            query_or_df,
241            target_columns_to_types,
242            replace,
243            materialized,
244            materialized_properties,
245            table_description=table_description,
246            column_descriptions=column_descriptions,
247            no_schema_binding=no_schema_binding,
248            view_properties=view_properties,
249            source_columns=source_columns,
250            **create_kwargs,
251        )
252
253    def replace_query(
254        self,
255        table_name: TableName,
256        query_or_df: QueryOrDF,
257        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
258        table_description: t.Optional[str] = None,
259        column_descriptions: t.Optional[t.Dict[str, str]] = None,
260        source_columns: t.Optional[t.List[str]] = None,
261        supports_replace_table_override: t.Optional[bool] = None,
262        **kwargs: t.Any,
263    ) -> None:
264        """
265        Redshift doesn't support `CREATE OR REPLACE TABLE...` and it also doesn't support `VALUES` expression so we need to specially
266        handle DataFrame replacements.
267
268        If the table doesn't exist then we just create it and load it with insert statements
269        If it does exist then we need to do the:
270            `CREATE TABLE...`, `INSERT INTO...`, `RENAME TABLE...`, `RENAME TABLE...`, DROP TABLE...`  dance.
271        """
272        import pandas as pd
273
274        target_data_object = self.get_data_object(table_name)
275        table_exists = target_data_object is not None
276        if self.drop_data_object_on_type_mismatch(target_data_object, DataObjectType.TABLE):
277            table_exists = False
278
279        if not isinstance(query_or_df, pd.DataFrame) or not table_exists:
280            return super().replace_query(
281                table_name,
282                query_or_df,
283                target_columns_to_types,
284                table_description,
285                column_descriptions,
286                source_columns=source_columns,
287                **kwargs,
288            )
289        source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types(
290            query_or_df,
291            target_columns_to_types,
292            target_table=table_name,
293            source_columns=source_columns,
294        )
295        target_columns_to_types = target_columns_to_types or self.columns(table_name)
296        target_table = exp.to_table(table_name)
297        with self.transaction():
298            temp_table = self._get_temp_table(target_table)
299            old_table = self._get_temp_table(target_table)
300            self.create_table(
301                temp_table,
302                target_columns_to_types,
303                exists=False,
304                table_description=table_description,
305                column_descriptions=column_descriptions,
306                **kwargs,
307            )
308            self._insert_append_source_queries(temp_table, source_queries, target_columns_to_types)
309            self.rename_table(target_table, old_table)
310            self.rename_table(temp_table, target_table)
311            self.drop_table(old_table)
312
313    def _get_data_objects(
314        self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
315    ) -> t.List[DataObject]:
316        """
317        Returns all the data objects that exist in the given schema and optionally catalog.
318        """
319        catalog = self.get_current_catalog()
320        table_query = exp.select(
321            exp.column("schemaname").as_("schema_name"),
322            exp.column("tablename").as_("name"),
323            exp.Literal.string("TABLE").as_("type"),
324        ).from_("pg_tables")
325        view_query = (
326            exp.select(
327                exp.column("schemaname").as_("schema_name"),
328                exp.column("viewname").as_("name"),
329                exp.Literal.string("VIEW").as_("type"),
330            )
331            .from_("pg_views")
332            .where(exp.column("definition").ilike("%create materialized view%").not_())
333        )
334        materialized_view_query = (
335            exp.select(
336                exp.column("schemaname").as_("schema_name"),
337                exp.column("viewname").as_("name"),
338                exp.Literal.string("MATERIALIZED_VIEW").as_("type"),
339            )
340            .from_("pg_views")
341            .where(exp.column("definition").ilike("%create materialized view%"))
342        )
343        subquery = exp.union(
344            table_query,
345            exp.union(view_query, materialized_view_query, distinct=False),
346            distinct=False,
347        )
348        query = (
349            exp.select("*")
350            .from_(subquery.subquery(alias="objs"))
351            .where(exp.column("schema_name").eq(to_schema(schema_name).db))
352        )
353        if object_names:
354            query = query.where(exp.column("name").isin(*object_names))
355        df = self.fetchdf(query)
356        return [
357            DataObject(
358                catalog=catalog,
359                schema=row.schema_name,
360                name=row.name,
361                type=DataObjectType.from_str(row.type),  # type: ignore
362            )
363            for row in df.itertuples()
364        ]
365
366    def merge(
367        self,
368        target_table: TableName,
369        source_table: QueryOrDF,
370        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
371        unique_key: t.Sequence[exp.Expr],
372        when_matched: t.Optional[exp.Whens] = None,
373        merge_filter: t.Optional[exp.Expr] = None,
374        source_columns: t.Optional[t.List[str]] = None,
375        **kwargs: t.Any,
376    ) -> None:
377        if self.enable_merge:
378            # By default we use the logical merge unless the user has opted in
379            super().merge(
380                target_table=target_table,
381                source_table=source_table,
382                target_columns_to_types=target_columns_to_types,
383                unique_key=unique_key,
384                when_matched=when_matched,
385                merge_filter=merge_filter,
386                source_columns=source_columns,
387            )
388        else:
389            logical_merge(
390                self,
391                target_table,
392                source_table,
393                target_columns_to_types,
394                unique_key,
395                when_matched=when_matched,
396                merge_filter=merge_filter,
397                source_columns=source_columns,
398            )
399
400    def _merge(
401        self,
402        target_table: TableName,
403        query: Query,
404        on: exp.Expr,
405        whens: exp.Whens,
406    ) -> None:
407        # Redshift does not support table aliases in the target table of a MERGE statement.
408        # So we must use the actual table name instead of an alias, as we do with the source table.
409        def resolve_target_table(expression: exp.Expr) -> exp.Expr:
410            if (
411                isinstance(expression, exp.Column)
412                and expression.table.upper() == MERGE_TARGET_ALIAS
413            ):
414                expression.set("table", exp.to_table(target_table))
415            return expression
416
417        # Ensure that there is exactly one "WHEN MATCHED" and one "WHEN NOT MATCHED" clause.
418        # Since Redshift does not support multiple "WHEN MATCHED" clauses.
419        if (
420            len(whens.expressions) != 2
421            or whens.expressions[0].args["matched"] == whens.expressions[1].args["matched"]
422        ):
423            raise SQLMeshError(
424                "Redshift only supports a single WHEN MATCHED and WHEN NOT MATCHED clause"
425            )
426
427        using = exp.alias_(
428            exp.Subquery(this=query), alias=MERGE_SOURCE_ALIAS, copy=False, table=True
429        )
430        self.execute(
431            exp.Merge(
432                this=target_table,
433                using=using,
434                on=on.transform(resolve_target_table),
435                whens=whens.transform(resolve_target_table),
436            ),
437            track_rows_processed=True,
438        )
439
440    def _normalize_decimal_value(self, expr: exp.Expr, precision: int) -> exp.Expr:
441        # Redshift is finicky. It truncates when the data is already in a table, but rounds when the data is generated as part of a SELECT.
442        #
443        # The following works:
444        #  > select cast(cast(3.14159 as decimal(6, 5)) as decimal(6, 3)); --produces '3.142', the value we want / what every other database produces
445        #
446        # However, if you write that to a table, and then cast it to a less precise decimal, you get _truncation_.
447        #  > create table foo (val decimal(6, 5)); insert into foo(val) values (3.14159);
448        #  > select cast(val as decimal(6, 3)) from foo; --produces '3.141'
449        #
450        # So to make up for this, we force it to round by injecting a round() expression
451        rounded = exp.func("ROUND", expr, precision)
452
453        return super()._normalize_decimal_value(rounded, precision)

Base class wrapping a Database API compliant connection.

The EngineAdapter is an easily-subclassable interface that interacts with the underlying engine and data store.

Arguments:
  • connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call.
  • dialect: The dialect with which this adapter is associated.
  • multithreaded: Indicates whether this adapter will be used by more than one thread.
DIALECT = 'redshift'
CURRENT_CATALOG_EXPRESSION = CurrentDatabase()
COMMENT_CREATION_VIEW = <CommentCreationView.UNSUPPORTED: 1>
SUPPORTS_REPLACE_TABLE = False
SUPPORTS_GRANTS = True
SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True
SCHEMA_DIFFER_KWARGS = {'parameterized_type_defaults': {<DType.VARBINARY: 'VARBINARY'>: [(64000,)], <DType.DECIMAL: 'DECIMAL'>: [(18, 0), (0,)], <DType.CHAR: 'CHAR'>: [(1,)], <DType.VARCHAR: 'VARCHAR'>: [(256,)], <DType.NCHAR: 'NCHAR'>: [(1,)], <DType.NVARCHAR: 'NVARCHAR'>: [(256,)]}, 'max_parameter_length': {<DType.CHAR: 'CHAR'>: 4096, <DType.VARCHAR: 'VARCHAR'>: 65535}, 'precision_increase_allowed_types': {<DType.VARCHAR: 'VARCHAR'>}, 'drop_cascade': True}
VARIABLE_LENGTH_DATA_TYPES = {'character varying', 'nvarchar', 'varbyte', 'character', 'varchar', 'binary varying', 'nchar', 'varbinary', 'char'}
def columns( self, table_name: Union[str, sqlglot.expressions.query.Table], include_pseudo_columns: bool = True) -> Dict[str, sqlglot.expressions.datatypes.DataType]:
 83    def columns(
 84        self,
 85        table_name: TableName,
 86        include_pseudo_columns: bool = True,
 87    ) -> t.Dict[str, exp.DataType]:
 88        table = exp.to_table(table_name)
 89
 90        sql = (
 91            exp.select(
 92                "column_name",
 93                "data_type",
 94                "character_maximum_length",
 95                "numeric_precision",
 96                "numeric_scale",
 97            )
 98            .from_("svv_columns")  # Includes late-binding views
 99            .where(exp.column("table_name").eq(table.alias_or_name))
100        )
101        if table.args.get("db"):
102            sql = sql.where(exp.column("table_schema").eq(table.args["db"].name))
103
104        columns_raw = self.fetchall(sql, quote_identifiers=True)
105
106        def build_var_length_col(
107            column_name: str,
108            data_type: str,
109            character_maximum_length: t.Optional[int] = None,
110            numeric_precision: t.Optional[int] = None,
111            numeric_scale: t.Optional[int] = None,
112        ) -> tuple:
113            data_type = data_type.lower()
114            if (
115                data_type in self.VARIABLE_LENGTH_DATA_TYPES
116                and character_maximum_length is not None
117            ):
118                return (column_name, f"{data_type}({character_maximum_length})")
119            if data_type in ("decimal", "numeric"):
120                return (column_name, f"{data_type}({numeric_precision}, {numeric_scale})")
121
122            return (column_name, data_type)
123
124        columns = [build_var_length_col(*row) for row in columns_raw]
125
126        return {
127            column_name: exp.DataType.build(data_type, dialect=self.dialect)
128            for column_name, data_type in columns
129        }

Fetches column names and types for the target table.

enable_merge: bool
131    @property
132    def enable_merge(self) -> bool:
133        # Redshift supports the MERGE operation but we use the logical merge
134        # unless the user has opted in by setting enable_merge in the connection.
135        return bool(self._extra_config.get("enable_merge"))
cursor: Any
137    @property
138    def cursor(self) -> t.Any:
139        # Redshift by default uses a `format` paramstyle that has issues when we try to write our snapshot
140        # data to snapshot table. There doesn't seem to be a way to disable parameter overriding so we just
141        # set it to `qmark` since that doesn't cause issues.
142        cursor = self._connection_pool.get_cursor()
143        cursor.paramstyle = "qmark"
144        return cursor
def create_view( self, view_name: Union[str, sqlglot.expressions.query.Table], query_or_df: <MagicMock id='132726905410976'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]] = None, replace: bool = True, materialized: bool = False, materialized_properties: Optional[Dict[str, Any]] = None, table_description: Optional[str] = None, column_descriptions: Optional[Dict[str, str]] = None, view_properties: Optional[Dict[str, sqlglot.expressions.core.Expr]] = None, source_columns: Optional[List[str]] = None, **create_kwargs: Any) -> None:
211    def create_view(
212        self,
213        view_name: TableName,
214        query_or_df: QueryOrDF,
215        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
216        replace: bool = True,
217        materialized: bool = False,
218        materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
219        table_description: t.Optional[str] = None,
220        column_descriptions: t.Optional[t.Dict[str, str]] = None,
221        view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
222        source_columns: t.Optional[t.List[str]] = None,
223        **create_kwargs: t.Any,
224    ) -> None:
225        """
226        Redshift views are "binding" by default to their underlying table which means you can't drop that
227        underlying table without dropping the view first. This is a problem for us since we want to be able to
228        swap tables out from under views. Therefore, we create the view as non-binding.
229        """
230        no_schema_binding = True
231        if isinstance(query_or_df, exp.Expr):
232            # We can't include NO SCHEMA BINDING if the query has a recursive CTE
233            has_recursive_cte = any(
234                w.args.get("recursive", False) for w in query_or_df.find_all(exp.With)
235            )
236            no_schema_binding = not has_recursive_cte
237
238        return super().create_view(
239            view_name,
240            query_or_df,
241            target_columns_to_types,
242            replace,
243            materialized,
244            materialized_properties,
245            table_description=table_description,
246            column_descriptions=column_descriptions,
247            no_schema_binding=no_schema_binding,
248            view_properties=view_properties,
249            source_columns=source_columns,
250            **create_kwargs,
251        )

Redshift views are "binding" by default to their underlying table which means you can't drop that underlying table without dropping the view first. This is a problem for us since we want to be able to swap tables out from under views. Therefore, we create the view as non-binding.

def replace_query( self, table_name: Union[str, sqlglot.expressions.query.Table], query_or_df: <MagicMock id='132726905410976'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]] = None, table_description: Optional[str] = None, column_descriptions: Optional[Dict[str, str]] = None, source_columns: Optional[List[str]] = None, supports_replace_table_override: Optional[bool] = None, **kwargs: Any) -> None:
253    def replace_query(
254        self,
255        table_name: TableName,
256        query_or_df: QueryOrDF,
257        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
258        table_description: t.Optional[str] = None,
259        column_descriptions: t.Optional[t.Dict[str, str]] = None,
260        source_columns: t.Optional[t.List[str]] = None,
261        supports_replace_table_override: t.Optional[bool] = None,
262        **kwargs: t.Any,
263    ) -> None:
264        """
265        Redshift doesn't support `CREATE OR REPLACE TABLE...` and it also doesn't support `VALUES` expression so we need to specially
266        handle DataFrame replacements.
267
268        If the table doesn't exist then we just create it and load it with insert statements
269        If it does exist then we need to do the:
270            `CREATE TABLE...`, `INSERT INTO...`, `RENAME TABLE...`, `RENAME TABLE...`, DROP TABLE...`  dance.
271        """
272        import pandas as pd
273
274        target_data_object = self.get_data_object(table_name)
275        table_exists = target_data_object is not None
276        if self.drop_data_object_on_type_mismatch(target_data_object, DataObjectType.TABLE):
277            table_exists = False
278
279        if not isinstance(query_or_df, pd.DataFrame) or not table_exists:
280            return super().replace_query(
281                table_name,
282                query_or_df,
283                target_columns_to_types,
284                table_description,
285                column_descriptions,
286                source_columns=source_columns,
287                **kwargs,
288            )
289        source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types(
290            query_or_df,
291            target_columns_to_types,
292            target_table=table_name,
293            source_columns=source_columns,
294        )
295        target_columns_to_types = target_columns_to_types or self.columns(table_name)
296        target_table = exp.to_table(table_name)
297        with self.transaction():
298            temp_table = self._get_temp_table(target_table)
299            old_table = self._get_temp_table(target_table)
300            self.create_table(
301                temp_table,
302                target_columns_to_types,
303                exists=False,
304                table_description=table_description,
305                column_descriptions=column_descriptions,
306                **kwargs,
307            )
308            self._insert_append_source_queries(temp_table, source_queries, target_columns_to_types)
309            self.rename_table(target_table, old_table)
310            self.rename_table(temp_table, target_table)
311            self.drop_table(old_table)

Redshift doesn't support CREATE OR REPLACE TABLE... and it also doesn't support VALUES expression so we need to specially handle DataFrame replacements.

If the table doesn't exist then we just create it and load it with insert statements

If it does exist then we need to do the:

CREATE TABLE..., INSERT INTO..., RENAME TABLE..., RENAME TABLE..., DROP TABLE...` dance.

def merge( self, target_table: Union[str, sqlglot.expressions.query.Table], source_table: <MagicMock id='132726905410976'>, target_columns_to_types: Optional[Dict[str, sqlglot.expressions.datatypes.DataType]], unique_key: Sequence[sqlglot.expressions.core.Expr], when_matched: Optional[sqlglot.expressions.dml.Whens] = None, merge_filter: Optional[sqlglot.expressions.core.Expr] = None, source_columns: Optional[List[str]] = None, **kwargs: Any) -> None:
366    def merge(
367        self,
368        target_table: TableName,
369        source_table: QueryOrDF,
370        target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
371        unique_key: t.Sequence[exp.Expr],
372        when_matched: t.Optional[exp.Whens] = None,
373        merge_filter: t.Optional[exp.Expr] = None,
374        source_columns: t.Optional[t.List[str]] = None,
375        **kwargs: t.Any,
376    ) -> None:
377        if self.enable_merge:
378            # By default we use the logical merge unless the user has opted in
379            super().merge(
380                target_table=target_table,
381                source_table=source_table,
382                target_columns_to_types=target_columns_to_types,
383                unique_key=unique_key,
384                when_matched=when_matched,
385                merge_filter=merge_filter,
386                source_columns=source_columns,
387            )
388        else:
389            logical_merge(
390                self,
391                target_table,
392                source_table,
393                target_columns_to_types,
394                unique_key,
395                when_matched=when_matched,
396                merge_filter=merge_filter,
397                source_columns=source_columns,
398            )
def table_exists(self, table_name: Union[str, sqlglot.expressions.query.Table]) -> bool:
 76    def table_exists(self, table_name: TableName) -> bool:
 77        """
 78        Postgres doesn't support describe so I'm using what the redshift cursor does to check if a table
 79        exists. We don't use this directly in order for this to work as a base class for other postgres
 80
 81        Reference: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/cursor.py#L528-L553
 82        """
 83        table = exp.to_table(table_name)
 84        data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
 85        if data_object_cache_key in self._data_object_cache:
 86            logger.debug("Table existence cache hit: %s", data_object_cache_key)
 87            return self._data_object_cache[data_object_cache_key] is not None
 88
 89        sql = (
 90            exp.select("1")
 91            .from_("information_schema.tables")
 92            .where(f"table_name = '{table.alias_or_name}'")
 93        )
 94        database_name = table.db
 95        if database_name:
 96            sql = sql.where(f"table_schema = '{database_name}'")
 97
 98        self.execute(sql)
 99
100        result = self.cursor.fetchone()
101
102        return result[0] == 1 if result is not None else False

Postgres doesn't support describe so I'm using what the redshift cursor does to check if a table exists. We don't use this directly in order for this to work as a base class for other postgres

Reference: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/cursor.py#L528-L553

def drop_view( self, view_name: Union[str, sqlglot.expressions.query.Table], ignore_if_not_exists: bool = True, materialized: bool = False, **kwargs: Any) -> None:
142    def drop_view(
143        self,
144        view_name: TableName,
145        ignore_if_not_exists: bool = True,
146        materialized: bool = False,
147        **kwargs: t.Any,
148    ) -> None:
149        kwargs["cascade"] = kwargs.get("cascade", True)
150        return super().drop_view(
151            view_name,
152            ignore_if_not_exists=ignore_if_not_exists,
153            materialized=materialized,
154            **kwargs,
155        )

Drop a view.

Inherited Members
sqlmesh.core.engine_adapter.base.EngineAdapter
EngineAdapter
DATA_OBJECT_FILTER_BATCH_SIZE
SUPPORTS_TRANSACTIONS
SUPPORTS_INDEXES
MAX_TABLE_COMMENT_LENGTH
MAX_COLUMN_COMMENT_LENGTH
INSERT_OVERWRITE_STRATEGY
SUPPORTS_MATERIALIZED_VIEWS
SUPPORTS_MATERIALIZED_VIEW_SCHEMA
SUPPORTS_VIEW_SCHEMA
SUPPORTS_CLONING
SUPPORTS_MANAGED_MODELS
SUPPORTS_CREATE_DROP_CATALOG
SUPPORTS_TUPLE_IN
HAS_VIEW_BINDING
DEFAULT_CATALOG_TYPE
QUOTE_IDENTIFIERS_IN_VIEWS
MAX_IDENTIFIER_LENGTH
ATTACH_CORRELATION_ID
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
dialect
correlation_id
with_settings
connection
spark
snowpark
bigframe
comments_enabled
schema_differ
default_catalog
engine_run_mode
recycle
close
set_current_catalog
get_catalog_type
get_catalog_type_from_table
current_catalog_type
create_index
create_table
create_managed_table
ctas
create_state_table
create_table_like
clone_table
drop_data_object
drop_table
drop_managed_table
get_alter_operations
alter_table
create_schema
drop_schema
create_catalog
drop_catalog
delete_from
insert_append
insert_overwrite_by_partition
insert_overwrite_by_time_partition
update_table
scd_type_2_by_time
scd_type_2_by_column
rename_table
get_data_object
get_data_objects
fetchone
fetchall
fetchdf
fetch_pyspark_df
wap_enabled
wap_supported
wap_table_name
wap_prepare
wap_publish
sync_grants_config
transaction
session
execute
temp_table
drop_data_object_on_type_mismatch
ensure_nulls_for_unmatched_after_join
use_server_nulls_for_unmatched_after_join
ping
get_table_last_modified_ts
sqlmesh.core.engine_adapter.base_postgres.BasePostgresEngineAdapter
DEFAULT_BATCH_SIZE
COMMENT_CREATION_TABLE
SUPPORTS_QUERY_EXECUTION_TRACKING
SUPPORTED_DROP_CASCADE_OBJECT_KINDS
catalog_support
sqlmesh.core.engine_adapter.mixins.GetCurrentCatalogFromFunctionMixin
get_current_catalog
sqlmesh.core.engine_adapter.mixins.RowDiffMixin
MAX_TIMESTAMP_PRECISION
concat_columns
normalize_value
sqlmesh.core.engine_adapter.mixins.GrantsFromInfoSchemaMixin
CURRENT_USER_OR_ROLE_EXPRESSION
USE_CATALOG_IN_GRANTS
GRANT_INFORMATION_SCHEMA_TABLE_NAME