Edit on GitHub

sqlmesh.core.table_diff

  1from __future__ import annotations
  2
  3import math
  4import typing as t
  5from functools import cached_property
  6
  7from sqlmesh.core.dialect import to_schema
  8from sqlmesh.core.engine_adapter.mixins import RowDiffMixin
  9from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter
 10from sqlglot import exp, parse_one
 11from sqlglot.helper import ensure_list
 12from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
 13from sqlglot.optimizer.qualify_columns import quote_identifiers
 14from sqlglot.optimizer.scope import find_all_in_scope
 15
 16from sqlmesh.utils.pydantic import PydanticModel
 17from sqlmesh.utils.errors import SQLMeshError
 18
 19
 20if t.TYPE_CHECKING:
 21    import pandas as pd
 22
 23    from sqlmesh.core._typing import TableName
 24    from sqlmesh.core.engine_adapter import EngineAdapter
 25
 26SQLMESH_JOIN_KEY_COL = "__sqlmesh_join_key"
 27SQLMESH_SAMPLE_TYPE_COL = "__sqlmesh_sample_type"
 28
 29
 30class SchemaDiff(PydanticModel, frozen=True):
 31    """An object containing the schema difference between a source and target table."""
 32
 33    source: str
 34    target: str
 35    source_schema: t.Dict[str, exp.DataType]
 36    target_schema: t.Dict[str, exp.DataType]
 37    source_alias: t.Optional[str] = None
 38    target_alias: t.Optional[str] = None
 39    model_name: t.Optional[str] = None
 40    ignore_case: bool = False
 41
 42    @property
 43    def _comparable_source_schema(self) -> t.Dict[str, exp.DataType]:
 44        return (
 45            self._lowercase_schema_names(self.source_schema)
 46            if self.ignore_case
 47            else self.source_schema
 48        )
 49
 50    @property
 51    def _comparable_target_schema(self) -> t.Dict[str, exp.DataType]:
 52        return (
 53            self._lowercase_schema_names(self.target_schema)
 54            if self.ignore_case
 55            else self.target_schema
 56        )
 57
 58    def _lowercase_schema_names(
 59        self, schema: t.Dict[str, exp.DataType]
 60    ) -> t.Dict[str, exp.DataType]:
 61        return {c.lower(): t for c, t in schema.items()}
 62
 63    def _original_column_name(
 64        self, maybe_lowercased_column_name: str, schema: t.Dict[str, exp.DataType]
 65    ) -> str:
 66        if not self.ignore_case:
 67            return maybe_lowercased_column_name
 68
 69        return next(c for c in schema if c.lower() == maybe_lowercased_column_name)
 70
 71    @property
 72    def added(self) -> t.List[t.Tuple[str, exp.DataType]]:
 73        """Added columns."""
 74        return [
 75            (self._original_column_name(c, self.target_schema), t)
 76            for c, t in self._comparable_target_schema.items()
 77            if c not in self._comparable_source_schema
 78        ]
 79
 80    @property
 81    def removed(self) -> t.List[t.Tuple[str, exp.DataType]]:
 82        """Removed columns."""
 83        return [
 84            (self._original_column_name(c, self.source_schema), t)
 85            for c, t in self._comparable_source_schema.items()
 86            if c not in self._comparable_target_schema
 87        ]
 88
 89    @property
 90    def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]:
 91        """Columns with modified types."""
 92        modified = {}
 93        for column in self._comparable_source_schema.keys() & self._comparable_target_schema.keys():
 94            source_type = self._comparable_source_schema[column]
 95            target_type = self._comparable_target_schema[column]
 96
 97            if source_type != target_type:
 98                modified[column] = (source_type, target_type)
 99
100        if self.ignore_case:
101            modified = {
102                self._original_column_name(c, self.source_schema): dt for c, dt in modified.items()
103            }
104
105        return modified
106
107    @property
108    def has_changes(self) -> bool:
109        """Does the schema contain any changes at all between source and target"""
110        return bool(self.added or self.removed or self.modified)
111
112
113class RowDiff(PydanticModel, frozen=True):
114    """Summary statistics and a sample dataframe."""
115
116    source: str
117    target: str
118    stats: t.Dict[str, float]
119    sample: pd.DataFrame
120    joined_sample: pd.DataFrame
121    s_sample: pd.DataFrame
122    t_sample: pd.DataFrame
123    column_stats: pd.DataFrame
124    source_alias: t.Optional[str] = None
125    target_alias: t.Optional[str] = None
126    model_name: t.Optional[str] = None
127    decimals: int = 3
128
129    _types_resolved: t.ClassVar[bool] = False
130
131    def __new__(cls, *args: t.Any, **kwargs: t.Any) -> RowDiff:
132        if not cls._types_resolved:
133            cls._resolve_types()
134        return super().__new__(cls)
135
136    @classmethod
137    def _resolve_types(cls) -> None:
138        # Pandas is imported by type checking so we need to resolve the types with the real import before instantiating
139        import pandas as pd  # noqa
140
141        cls.model_rebuild()
142        cls._types_resolved = True
143
144    @property
145    def source_count(self) -> int:
146        """Count of the source."""
147        return int(self.stats["s_count"])
148
149    @property
150    def target_count(self) -> int:
151        """Count of the target."""
152        return int(self.stats["t_count"])
153
154    @property
155    def empty(self) -> bool:
156        return (
157            self.source_count == 0
158            and self.target_count == 0
159            and self.s_only_count == 0
160            and self.t_only_count == 0
161        )
162
163    @property
164    def count_pct_change(self) -> float:
165        """The percentage change of the counts."""
166        if self.source_count == 0:
167            return math.inf
168        return ((self.target_count - self.source_count) / self.source_count) * 100
169
170    @property
171    def join_count(self) -> int:
172        """Count of successfully joined rows."""
173        return int(self.stats["join_count"])
174
175    @property
176    def full_match_count(self) -> int:
177        """The number of rows for which shared columns have same values."""
178        return int(self.stats["full_match_count"])
179
180    @property
181    def full_match_pct(self) -> float:
182        """The percentage of rows for which shared columns have same values."""
183        return self._pct(2 * self.full_match_count)
184
185    @property
186    def partial_match_count(self) -> int:
187        """The number of rows for which some shared columns have same values."""
188        return self.join_count - self.full_match_count
189
190    @property
191    def partial_match_pct(self) -> float:
192        """The percentage of rows for which some shared columns have same values."""
193        return self._pct(2 * self.partial_match_count)
194
195    @property
196    def s_only_count(self) -> int:
197        """Count of rows only present in source."""
198        return int(self.stats["s_only_count"])
199
200    @property
201    def s_only_pct(self) -> float:
202        """The percentage of rows that are only present in source."""
203        return self._pct(self.s_only_count)
204
205    @property
206    def t_only_count(self) -> int:
207        """Count of rows only present in target."""
208        return int(self.stats["t_only_count"])
209
210    @property
211    def t_only_pct(self) -> float:
212        """The percentage of rows that are only present in target."""
213        return self._pct(self.t_only_count)
214
215    def _pct(self, numerator: int) -> float:
216        return round((numerator / (self.source_count + self.target_count)) * 100, 2)
217
218
219class TableDiff:
220    """Calculates differences between tables, taking into account schema and row level differences."""
221
222    def __init__(
223        self,
224        adapter: EngineAdapter,
225        source: TableName,
226        target: TableName,
227        on: t.List[str] | exp.Expr,
228        skip_columns: t.List[str] | None = None,
229        where: t.Optional[str | exp.Expr] = None,
230        limit: int = 20,
231        source_alias: t.Optional[str] = None,
232        target_alias: t.Optional[str] = None,
233        model_name: t.Optional[str] = None,
234        model_dialect: t.Optional[str] = None,
235        decimals: int = 3,
236        schema_diff_ignore_case: bool = False,
237    ):
238        if not isinstance(adapter, RowDiffMixin):
239            raise ValueError(f"Engine {adapter} doesnt support RowDiff")
240
241        self.adapter = adapter
242        self.source = source
243        self.target = target
244        self.dialect = adapter.dialect
245        self.source_table = exp.to_table(self.source, dialect=self.dialect)
246        self.target_table = exp.to_table(self.target, dialect=self.dialect)
247        self.where = exp.condition(where, dialect=self.dialect) if where else None
248        self.limit = limit
249        self.model_name = model_name
250        self.model_dialect = model_dialect
251        self.decimals = decimals
252        self.schema_diff_ignore_case = schema_diff_ignore_case
253
254        # Support environment aliases for diff output improvement in certain cases
255        self.source_alias = source_alias
256        self.target_alias = target_alias
257
258        self.skip_columns = {
259            normalize_identifiers(
260                exp.parse_identifier(t.cast(str, col)),
261                dialect=self.model_dialect or self.dialect,
262            ).name
263            for col in ensure_list(skip_columns)
264        }
265
266        self._on = on
267        self._row_diff: t.Optional[RowDiff] = None
268
269    @cached_property
270    def source_schema(self) -> t.Dict[str, exp.DataType]:
271        return self.adapter.columns(self.source_table)
272
273    @cached_property
274    def target_schema(self) -> t.Dict[str, exp.DataType]:
275        return self.adapter.columns(self.target_table)
276
277    @cached_property
278    def key_columns(self) -> t.Tuple[t.List[exp.Column], t.List[exp.Column], t.List[str]]:
279        dialect = self.model_dialect or self.dialect
280
281        # If the columns to join on are explicitly specified, then just return them
282        if isinstance(self._on, (list, tuple)):
283            identifiers = [normalize_identifiers(c, dialect=dialect) for c in self._on]
284            s_index = [exp.column(c, "s") for c in identifiers]
285            t_index = [exp.column(c, "t") for c in identifiers]
286            return s_index, t_index, [i.name for i in identifiers]
287
288        # Otherwise, we need to parse them out of the supplied "on" condition
289        index_cols = []
290        s_index = []
291        t_index = []
292
293        normalize_identifiers(self._on, dialect=dialect)
294        for col in self._on.find_all(exp.Column):
295            index_cols.append(col.name)
296            if col.table.lower() == "s":
297                s_index.append(col)
298            elif col.table.lower() == "t":
299                t_index.append(col)
300
301        index_cols = list(dict.fromkeys(index_cols))
302        s_index = list(dict.fromkeys(s_index))
303        t_index = list(dict.fromkeys(t_index))
304
305        return s_index, t_index, index_cols
306
307    @property
308    def source_key_expression(self) -> exp.Expr:
309        s_index, _, _ = self.key_columns
310        return self._key_expression(s_index, self.source_schema)
311
312    @property
313    def target_key_expression(self) -> exp.Expr:
314        _, t_index, _ = self.key_columns
315        return self._key_expression(t_index, self.target_schema)
316
317    def _key_expression(
318        self, cols: t.List[exp.Column], schema: t.Dict[str, exp.DataType]
319    ) -> exp.Expr:
320        # if there is a single column, dont do anything fancy to it in order to allow existing indexes to be hit
321        if len(cols) == 1:
322            return exp.to_column(cols[0].name)
323
324        # if there are multiple columns, turn them into a single column by stringify-ing/concatenating them together
325        key_columns_to_types = {key.name: schema[key.name] for key in cols}
326        return self.adapter.concat_columns(key_columns_to_types, self.decimals)
327
328    def schema_diff(self) -> SchemaDiff:
329        return SchemaDiff(
330            source=self.source,
331            target=self.target,
332            source_schema=self.source_schema,
333            target_schema=self.target_schema,
334            source_alias=self.source_alias,
335            target_alias=self.target_alias,
336            model_name=self.model_name,
337            ignore_case=self.schema_diff_ignore_case,
338        )
339
340    def row_diff(
341        self, temp_schema: t.Optional[str] = None, skip_grain_check: bool = False
342    ) -> RowDiff:
343        if self._row_diff is None:
344            source_schema = {
345                c: t for c, t in self.source_schema.items() if c not in self.skip_columns
346            }
347            target_schema = {
348                c: t for c, t in self.target_schema.items() if c not in self.skip_columns
349            }
350
351            s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in source_schema}
352            t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in target_schema}
353            s_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "s").as_(
354                f"s__{SQLMESH_JOIN_KEY_COL}"
355            )
356            t_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "t").as_(
357                f"t__{SQLMESH_JOIN_KEY_COL}"
358            )
359
360            matched_columns = {c: t for c, t in source_schema.items() if t == target_schema.get(c)}
361
362            s_index, t_index, index_cols = self.key_columns
363            s_index_names = [c.name for c in s_index]
364            t_index_names = [t.name for t in t_index]
365
366            def _column_expr(name: str, table: str) -> exp.Expr:
367                column_type = matched_columns[name]
368                qualified_column = exp.column(name, table)
369
370                if column_type.is_type(*exp.DataType.REAL_TYPES):
371                    return self.adapter._normalize_decimal_value(qualified_column, self.decimals)
372                if column_type.is_type(*exp.DataType.NESTED_TYPES):
373                    return self.adapter._normalize_nested_value(qualified_column)
374
375                return qualified_column
376
377            comparisons = [
378                exp.Case()
379                .when(_column_expr(c, "s").eq(_column_expr(c, "t")), exp.Literal.number(1))
380                .when(
381                    exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()),
382                    exp.Literal.number(1),
383                )
384                .when(
385                    exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()),
386                    exp.Literal.number(0),
387                )
388                .else_(exp.Literal.number(0))
389                .as_(f"{c}_matches")
390                for c, t in matched_columns.items()
391            ]
392
393            source_query = (
394                exp.select(
395                    *(exp.column(c) for c in source_schema),
396                    self.source_key_expression.as_(SQLMESH_JOIN_KEY_COL),
397                )
398                .from_(self.source_table.as_("s"))
399                .where(self.where)
400            )
401            target_query = (
402                exp.select(
403                    *(exp.column(c) for c in target_schema),
404                    self.target_key_expression.as_(SQLMESH_JOIN_KEY_COL),
405                )
406                .from_(self.target_table.as_("t"))
407                .where(self.where)
408            )
409
410            # Ensure every column is qualified with the alias in the source and target queries
411            for col in find_all_in_scope(source_query, exp.Column):
412                col.set("table", exp.to_identifier("s"))
413            for col in find_all_in_scope(target_query, exp.Column):
414                col.set("table", exp.to_identifier("t"))
415
416            source_table = exp.table_("__source")
417            target_table = exp.table_("__target")
418            stats_table = exp.table_("__stats")
419
420            stats_query = (
421                exp.select(
422                    *s_selects.values(),
423                    *t_selects.values(),
424                    exp.func(
425                        "IF", exp.column(SQLMESH_JOIN_KEY_COL, "s").is_(exp.Null()).not_(), 1, 0
426                    ).as_("s_exists"),
427                    exp.func(
428                        "IF", exp.column(SQLMESH_JOIN_KEY_COL, "t").is_(exp.Null()).not_(), 1, 0
429                    ).as_("t_exists"),
430                    exp.func(
431                        "IF",
432                        exp.column(SQLMESH_JOIN_KEY_COL, "s").eq(
433                            exp.column(SQLMESH_JOIN_KEY_COL, "t")
434                        ),
435                        1,
436                        0,
437                    ).as_("row_joined"),
438                    exp.func(
439                        "IF",
440                        exp.or_(
441                            *(
442                                exp.and_(
443                                    s.is_(exp.Null()),
444                                    t.is_(exp.Null()),
445                                )
446                                for s, t in zip(s_index, t_index)
447                            ),
448                        ),
449                        1,
450                        0,
451                    ).as_("null_grain"),
452                    *comparisons,
453                )
454                .from_(source_table.as_("s"))
455                .join(
456                    target_table.as_("t"),
457                    on=exp.column(SQLMESH_JOIN_KEY_COL, "s").eq(
458                        exp.column(SQLMESH_JOIN_KEY_COL, "t")
459                    ),
460                    join_type="FULL",
461                )
462            )
463
464            base_query = (
465                exp.Select()
466                .with_(source_table, source_query)
467                .with_(target_table, target_query)
468                .with_(stats_table, stats_query)
469                .select(
470                    "*",
471                    exp.Case()
472                    .when(
473                        exp.and_(
474                            *[
475                                exp.column(f"{c}_matches").eq(exp.Literal.number(1))
476                                for c in matched_columns
477                            ]
478                        ),
479                        exp.Literal.number(1),
480                    )
481                    .else_(exp.Literal.number(0))
482                    .as_("row_full_match"),
483                )
484                .from_(stats_table)
485            )
486
487            query = self.adapter.ensure_nulls_for_unmatched_after_join(
488                quote_identifiers(base_query.copy(), dialect=self.model_dialect or self.dialect)
489            )
490
491            if not temp_schema:
492                temp_schema = "sqlmesh_temp"
493
494            schema = to_schema(temp_schema, dialect=self.dialect)
495            temp_table = exp.table_("diff", db=schema.db, catalog=schema.catalog, quoted=True)
496
497            temp_table_kwargs: t.Dict[str, t.Any] = {}
498            if isinstance(self.adapter, AthenaEngineAdapter):
499                # Athena has two table formats: Hive (the default) and Iceberg. TableDiff requires that
500                # the formats be the same for the source, target, and temp tables.
501                source_table_type = self.adapter._query_table_type(self.source_table)
502                target_table_type = self.adapter._query_table_type(self.target_table)
503
504                if source_table_type == "iceberg" and target_table_type == "iceberg":
505                    temp_table_kwargs["table_format"] = "iceberg"
506                # Sets the temp table's format to Iceberg.
507                # If neither source nor target table is Iceberg, it defaults to Hive (Athena's default).
508                elif source_table_type == "iceberg" or target_table_type == "iceberg":
509                    raise SQLMeshError(
510                        f"Source table '{self.source}' format '{source_table_type}' and target table '{self.target}' format '{target_table_type}' "
511                        f"do not match for Athena. Diffing between different table formats is not supported."
512                    )
513
514            with self.adapter.temp_table(
515                query, name=temp_table, target_columns_to_types=None, **temp_table_kwargs
516            ) as table:
517                summary_sums = [
518                    exp.func("SUM", "s_exists").as_("s_count"),
519                    exp.func("SUM", "t_exists").as_("t_count"),
520                    exp.func("SUM", "row_joined").as_("join_count"),
521                    exp.func("SUM", "null_grain").as_("null_grain_count"),
522                    exp.func("SUM", "row_full_match").as_("full_match_count"),
523                    *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons),
524                ]
525
526                if not skip_grain_check:
527                    summary_sums.extend(
528                        [
529                            parse_one(f"COUNT(DISTINCT(s__{SQLMESH_JOIN_KEY_COL}))").as_(
530                                "distinct_count_s"
531                            ),
532                            parse_one(f"COUNT(DISTINCT(t__{SQLMESH_JOIN_KEY_COL}))").as_(
533                                "distinct_count_t"
534                            ),
535                        ]
536                    )
537
538                summary_query = exp.select(*summary_sums).from_(table)
539
540                stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True).fillna(0)
541                stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"]
542                stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"]
543                stats = stats_df.iloc[0].to_dict()
544
545                column_stats_query = (
546                    exp.select(
547                        *(
548                            exp.func(
549                                "ROUND",
550                                100
551                                * (
552                                    exp.cast(
553                                        exp.func("SUM", name(c)), exp.DataType.build("NUMERIC")
554                                    )
555                                    / exp.func("COUNT", name(c))
556                                ),
557                                9,
558                            ).as_(c.alias)
559                            for c in comparisons
560                        )
561                    )
562                    .from_(table)
563                    .where(exp.column("row_joined").eq(exp.Literal.number(1)))
564                )
565
566                column_stats = (
567                    self.adapter.fetchdf(column_stats_query, quote_identifiers=True)
568                    .T.rename(
569                        columns={0: "pct_match"},
570                        index=lambda x: str(x).replace("_matches", "") if x else "",
571                    )
572                    # errors=ignore because all the index_cols might not be present in the DF if the `on` condition was something like "s.id == t.item_id"
573                    # because these would not be present in the matching_cols (since they have different names) and thus no summary would be generated
574                    .drop(index=index_cols, errors="ignore")
575                )
576
577                sample = self._fetch_sample(
578                    table, s_selects, s_index, t_selects, t_index, self.limit
579                )
580
581                joined_sample_cols = [f"s__{c}" for c in s_index_names]
582                comparison_cols = [
583                    (f"s__{c}", f"t__{c}")
584                    for c in column_stats[column_stats["pct_match"] < 100].index
585                ]
586
587                for cols in comparison_cols:
588                    joined_sample_cols.extend(cols)
589
590                joined_renamed_cols = {
591                    c: c.split("__")[1] if c.split("__")[1] in index_cols else c
592                    for c in joined_sample_cols
593                }
594
595                if (
596                    self.source_alias
597                    and self.target_alias
598                    and self.source != self.source_alias
599                    and self.target != self.target_alias
600                ):
601                    joined_renamed_cols = {
602                        c: (
603                            n.replace(
604                                "s__",
605                                f"{self.source_alias.upper()}__",
606                            )
607                            if n.startswith("s__")
608                            else n
609                        )
610                        for c, n in joined_renamed_cols.items()
611                    }
612                    joined_renamed_cols = {
613                        c: (
614                            n.replace(
615                                "t__",
616                                f"{self.target_alias.upper()}__",
617                            )
618                            if n.startswith("t__")
619                            else n
620                        )
621                        for c, n in joined_renamed_cols.items()
622                    }
623
624                joined_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "common_rows"][
625                    joined_sample_cols
626                ]
627                joined_sample.rename(
628                    columns=joined_renamed_cols,
629                    inplace=True,
630                )
631
632                s_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "source_only"][
633                    [
634                        *[f"s__{c}" for c in s_index_names],
635                        *[f"s__{c}" for c in source_schema if c not in s_index_names],
636                    ]
637                ]
638                s_sample.rename(
639                    columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True
640                )
641
642                t_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "target_only"][
643                    [
644                        *[f"t__{c}" for c in t_index_names],
645                        *[f"t__{c}" for c in target_schema if c not in t_index_names],
646                    ]
647                ]
648                t_sample.rename(
649                    columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True
650                )
651
652                sample.drop(
653                    columns=[
654                        f"s__{SQLMESH_JOIN_KEY_COL}",
655                        f"t__{SQLMESH_JOIN_KEY_COL}",
656                        SQLMESH_SAMPLE_TYPE_COL,
657                    ],
658                    inplace=True,
659                )
660
661                self._row_diff = RowDiff(
662                    source=self.source,
663                    target=self.target,
664                    stats=stats,
665                    column_stats=column_stats,
666                    sample=sample,
667                    joined_sample=joined_sample,
668                    s_sample=s_sample,
669                    t_sample=t_sample,
670                    source_alias=self.source_alias,
671                    target_alias=self.target_alias,
672                    model_name=self.model_name,
673                    decimals=self.decimals,
674                )
675
676        return self._row_diff
677
678    def _fetch_sample(
679        self,
680        sample_table: exp.Table,
681        s_selects: t.Dict[str, exp.Expr],
682        s_index: t.List[exp.Column],
683        t_selects: t.Dict[str, exp.Expr],
684        t_index: t.List[exp.Column],
685        limit: int,
686    ) -> pd.DataFrame:
687        rendered_data_column_names = [
688            name(s) for s in list(s_selects.values()) + list(t_selects.values())
689        ]
690        sample_type = exp.to_identifier(SQLMESH_SAMPLE_TYPE_COL)
691
692        source_only_sample = (
693            exp.select(
694                exp.Literal.string("source_only").as_(sample_type), *rendered_data_column_names
695            )
696            .from_(sample_table)
697            .where(exp.and_(exp.column("s_exists").eq(1), exp.column("row_joined").eq(0)))
698            .order_by(*(name(s_selects[c.name]) for c in s_index))
699            .limit(limit)
700        )
701
702        target_only_sample = (
703            exp.select(
704                exp.Literal.string("target_only").as_(sample_type), *rendered_data_column_names
705            )
706            .from_(sample_table)
707            .where(exp.and_(exp.column("t_exists").eq(1), exp.column("row_joined").eq(0)))
708            .order_by(*(name(t_selects[c.name]) for c in t_index))
709            .limit(limit)
710        )
711
712        common_rows_sample = (
713            exp.select(
714                exp.Literal.string("common_rows").as_(sample_type), *rendered_data_column_names
715            )
716            .from_(sample_table)
717            .where(exp.and_(exp.column("row_joined").eq(1), exp.column("row_full_match").eq(0)))
718            .order_by(
719                *(name(s_selects[c.name]) for c in s_index),
720                *(name(t_selects[c.name]) for c in t_index),
721            )
722            .limit(limit)
723        )
724
725        query = (
726            exp.Select()
727            .with_("source_only", source_only_sample)
728            .with_("target_only", target_only_sample)
729            .with_("common_rows", common_rows_sample)
730            .select(sample_type, *rendered_data_column_names)
731            .from_("source_only")
732            .union(
733                exp.select(sample_type, *rendered_data_column_names).from_("target_only"),
734                distinct=False,
735            )
736            .union(
737                exp.select(sample_type, *rendered_data_column_names).from_("common_rows"),
738                distinct=False,
739            )
740        )
741
742        return self.adapter.fetchdf(query, quote_identifiers=True)
743
744
745def name(e: exp.Expr) -> str:
746    return e.args["alias"].sql(identify=True)
SQLMESH_JOIN_KEY_COL = '__sqlmesh_join_key'
SQLMESH_SAMPLE_TYPE_COL = '__sqlmesh_sample_type'
class SchemaDiff(sqlmesh.utils.pydantic.PydanticModel):
 31class SchemaDiff(PydanticModel, frozen=True):
 32    """An object containing the schema difference between a source and target table."""
 33
 34    source: str
 35    target: str
 36    source_schema: t.Dict[str, exp.DataType]
 37    target_schema: t.Dict[str, exp.DataType]
 38    source_alias: t.Optional[str] = None
 39    target_alias: t.Optional[str] = None
 40    model_name: t.Optional[str] = None
 41    ignore_case: bool = False
 42
 43    @property
 44    def _comparable_source_schema(self) -> t.Dict[str, exp.DataType]:
 45        return (
 46            self._lowercase_schema_names(self.source_schema)
 47            if self.ignore_case
 48            else self.source_schema
 49        )
 50
 51    @property
 52    def _comparable_target_schema(self) -> t.Dict[str, exp.DataType]:
 53        return (
 54            self._lowercase_schema_names(self.target_schema)
 55            if self.ignore_case
 56            else self.target_schema
 57        )
 58
 59    def _lowercase_schema_names(
 60        self, schema: t.Dict[str, exp.DataType]
 61    ) -> t.Dict[str, exp.DataType]:
 62        return {c.lower(): t for c, t in schema.items()}
 63
 64    def _original_column_name(
 65        self, maybe_lowercased_column_name: str, schema: t.Dict[str, exp.DataType]
 66    ) -> str:
 67        if not self.ignore_case:
 68            return maybe_lowercased_column_name
 69
 70        return next(c for c in schema if c.lower() == maybe_lowercased_column_name)
 71
 72    @property
 73    def added(self) -> t.List[t.Tuple[str, exp.DataType]]:
 74        """Added columns."""
 75        return [
 76            (self._original_column_name(c, self.target_schema), t)
 77            for c, t in self._comparable_target_schema.items()
 78            if c not in self._comparable_source_schema
 79        ]
 80
 81    @property
 82    def removed(self) -> t.List[t.Tuple[str, exp.DataType]]:
 83        """Removed columns."""
 84        return [
 85            (self._original_column_name(c, self.source_schema), t)
 86            for c, t in self._comparable_source_schema.items()
 87            if c not in self._comparable_target_schema
 88        ]
 89
 90    @property
 91    def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]:
 92        """Columns with modified types."""
 93        modified = {}
 94        for column in self._comparable_source_schema.keys() & self._comparable_target_schema.keys():
 95            source_type = self._comparable_source_schema[column]
 96            target_type = self._comparable_target_schema[column]
 97
 98            if source_type != target_type:
 99                modified[column] = (source_type, target_type)
100
101        if self.ignore_case:
102            modified = {
103                self._original_column_name(c, self.source_schema): dt for c, dt in modified.items()
104            }
105
106        return modified
107
108    @property
109    def has_changes(self) -> bool:
110        """Does the schema contain any changes at all between source and target"""
111        return bool(self.added or self.removed or self.modified)

An object containing the schema difference between a source and target table.

source: str
target: str
source_schema: Dict[str, sqlglot.expressions.datatypes.DataType]
target_schema: Dict[str, sqlglot.expressions.datatypes.DataType]
source_alias: Optional[str]
target_alias: Optional[str]
model_name: Optional[str]
ignore_case: bool
added: List[Tuple[str, sqlglot.expressions.datatypes.DataType]]
72    @property
73    def added(self) -> t.List[t.Tuple[str, exp.DataType]]:
74        """Added columns."""
75        return [
76            (self._original_column_name(c, self.target_schema), t)
77            for c, t in self._comparable_target_schema.items()
78            if c not in self._comparable_source_schema
79        ]

Added columns.

removed: List[Tuple[str, sqlglot.expressions.datatypes.DataType]]
81    @property
82    def removed(self) -> t.List[t.Tuple[str, exp.DataType]]:
83        """Removed columns."""
84        return [
85            (self._original_column_name(c, self.source_schema), t)
86            for c, t in self._comparable_source_schema.items()
87            if c not in self._comparable_target_schema
88        ]

Removed columns.

modified: Dict[str, Tuple[sqlglot.expressions.datatypes.DataType, sqlglot.expressions.datatypes.DataType]]
 90    @property
 91    def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]:
 92        """Columns with modified types."""
 93        modified = {}
 94        for column in self._comparable_source_schema.keys() & self._comparable_target_schema.keys():
 95            source_type = self._comparable_source_schema[column]
 96            target_type = self._comparable_target_schema[column]
 97
 98            if source_type != target_type:
 99                modified[column] = (source_type, target_type)
100
101        if self.ignore_case:
102            modified = {
103                self._original_column_name(c, self.source_schema): dt for c, dt in modified.items()
104            }
105
106        return modified

Columns with modified types.

has_changes: bool
108    @property
109    def has_changes(self) -> bool:
110        """Does the schema contain any changes at all between source and target"""
111        return bool(self.added or self.removed or self.modified)

Does the schema contain any changes at all between source and target

model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': (), 'frozen': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class RowDiff(sqlmesh.utils.pydantic.PydanticModel):
114class RowDiff(PydanticModel, frozen=True):
115    """Summary statistics and a sample dataframe."""
116
117    source: str
118    target: str
119    stats: t.Dict[str, float]
120    sample: pd.DataFrame
121    joined_sample: pd.DataFrame
122    s_sample: pd.DataFrame
123    t_sample: pd.DataFrame
124    column_stats: pd.DataFrame
125    source_alias: t.Optional[str] = None
126    target_alias: t.Optional[str] = None
127    model_name: t.Optional[str] = None
128    decimals: int = 3
129
130    _types_resolved: t.ClassVar[bool] = False
131
132    def __new__(cls, *args: t.Any, **kwargs: t.Any) -> RowDiff:
133        if not cls._types_resolved:
134            cls._resolve_types()
135        return super().__new__(cls)
136
137    @classmethod
138    def _resolve_types(cls) -> None:
139        # Pandas is imported by type checking so we need to resolve the types with the real import before instantiating
140        import pandas as pd  # noqa
141
142        cls.model_rebuild()
143        cls._types_resolved = True
144
145    @property
146    def source_count(self) -> int:
147        """Count of the source."""
148        return int(self.stats["s_count"])
149
150    @property
151    def target_count(self) -> int:
152        """Count of the target."""
153        return int(self.stats["t_count"])
154
155    @property
156    def empty(self) -> bool:
157        return (
158            self.source_count == 0
159            and self.target_count == 0
160            and self.s_only_count == 0
161            and self.t_only_count == 0
162        )
163
164    @property
165    def count_pct_change(self) -> float:
166        """The percentage change of the counts."""
167        if self.source_count == 0:
168            return math.inf
169        return ((self.target_count - self.source_count) / self.source_count) * 100
170
171    @property
172    def join_count(self) -> int:
173        """Count of successfully joined rows."""
174        return int(self.stats["join_count"])
175
176    @property
177    def full_match_count(self) -> int:
178        """The number of rows for which shared columns have same values."""
179        return int(self.stats["full_match_count"])
180
181    @property
182    def full_match_pct(self) -> float:
183        """The percentage of rows for which shared columns have same values."""
184        return self._pct(2 * self.full_match_count)
185
186    @property
187    def partial_match_count(self) -> int:
188        """The number of rows for which some shared columns have same values."""
189        return self.join_count - self.full_match_count
190
191    @property
192    def partial_match_pct(self) -> float:
193        """The percentage of rows for which some shared columns have same values."""
194        return self._pct(2 * self.partial_match_count)
195
196    @property
197    def s_only_count(self) -> int:
198        """Count of rows only present in source."""
199        return int(self.stats["s_only_count"])
200
201    @property
202    def s_only_pct(self) -> float:
203        """The percentage of rows that are only present in source."""
204        return self._pct(self.s_only_count)
205
206    @property
207    def t_only_count(self) -> int:
208        """Count of rows only present in target."""
209        return int(self.stats["t_only_count"])
210
211    @property
212    def t_only_pct(self) -> float:
213        """The percentage of rows that are only present in target."""
214        return self._pct(self.t_only_count)
215
216    def _pct(self, numerator: int) -> float:
217        return round((numerator / (self.source_count + self.target_count)) * 100, 2)

Summary statistics and a sample dataframe.

source: str
target: str
stats: Dict[str, float]
sample: pandas.core.frame.DataFrame
joined_sample: pandas.core.frame.DataFrame
s_sample: pandas.core.frame.DataFrame
t_sample: pandas.core.frame.DataFrame
column_stats: pandas.core.frame.DataFrame
source_alias: Optional[str]
target_alias: Optional[str]
model_name: Optional[str]
decimals: int
source_count: int
145    @property
146    def source_count(self) -> int:
147        """Count of the source."""
148        return int(self.stats["s_count"])

Count of the source.

target_count: int
150    @property
151    def target_count(self) -> int:
152        """Count of the target."""
153        return int(self.stats["t_count"])

Count of the target.

empty: bool
155    @property
156    def empty(self) -> bool:
157        return (
158            self.source_count == 0
159            and self.target_count == 0
160            and self.s_only_count == 0
161            and self.t_only_count == 0
162        )
count_pct_change: float
164    @property
165    def count_pct_change(self) -> float:
166        """The percentage change of the counts."""
167        if self.source_count == 0:
168            return math.inf
169        return ((self.target_count - self.source_count) / self.source_count) * 100

The percentage change of the counts.

join_count: int
171    @property
172    def join_count(self) -> int:
173        """Count of successfully joined rows."""
174        return int(self.stats["join_count"])

Count of successfully joined rows.

full_match_count: int
176    @property
177    def full_match_count(self) -> int:
178        """The number of rows for which shared columns have same values."""
179        return int(self.stats["full_match_count"])

The number of rows for which shared columns have same values.

full_match_pct: float
181    @property
182    def full_match_pct(self) -> float:
183        """The percentage of rows for which shared columns have same values."""
184        return self._pct(2 * self.full_match_count)

The percentage of rows for which shared columns have same values.

partial_match_count: int
186    @property
187    def partial_match_count(self) -> int:
188        """The number of rows for which some shared columns have same values."""
189        return self.join_count - self.full_match_count

The number of rows for which some shared columns have same values.

partial_match_pct: float
191    @property
192    def partial_match_pct(self) -> float:
193        """The percentage of rows for which some shared columns have same values."""
194        return self._pct(2 * self.partial_match_count)

The percentage of rows for which some shared columns have same values.

s_only_count: int
196    @property
197    def s_only_count(self) -> int:
198        """Count of rows only present in source."""
199        return int(self.stats["s_only_count"])

Count of rows only present in source.

s_only_pct: float
201    @property
202    def s_only_pct(self) -> float:
203        """The percentage of rows that are only present in source."""
204        return self._pct(self.s_only_count)

The percentage of rows that are only present in source.

t_only_count: int
206    @property
207    def t_only_count(self) -> int:
208        """Count of rows only present in target."""
209        return int(self.stats["t_only_count"])

Count of rows only present in target.

t_only_pct: float
211    @property
212    def t_only_pct(self) -> float:
213        """The percentage of rows that are only present in target."""
214        return self._pct(self.t_only_count)

The percentage of rows that are only present in target.

model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': (), 'frozen': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class TableDiff:
220class TableDiff:
221    """Calculates differences between tables, taking into account schema and row level differences."""
222
223    def __init__(
224        self,
225        adapter: EngineAdapter,
226        source: TableName,
227        target: TableName,
228        on: t.List[str] | exp.Expr,
229        skip_columns: t.List[str] | None = None,
230        where: t.Optional[str | exp.Expr] = None,
231        limit: int = 20,
232        source_alias: t.Optional[str] = None,
233        target_alias: t.Optional[str] = None,
234        model_name: t.Optional[str] = None,
235        model_dialect: t.Optional[str] = None,
236        decimals: int = 3,
237        schema_diff_ignore_case: bool = False,
238    ):
239        if not isinstance(adapter, RowDiffMixin):
240            raise ValueError(f"Engine {adapter} doesnt support RowDiff")
241
242        self.adapter = adapter
243        self.source = source
244        self.target = target
245        self.dialect = adapter.dialect
246        self.source_table = exp.to_table(self.source, dialect=self.dialect)
247        self.target_table = exp.to_table(self.target, dialect=self.dialect)
248        self.where = exp.condition(where, dialect=self.dialect) if where else None
249        self.limit = limit
250        self.model_name = model_name
251        self.model_dialect = model_dialect
252        self.decimals = decimals
253        self.schema_diff_ignore_case = schema_diff_ignore_case
254
255        # Support environment aliases for diff output improvement in certain cases
256        self.source_alias = source_alias
257        self.target_alias = target_alias
258
259        self.skip_columns = {
260            normalize_identifiers(
261                exp.parse_identifier(t.cast(str, col)),
262                dialect=self.model_dialect or self.dialect,
263            ).name
264            for col in ensure_list(skip_columns)
265        }
266
267        self._on = on
268        self._row_diff: t.Optional[RowDiff] = None
269
270    @cached_property
271    def source_schema(self) -> t.Dict[str, exp.DataType]:
272        return self.adapter.columns(self.source_table)
273
274    @cached_property
275    def target_schema(self) -> t.Dict[str, exp.DataType]:
276        return self.adapter.columns(self.target_table)
277
278    @cached_property
279    def key_columns(self) -> t.Tuple[t.List[exp.Column], t.List[exp.Column], t.List[str]]:
280        dialect = self.model_dialect or self.dialect
281
282        # If the columns to join on are explicitly specified, then just return them
283        if isinstance(self._on, (list, tuple)):
284            identifiers = [normalize_identifiers(c, dialect=dialect) for c in self._on]
285            s_index = [exp.column(c, "s") for c in identifiers]
286            t_index = [exp.column(c, "t") for c in identifiers]
287            return s_index, t_index, [i.name for i in identifiers]
288
289        # Otherwise, we need to parse them out of the supplied "on" condition
290        index_cols = []
291        s_index = []
292        t_index = []
293
294        normalize_identifiers(self._on, dialect=dialect)
295        for col in self._on.find_all(exp.Column):
296            index_cols.append(col.name)
297            if col.table.lower() == "s":
298                s_index.append(col)
299            elif col.table.lower() == "t":
300                t_index.append(col)
301
302        index_cols = list(dict.fromkeys(index_cols))
303        s_index = list(dict.fromkeys(s_index))
304        t_index = list(dict.fromkeys(t_index))
305
306        return s_index, t_index, index_cols
307
308    @property
309    def source_key_expression(self) -> exp.Expr:
310        s_index, _, _ = self.key_columns
311        return self._key_expression(s_index, self.source_schema)
312
313    @property
314    def target_key_expression(self) -> exp.Expr:
315        _, t_index, _ = self.key_columns
316        return self._key_expression(t_index, self.target_schema)
317
318    def _key_expression(
319        self, cols: t.List[exp.Column], schema: t.Dict[str, exp.DataType]
320    ) -> exp.Expr:
321        # if there is a single column, dont do anything fancy to it in order to allow existing indexes to be hit
322        if len(cols) == 1:
323            return exp.to_column(cols[0].name)
324
325        # if there are multiple columns, turn them into a single column by stringify-ing/concatenating them together
326        key_columns_to_types = {key.name: schema[key.name] for key in cols}
327        return self.adapter.concat_columns(key_columns_to_types, self.decimals)
328
329    def schema_diff(self) -> SchemaDiff:
330        return SchemaDiff(
331            source=self.source,
332            target=self.target,
333            source_schema=self.source_schema,
334            target_schema=self.target_schema,
335            source_alias=self.source_alias,
336            target_alias=self.target_alias,
337            model_name=self.model_name,
338            ignore_case=self.schema_diff_ignore_case,
339        )
340
341    def row_diff(
342        self, temp_schema: t.Optional[str] = None, skip_grain_check: bool = False
343    ) -> RowDiff:
344        if self._row_diff is None:
345            source_schema = {
346                c: t for c, t in self.source_schema.items() if c not in self.skip_columns
347            }
348            target_schema = {
349                c: t for c, t in self.target_schema.items() if c not in self.skip_columns
350            }
351
352            s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in source_schema}
353            t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in target_schema}
354            s_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "s").as_(
355                f"s__{SQLMESH_JOIN_KEY_COL}"
356            )
357            t_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "t").as_(
358                f"t__{SQLMESH_JOIN_KEY_COL}"
359            )
360
361            matched_columns = {c: t for c, t in source_schema.items() if t == target_schema.get(c)}
362
363            s_index, t_index, index_cols = self.key_columns
364            s_index_names = [c.name for c in s_index]
365            t_index_names = [t.name for t in t_index]
366
367            def _column_expr(name: str, table: str) -> exp.Expr:
368                column_type = matched_columns[name]
369                qualified_column = exp.column(name, table)
370
371                if column_type.is_type(*exp.DataType.REAL_TYPES):
372                    return self.adapter._normalize_decimal_value(qualified_column, self.decimals)
373                if column_type.is_type(*exp.DataType.NESTED_TYPES):
374                    return self.adapter._normalize_nested_value(qualified_column)
375
376                return qualified_column
377
378            comparisons = [
379                exp.Case()
380                .when(_column_expr(c, "s").eq(_column_expr(c, "t")), exp.Literal.number(1))
381                .when(
382                    exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()),
383                    exp.Literal.number(1),
384                )
385                .when(
386                    exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()),
387                    exp.Literal.number(0),
388                )
389                .else_(exp.Literal.number(0))
390                .as_(f"{c}_matches")
391                for c, t in matched_columns.items()
392            ]
393
394            source_query = (
395                exp.select(
396                    *(exp.column(c) for c in source_schema),
397                    self.source_key_expression.as_(SQLMESH_JOIN_KEY_COL),
398                )
399                .from_(self.source_table.as_("s"))
400                .where(self.where)
401            )
402            target_query = (
403                exp.select(
404                    *(exp.column(c) for c in target_schema),
405                    self.target_key_expression.as_(SQLMESH_JOIN_KEY_COL),
406                )
407                .from_(self.target_table.as_("t"))
408                .where(self.where)
409            )
410
411            # Ensure every column is qualified with the alias in the source and target queries
412            for col in find_all_in_scope(source_query, exp.Column):
413                col.set("table", exp.to_identifier("s"))
414            for col in find_all_in_scope(target_query, exp.Column):
415                col.set("table", exp.to_identifier("t"))
416
417            source_table = exp.table_("__source")
418            target_table = exp.table_("__target")
419            stats_table = exp.table_("__stats")
420
421            stats_query = (
422                exp.select(
423                    *s_selects.values(),
424                    *t_selects.values(),
425                    exp.func(
426                        "IF", exp.column(SQLMESH_JOIN_KEY_COL, "s").is_(exp.Null()).not_(), 1, 0
427                    ).as_("s_exists"),
428                    exp.func(
429                        "IF", exp.column(SQLMESH_JOIN_KEY_COL, "t").is_(exp.Null()).not_(), 1, 0
430                    ).as_("t_exists"),
431                    exp.func(
432                        "IF",
433                        exp.column(SQLMESH_JOIN_KEY_COL, "s").eq(
434                            exp.column(SQLMESH_JOIN_KEY_COL, "t")
435                        ),
436                        1,
437                        0,
438                    ).as_("row_joined"),
439                    exp.func(
440                        "IF",
441                        exp.or_(
442                            *(
443                                exp.and_(
444                                    s.is_(exp.Null()),
445                                    t.is_(exp.Null()),
446                                )
447                                for s, t in zip(s_index, t_index)
448                            ),
449                        ),
450                        1,
451                        0,
452                    ).as_("null_grain"),
453                    *comparisons,
454                )
455                .from_(source_table.as_("s"))
456                .join(
457                    target_table.as_("t"),
458                    on=exp.column(SQLMESH_JOIN_KEY_COL, "s").eq(
459                        exp.column(SQLMESH_JOIN_KEY_COL, "t")
460                    ),
461                    join_type="FULL",
462                )
463            )
464
465            base_query = (
466                exp.Select()
467                .with_(source_table, source_query)
468                .with_(target_table, target_query)
469                .with_(stats_table, stats_query)
470                .select(
471                    "*",
472                    exp.Case()
473                    .when(
474                        exp.and_(
475                            *[
476                                exp.column(f"{c}_matches").eq(exp.Literal.number(1))
477                                for c in matched_columns
478                            ]
479                        ),
480                        exp.Literal.number(1),
481                    )
482                    .else_(exp.Literal.number(0))
483                    .as_("row_full_match"),
484                )
485                .from_(stats_table)
486            )
487
488            query = self.adapter.ensure_nulls_for_unmatched_after_join(
489                quote_identifiers(base_query.copy(), dialect=self.model_dialect or self.dialect)
490            )
491
492            if not temp_schema:
493                temp_schema = "sqlmesh_temp"
494
495            schema = to_schema(temp_schema, dialect=self.dialect)
496            temp_table = exp.table_("diff", db=schema.db, catalog=schema.catalog, quoted=True)
497
498            temp_table_kwargs: t.Dict[str, t.Any] = {}
499            if isinstance(self.adapter, AthenaEngineAdapter):
500                # Athena has two table formats: Hive (the default) and Iceberg. TableDiff requires that
501                # the formats be the same for the source, target, and temp tables.
502                source_table_type = self.adapter._query_table_type(self.source_table)
503                target_table_type = self.adapter._query_table_type(self.target_table)
504
505                if source_table_type == "iceberg" and target_table_type == "iceberg":
506                    temp_table_kwargs["table_format"] = "iceberg"
507                # Sets the temp table's format to Iceberg.
508                # If neither source nor target table is Iceberg, it defaults to Hive (Athena's default).
509                elif source_table_type == "iceberg" or target_table_type == "iceberg":
510                    raise SQLMeshError(
511                        f"Source table '{self.source}' format '{source_table_type}' and target table '{self.target}' format '{target_table_type}' "
512                        f"do not match for Athena. Diffing between different table formats is not supported."
513                    )
514
515            with self.adapter.temp_table(
516                query, name=temp_table, target_columns_to_types=None, **temp_table_kwargs
517            ) as table:
518                summary_sums = [
519                    exp.func("SUM", "s_exists").as_("s_count"),
520                    exp.func("SUM", "t_exists").as_("t_count"),
521                    exp.func("SUM", "row_joined").as_("join_count"),
522                    exp.func("SUM", "null_grain").as_("null_grain_count"),
523                    exp.func("SUM", "row_full_match").as_("full_match_count"),
524                    *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons),
525                ]
526
527                if not skip_grain_check:
528                    summary_sums.extend(
529                        [
530                            parse_one(f"COUNT(DISTINCT(s__{SQLMESH_JOIN_KEY_COL}))").as_(
531                                "distinct_count_s"
532                            ),
533                            parse_one(f"COUNT(DISTINCT(t__{SQLMESH_JOIN_KEY_COL}))").as_(
534                                "distinct_count_t"
535                            ),
536                        ]
537                    )
538
539                summary_query = exp.select(*summary_sums).from_(table)
540
541                stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True).fillna(0)
542                stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"]
543                stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"]
544                stats = stats_df.iloc[0].to_dict()
545
546                column_stats_query = (
547                    exp.select(
548                        *(
549                            exp.func(
550                                "ROUND",
551                                100
552                                * (
553                                    exp.cast(
554                                        exp.func("SUM", name(c)), exp.DataType.build("NUMERIC")
555                                    )
556                                    / exp.func("COUNT", name(c))
557                                ),
558                                9,
559                            ).as_(c.alias)
560                            for c in comparisons
561                        )
562                    )
563                    .from_(table)
564                    .where(exp.column("row_joined").eq(exp.Literal.number(1)))
565                )
566
567                column_stats = (
568                    self.adapter.fetchdf(column_stats_query, quote_identifiers=True)
569                    .T.rename(
570                        columns={0: "pct_match"},
571                        index=lambda x: str(x).replace("_matches", "") if x else "",
572                    )
573                    # errors=ignore because all the index_cols might not be present in the DF if the `on` condition was something like "s.id == t.item_id"
574                    # because these would not be present in the matching_cols (since they have different names) and thus no summary would be generated
575                    .drop(index=index_cols, errors="ignore")
576                )
577
578                sample = self._fetch_sample(
579                    table, s_selects, s_index, t_selects, t_index, self.limit
580                )
581
582                joined_sample_cols = [f"s__{c}" for c in s_index_names]
583                comparison_cols = [
584                    (f"s__{c}", f"t__{c}")
585                    for c in column_stats[column_stats["pct_match"] < 100].index
586                ]
587
588                for cols in comparison_cols:
589                    joined_sample_cols.extend(cols)
590
591                joined_renamed_cols = {
592                    c: c.split("__")[1] if c.split("__")[1] in index_cols else c
593                    for c in joined_sample_cols
594                }
595
596                if (
597                    self.source_alias
598                    and self.target_alias
599                    and self.source != self.source_alias
600                    and self.target != self.target_alias
601                ):
602                    joined_renamed_cols = {
603                        c: (
604                            n.replace(
605                                "s__",
606                                f"{self.source_alias.upper()}__",
607                            )
608                            if n.startswith("s__")
609                            else n
610                        )
611                        for c, n in joined_renamed_cols.items()
612                    }
613                    joined_renamed_cols = {
614                        c: (
615                            n.replace(
616                                "t__",
617                                f"{self.target_alias.upper()}__",
618                            )
619                            if n.startswith("t__")
620                            else n
621                        )
622                        for c, n in joined_renamed_cols.items()
623                    }
624
625                joined_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "common_rows"][
626                    joined_sample_cols
627                ]
628                joined_sample.rename(
629                    columns=joined_renamed_cols,
630                    inplace=True,
631                )
632
633                s_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "source_only"][
634                    [
635                        *[f"s__{c}" for c in s_index_names],
636                        *[f"s__{c}" for c in source_schema if c not in s_index_names],
637                    ]
638                ]
639                s_sample.rename(
640                    columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True
641                )
642
643                t_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "target_only"][
644                    [
645                        *[f"t__{c}" for c in t_index_names],
646                        *[f"t__{c}" for c in target_schema if c not in t_index_names],
647                    ]
648                ]
649                t_sample.rename(
650                    columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True
651                )
652
653                sample.drop(
654                    columns=[
655                        f"s__{SQLMESH_JOIN_KEY_COL}",
656                        f"t__{SQLMESH_JOIN_KEY_COL}",
657                        SQLMESH_SAMPLE_TYPE_COL,
658                    ],
659                    inplace=True,
660                )
661
662                self._row_diff = RowDiff(
663                    source=self.source,
664                    target=self.target,
665                    stats=stats,
666                    column_stats=column_stats,
667                    sample=sample,
668                    joined_sample=joined_sample,
669                    s_sample=s_sample,
670                    t_sample=t_sample,
671                    source_alias=self.source_alias,
672                    target_alias=self.target_alias,
673                    model_name=self.model_name,
674                    decimals=self.decimals,
675                )
676
677        return self._row_diff
678
679    def _fetch_sample(
680        self,
681        sample_table: exp.Table,
682        s_selects: t.Dict[str, exp.Expr],
683        s_index: t.List[exp.Column],
684        t_selects: t.Dict[str, exp.Expr],
685        t_index: t.List[exp.Column],
686        limit: int,
687    ) -> pd.DataFrame:
688        rendered_data_column_names = [
689            name(s) for s in list(s_selects.values()) + list(t_selects.values())
690        ]
691        sample_type = exp.to_identifier(SQLMESH_SAMPLE_TYPE_COL)
692
693        source_only_sample = (
694            exp.select(
695                exp.Literal.string("source_only").as_(sample_type), *rendered_data_column_names
696            )
697            .from_(sample_table)
698            .where(exp.and_(exp.column("s_exists").eq(1), exp.column("row_joined").eq(0)))
699            .order_by(*(name(s_selects[c.name]) for c in s_index))
700            .limit(limit)
701        )
702
703        target_only_sample = (
704            exp.select(
705                exp.Literal.string("target_only").as_(sample_type), *rendered_data_column_names
706            )
707            .from_(sample_table)
708            .where(exp.and_(exp.column("t_exists").eq(1), exp.column("row_joined").eq(0)))
709            .order_by(*(name(t_selects[c.name]) for c in t_index))
710            .limit(limit)
711        )
712
713        common_rows_sample = (
714            exp.select(
715                exp.Literal.string("common_rows").as_(sample_type), *rendered_data_column_names
716            )
717            .from_(sample_table)
718            .where(exp.and_(exp.column("row_joined").eq(1), exp.column("row_full_match").eq(0)))
719            .order_by(
720                *(name(s_selects[c.name]) for c in s_index),
721                *(name(t_selects[c.name]) for c in t_index),
722            )
723            .limit(limit)
724        )
725
726        query = (
727            exp.Select()
728            .with_("source_only", source_only_sample)
729            .with_("target_only", target_only_sample)
730            .with_("common_rows", common_rows_sample)
731            .select(sample_type, *rendered_data_column_names)
732            .from_("source_only")
733            .union(
734                exp.select(sample_type, *rendered_data_column_names).from_("target_only"),
735                distinct=False,
736            )
737            .union(
738                exp.select(sample_type, *rendered_data_column_names).from_("common_rows"),
739                distinct=False,
740            )
741        )
742
743        return self.adapter.fetchdf(query, quote_identifiers=True)

Calculates differences between tables, taking into account schema and row level differences.

TableDiff( adapter: sqlmesh.core.engine_adapter.base.EngineAdapter, source: Union[str, sqlglot.expressions.query.Table], target: Union[str, sqlglot.expressions.query.Table], on: Union[List[str], sqlglot.expressions.core.Expr], skip_columns: Optional[List[str]] = None, where: Union[str, sqlglot.expressions.core.Expr, NoneType] = None, limit: int = 20, source_alias: Optional[str] = None, target_alias: Optional[str] = None, model_name: Optional[str] = None, model_dialect: Optional[str] = None, decimals: int = 3, schema_diff_ignore_case: bool = False)
223    def __init__(
224        self,
225        adapter: EngineAdapter,
226        source: TableName,
227        target: TableName,
228        on: t.List[str] | exp.Expr,
229        skip_columns: t.List[str] | None = None,
230        where: t.Optional[str | exp.Expr] = None,
231        limit: int = 20,
232        source_alias: t.Optional[str] = None,
233        target_alias: t.Optional[str] = None,
234        model_name: t.Optional[str] = None,
235        model_dialect: t.Optional[str] = None,
236        decimals: int = 3,
237        schema_diff_ignore_case: bool = False,
238    ):
239        if not isinstance(adapter, RowDiffMixin):
240            raise ValueError(f"Engine {adapter} doesnt support RowDiff")
241
242        self.adapter = adapter
243        self.source = source
244        self.target = target
245        self.dialect = adapter.dialect
246        self.source_table = exp.to_table(self.source, dialect=self.dialect)
247        self.target_table = exp.to_table(self.target, dialect=self.dialect)
248        self.where = exp.condition(where, dialect=self.dialect) if where else None
249        self.limit = limit
250        self.model_name = model_name
251        self.model_dialect = model_dialect
252        self.decimals = decimals
253        self.schema_diff_ignore_case = schema_diff_ignore_case
254
255        # Support environment aliases for diff output improvement in certain cases
256        self.source_alias = source_alias
257        self.target_alias = target_alias
258
259        self.skip_columns = {
260            normalize_identifiers(
261                exp.parse_identifier(t.cast(str, col)),
262                dialect=self.model_dialect or self.dialect,
263            ).name
264            for col in ensure_list(skip_columns)
265        }
266
267        self._on = on
268        self._row_diff: t.Optional[RowDiff] = None
adapter
source
target
dialect
source_table
target_table
where
limit
model_name
model_dialect
decimals
schema_diff_ignore_case
source_alias
target_alias
skip_columns
source_schema: Dict[str, sqlglot.expressions.datatypes.DataType]
270    @cached_property
271    def source_schema(self) -> t.Dict[str, exp.DataType]:
272        return self.adapter.columns(self.source_table)
target_schema: Dict[str, sqlglot.expressions.datatypes.DataType]
274    @cached_property
275    def target_schema(self) -> t.Dict[str, exp.DataType]:
276        return self.adapter.columns(self.target_table)
key_columns: Tuple[List[sqlglot.expressions.core.Column], List[sqlglot.expressions.core.Column], List[str]]
278    @cached_property
279    def key_columns(self) -> t.Tuple[t.List[exp.Column], t.List[exp.Column], t.List[str]]:
280        dialect = self.model_dialect or self.dialect
281
282        # If the columns to join on are explicitly specified, then just return them
283        if isinstance(self._on, (list, tuple)):
284            identifiers = [normalize_identifiers(c, dialect=dialect) for c in self._on]
285            s_index = [exp.column(c, "s") for c in identifiers]
286            t_index = [exp.column(c, "t") for c in identifiers]
287            return s_index, t_index, [i.name for i in identifiers]
288
289        # Otherwise, we need to parse them out of the supplied "on" condition
290        index_cols = []
291        s_index = []
292        t_index = []
293
294        normalize_identifiers(self._on, dialect=dialect)
295        for col in self._on.find_all(exp.Column):
296            index_cols.append(col.name)
297            if col.table.lower() == "s":
298                s_index.append(col)
299            elif col.table.lower() == "t":
300                t_index.append(col)
301
302        index_cols = list(dict.fromkeys(index_cols))
303        s_index = list(dict.fromkeys(s_index))
304        t_index = list(dict.fromkeys(t_index))
305
306        return s_index, t_index, index_cols
source_key_expression: sqlglot.expressions.core.Expr
308    @property
309    def source_key_expression(self) -> exp.Expr:
310        s_index, _, _ = self.key_columns
311        return self._key_expression(s_index, self.source_schema)
target_key_expression: sqlglot.expressions.core.Expr
313    @property
314    def target_key_expression(self) -> exp.Expr:
315        _, t_index, _ = self.key_columns
316        return self._key_expression(t_index, self.target_schema)
def schema_diff(self) -> SchemaDiff:
329    def schema_diff(self) -> SchemaDiff:
330        return SchemaDiff(
331            source=self.source,
332            target=self.target,
333            source_schema=self.source_schema,
334            target_schema=self.target_schema,
335            source_alias=self.source_alias,
336            target_alias=self.target_alias,
337            model_name=self.model_name,
338            ignore_case=self.schema_diff_ignore_case,
339        )
def row_diff( self, temp_schema: Optional[str] = None, skip_grain_check: bool = False) -> RowDiff:
341    def row_diff(
342        self, temp_schema: t.Optional[str] = None, skip_grain_check: bool = False
343    ) -> RowDiff:
344        if self._row_diff is None:
345            source_schema = {
346                c: t for c, t in self.source_schema.items() if c not in self.skip_columns
347            }
348            target_schema = {
349                c: t for c, t in self.target_schema.items() if c not in self.skip_columns
350            }
351
352            s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in source_schema}
353            t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in target_schema}
354            s_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "s").as_(
355                f"s__{SQLMESH_JOIN_KEY_COL}"
356            )
357            t_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "t").as_(
358                f"t__{SQLMESH_JOIN_KEY_COL}"
359            )
360
361            matched_columns = {c: t for c, t in source_schema.items() if t == target_schema.get(c)}
362
363            s_index, t_index, index_cols = self.key_columns
364            s_index_names = [c.name for c in s_index]
365            t_index_names = [t.name for t in t_index]
366
367            def _column_expr(name: str, table: str) -> exp.Expr:
368                column_type = matched_columns[name]
369                qualified_column = exp.column(name, table)
370
371                if column_type.is_type(*exp.DataType.REAL_TYPES):
372                    return self.adapter._normalize_decimal_value(qualified_column, self.decimals)
373                if column_type.is_type(*exp.DataType.NESTED_TYPES):
374                    return self.adapter._normalize_nested_value(qualified_column)
375
376                return qualified_column
377
378            comparisons = [
379                exp.Case()
380                .when(_column_expr(c, "s").eq(_column_expr(c, "t")), exp.Literal.number(1))
381                .when(
382                    exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()),
383                    exp.Literal.number(1),
384                )
385                .when(
386                    exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()),
387                    exp.Literal.number(0),
388                )
389                .else_(exp.Literal.number(0))
390                .as_(f"{c}_matches")
391                for c, t in matched_columns.items()
392            ]
393
394            source_query = (
395                exp.select(
396                    *(exp.column(c) for c in source_schema),
397                    self.source_key_expression.as_(SQLMESH_JOIN_KEY_COL),
398                )
399                .from_(self.source_table.as_("s"))
400                .where(self.where)
401            )
402            target_query = (
403                exp.select(
404                    *(exp.column(c) for c in target_schema),
405                    self.target_key_expression.as_(SQLMESH_JOIN_KEY_COL),
406                )
407                .from_(self.target_table.as_("t"))
408                .where(self.where)
409            )
410
411            # Ensure every column is qualified with the alias in the source and target queries
412            for col in find_all_in_scope(source_query, exp.Column):
413                col.set("table", exp.to_identifier("s"))
414            for col in find_all_in_scope(target_query, exp.Column):
415                col.set("table", exp.to_identifier("t"))
416
417            source_table = exp.table_("__source")
418            target_table = exp.table_("__target")
419            stats_table = exp.table_("__stats")
420
421            stats_query = (
422                exp.select(
423                    *s_selects.values(),
424                    *t_selects.values(),
425                    exp.func(
426                        "IF", exp.column(SQLMESH_JOIN_KEY_COL, "s").is_(exp.Null()).not_(), 1, 0
427                    ).as_("s_exists"),
428                    exp.func(
429                        "IF", exp.column(SQLMESH_JOIN_KEY_COL, "t").is_(exp.Null()).not_(), 1, 0
430                    ).as_("t_exists"),
431                    exp.func(
432                        "IF",
433                        exp.column(SQLMESH_JOIN_KEY_COL, "s").eq(
434                            exp.column(SQLMESH_JOIN_KEY_COL, "t")
435                        ),
436                        1,
437                        0,
438                    ).as_("row_joined"),
439                    exp.func(
440                        "IF",
441                        exp.or_(
442                            *(
443                                exp.and_(
444                                    s.is_(exp.Null()),
445                                    t.is_(exp.Null()),
446                                )
447                                for s, t in zip(s_index, t_index)
448                            ),
449                        ),
450                        1,
451                        0,
452                    ).as_("null_grain"),
453                    *comparisons,
454                )
455                .from_(source_table.as_("s"))
456                .join(
457                    target_table.as_("t"),
458                    on=exp.column(SQLMESH_JOIN_KEY_COL, "s").eq(
459                        exp.column(SQLMESH_JOIN_KEY_COL, "t")
460                    ),
461                    join_type="FULL",
462                )
463            )
464
465            base_query = (
466                exp.Select()
467                .with_(source_table, source_query)
468                .with_(target_table, target_query)
469                .with_(stats_table, stats_query)
470                .select(
471                    "*",
472                    exp.Case()
473                    .when(
474                        exp.and_(
475                            *[
476                                exp.column(f"{c}_matches").eq(exp.Literal.number(1))
477                                for c in matched_columns
478                            ]
479                        ),
480                        exp.Literal.number(1),
481                    )
482                    .else_(exp.Literal.number(0))
483                    .as_("row_full_match"),
484                )
485                .from_(stats_table)
486            )
487
488            query = self.adapter.ensure_nulls_for_unmatched_after_join(
489                quote_identifiers(base_query.copy(), dialect=self.model_dialect or self.dialect)
490            )
491
492            if not temp_schema:
493                temp_schema = "sqlmesh_temp"
494
495            schema = to_schema(temp_schema, dialect=self.dialect)
496            temp_table = exp.table_("diff", db=schema.db, catalog=schema.catalog, quoted=True)
497
498            temp_table_kwargs: t.Dict[str, t.Any] = {}
499            if isinstance(self.adapter, AthenaEngineAdapter):
500                # Athena has two table formats: Hive (the default) and Iceberg. TableDiff requires that
501                # the formats be the same for the source, target, and temp tables.
502                source_table_type = self.adapter._query_table_type(self.source_table)
503                target_table_type = self.adapter._query_table_type(self.target_table)
504
505                if source_table_type == "iceberg" and target_table_type == "iceberg":
506                    temp_table_kwargs["table_format"] = "iceberg"
507                # Sets the temp table's format to Iceberg.
508                # If neither source nor target table is Iceberg, it defaults to Hive (Athena's default).
509                elif source_table_type == "iceberg" or target_table_type == "iceberg":
510                    raise SQLMeshError(
511                        f"Source table '{self.source}' format '{source_table_type}' and target table '{self.target}' format '{target_table_type}' "
512                        f"do not match for Athena. Diffing between different table formats is not supported."
513                    )
514
515            with self.adapter.temp_table(
516                query, name=temp_table, target_columns_to_types=None, **temp_table_kwargs
517            ) as table:
518                summary_sums = [
519                    exp.func("SUM", "s_exists").as_("s_count"),
520                    exp.func("SUM", "t_exists").as_("t_count"),
521                    exp.func("SUM", "row_joined").as_("join_count"),
522                    exp.func("SUM", "null_grain").as_("null_grain_count"),
523                    exp.func("SUM", "row_full_match").as_("full_match_count"),
524                    *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons),
525                ]
526
527                if not skip_grain_check:
528                    summary_sums.extend(
529                        [
530                            parse_one(f"COUNT(DISTINCT(s__{SQLMESH_JOIN_KEY_COL}))").as_(
531                                "distinct_count_s"
532                            ),
533                            parse_one(f"COUNT(DISTINCT(t__{SQLMESH_JOIN_KEY_COL}))").as_(
534                                "distinct_count_t"
535                            ),
536                        ]
537                    )
538
539                summary_query = exp.select(*summary_sums).from_(table)
540
541                stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True).fillna(0)
542                stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"]
543                stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"]
544                stats = stats_df.iloc[0].to_dict()
545
546                column_stats_query = (
547                    exp.select(
548                        *(
549                            exp.func(
550                                "ROUND",
551                                100
552                                * (
553                                    exp.cast(
554                                        exp.func("SUM", name(c)), exp.DataType.build("NUMERIC")
555                                    )
556                                    / exp.func("COUNT", name(c))
557                                ),
558                                9,
559                            ).as_(c.alias)
560                            for c in comparisons
561                        )
562                    )
563                    .from_(table)
564                    .where(exp.column("row_joined").eq(exp.Literal.number(1)))
565                )
566
567                column_stats = (
568                    self.adapter.fetchdf(column_stats_query, quote_identifiers=True)
569                    .T.rename(
570                        columns={0: "pct_match"},
571                        index=lambda x: str(x).replace("_matches", "") if x else "",
572                    )
573                    # errors=ignore because all the index_cols might not be present in the DF if the `on` condition was something like "s.id == t.item_id"
574                    # because these would not be present in the matching_cols (since they have different names) and thus no summary would be generated
575                    .drop(index=index_cols, errors="ignore")
576                )
577
578                sample = self._fetch_sample(
579                    table, s_selects, s_index, t_selects, t_index, self.limit
580                )
581
582                joined_sample_cols = [f"s__{c}" for c in s_index_names]
583                comparison_cols = [
584                    (f"s__{c}", f"t__{c}")
585                    for c in column_stats[column_stats["pct_match"] < 100].index
586                ]
587
588                for cols in comparison_cols:
589                    joined_sample_cols.extend(cols)
590
591                joined_renamed_cols = {
592                    c: c.split("__")[1] if c.split("__")[1] in index_cols else c
593                    for c in joined_sample_cols
594                }
595
596                if (
597                    self.source_alias
598                    and self.target_alias
599                    and self.source != self.source_alias
600                    and self.target != self.target_alias
601                ):
602                    joined_renamed_cols = {
603                        c: (
604                            n.replace(
605                                "s__",
606                                f"{self.source_alias.upper()}__",
607                            )
608                            if n.startswith("s__")
609                            else n
610                        )
611                        for c, n in joined_renamed_cols.items()
612                    }
613                    joined_renamed_cols = {
614                        c: (
615                            n.replace(
616                                "t__",
617                                f"{self.target_alias.upper()}__",
618                            )
619                            if n.startswith("t__")
620                            else n
621                        )
622                        for c, n in joined_renamed_cols.items()
623                    }
624
625                joined_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "common_rows"][
626                    joined_sample_cols
627                ]
628                joined_sample.rename(
629                    columns=joined_renamed_cols,
630                    inplace=True,
631                )
632
633                s_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "source_only"][
634                    [
635                        *[f"s__{c}" for c in s_index_names],
636                        *[f"s__{c}" for c in source_schema if c not in s_index_names],
637                    ]
638                ]
639                s_sample.rename(
640                    columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True
641                )
642
643                t_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "target_only"][
644                    [
645                        *[f"t__{c}" for c in t_index_names],
646                        *[f"t__{c}" for c in target_schema if c not in t_index_names],
647                    ]
648                ]
649                t_sample.rename(
650                    columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True
651                )
652
653                sample.drop(
654                    columns=[
655                        f"s__{SQLMESH_JOIN_KEY_COL}",
656                        f"t__{SQLMESH_JOIN_KEY_COL}",
657                        SQLMESH_SAMPLE_TYPE_COL,
658                    ],
659                    inplace=True,
660                )
661
662                self._row_diff = RowDiff(
663                    source=self.source,
664                    target=self.target,
665                    stats=stats,
666                    column_stats=column_stats,
667                    sample=sample,
668                    joined_sample=joined_sample,
669                    s_sample=s_sample,
670                    t_sample=t_sample,
671                    source_alias=self.source_alias,
672                    target_alias=self.target_alias,
673                    model_name=self.model_name,
674                    decimals=self.decimals,
675                )
676
677        return self._row_diff
def name(e: sqlglot.expressions.core.Expr) -> str:
746def name(e: exp.Expr) -> str:
747    return e.args["alias"].sql(identify=True)