Edit on GitHub

sqlmesh.core.table_diff

  1from __future__ import annotations
  2
  3import math
  4import typing as t
  5
  6import pandas as pd
  7from sqlglot import exp
  8from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
  9from sqlglot.optimizer.qualify_columns import quote_identifiers
 10
 11from sqlmesh.utils.pydantic import PydanticModel
 12
 13if t.TYPE_CHECKING:
 14    from sqlmesh.core._typing import TableName
 15    from sqlmesh.core.engine_adapter import EngineAdapter
 16
 17
 18class SchemaDiff(PydanticModel, frozen=True):
 19    """An object containing the schema difference between a source and target table."""
 20
 21    source: str
 22    target: str
 23    source_schema: t.Dict[str, exp.DataType]
 24    target_schema: t.Dict[str, exp.DataType]
 25    source_alias: t.Optional[str] = None
 26    target_alias: t.Optional[str] = None
 27    model_name: t.Optional[str] = None
 28
 29    @property
 30    def added(self) -> t.List[t.Tuple[str, exp.DataType]]:
 31        """Added columns."""
 32        return [(c, t) for c, t in self.target_schema.items() if c not in self.source_schema]
 33
 34    @property
 35    def removed(self) -> t.List[t.Tuple[str, exp.DataType]]:
 36        """Removed columns."""
 37        return [(c, t) for c, t in self.source_schema.items() if c not in self.target_schema]
 38
 39    @property
 40    def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]:
 41        """Columns with modified types."""
 42        modified = {}
 43        for column in self.source_schema.keys() & self.target_schema.keys():
 44            source_type = self.source_schema[column]
 45            target_type = self.target_schema[column]
 46
 47            if source_type != target_type:
 48                modified[column] = (source_type, target_type)
 49        return modified
 50
 51
 52class RowDiff(PydanticModel, frozen=True):
 53    """Summary statistics and a sample dataframe."""
 54
 55    source: str
 56    target: str
 57    stats: t.Dict[str, float]
 58    sample: pd.DataFrame
 59    joined_sample: pd.DataFrame
 60    s_sample: pd.DataFrame
 61    t_sample: pd.DataFrame
 62    column_stats: pd.DataFrame
 63    source_alias: t.Optional[str] = None
 64    target_alias: t.Optional[str] = None
 65    model_name: t.Optional[str] = None
 66
 67    @property
 68    def source_count(self) -> int:
 69        """Count of the source."""
 70        return int(self.stats["s_count"])
 71
 72    @property
 73    def target_count(self) -> int:
 74        """Count of the target."""
 75        return int(self.stats["t_count"])
 76
 77    @property
 78    def count_pct_change(self) -> float:
 79        """The percentage change of the counts."""
 80        if self.source_count == 0:
 81            return math.inf
 82        return ((self.target_count - self.source_count) / self.source_count) * 100
 83
 84    @property
 85    def join_count(self) -> int:
 86        """Count of successfully joined rows."""
 87        return int(self.stats["join_count"])
 88
 89    @property
 90    def s_only_count(self) -> int:
 91        """Count of rows only present in source."""
 92        return int(self.stats["s_only_count"])
 93
 94    @property
 95    def t_only_count(self) -> int:
 96        """Count of rows only present in target."""
 97        return int(self.stats["t_only_count"])
 98
 99
100class TableDiff:
101    """Calculates differences between tables, taking into account schema and row level differences."""
102
103    def __init__(
104        self,
105        adapter: EngineAdapter,
106        source: TableName,
107        target: TableName,
108        on: t.List[str] | exp.Condition,
109        where: t.Optional[str | exp.Condition] = None,
110        limit: int = 20,
111        source_alias: t.Optional[str] = None,
112        target_alias: t.Optional[str] = None,
113        model_name: t.Optional[str] = None,
114    ):
115        self.adapter = adapter
116        self.source = source
117        self.target = target
118        self.dialect = adapter.dialect
119        self.where = exp.condition(where, dialect=self.dialect) if where else None
120        self.limit = limit
121        self.model_name = model_name
122
123        # Support environment aliases for diff output improvement in certain cases
124        self.source_alias = source_alias
125        self.target_alias = target_alias
126
127        if isinstance(on, (list, tuple)):
128            s_table = exp.to_identifier("s", quoted=True)
129            t_table = exp.to_identifier("t", quoted=True)
130
131            self.on: exp.Condition = exp.and_(
132                *(
133                    exp.column(c, s_table).eq(exp.column(c, t_table))
134                    | (
135                        exp.column(c, s_table).is_(exp.null())
136                        & exp.column(c, t_table).is_(exp.null())
137                    )
138                    for c in on
139                )
140            )
141        else:
142            self.on = on
143
144        normalize_identifiers(self.on, dialect=self.dialect)
145
146        self._source_schema: t.Optional[t.Dict[str, exp.DataType]] = None
147        self._target_schema: t.Optional[t.Dict[str, exp.DataType]] = None
148        self._row_diff: t.Optional[RowDiff] = None
149
150    @property
151    def source_schema(self) -> t.Dict[str, exp.DataType]:
152        if self._source_schema is None:
153            self._source_schema = self.adapter.columns(self.source)
154        return self._source_schema
155
156    @property
157    def target_schema(self) -> t.Dict[str, exp.DataType]:
158        if self._target_schema is None:
159            self._target_schema = self.adapter.columns(self.target)
160        return self._target_schema
161
162    def schema_diff(self) -> SchemaDiff:
163        return SchemaDiff(
164            source=self.source,
165            target=self.target,
166            source_schema=self.source_schema,
167            target_schema=self.target_schema,
168            source_alias=self.source_alias,
169            target_alias=self.target_alias,
170            model_name=self.model_name,
171        )
172
173    def row_diff(self) -> RowDiff:
174        if self._row_diff is None:
175            s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in self.source_schema}
176            t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in self.target_schema}
177
178            index_cols = []
179            s_index = []
180            t_index = []
181
182            for col in self.on.find_all(exp.Column):
183                index_cols.append(col.name)
184                if col.table == "s":
185                    s_index.append(col)
186                elif col.table == "t":
187                    t_index.append(col)
188            index_cols = list(dict.fromkeys(index_cols))
189
190            comparisons = [
191                exp.Case()
192                .when(exp.column(c, "s").eq(exp.column(c, "t")), exp.Literal.number(1))
193                .when(
194                    exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()),
195                    exp.Literal.number(1),
196                )
197                .when(
198                    exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()),
199                    exp.Literal.number(0),
200                )
201                .else_(exp.Literal.number(0))
202                .as_(f"{c}_matches")
203                for c, t in self.source_schema.items()
204                if t == self.target_schema.get(c)
205            ]
206
207            def name(e: exp.Expression) -> str:
208                return e.args["alias"].sql(identify=True)
209
210            query = (
211                exp.select(
212                    *s_selects.values(),
213                    *t_selects.values(),
214                    exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in s_index)), 1, 0).as_(
215                        "s_exists"
216                    ),
217                    exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in t_index)), 1, 0).as_(
218                        "t_exists"
219                    ),
220                    exp.func(
221                        "IF",
222                        exp.and_(
223                            *(
224                                exp.and_(
225                                    exp.column(c, "s").eq(exp.column(c, "t")),
226                                    exp.column(c, "s").not_().is_(exp.Null()),
227                                    exp.column(c, "t").not_().is_(exp.Null()),
228                                )
229                                for c in index_cols
230                            ),
231                        ),
232                        1,
233                        0,
234                    ).as_("rows_joined"),
235                    *comparisons,
236                )
237                .from_(exp.alias_(self.source, "s"))
238                .join(
239                    self.target,
240                    on=self.on,
241                    join_type="FULL",
242                    join_alias="t",
243                )
244                .where(self.where)
245            )
246
247            query = quote_identifiers(query, dialect=self.dialect)
248            temp_table = exp.table_("diff", db="sqlmesh_temp", quoted=True)
249
250            with self.adapter.temp_table(query, name=temp_table) as table:
251                summary_query = exp.select(
252                    exp.func("SUM", "s_exists").as_("s_count"),
253                    exp.func("SUM", "t_exists").as_("t_count"),
254                    exp.func("SUM", "rows_joined").as_("join_count"),
255                    *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons),
256                ).from_(table)
257
258                stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True)
259                stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"]
260                stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"]
261                stats = stats_df.iloc[0].to_dict()
262
263                column_stats_query = (
264                    exp.select(
265                        *(
266                            exp.func(
267                                "ROUND",
268                                100 * (exp.func("SUM", name(c)) / exp.func("COUNT", name(c))),
269                                1,
270                            ).as_(c.alias)
271                            for c in comparisons
272                        )
273                    )
274                    .from_(table)
275                    .where(exp.column("rows_joined").eq(exp.Literal.number(1)))
276                )
277                column_stats = (
278                    self.adapter.fetchdf(column_stats_query, quote_identifiers=True)
279                    .T.rename(
280                        columns={0: "pct_match"},
281                        index=lambda x: str(x).replace("_matches", "") if x else "",
282                    )
283                    .drop(index=index_cols)
284                )
285
286                sample_filter_cols = ["s_exists", "t_exists", "rows_joined"]
287                sample_query = (
288                    exp.select(
289                        *(sample_filter_cols),
290                        *(name(c) for c in s_selects.values()),
291                        *(name(c) for c in t_selects.values()),
292                    )
293                    .from_(table)
294                    .where(exp.or_(*(exp.column(c.alias).eq(0) for c in comparisons)))
295                    .order_by(
296                        *(name(s_selects[c.name]) for c in s_index),
297                        *(name(t_selects[c.name]) for c in t_index),
298                    )
299                    .limit(self.limit)
300                )
301                sample = self.adapter.fetchdf(sample_query, quote_identifiers=True)
302
303                joined_sample_cols = [f"s__{c}" for c in index_cols]
304                comparison_cols = [
305                    (f"s__{c}", f"t__{c}")
306                    for c in column_stats[column_stats["pct_match"] < 100].index
307                ]
308                for cols in comparison_cols:
309                    joined_sample_cols.extend(cols)
310                joined_renamed_cols = {
311                    c: c.split("__")[1] if c.split("__")[1] in index_cols else c
312                    for c in joined_sample_cols
313                }
314                if self.source != self.source_alias and self.target != self.target_alias:
315                    joined_renamed_cols = {
316                        c: (
317                            n.replace(
318                                "s__", f"{self.source_alias.upper() if self.source_alias else ''}__"
319                            )
320                            if n.startswith("s__")
321                            else n
322                        )
323                        for c, n in joined_renamed_cols.items()
324                    }
325                    joined_renamed_cols = {
326                        c: (
327                            n.replace(
328                                "t__", f"{self.target_alias.upper() if self.target_alias else ''}__"
329                            )
330                            if n.startswith("t__")
331                            else n
332                        )
333                        for c, n in joined_renamed_cols.items()
334                    }
335                joined_sample = sample[sample["rows_joined"] == 1][joined_sample_cols]
336                joined_sample.rename(
337                    columns=joined_renamed_cols,
338                    inplace=True,
339                )
340
341                s_sample = sample[(sample["s_exists"] == 1) & (sample["rows_joined"] == 0)][
342                    [
343                        *[f"s__{c}" for c in index_cols],
344                        *[f"s__{c}" for c in self.source_schema if c not in index_cols],
345                    ]
346                ]
347                s_sample.rename(
348                    columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True
349                )
350
351                t_sample = sample[(sample["t_exists"] == 1) & (sample["rows_joined"] == 0)][
352                    [
353                        *[f"t__{c}" for c in index_cols],
354                        *[f"t__{c}" for c in self.target_schema if c not in index_cols],
355                    ]
356                ]
357                t_sample.rename(
358                    columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True
359                )
360
361                sample.drop(columns=sample_filter_cols, inplace=True)
362
363                self._row_diff = RowDiff(
364                    source=self.source,
365                    target=self.target,
366                    stats=stats,
367                    column_stats=column_stats,
368                    sample=sample,
369                    joined_sample=joined_sample,
370                    s_sample=s_sample,
371                    t_sample=t_sample,
372                    source_alias=self.source_alias,
373                    target_alias=self.target_alias,
374                    model_name=self.model_name,
375                )
376        return self._row_diff
class SchemaDiff(sqlmesh.utils.pydantic.PydanticModel):
19class SchemaDiff(PydanticModel, frozen=True):
20    """An object containing the schema difference between a source and target table."""
21
22    source: str
23    target: str
24    source_schema: t.Dict[str, exp.DataType]
25    target_schema: t.Dict[str, exp.DataType]
26    source_alias: t.Optional[str] = None
27    target_alias: t.Optional[str] = None
28    model_name: t.Optional[str] = None
29
30    @property
31    def added(self) -> t.List[t.Tuple[str, exp.DataType]]:
32        """Added columns."""
33        return [(c, t) for c, t in self.target_schema.items() if c not in self.source_schema]
34
35    @property
36    def removed(self) -> t.List[t.Tuple[str, exp.DataType]]:
37        """Removed columns."""
38        return [(c, t) for c, t in self.source_schema.items() if c not in self.target_schema]
39
40    @property
41    def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]:
42        """Columns with modified types."""
43        modified = {}
44        for column in self.source_schema.keys() & self.target_schema.keys():
45            source_type = self.source_schema[column]
46            target_type = self.target_schema[column]
47
48            if source_type != target_type:
49                modified[column] = (source_type, target_type)
50        return modified

An object containing the schema difference between a source and target table.

added: List[Tuple[str, sqlglot.expressions.DataType]]

Added columns.

removed: List[Tuple[str, sqlglot.expressions.DataType]]

Removed columns.

modified: Dict[str, Tuple[sqlglot.expressions.DataType, sqlglot.expressions.DataType]]

Columns with modified types.

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
model_post_init
class RowDiff(sqlmesh.utils.pydantic.PydanticModel):
53class RowDiff(PydanticModel, frozen=True):
54    """Summary statistics and a sample dataframe."""
55
56    source: str
57    target: str
58    stats: t.Dict[str, float]
59    sample: pd.DataFrame
60    joined_sample: pd.DataFrame
61    s_sample: pd.DataFrame
62    t_sample: pd.DataFrame
63    column_stats: pd.DataFrame
64    source_alias: t.Optional[str] = None
65    target_alias: t.Optional[str] = None
66    model_name: t.Optional[str] = None
67
68    @property
69    def source_count(self) -> int:
70        """Count of the source."""
71        return int(self.stats["s_count"])
72
73    @property
74    def target_count(self) -> int:
75        """Count of the target."""
76        return int(self.stats["t_count"])
77
78    @property
79    def count_pct_change(self) -> float:
80        """The percentage change of the counts."""
81        if self.source_count == 0:
82            return math.inf
83        return ((self.target_count - self.source_count) / self.source_count) * 100
84
85    @property
86    def join_count(self) -> int:
87        """Count of successfully joined rows."""
88        return int(self.stats["join_count"])
89
90    @property
91    def s_only_count(self) -> int:
92        """Count of rows only present in source."""
93        return int(self.stats["s_only_count"])
94
95    @property
96    def t_only_count(self) -> int:
97        """Count of rows only present in target."""
98        return int(self.stats["t_only_count"])

Summary statistics and a sample dataframe.

source_count: int

Count of the source.

target_count: int

Count of the target.

count_pct_change: float

The percentage change of the counts.

join_count: int

Count of successfully joined rows.

s_only_count: int

Count of rows only present in source.

t_only_count: int

Count of rows only present in target.

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
model_post_init
class TableDiff:
101class TableDiff:
102    """Calculates differences between tables, taking into account schema and row level differences."""
103
104    def __init__(
105        self,
106        adapter: EngineAdapter,
107        source: TableName,
108        target: TableName,
109        on: t.List[str] | exp.Condition,
110        where: t.Optional[str | exp.Condition] = None,
111        limit: int = 20,
112        source_alias: t.Optional[str] = None,
113        target_alias: t.Optional[str] = None,
114        model_name: t.Optional[str] = None,
115    ):
116        self.adapter = adapter
117        self.source = source
118        self.target = target
119        self.dialect = adapter.dialect
120        self.where = exp.condition(where, dialect=self.dialect) if where else None
121        self.limit = limit
122        self.model_name = model_name
123
124        # Support environment aliases for diff output improvement in certain cases
125        self.source_alias = source_alias
126        self.target_alias = target_alias
127
128        if isinstance(on, (list, tuple)):
129            s_table = exp.to_identifier("s", quoted=True)
130            t_table = exp.to_identifier("t", quoted=True)
131
132            self.on: exp.Condition = exp.and_(
133                *(
134                    exp.column(c, s_table).eq(exp.column(c, t_table))
135                    | (
136                        exp.column(c, s_table).is_(exp.null())
137                        & exp.column(c, t_table).is_(exp.null())
138                    )
139                    for c in on
140                )
141            )
142        else:
143            self.on = on
144
145        normalize_identifiers(self.on, dialect=self.dialect)
146
147        self._source_schema: t.Optional[t.Dict[str, exp.DataType]] = None
148        self._target_schema: t.Optional[t.Dict[str, exp.DataType]] = None
149        self._row_diff: t.Optional[RowDiff] = None
150
151    @property
152    def source_schema(self) -> t.Dict[str, exp.DataType]:
153        if self._source_schema is None:
154            self._source_schema = self.adapter.columns(self.source)
155        return self._source_schema
156
157    @property
158    def target_schema(self) -> t.Dict[str, exp.DataType]:
159        if self._target_schema is None:
160            self._target_schema = self.adapter.columns(self.target)
161        return self._target_schema
162
163    def schema_diff(self) -> SchemaDiff:
164        return SchemaDiff(
165            source=self.source,
166            target=self.target,
167            source_schema=self.source_schema,
168            target_schema=self.target_schema,
169            source_alias=self.source_alias,
170            target_alias=self.target_alias,
171            model_name=self.model_name,
172        )
173
174    def row_diff(self) -> RowDiff:
175        if self._row_diff is None:
176            s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in self.source_schema}
177            t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in self.target_schema}
178
179            index_cols = []
180            s_index = []
181            t_index = []
182
183            for col in self.on.find_all(exp.Column):
184                index_cols.append(col.name)
185                if col.table == "s":
186                    s_index.append(col)
187                elif col.table == "t":
188                    t_index.append(col)
189            index_cols = list(dict.fromkeys(index_cols))
190
191            comparisons = [
192                exp.Case()
193                .when(exp.column(c, "s").eq(exp.column(c, "t")), exp.Literal.number(1))
194                .when(
195                    exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()),
196                    exp.Literal.number(1),
197                )
198                .when(
199                    exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()),
200                    exp.Literal.number(0),
201                )
202                .else_(exp.Literal.number(0))
203                .as_(f"{c}_matches")
204                for c, t in self.source_schema.items()
205                if t == self.target_schema.get(c)
206            ]
207
208            def name(e: exp.Expression) -> str:
209                return e.args["alias"].sql(identify=True)
210
211            query = (
212                exp.select(
213                    *s_selects.values(),
214                    *t_selects.values(),
215                    exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in s_index)), 1, 0).as_(
216                        "s_exists"
217                    ),
218                    exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in t_index)), 1, 0).as_(
219                        "t_exists"
220                    ),
221                    exp.func(
222                        "IF",
223                        exp.and_(
224                            *(
225                                exp.and_(
226                                    exp.column(c, "s").eq(exp.column(c, "t")),
227                                    exp.column(c, "s").not_().is_(exp.Null()),
228                                    exp.column(c, "t").not_().is_(exp.Null()),
229                                )
230                                for c in index_cols
231                            ),
232                        ),
233                        1,
234                        0,
235                    ).as_("rows_joined"),
236                    *comparisons,
237                )
238                .from_(exp.alias_(self.source, "s"))
239                .join(
240                    self.target,
241                    on=self.on,
242                    join_type="FULL",
243                    join_alias="t",
244                )
245                .where(self.where)
246            )
247
248            query = quote_identifiers(query, dialect=self.dialect)
249            temp_table = exp.table_("diff", db="sqlmesh_temp", quoted=True)
250
251            with self.adapter.temp_table(query, name=temp_table) as table:
252                summary_query = exp.select(
253                    exp.func("SUM", "s_exists").as_("s_count"),
254                    exp.func("SUM", "t_exists").as_("t_count"),
255                    exp.func("SUM", "rows_joined").as_("join_count"),
256                    *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons),
257                ).from_(table)
258
259                stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True)
260                stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"]
261                stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"]
262                stats = stats_df.iloc[0].to_dict()
263
264                column_stats_query = (
265                    exp.select(
266                        *(
267                            exp.func(
268                                "ROUND",
269                                100 * (exp.func("SUM", name(c)) / exp.func("COUNT", name(c))),
270                                1,
271                            ).as_(c.alias)
272                            for c in comparisons
273                        )
274                    )
275                    .from_(table)
276                    .where(exp.column("rows_joined").eq(exp.Literal.number(1)))
277                )
278                column_stats = (
279                    self.adapter.fetchdf(column_stats_query, quote_identifiers=True)
280                    .T.rename(
281                        columns={0: "pct_match"},
282                        index=lambda x: str(x).replace("_matches", "") if x else "",
283                    )
284                    .drop(index=index_cols)
285                )
286
287                sample_filter_cols = ["s_exists", "t_exists", "rows_joined"]
288                sample_query = (
289                    exp.select(
290                        *(sample_filter_cols),
291                        *(name(c) for c in s_selects.values()),
292                        *(name(c) for c in t_selects.values()),
293                    )
294                    .from_(table)
295                    .where(exp.or_(*(exp.column(c.alias).eq(0) for c in comparisons)))
296                    .order_by(
297                        *(name(s_selects[c.name]) for c in s_index),
298                        *(name(t_selects[c.name]) for c in t_index),
299                    )
300                    .limit(self.limit)
301                )
302                sample = self.adapter.fetchdf(sample_query, quote_identifiers=True)
303
304                joined_sample_cols = [f"s__{c}" for c in index_cols]
305                comparison_cols = [
306                    (f"s__{c}", f"t__{c}")
307                    for c in column_stats[column_stats["pct_match"] < 100].index
308                ]
309                for cols in comparison_cols:
310                    joined_sample_cols.extend(cols)
311                joined_renamed_cols = {
312                    c: c.split("__")[1] if c.split("__")[1] in index_cols else c
313                    for c in joined_sample_cols
314                }
315                if self.source != self.source_alias and self.target != self.target_alias:
316                    joined_renamed_cols = {
317                        c: (
318                            n.replace(
319                                "s__", f"{self.source_alias.upper() if self.source_alias else ''}__"
320                            )
321                            if n.startswith("s__")
322                            else n
323                        )
324                        for c, n in joined_renamed_cols.items()
325                    }
326                    joined_renamed_cols = {
327                        c: (
328                            n.replace(
329                                "t__", f"{self.target_alias.upper() if self.target_alias else ''}__"
330                            )
331                            if n.startswith("t__")
332                            else n
333                        )
334                        for c, n in joined_renamed_cols.items()
335                    }
336                joined_sample = sample[sample["rows_joined"] == 1][joined_sample_cols]
337                joined_sample.rename(
338                    columns=joined_renamed_cols,
339                    inplace=True,
340                )
341
342                s_sample = sample[(sample["s_exists"] == 1) & (sample["rows_joined"] == 0)][
343                    [
344                        *[f"s__{c}" for c in index_cols],
345                        *[f"s__{c}" for c in self.source_schema if c not in index_cols],
346                    ]
347                ]
348                s_sample.rename(
349                    columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True
350                )
351
352                t_sample = sample[(sample["t_exists"] == 1) & (sample["rows_joined"] == 0)][
353                    [
354                        *[f"t__{c}" for c in index_cols],
355                        *[f"t__{c}" for c in self.target_schema if c not in index_cols],
356                    ]
357                ]
358                t_sample.rename(
359                    columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True
360                )
361
362                sample.drop(columns=sample_filter_cols, inplace=True)
363
364                self._row_diff = RowDiff(
365                    source=self.source,
366                    target=self.target,
367                    stats=stats,
368                    column_stats=column_stats,
369                    sample=sample,
370                    joined_sample=joined_sample,
371                    s_sample=s_sample,
372                    t_sample=t_sample,
373                    source_alias=self.source_alias,
374                    target_alias=self.target_alias,
375                    model_name=self.model_name,
376                )
377        return self._row_diff

Calculates differences between tables, taking into account schema and row level differences.

TableDiff( adapter: <MagicMock id='140338205800528'>, source: <MagicMock id='140338206615488'>, target: <MagicMock id='140338206627632'>, on: 't.List[str] | exp.Condition', where: 't.Optional[str | exp.Condition]' = None, limit: int = 20, source_alias: Union[str, NoneType] = None, target_alias: Union[str, NoneType] = None, model_name: Union[str, NoneType] = None)
104    def __init__(
105        self,
106        adapter: EngineAdapter,
107        source: TableName,
108        target: TableName,
109        on: t.List[str] | exp.Condition,
110        where: t.Optional[str | exp.Condition] = None,
111        limit: int = 20,
112        source_alias: t.Optional[str] = None,
113        target_alias: t.Optional[str] = None,
114        model_name: t.Optional[str] = None,
115    ):
116        self.adapter = adapter
117        self.source = source
118        self.target = target
119        self.dialect = adapter.dialect
120        self.where = exp.condition(where, dialect=self.dialect) if where else None
121        self.limit = limit
122        self.model_name = model_name
123
124        # Support environment aliases for diff output improvement in certain cases
125        self.source_alias = source_alias
126        self.target_alias = target_alias
127
128        if isinstance(on, (list, tuple)):
129            s_table = exp.to_identifier("s", quoted=True)
130            t_table = exp.to_identifier("t", quoted=True)
131
132            self.on: exp.Condition = exp.and_(
133                *(
134                    exp.column(c, s_table).eq(exp.column(c, t_table))
135                    | (
136                        exp.column(c, s_table).is_(exp.null())
137                        & exp.column(c, t_table).is_(exp.null())
138                    )
139                    for c in on
140                )
141            )
142        else:
143            self.on = on
144
145        normalize_identifiers(self.on, dialect=self.dialect)
146
147        self._source_schema: t.Optional[t.Dict[str, exp.DataType]] = None
148        self._target_schema: t.Optional[t.Dict[str, exp.DataType]] = None
149        self._row_diff: t.Optional[RowDiff] = None
def schema_diff(self) -> sqlmesh.core.table_diff.SchemaDiff:
163    def schema_diff(self) -> SchemaDiff:
164        return SchemaDiff(
165            source=self.source,
166            target=self.target,
167            source_schema=self.source_schema,
168            target_schema=self.target_schema,
169            source_alias=self.source_alias,
170            target_alias=self.target_alias,
171            model_name=self.model_name,
172        )
def row_diff(self) -> sqlmesh.core.table_diff.RowDiff:
174    def row_diff(self) -> RowDiff:
175        if self._row_diff is None:
176            s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in self.source_schema}
177            t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in self.target_schema}
178
179            index_cols = []
180            s_index = []
181            t_index = []
182
183            for col in self.on.find_all(exp.Column):
184                index_cols.append(col.name)
185                if col.table == "s":
186                    s_index.append(col)
187                elif col.table == "t":
188                    t_index.append(col)
189            index_cols = list(dict.fromkeys(index_cols))
190
191            comparisons = [
192                exp.Case()
193                .when(exp.column(c, "s").eq(exp.column(c, "t")), exp.Literal.number(1))
194                .when(
195                    exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()),
196                    exp.Literal.number(1),
197                )
198                .when(
199                    exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()),
200                    exp.Literal.number(0),
201                )
202                .else_(exp.Literal.number(0))
203                .as_(f"{c}_matches")
204                for c, t in self.source_schema.items()
205                if t == self.target_schema.get(c)
206            ]
207
208            def name(e: exp.Expression) -> str:
209                return e.args["alias"].sql(identify=True)
210
211            query = (
212                exp.select(
213                    *s_selects.values(),
214                    *t_selects.values(),
215                    exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in s_index)), 1, 0).as_(
216                        "s_exists"
217                    ),
218                    exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in t_index)), 1, 0).as_(
219                        "t_exists"
220                    ),
221                    exp.func(
222                        "IF",
223                        exp.and_(
224                            *(
225                                exp.and_(
226                                    exp.column(c, "s").eq(exp.column(c, "t")),
227                                    exp.column(c, "s").not_().is_(exp.Null()),
228                                    exp.column(c, "t").not_().is_(exp.Null()),
229                                )
230                                for c in index_cols
231                            ),
232                        ),
233                        1,
234                        0,
235                    ).as_("rows_joined"),
236                    *comparisons,
237                )
238                .from_(exp.alias_(self.source, "s"))
239                .join(
240                    self.target,
241                    on=self.on,
242                    join_type="FULL",
243                    join_alias="t",
244                )
245                .where(self.where)
246            )
247
248            query = quote_identifiers(query, dialect=self.dialect)
249            temp_table = exp.table_("diff", db="sqlmesh_temp", quoted=True)
250
251            with self.adapter.temp_table(query, name=temp_table) as table:
252                summary_query = exp.select(
253                    exp.func("SUM", "s_exists").as_("s_count"),
254                    exp.func("SUM", "t_exists").as_("t_count"),
255                    exp.func("SUM", "rows_joined").as_("join_count"),
256                    *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons),
257                ).from_(table)
258
259                stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True)
260                stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"]
261                stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"]
262                stats = stats_df.iloc[0].to_dict()
263
264                column_stats_query = (
265                    exp.select(
266                        *(
267                            exp.func(
268                                "ROUND",
269                                100 * (exp.func("SUM", name(c)) / exp.func("COUNT", name(c))),
270                                1,
271                            ).as_(c.alias)
272                            for c in comparisons
273                        )
274                    )
275                    .from_(table)
276                    .where(exp.column("rows_joined").eq(exp.Literal.number(1)))
277                )
278                column_stats = (
279                    self.adapter.fetchdf(column_stats_query, quote_identifiers=True)
280                    .T.rename(
281                        columns={0: "pct_match"},
282                        index=lambda x: str(x).replace("_matches", "") if x else "",
283                    )
284                    .drop(index=index_cols)
285                )
286
287                sample_filter_cols = ["s_exists", "t_exists", "rows_joined"]
288                sample_query = (
289                    exp.select(
290                        *(sample_filter_cols),
291                        *(name(c) for c in s_selects.values()),
292                        *(name(c) for c in t_selects.values()),
293                    )
294                    .from_(table)
295                    .where(exp.or_(*(exp.column(c.alias).eq(0) for c in comparisons)))
296                    .order_by(
297                        *(name(s_selects[c.name]) for c in s_index),
298                        *(name(t_selects[c.name]) for c in t_index),
299                    )
300                    .limit(self.limit)
301                )
302                sample = self.adapter.fetchdf(sample_query, quote_identifiers=True)
303
304                joined_sample_cols = [f"s__{c}" for c in index_cols]
305                comparison_cols = [
306                    (f"s__{c}", f"t__{c}")
307                    for c in column_stats[column_stats["pct_match"] < 100].index
308                ]
309                for cols in comparison_cols:
310                    joined_sample_cols.extend(cols)
311                joined_renamed_cols = {
312                    c: c.split("__")[1] if c.split("__")[1] in index_cols else c
313                    for c in joined_sample_cols
314                }
315                if self.source != self.source_alias and self.target != self.target_alias:
316                    joined_renamed_cols = {
317                        c: (
318                            n.replace(
319                                "s__", f"{self.source_alias.upper() if self.source_alias else ''}__"
320                            )
321                            if n.startswith("s__")
322                            else n
323                        )
324                        for c, n in joined_renamed_cols.items()
325                    }
326                    joined_renamed_cols = {
327                        c: (
328                            n.replace(
329                                "t__", f"{self.target_alias.upper() if self.target_alias else ''}__"
330                            )
331                            if n.startswith("t__")
332                            else n
333                        )
334                        for c, n in joined_renamed_cols.items()
335                    }
336                joined_sample = sample[sample["rows_joined"] == 1][joined_sample_cols]
337                joined_sample.rename(
338                    columns=joined_renamed_cols,
339                    inplace=True,
340                )
341
342                s_sample = sample[(sample["s_exists"] == 1) & (sample["rows_joined"] == 0)][
343                    [
344                        *[f"s__{c}" for c in index_cols],
345                        *[f"s__{c}" for c in self.source_schema if c not in index_cols],
346                    ]
347                ]
348                s_sample.rename(
349                    columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True
350                )
351
352                t_sample = sample[(sample["t_exists"] == 1) & (sample["rows_joined"] == 0)][
353                    [
354                        *[f"t__{c}" for c in index_cols],
355                        *[f"t__{c}" for c in self.target_schema if c not in index_cols],
356                    ]
357                ]
358                t_sample.rename(
359                    columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True
360                )
361
362                sample.drop(columns=sample_filter_cols, inplace=True)
363
364                self._row_diff = RowDiff(
365                    source=self.source,
366                    target=self.target,
367                    stats=stats,
368                    column_stats=column_stats,
369                    sample=sample,
370                    joined_sample=joined_sample,
371                    s_sample=s_sample,
372                    t_sample=t_sample,
373                    source_alias=self.source_alias,
374                    target_alias=self.target_alias,
375                    model_name=self.model_name,
376                )
377        return self._row_diff