sqlmesh.core.table_diff
1from __future__ import annotations 2 3import math 4import typing as t 5from functools import cached_property 6 7from sqlmesh.core.dialect import to_schema 8from sqlmesh.core.engine_adapter.mixins import RowDiffMixin 9from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter 10from sqlglot import exp, parse_one 11from sqlglot.helper import ensure_list 12from sqlglot.optimizer.normalize_identifiers import normalize_identifiers 13from sqlglot.optimizer.qualify_columns import quote_identifiers 14from sqlglot.optimizer.scope import find_all_in_scope 15 16from sqlmesh.utils.pydantic import PydanticModel 17from sqlmesh.utils.errors import SQLMeshError 18 19 20if t.TYPE_CHECKING: 21 import pandas as pd 22 23 from sqlmesh.core._typing import TableName 24 from sqlmesh.core.engine_adapter import EngineAdapter 25 26SQLMESH_JOIN_KEY_COL = "__sqlmesh_join_key" 27SQLMESH_SAMPLE_TYPE_COL = "__sqlmesh_sample_type" 28 29 30class SchemaDiff(PydanticModel, frozen=True): 31 """An object containing the schema difference between a source and target table.""" 32 33 source: str 34 target: str 35 source_schema: t.Dict[str, exp.DataType] 36 target_schema: t.Dict[str, exp.DataType] 37 source_alias: t.Optional[str] = None 38 target_alias: t.Optional[str] = None 39 model_name: t.Optional[str] = None 40 ignore_case: bool = False 41 42 @property 43 def _comparable_source_schema(self) -> t.Dict[str, exp.DataType]: 44 return ( 45 self._lowercase_schema_names(self.source_schema) 46 if self.ignore_case 47 else self.source_schema 48 ) 49 50 @property 51 def _comparable_target_schema(self) -> t.Dict[str, exp.DataType]: 52 return ( 53 self._lowercase_schema_names(self.target_schema) 54 if self.ignore_case 55 else self.target_schema 56 ) 57 58 def _lowercase_schema_names( 59 self, schema: t.Dict[str, exp.DataType] 60 ) -> t.Dict[str, exp.DataType]: 61 return {c.lower(): t for c, t in schema.items()} 62 63 def _original_column_name( 64 self, maybe_lowercased_column_name: str, schema: t.Dict[str, exp.DataType] 65 ) -> str: 66 if not self.ignore_case: 67 return maybe_lowercased_column_name 68 69 return next(c for c in schema if c.lower() == maybe_lowercased_column_name) 70 71 @property 72 def added(self) -> t.List[t.Tuple[str, exp.DataType]]: 73 """Added columns.""" 74 return [ 75 (self._original_column_name(c, self.target_schema), t) 76 for c, t in self._comparable_target_schema.items() 77 if c not in self._comparable_source_schema 78 ] 79 80 @property 81 def removed(self) -> t.List[t.Tuple[str, exp.DataType]]: 82 """Removed columns.""" 83 return [ 84 (self._original_column_name(c, self.source_schema), t) 85 for c, t in self._comparable_source_schema.items() 86 if c not in self._comparable_target_schema 87 ] 88 89 @property 90 def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]: 91 """Columns with modified types.""" 92 modified = {} 93 for column in self._comparable_source_schema.keys() & self._comparable_target_schema.keys(): 94 source_type = self._comparable_source_schema[column] 95 target_type = self._comparable_target_schema[column] 96 97 if source_type != target_type: 98 modified[column] = (source_type, target_type) 99 100 if self.ignore_case: 101 modified = { 102 self._original_column_name(c, self.source_schema): dt for c, dt in modified.items() 103 } 104 105 return modified 106 107 @property 108 def has_changes(self) -> bool: 109 """Does the schema contain any changes at all between source and target""" 110 return bool(self.added or self.removed or self.modified) 111 112 113class RowDiff(PydanticModel, frozen=True): 114 """Summary statistics and a sample dataframe.""" 115 116 source: str 117 target: str 118 stats: t.Dict[str, float] 119 sample: pd.DataFrame 120 joined_sample: pd.DataFrame 121 s_sample: pd.DataFrame 122 t_sample: pd.DataFrame 123 column_stats: pd.DataFrame 124 source_alias: t.Optional[str] = None 125 target_alias: t.Optional[str] = None 126 model_name: t.Optional[str] = None 127 decimals: int = 3 128 129 _types_resolved: t.ClassVar[bool] = False 130 131 def __new__(cls, *args: t.Any, **kwargs: t.Any) -> RowDiff: 132 if not cls._types_resolved: 133 cls._resolve_types() 134 return super().__new__(cls) 135 136 @classmethod 137 def _resolve_types(cls) -> None: 138 # Pandas is imported by type checking so we need to resolve the types with the real import before instantiating 139 import pandas as pd # noqa 140 141 cls.model_rebuild() 142 cls._types_resolved = True 143 144 @property 145 def source_count(self) -> int: 146 """Count of the source.""" 147 return int(self.stats["s_count"]) 148 149 @property 150 def target_count(self) -> int: 151 """Count of the target.""" 152 return int(self.stats["t_count"]) 153 154 @property 155 def empty(self) -> bool: 156 return ( 157 self.source_count == 0 158 and self.target_count == 0 159 and self.s_only_count == 0 160 and self.t_only_count == 0 161 ) 162 163 @property 164 def count_pct_change(self) -> float: 165 """The percentage change of the counts.""" 166 if self.source_count == 0: 167 return math.inf 168 return ((self.target_count - self.source_count) / self.source_count) * 100 169 170 @property 171 def join_count(self) -> int: 172 """Count of successfully joined rows.""" 173 return int(self.stats["join_count"]) 174 175 @property 176 def full_match_count(self) -> int: 177 """The number of rows for which shared columns have same values.""" 178 return int(self.stats["full_match_count"]) 179 180 @property 181 def full_match_pct(self) -> float: 182 """The percentage of rows for which shared columns have same values.""" 183 return self._pct(2 * self.full_match_count) 184 185 @property 186 def partial_match_count(self) -> int: 187 """The number of rows for which some shared columns have same values.""" 188 return self.join_count - self.full_match_count 189 190 @property 191 def partial_match_pct(self) -> float: 192 """The percentage of rows for which some shared columns have same values.""" 193 return self._pct(2 * self.partial_match_count) 194 195 @property 196 def s_only_count(self) -> int: 197 """Count of rows only present in source.""" 198 return int(self.stats["s_only_count"]) 199 200 @property 201 def s_only_pct(self) -> float: 202 """The percentage of rows that are only present in source.""" 203 return self._pct(self.s_only_count) 204 205 @property 206 def t_only_count(self) -> int: 207 """Count of rows only present in target.""" 208 return int(self.stats["t_only_count"]) 209 210 @property 211 def t_only_pct(self) -> float: 212 """The percentage of rows that are only present in target.""" 213 return self._pct(self.t_only_count) 214 215 def _pct(self, numerator: int) -> float: 216 return round((numerator / (self.source_count + self.target_count)) * 100, 2) 217 218 219class TableDiff: 220 """Calculates differences between tables, taking into account schema and row level differences.""" 221 222 def __init__( 223 self, 224 adapter: EngineAdapter, 225 source: TableName, 226 target: TableName, 227 on: t.List[str] | exp.Expr, 228 skip_columns: t.List[str] | None = None, 229 where: t.Optional[str | exp.Expr] = None, 230 limit: int = 20, 231 source_alias: t.Optional[str] = None, 232 target_alias: t.Optional[str] = None, 233 model_name: t.Optional[str] = None, 234 model_dialect: t.Optional[str] = None, 235 decimals: int = 3, 236 schema_diff_ignore_case: bool = False, 237 ): 238 if not isinstance(adapter, RowDiffMixin): 239 raise ValueError(f"Engine {adapter} doesnt support RowDiff") 240 241 self.adapter = adapter 242 self.source = source 243 self.target = target 244 self.dialect = adapter.dialect 245 self.source_table = exp.to_table(self.source, dialect=self.dialect) 246 self.target_table = exp.to_table(self.target, dialect=self.dialect) 247 self.where = exp.condition(where, dialect=self.dialect) if where else None 248 self.limit = limit 249 self.model_name = model_name 250 self.model_dialect = model_dialect 251 self.decimals = decimals 252 self.schema_diff_ignore_case = schema_diff_ignore_case 253 254 # Support environment aliases for diff output improvement in certain cases 255 self.source_alias = source_alias 256 self.target_alias = target_alias 257 258 self.skip_columns = { 259 normalize_identifiers( 260 exp.parse_identifier(t.cast(str, col)), 261 dialect=self.model_dialect or self.dialect, 262 ).name 263 for col in ensure_list(skip_columns) 264 } 265 266 self._on = on 267 self._row_diff: t.Optional[RowDiff] = None 268 269 @cached_property 270 def source_schema(self) -> t.Dict[str, exp.DataType]: 271 return self.adapter.columns(self.source_table) 272 273 @cached_property 274 def target_schema(self) -> t.Dict[str, exp.DataType]: 275 return self.adapter.columns(self.target_table) 276 277 @cached_property 278 def key_columns(self) -> t.Tuple[t.List[exp.Column], t.List[exp.Column], t.List[str]]: 279 dialect = self.model_dialect or self.dialect 280 281 # If the columns to join on are explicitly specified, then just return them 282 if isinstance(self._on, (list, tuple)): 283 identifiers = [normalize_identifiers(c, dialect=dialect) for c in self._on] 284 s_index = [exp.column(c, "s") for c in identifiers] 285 t_index = [exp.column(c, "t") for c in identifiers] 286 return s_index, t_index, [i.name for i in identifiers] 287 288 # Otherwise, we need to parse them out of the supplied "on" condition 289 index_cols = [] 290 s_index = [] 291 t_index = [] 292 293 normalize_identifiers(self._on, dialect=dialect) 294 for col in self._on.find_all(exp.Column): 295 index_cols.append(col.name) 296 if col.table.lower() == "s": 297 s_index.append(col) 298 elif col.table.lower() == "t": 299 t_index.append(col) 300 301 index_cols = list(dict.fromkeys(index_cols)) 302 s_index = list(dict.fromkeys(s_index)) 303 t_index = list(dict.fromkeys(t_index)) 304 305 return s_index, t_index, index_cols 306 307 @property 308 def source_key_expression(self) -> exp.Expr: 309 s_index, _, _ = self.key_columns 310 return self._key_expression(s_index, self.source_schema) 311 312 @property 313 def target_key_expression(self) -> exp.Expr: 314 _, t_index, _ = self.key_columns 315 return self._key_expression(t_index, self.target_schema) 316 317 def _key_expression( 318 self, cols: t.List[exp.Column], schema: t.Dict[str, exp.DataType] 319 ) -> exp.Expr: 320 # if there is a single column, dont do anything fancy to it in order to allow existing indexes to be hit 321 if len(cols) == 1: 322 return exp.to_column(cols[0].name) 323 324 # if there are multiple columns, turn them into a single column by stringify-ing/concatenating them together 325 key_columns_to_types = {key.name: schema[key.name] for key in cols} 326 return self.adapter.concat_columns(key_columns_to_types, self.decimals) 327 328 def schema_diff(self) -> SchemaDiff: 329 return SchemaDiff( 330 source=self.source, 331 target=self.target, 332 source_schema=self.source_schema, 333 target_schema=self.target_schema, 334 source_alias=self.source_alias, 335 target_alias=self.target_alias, 336 model_name=self.model_name, 337 ignore_case=self.schema_diff_ignore_case, 338 ) 339 340 def row_diff( 341 self, temp_schema: t.Optional[str] = None, skip_grain_check: bool = False 342 ) -> RowDiff: 343 if self._row_diff is None: 344 source_schema = { 345 c: t for c, t in self.source_schema.items() if c not in self.skip_columns 346 } 347 target_schema = { 348 c: t for c, t in self.target_schema.items() if c not in self.skip_columns 349 } 350 351 s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in source_schema} 352 t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in target_schema} 353 s_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "s").as_( 354 f"s__{SQLMESH_JOIN_KEY_COL}" 355 ) 356 t_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "t").as_( 357 f"t__{SQLMESH_JOIN_KEY_COL}" 358 ) 359 360 matched_columns = {c: t for c, t in source_schema.items() if t == target_schema.get(c)} 361 362 s_index, t_index, index_cols = self.key_columns 363 s_index_names = [c.name for c in s_index] 364 t_index_names = [t.name for t in t_index] 365 366 def _column_expr(name: str, table: str) -> exp.Expr: 367 column_type = matched_columns[name] 368 qualified_column = exp.column(name, table) 369 370 if column_type.is_type(*exp.DataType.REAL_TYPES): 371 return self.adapter._normalize_decimal_value(qualified_column, self.decimals) 372 if column_type.is_type(*exp.DataType.NESTED_TYPES): 373 return self.adapter._normalize_nested_value(qualified_column) 374 375 return qualified_column 376 377 comparisons = [ 378 exp.Case() 379 .when(_column_expr(c, "s").eq(_column_expr(c, "t")), exp.Literal.number(1)) 380 .when( 381 exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()), 382 exp.Literal.number(1), 383 ) 384 .when( 385 exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()), 386 exp.Literal.number(0), 387 ) 388 .else_(exp.Literal.number(0)) 389 .as_(f"{c}_matches") 390 for c, t in matched_columns.items() 391 ] 392 393 source_query = ( 394 exp.select( 395 *(exp.column(c) for c in source_schema), 396 self.source_key_expression.as_(SQLMESH_JOIN_KEY_COL), 397 ) 398 .from_(self.source_table.as_("s")) 399 .where(self.where) 400 ) 401 target_query = ( 402 exp.select( 403 *(exp.column(c) for c in target_schema), 404 self.target_key_expression.as_(SQLMESH_JOIN_KEY_COL), 405 ) 406 .from_(self.target_table.as_("t")) 407 .where(self.where) 408 ) 409 410 # Ensure every column is qualified with the alias in the source and target queries 411 for col in find_all_in_scope(source_query, exp.Column): 412 col.set("table", exp.to_identifier("s")) 413 for col in find_all_in_scope(target_query, exp.Column): 414 col.set("table", exp.to_identifier("t")) 415 416 source_table = exp.table_("__source") 417 target_table = exp.table_("__target") 418 stats_table = exp.table_("__stats") 419 420 stats_query = ( 421 exp.select( 422 *s_selects.values(), 423 *t_selects.values(), 424 exp.func( 425 "IF", exp.column(SQLMESH_JOIN_KEY_COL, "s").is_(exp.Null()).not_(), 1, 0 426 ).as_("s_exists"), 427 exp.func( 428 "IF", exp.column(SQLMESH_JOIN_KEY_COL, "t").is_(exp.Null()).not_(), 1, 0 429 ).as_("t_exists"), 430 exp.func( 431 "IF", 432 exp.column(SQLMESH_JOIN_KEY_COL, "s").eq( 433 exp.column(SQLMESH_JOIN_KEY_COL, "t") 434 ), 435 1, 436 0, 437 ).as_("row_joined"), 438 exp.func( 439 "IF", 440 exp.or_( 441 *( 442 exp.and_( 443 s.is_(exp.Null()), 444 t.is_(exp.Null()), 445 ) 446 for s, t in zip(s_index, t_index) 447 ), 448 ), 449 1, 450 0, 451 ).as_("null_grain"), 452 *comparisons, 453 ) 454 .from_(source_table.as_("s")) 455 .join( 456 target_table.as_("t"), 457 on=exp.column(SQLMESH_JOIN_KEY_COL, "s").eq( 458 exp.column(SQLMESH_JOIN_KEY_COL, "t") 459 ), 460 join_type="FULL", 461 ) 462 ) 463 464 base_query = ( 465 exp.Select() 466 .with_(source_table, source_query) 467 .with_(target_table, target_query) 468 .with_(stats_table, stats_query) 469 .select( 470 "*", 471 exp.Case() 472 .when( 473 exp.and_( 474 *[ 475 exp.column(f"{c}_matches").eq(exp.Literal.number(1)) 476 for c in matched_columns 477 ] 478 ), 479 exp.Literal.number(1), 480 ) 481 .else_(exp.Literal.number(0)) 482 .as_("row_full_match"), 483 ) 484 .from_(stats_table) 485 ) 486 487 query = self.adapter.ensure_nulls_for_unmatched_after_join( 488 quote_identifiers(base_query.copy(), dialect=self.model_dialect or self.dialect) 489 ) 490 491 if not temp_schema: 492 temp_schema = "sqlmesh_temp" 493 494 schema = to_schema(temp_schema, dialect=self.dialect) 495 temp_table = exp.table_("diff", db=schema.db, catalog=schema.catalog, quoted=True) 496 497 temp_table_kwargs: t.Dict[str, t.Any] = {} 498 if isinstance(self.adapter, AthenaEngineAdapter): 499 # Athena has two table formats: Hive (the default) and Iceberg. TableDiff requires that 500 # the formats be the same for the source, target, and temp tables. 501 source_table_type = self.adapter._query_table_type(self.source_table) 502 target_table_type = self.adapter._query_table_type(self.target_table) 503 504 if source_table_type == "iceberg" and target_table_type == "iceberg": 505 temp_table_kwargs["table_format"] = "iceberg" 506 # Sets the temp table's format to Iceberg. 507 # If neither source nor target table is Iceberg, it defaults to Hive (Athena's default). 508 elif source_table_type == "iceberg" or target_table_type == "iceberg": 509 raise SQLMeshError( 510 f"Source table '{self.source}' format '{source_table_type}' and target table '{self.target}' format '{target_table_type}' " 511 f"do not match for Athena. Diffing between different table formats is not supported." 512 ) 513 514 with self.adapter.temp_table( 515 query, name=temp_table, target_columns_to_types=None, **temp_table_kwargs 516 ) as table: 517 summary_sums = [ 518 exp.func("SUM", "s_exists").as_("s_count"), 519 exp.func("SUM", "t_exists").as_("t_count"), 520 exp.func("SUM", "row_joined").as_("join_count"), 521 exp.func("SUM", "null_grain").as_("null_grain_count"), 522 exp.func("SUM", "row_full_match").as_("full_match_count"), 523 *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons), 524 ] 525 526 if not skip_grain_check: 527 summary_sums.extend( 528 [ 529 parse_one(f"COUNT(DISTINCT(s__{SQLMESH_JOIN_KEY_COL}))").as_( 530 "distinct_count_s" 531 ), 532 parse_one(f"COUNT(DISTINCT(t__{SQLMESH_JOIN_KEY_COL}))").as_( 533 "distinct_count_t" 534 ), 535 ] 536 ) 537 538 summary_query = exp.select(*summary_sums).from_(table) 539 540 stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True).fillna(0) 541 stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"] 542 stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"] 543 stats = stats_df.iloc[0].to_dict() 544 545 column_stats_query = ( 546 exp.select( 547 *( 548 exp.func( 549 "ROUND", 550 100 551 * ( 552 exp.cast( 553 exp.func("SUM", name(c)), exp.DataType.build("NUMERIC") 554 ) 555 / exp.func("COUNT", name(c)) 556 ), 557 9, 558 ).as_(c.alias) 559 for c in comparisons 560 ) 561 ) 562 .from_(table) 563 .where(exp.column("row_joined").eq(exp.Literal.number(1))) 564 ) 565 566 column_stats = ( 567 self.adapter.fetchdf(column_stats_query, quote_identifiers=True) 568 .T.rename( 569 columns={0: "pct_match"}, 570 index=lambda x: str(x).replace("_matches", "") if x else "", 571 ) 572 # errors=ignore because all the index_cols might not be present in the DF if the `on` condition was something like "s.id == t.item_id" 573 # because these would not be present in the matching_cols (since they have different names) and thus no summary would be generated 574 .drop(index=index_cols, errors="ignore") 575 ) 576 577 sample = self._fetch_sample( 578 table, s_selects, s_index, t_selects, t_index, self.limit 579 ) 580 581 joined_sample_cols = [f"s__{c}" for c in s_index_names] 582 comparison_cols = [ 583 (f"s__{c}", f"t__{c}") 584 for c in column_stats[column_stats["pct_match"] < 100].index 585 ] 586 587 for cols in comparison_cols: 588 joined_sample_cols.extend(cols) 589 590 joined_renamed_cols = { 591 c: c.split("__")[1] if c.split("__")[1] in index_cols else c 592 for c in joined_sample_cols 593 } 594 595 if ( 596 self.source_alias 597 and self.target_alias 598 and self.source != self.source_alias 599 and self.target != self.target_alias 600 ): 601 joined_renamed_cols = { 602 c: ( 603 n.replace( 604 "s__", 605 f"{self.source_alias.upper()}__", 606 ) 607 if n.startswith("s__") 608 else n 609 ) 610 for c, n in joined_renamed_cols.items() 611 } 612 joined_renamed_cols = { 613 c: ( 614 n.replace( 615 "t__", 616 f"{self.target_alias.upper()}__", 617 ) 618 if n.startswith("t__") 619 else n 620 ) 621 for c, n in joined_renamed_cols.items() 622 } 623 624 joined_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "common_rows"][ 625 joined_sample_cols 626 ] 627 joined_sample.rename( 628 columns=joined_renamed_cols, 629 inplace=True, 630 ) 631 632 s_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "source_only"][ 633 [ 634 *[f"s__{c}" for c in s_index_names], 635 *[f"s__{c}" for c in source_schema if c not in s_index_names], 636 ] 637 ] 638 s_sample.rename( 639 columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True 640 ) 641 642 t_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "target_only"][ 643 [ 644 *[f"t__{c}" for c in t_index_names], 645 *[f"t__{c}" for c in target_schema if c not in t_index_names], 646 ] 647 ] 648 t_sample.rename( 649 columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True 650 ) 651 652 sample.drop( 653 columns=[ 654 f"s__{SQLMESH_JOIN_KEY_COL}", 655 f"t__{SQLMESH_JOIN_KEY_COL}", 656 SQLMESH_SAMPLE_TYPE_COL, 657 ], 658 inplace=True, 659 ) 660 661 self._row_diff = RowDiff( 662 source=self.source, 663 target=self.target, 664 stats=stats, 665 column_stats=column_stats, 666 sample=sample, 667 joined_sample=joined_sample, 668 s_sample=s_sample, 669 t_sample=t_sample, 670 source_alias=self.source_alias, 671 target_alias=self.target_alias, 672 model_name=self.model_name, 673 decimals=self.decimals, 674 ) 675 676 return self._row_diff 677 678 def _fetch_sample( 679 self, 680 sample_table: exp.Table, 681 s_selects: t.Dict[str, exp.Expr], 682 s_index: t.List[exp.Column], 683 t_selects: t.Dict[str, exp.Expr], 684 t_index: t.List[exp.Column], 685 limit: int, 686 ) -> pd.DataFrame: 687 rendered_data_column_names = [ 688 name(s) for s in list(s_selects.values()) + list(t_selects.values()) 689 ] 690 sample_type = exp.to_identifier(SQLMESH_SAMPLE_TYPE_COL) 691 692 source_only_sample = ( 693 exp.select( 694 exp.Literal.string("source_only").as_(sample_type), *rendered_data_column_names 695 ) 696 .from_(sample_table) 697 .where(exp.and_(exp.column("s_exists").eq(1), exp.column("row_joined").eq(0))) 698 .order_by(*(name(s_selects[c.name]) for c in s_index)) 699 .limit(limit) 700 ) 701 702 target_only_sample = ( 703 exp.select( 704 exp.Literal.string("target_only").as_(sample_type), *rendered_data_column_names 705 ) 706 .from_(sample_table) 707 .where(exp.and_(exp.column("t_exists").eq(1), exp.column("row_joined").eq(0))) 708 .order_by(*(name(t_selects[c.name]) for c in t_index)) 709 .limit(limit) 710 ) 711 712 common_rows_sample = ( 713 exp.select( 714 exp.Literal.string("common_rows").as_(sample_type), *rendered_data_column_names 715 ) 716 .from_(sample_table) 717 .where(exp.and_(exp.column("row_joined").eq(1), exp.column("row_full_match").eq(0))) 718 .order_by( 719 *(name(s_selects[c.name]) for c in s_index), 720 *(name(t_selects[c.name]) for c in t_index), 721 ) 722 .limit(limit) 723 ) 724 725 query = ( 726 exp.Select() 727 .with_("source_only", source_only_sample) 728 .with_("target_only", target_only_sample) 729 .with_("common_rows", common_rows_sample) 730 .select(sample_type, *rendered_data_column_names) 731 .from_("source_only") 732 .union( 733 exp.select(sample_type, *rendered_data_column_names).from_("target_only"), 734 distinct=False, 735 ) 736 .union( 737 exp.select(sample_type, *rendered_data_column_names).from_("common_rows"), 738 distinct=False, 739 ) 740 ) 741 742 return self.adapter.fetchdf(query, quote_identifiers=True) 743 744 745def name(e: exp.Expr) -> str: 746 return e.args["alias"].sql(identify=True)
31class SchemaDiff(PydanticModel, frozen=True): 32 """An object containing the schema difference between a source and target table.""" 33 34 source: str 35 target: str 36 source_schema: t.Dict[str, exp.DataType] 37 target_schema: t.Dict[str, exp.DataType] 38 source_alias: t.Optional[str] = None 39 target_alias: t.Optional[str] = None 40 model_name: t.Optional[str] = None 41 ignore_case: bool = False 42 43 @property 44 def _comparable_source_schema(self) -> t.Dict[str, exp.DataType]: 45 return ( 46 self._lowercase_schema_names(self.source_schema) 47 if self.ignore_case 48 else self.source_schema 49 ) 50 51 @property 52 def _comparable_target_schema(self) -> t.Dict[str, exp.DataType]: 53 return ( 54 self._lowercase_schema_names(self.target_schema) 55 if self.ignore_case 56 else self.target_schema 57 ) 58 59 def _lowercase_schema_names( 60 self, schema: t.Dict[str, exp.DataType] 61 ) -> t.Dict[str, exp.DataType]: 62 return {c.lower(): t for c, t in schema.items()} 63 64 def _original_column_name( 65 self, maybe_lowercased_column_name: str, schema: t.Dict[str, exp.DataType] 66 ) -> str: 67 if not self.ignore_case: 68 return maybe_lowercased_column_name 69 70 return next(c for c in schema if c.lower() == maybe_lowercased_column_name) 71 72 @property 73 def added(self) -> t.List[t.Tuple[str, exp.DataType]]: 74 """Added columns.""" 75 return [ 76 (self._original_column_name(c, self.target_schema), t) 77 for c, t in self._comparable_target_schema.items() 78 if c not in self._comparable_source_schema 79 ] 80 81 @property 82 def removed(self) -> t.List[t.Tuple[str, exp.DataType]]: 83 """Removed columns.""" 84 return [ 85 (self._original_column_name(c, self.source_schema), t) 86 for c, t in self._comparable_source_schema.items() 87 if c not in self._comparable_target_schema 88 ] 89 90 @property 91 def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]: 92 """Columns with modified types.""" 93 modified = {} 94 for column in self._comparable_source_schema.keys() & self._comparable_target_schema.keys(): 95 source_type = self._comparable_source_schema[column] 96 target_type = self._comparable_target_schema[column] 97 98 if source_type != target_type: 99 modified[column] = (source_type, target_type) 100 101 if self.ignore_case: 102 modified = { 103 self._original_column_name(c, self.source_schema): dt for c, dt in modified.items() 104 } 105 106 return modified 107 108 @property 109 def has_changes(self) -> bool: 110 """Does the schema contain any changes at all between source and target""" 111 return bool(self.added or self.removed or self.modified)
An object containing the schema difference between a source and target table.
72 @property 73 def added(self) -> t.List[t.Tuple[str, exp.DataType]]: 74 """Added columns.""" 75 return [ 76 (self._original_column_name(c, self.target_schema), t) 77 for c, t in self._comparable_target_schema.items() 78 if c not in self._comparable_source_schema 79 ]
Added columns.
81 @property 82 def removed(self) -> t.List[t.Tuple[str, exp.DataType]]: 83 """Removed columns.""" 84 return [ 85 (self._original_column_name(c, self.source_schema), t) 86 for c, t in self._comparable_source_schema.items() 87 if c not in self._comparable_target_schema 88 ]
Removed columns.
90 @property 91 def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]: 92 """Columns with modified types.""" 93 modified = {} 94 for column in self._comparable_source_schema.keys() & self._comparable_target_schema.keys(): 95 source_type = self._comparable_source_schema[column] 96 target_type = self._comparable_target_schema[column] 97 98 if source_type != target_type: 99 modified[column] = (source_type, target_type) 100 101 if self.ignore_case: 102 modified = { 103 self._original_column_name(c, self.source_schema): dt for c, dt in modified.items() 104 } 105 106 return modified
Columns with modified types.
108 @property 109 def has_changes(self) -> bool: 110 """Does the schema contain any changes at all between source and target""" 111 return bool(self.added or self.removed or self.modified)
Does the schema contain any changes at all between source and target
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
114class RowDiff(PydanticModel, frozen=True): 115 """Summary statistics and a sample dataframe.""" 116 117 source: str 118 target: str 119 stats: t.Dict[str, float] 120 sample: pd.DataFrame 121 joined_sample: pd.DataFrame 122 s_sample: pd.DataFrame 123 t_sample: pd.DataFrame 124 column_stats: pd.DataFrame 125 source_alias: t.Optional[str] = None 126 target_alias: t.Optional[str] = None 127 model_name: t.Optional[str] = None 128 decimals: int = 3 129 130 _types_resolved: t.ClassVar[bool] = False 131 132 def __new__(cls, *args: t.Any, **kwargs: t.Any) -> RowDiff: 133 if not cls._types_resolved: 134 cls._resolve_types() 135 return super().__new__(cls) 136 137 @classmethod 138 def _resolve_types(cls) -> None: 139 # Pandas is imported by type checking so we need to resolve the types with the real import before instantiating 140 import pandas as pd # noqa 141 142 cls.model_rebuild() 143 cls._types_resolved = True 144 145 @property 146 def source_count(self) -> int: 147 """Count of the source.""" 148 return int(self.stats["s_count"]) 149 150 @property 151 def target_count(self) -> int: 152 """Count of the target.""" 153 return int(self.stats["t_count"]) 154 155 @property 156 def empty(self) -> bool: 157 return ( 158 self.source_count == 0 159 and self.target_count == 0 160 and self.s_only_count == 0 161 and self.t_only_count == 0 162 ) 163 164 @property 165 def count_pct_change(self) -> float: 166 """The percentage change of the counts.""" 167 if self.source_count == 0: 168 return math.inf 169 return ((self.target_count - self.source_count) / self.source_count) * 100 170 171 @property 172 def join_count(self) -> int: 173 """Count of successfully joined rows.""" 174 return int(self.stats["join_count"]) 175 176 @property 177 def full_match_count(self) -> int: 178 """The number of rows for which shared columns have same values.""" 179 return int(self.stats["full_match_count"]) 180 181 @property 182 def full_match_pct(self) -> float: 183 """The percentage of rows for which shared columns have same values.""" 184 return self._pct(2 * self.full_match_count) 185 186 @property 187 def partial_match_count(self) -> int: 188 """The number of rows for which some shared columns have same values.""" 189 return self.join_count - self.full_match_count 190 191 @property 192 def partial_match_pct(self) -> float: 193 """The percentage of rows for which some shared columns have same values.""" 194 return self._pct(2 * self.partial_match_count) 195 196 @property 197 def s_only_count(self) -> int: 198 """Count of rows only present in source.""" 199 return int(self.stats["s_only_count"]) 200 201 @property 202 def s_only_pct(self) -> float: 203 """The percentage of rows that are only present in source.""" 204 return self._pct(self.s_only_count) 205 206 @property 207 def t_only_count(self) -> int: 208 """Count of rows only present in target.""" 209 return int(self.stats["t_only_count"]) 210 211 @property 212 def t_only_pct(self) -> float: 213 """The percentage of rows that are only present in target.""" 214 return self._pct(self.t_only_count) 215 216 def _pct(self, numerator: int) -> float: 217 return round((numerator / (self.source_count + self.target_count)) * 100, 2)
Summary statistics and a sample dataframe.
145 @property 146 def source_count(self) -> int: 147 """Count of the source.""" 148 return int(self.stats["s_count"])
Count of the source.
150 @property 151 def target_count(self) -> int: 152 """Count of the target.""" 153 return int(self.stats["t_count"])
Count of the target.
164 @property 165 def count_pct_change(self) -> float: 166 """The percentage change of the counts.""" 167 if self.source_count == 0: 168 return math.inf 169 return ((self.target_count - self.source_count) / self.source_count) * 100
The percentage change of the counts.
171 @property 172 def join_count(self) -> int: 173 """Count of successfully joined rows.""" 174 return int(self.stats["join_count"])
Count of successfully joined rows.
176 @property 177 def full_match_count(self) -> int: 178 """The number of rows for which shared columns have same values.""" 179 return int(self.stats["full_match_count"])
The number of rows for which shared columns have same values.
181 @property 182 def full_match_pct(self) -> float: 183 """The percentage of rows for which shared columns have same values.""" 184 return self._pct(2 * self.full_match_count)
The percentage of rows for which shared columns have same values.
186 @property 187 def partial_match_count(self) -> int: 188 """The number of rows for which some shared columns have same values.""" 189 return self.join_count - self.full_match_count
The number of rows for which some shared columns have same values.
191 @property 192 def partial_match_pct(self) -> float: 193 """The percentage of rows for which some shared columns have same values.""" 194 return self._pct(2 * self.partial_match_count)
The percentage of rows for which some shared columns have same values.
196 @property 197 def s_only_count(self) -> int: 198 """Count of rows only present in source.""" 199 return int(self.stats["s_only_count"])
Count of rows only present in source.
201 @property 202 def s_only_pct(self) -> float: 203 """The percentage of rows that are only present in source.""" 204 return self._pct(self.s_only_count)
The percentage of rows that are only present in source.
206 @property 207 def t_only_count(self) -> int: 208 """Count of rows only present in target.""" 209 return int(self.stats["t_only_count"])
Count of rows only present in target.
211 @property 212 def t_only_pct(self) -> float: 213 """The percentage of rows that are only present in target.""" 214 return self._pct(self.t_only_count)
The percentage of rows that are only present in target.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
220class TableDiff: 221 """Calculates differences between tables, taking into account schema and row level differences.""" 222 223 def __init__( 224 self, 225 adapter: EngineAdapter, 226 source: TableName, 227 target: TableName, 228 on: t.List[str] | exp.Expr, 229 skip_columns: t.List[str] | None = None, 230 where: t.Optional[str | exp.Expr] = None, 231 limit: int = 20, 232 source_alias: t.Optional[str] = None, 233 target_alias: t.Optional[str] = None, 234 model_name: t.Optional[str] = None, 235 model_dialect: t.Optional[str] = None, 236 decimals: int = 3, 237 schema_diff_ignore_case: bool = False, 238 ): 239 if not isinstance(adapter, RowDiffMixin): 240 raise ValueError(f"Engine {adapter} doesnt support RowDiff") 241 242 self.adapter = adapter 243 self.source = source 244 self.target = target 245 self.dialect = adapter.dialect 246 self.source_table = exp.to_table(self.source, dialect=self.dialect) 247 self.target_table = exp.to_table(self.target, dialect=self.dialect) 248 self.where = exp.condition(where, dialect=self.dialect) if where else None 249 self.limit = limit 250 self.model_name = model_name 251 self.model_dialect = model_dialect 252 self.decimals = decimals 253 self.schema_diff_ignore_case = schema_diff_ignore_case 254 255 # Support environment aliases for diff output improvement in certain cases 256 self.source_alias = source_alias 257 self.target_alias = target_alias 258 259 self.skip_columns = { 260 normalize_identifiers( 261 exp.parse_identifier(t.cast(str, col)), 262 dialect=self.model_dialect or self.dialect, 263 ).name 264 for col in ensure_list(skip_columns) 265 } 266 267 self._on = on 268 self._row_diff: t.Optional[RowDiff] = None 269 270 @cached_property 271 def source_schema(self) -> t.Dict[str, exp.DataType]: 272 return self.adapter.columns(self.source_table) 273 274 @cached_property 275 def target_schema(self) -> t.Dict[str, exp.DataType]: 276 return self.adapter.columns(self.target_table) 277 278 @cached_property 279 def key_columns(self) -> t.Tuple[t.List[exp.Column], t.List[exp.Column], t.List[str]]: 280 dialect = self.model_dialect or self.dialect 281 282 # If the columns to join on are explicitly specified, then just return them 283 if isinstance(self._on, (list, tuple)): 284 identifiers = [normalize_identifiers(c, dialect=dialect) for c in self._on] 285 s_index = [exp.column(c, "s") for c in identifiers] 286 t_index = [exp.column(c, "t") for c in identifiers] 287 return s_index, t_index, [i.name for i in identifiers] 288 289 # Otherwise, we need to parse them out of the supplied "on" condition 290 index_cols = [] 291 s_index = [] 292 t_index = [] 293 294 normalize_identifiers(self._on, dialect=dialect) 295 for col in self._on.find_all(exp.Column): 296 index_cols.append(col.name) 297 if col.table.lower() == "s": 298 s_index.append(col) 299 elif col.table.lower() == "t": 300 t_index.append(col) 301 302 index_cols = list(dict.fromkeys(index_cols)) 303 s_index = list(dict.fromkeys(s_index)) 304 t_index = list(dict.fromkeys(t_index)) 305 306 return s_index, t_index, index_cols 307 308 @property 309 def source_key_expression(self) -> exp.Expr: 310 s_index, _, _ = self.key_columns 311 return self._key_expression(s_index, self.source_schema) 312 313 @property 314 def target_key_expression(self) -> exp.Expr: 315 _, t_index, _ = self.key_columns 316 return self._key_expression(t_index, self.target_schema) 317 318 def _key_expression( 319 self, cols: t.List[exp.Column], schema: t.Dict[str, exp.DataType] 320 ) -> exp.Expr: 321 # if there is a single column, dont do anything fancy to it in order to allow existing indexes to be hit 322 if len(cols) == 1: 323 return exp.to_column(cols[0].name) 324 325 # if there are multiple columns, turn them into a single column by stringify-ing/concatenating them together 326 key_columns_to_types = {key.name: schema[key.name] for key in cols} 327 return self.adapter.concat_columns(key_columns_to_types, self.decimals) 328 329 def schema_diff(self) -> SchemaDiff: 330 return SchemaDiff( 331 source=self.source, 332 target=self.target, 333 source_schema=self.source_schema, 334 target_schema=self.target_schema, 335 source_alias=self.source_alias, 336 target_alias=self.target_alias, 337 model_name=self.model_name, 338 ignore_case=self.schema_diff_ignore_case, 339 ) 340 341 def row_diff( 342 self, temp_schema: t.Optional[str] = None, skip_grain_check: bool = False 343 ) -> RowDiff: 344 if self._row_diff is None: 345 source_schema = { 346 c: t for c, t in self.source_schema.items() if c not in self.skip_columns 347 } 348 target_schema = { 349 c: t for c, t in self.target_schema.items() if c not in self.skip_columns 350 } 351 352 s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in source_schema} 353 t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in target_schema} 354 s_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "s").as_( 355 f"s__{SQLMESH_JOIN_KEY_COL}" 356 ) 357 t_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "t").as_( 358 f"t__{SQLMESH_JOIN_KEY_COL}" 359 ) 360 361 matched_columns = {c: t for c, t in source_schema.items() if t == target_schema.get(c)} 362 363 s_index, t_index, index_cols = self.key_columns 364 s_index_names = [c.name for c in s_index] 365 t_index_names = [t.name for t in t_index] 366 367 def _column_expr(name: str, table: str) -> exp.Expr: 368 column_type = matched_columns[name] 369 qualified_column = exp.column(name, table) 370 371 if column_type.is_type(*exp.DataType.REAL_TYPES): 372 return self.adapter._normalize_decimal_value(qualified_column, self.decimals) 373 if column_type.is_type(*exp.DataType.NESTED_TYPES): 374 return self.adapter._normalize_nested_value(qualified_column) 375 376 return qualified_column 377 378 comparisons = [ 379 exp.Case() 380 .when(_column_expr(c, "s").eq(_column_expr(c, "t")), exp.Literal.number(1)) 381 .when( 382 exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()), 383 exp.Literal.number(1), 384 ) 385 .when( 386 exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()), 387 exp.Literal.number(0), 388 ) 389 .else_(exp.Literal.number(0)) 390 .as_(f"{c}_matches") 391 for c, t in matched_columns.items() 392 ] 393 394 source_query = ( 395 exp.select( 396 *(exp.column(c) for c in source_schema), 397 self.source_key_expression.as_(SQLMESH_JOIN_KEY_COL), 398 ) 399 .from_(self.source_table.as_("s")) 400 .where(self.where) 401 ) 402 target_query = ( 403 exp.select( 404 *(exp.column(c) for c in target_schema), 405 self.target_key_expression.as_(SQLMESH_JOIN_KEY_COL), 406 ) 407 .from_(self.target_table.as_("t")) 408 .where(self.where) 409 ) 410 411 # Ensure every column is qualified with the alias in the source and target queries 412 for col in find_all_in_scope(source_query, exp.Column): 413 col.set("table", exp.to_identifier("s")) 414 for col in find_all_in_scope(target_query, exp.Column): 415 col.set("table", exp.to_identifier("t")) 416 417 source_table = exp.table_("__source") 418 target_table = exp.table_("__target") 419 stats_table = exp.table_("__stats") 420 421 stats_query = ( 422 exp.select( 423 *s_selects.values(), 424 *t_selects.values(), 425 exp.func( 426 "IF", exp.column(SQLMESH_JOIN_KEY_COL, "s").is_(exp.Null()).not_(), 1, 0 427 ).as_("s_exists"), 428 exp.func( 429 "IF", exp.column(SQLMESH_JOIN_KEY_COL, "t").is_(exp.Null()).not_(), 1, 0 430 ).as_("t_exists"), 431 exp.func( 432 "IF", 433 exp.column(SQLMESH_JOIN_KEY_COL, "s").eq( 434 exp.column(SQLMESH_JOIN_KEY_COL, "t") 435 ), 436 1, 437 0, 438 ).as_("row_joined"), 439 exp.func( 440 "IF", 441 exp.or_( 442 *( 443 exp.and_( 444 s.is_(exp.Null()), 445 t.is_(exp.Null()), 446 ) 447 for s, t in zip(s_index, t_index) 448 ), 449 ), 450 1, 451 0, 452 ).as_("null_grain"), 453 *comparisons, 454 ) 455 .from_(source_table.as_("s")) 456 .join( 457 target_table.as_("t"), 458 on=exp.column(SQLMESH_JOIN_KEY_COL, "s").eq( 459 exp.column(SQLMESH_JOIN_KEY_COL, "t") 460 ), 461 join_type="FULL", 462 ) 463 ) 464 465 base_query = ( 466 exp.Select() 467 .with_(source_table, source_query) 468 .with_(target_table, target_query) 469 .with_(stats_table, stats_query) 470 .select( 471 "*", 472 exp.Case() 473 .when( 474 exp.and_( 475 *[ 476 exp.column(f"{c}_matches").eq(exp.Literal.number(1)) 477 for c in matched_columns 478 ] 479 ), 480 exp.Literal.number(1), 481 ) 482 .else_(exp.Literal.number(0)) 483 .as_("row_full_match"), 484 ) 485 .from_(stats_table) 486 ) 487 488 query = self.adapter.ensure_nulls_for_unmatched_after_join( 489 quote_identifiers(base_query.copy(), dialect=self.model_dialect or self.dialect) 490 ) 491 492 if not temp_schema: 493 temp_schema = "sqlmesh_temp" 494 495 schema = to_schema(temp_schema, dialect=self.dialect) 496 temp_table = exp.table_("diff", db=schema.db, catalog=schema.catalog, quoted=True) 497 498 temp_table_kwargs: t.Dict[str, t.Any] = {} 499 if isinstance(self.adapter, AthenaEngineAdapter): 500 # Athena has two table formats: Hive (the default) and Iceberg. TableDiff requires that 501 # the formats be the same for the source, target, and temp tables. 502 source_table_type = self.adapter._query_table_type(self.source_table) 503 target_table_type = self.adapter._query_table_type(self.target_table) 504 505 if source_table_type == "iceberg" and target_table_type == "iceberg": 506 temp_table_kwargs["table_format"] = "iceberg" 507 # Sets the temp table's format to Iceberg. 508 # If neither source nor target table is Iceberg, it defaults to Hive (Athena's default). 509 elif source_table_type == "iceberg" or target_table_type == "iceberg": 510 raise SQLMeshError( 511 f"Source table '{self.source}' format '{source_table_type}' and target table '{self.target}' format '{target_table_type}' " 512 f"do not match for Athena. Diffing between different table formats is not supported." 513 ) 514 515 with self.adapter.temp_table( 516 query, name=temp_table, target_columns_to_types=None, **temp_table_kwargs 517 ) as table: 518 summary_sums = [ 519 exp.func("SUM", "s_exists").as_("s_count"), 520 exp.func("SUM", "t_exists").as_("t_count"), 521 exp.func("SUM", "row_joined").as_("join_count"), 522 exp.func("SUM", "null_grain").as_("null_grain_count"), 523 exp.func("SUM", "row_full_match").as_("full_match_count"), 524 *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons), 525 ] 526 527 if not skip_grain_check: 528 summary_sums.extend( 529 [ 530 parse_one(f"COUNT(DISTINCT(s__{SQLMESH_JOIN_KEY_COL}))").as_( 531 "distinct_count_s" 532 ), 533 parse_one(f"COUNT(DISTINCT(t__{SQLMESH_JOIN_KEY_COL}))").as_( 534 "distinct_count_t" 535 ), 536 ] 537 ) 538 539 summary_query = exp.select(*summary_sums).from_(table) 540 541 stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True).fillna(0) 542 stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"] 543 stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"] 544 stats = stats_df.iloc[0].to_dict() 545 546 column_stats_query = ( 547 exp.select( 548 *( 549 exp.func( 550 "ROUND", 551 100 552 * ( 553 exp.cast( 554 exp.func("SUM", name(c)), exp.DataType.build("NUMERIC") 555 ) 556 / exp.func("COUNT", name(c)) 557 ), 558 9, 559 ).as_(c.alias) 560 for c in comparisons 561 ) 562 ) 563 .from_(table) 564 .where(exp.column("row_joined").eq(exp.Literal.number(1))) 565 ) 566 567 column_stats = ( 568 self.adapter.fetchdf(column_stats_query, quote_identifiers=True) 569 .T.rename( 570 columns={0: "pct_match"}, 571 index=lambda x: str(x).replace("_matches", "") if x else "", 572 ) 573 # errors=ignore because all the index_cols might not be present in the DF if the `on` condition was something like "s.id == t.item_id" 574 # because these would not be present in the matching_cols (since they have different names) and thus no summary would be generated 575 .drop(index=index_cols, errors="ignore") 576 ) 577 578 sample = self._fetch_sample( 579 table, s_selects, s_index, t_selects, t_index, self.limit 580 ) 581 582 joined_sample_cols = [f"s__{c}" for c in s_index_names] 583 comparison_cols = [ 584 (f"s__{c}", f"t__{c}") 585 for c in column_stats[column_stats["pct_match"] < 100].index 586 ] 587 588 for cols in comparison_cols: 589 joined_sample_cols.extend(cols) 590 591 joined_renamed_cols = { 592 c: c.split("__")[1] if c.split("__")[1] in index_cols else c 593 for c in joined_sample_cols 594 } 595 596 if ( 597 self.source_alias 598 and self.target_alias 599 and self.source != self.source_alias 600 and self.target != self.target_alias 601 ): 602 joined_renamed_cols = { 603 c: ( 604 n.replace( 605 "s__", 606 f"{self.source_alias.upper()}__", 607 ) 608 if n.startswith("s__") 609 else n 610 ) 611 for c, n in joined_renamed_cols.items() 612 } 613 joined_renamed_cols = { 614 c: ( 615 n.replace( 616 "t__", 617 f"{self.target_alias.upper()}__", 618 ) 619 if n.startswith("t__") 620 else n 621 ) 622 for c, n in joined_renamed_cols.items() 623 } 624 625 joined_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "common_rows"][ 626 joined_sample_cols 627 ] 628 joined_sample.rename( 629 columns=joined_renamed_cols, 630 inplace=True, 631 ) 632 633 s_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "source_only"][ 634 [ 635 *[f"s__{c}" for c in s_index_names], 636 *[f"s__{c}" for c in source_schema if c not in s_index_names], 637 ] 638 ] 639 s_sample.rename( 640 columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True 641 ) 642 643 t_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "target_only"][ 644 [ 645 *[f"t__{c}" for c in t_index_names], 646 *[f"t__{c}" for c in target_schema if c not in t_index_names], 647 ] 648 ] 649 t_sample.rename( 650 columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True 651 ) 652 653 sample.drop( 654 columns=[ 655 f"s__{SQLMESH_JOIN_KEY_COL}", 656 f"t__{SQLMESH_JOIN_KEY_COL}", 657 SQLMESH_SAMPLE_TYPE_COL, 658 ], 659 inplace=True, 660 ) 661 662 self._row_diff = RowDiff( 663 source=self.source, 664 target=self.target, 665 stats=stats, 666 column_stats=column_stats, 667 sample=sample, 668 joined_sample=joined_sample, 669 s_sample=s_sample, 670 t_sample=t_sample, 671 source_alias=self.source_alias, 672 target_alias=self.target_alias, 673 model_name=self.model_name, 674 decimals=self.decimals, 675 ) 676 677 return self._row_diff 678 679 def _fetch_sample( 680 self, 681 sample_table: exp.Table, 682 s_selects: t.Dict[str, exp.Expr], 683 s_index: t.List[exp.Column], 684 t_selects: t.Dict[str, exp.Expr], 685 t_index: t.List[exp.Column], 686 limit: int, 687 ) -> pd.DataFrame: 688 rendered_data_column_names = [ 689 name(s) for s in list(s_selects.values()) + list(t_selects.values()) 690 ] 691 sample_type = exp.to_identifier(SQLMESH_SAMPLE_TYPE_COL) 692 693 source_only_sample = ( 694 exp.select( 695 exp.Literal.string("source_only").as_(sample_type), *rendered_data_column_names 696 ) 697 .from_(sample_table) 698 .where(exp.and_(exp.column("s_exists").eq(1), exp.column("row_joined").eq(0))) 699 .order_by(*(name(s_selects[c.name]) for c in s_index)) 700 .limit(limit) 701 ) 702 703 target_only_sample = ( 704 exp.select( 705 exp.Literal.string("target_only").as_(sample_type), *rendered_data_column_names 706 ) 707 .from_(sample_table) 708 .where(exp.and_(exp.column("t_exists").eq(1), exp.column("row_joined").eq(0))) 709 .order_by(*(name(t_selects[c.name]) for c in t_index)) 710 .limit(limit) 711 ) 712 713 common_rows_sample = ( 714 exp.select( 715 exp.Literal.string("common_rows").as_(sample_type), *rendered_data_column_names 716 ) 717 .from_(sample_table) 718 .where(exp.and_(exp.column("row_joined").eq(1), exp.column("row_full_match").eq(0))) 719 .order_by( 720 *(name(s_selects[c.name]) for c in s_index), 721 *(name(t_selects[c.name]) for c in t_index), 722 ) 723 .limit(limit) 724 ) 725 726 query = ( 727 exp.Select() 728 .with_("source_only", source_only_sample) 729 .with_("target_only", target_only_sample) 730 .with_("common_rows", common_rows_sample) 731 .select(sample_type, *rendered_data_column_names) 732 .from_("source_only") 733 .union( 734 exp.select(sample_type, *rendered_data_column_names).from_("target_only"), 735 distinct=False, 736 ) 737 .union( 738 exp.select(sample_type, *rendered_data_column_names).from_("common_rows"), 739 distinct=False, 740 ) 741 ) 742 743 return self.adapter.fetchdf(query, quote_identifiers=True)
Calculates differences between tables, taking into account schema and row level differences.
223 def __init__( 224 self, 225 adapter: EngineAdapter, 226 source: TableName, 227 target: TableName, 228 on: t.List[str] | exp.Expr, 229 skip_columns: t.List[str] | None = None, 230 where: t.Optional[str | exp.Expr] = None, 231 limit: int = 20, 232 source_alias: t.Optional[str] = None, 233 target_alias: t.Optional[str] = None, 234 model_name: t.Optional[str] = None, 235 model_dialect: t.Optional[str] = None, 236 decimals: int = 3, 237 schema_diff_ignore_case: bool = False, 238 ): 239 if not isinstance(adapter, RowDiffMixin): 240 raise ValueError(f"Engine {adapter} doesnt support RowDiff") 241 242 self.adapter = adapter 243 self.source = source 244 self.target = target 245 self.dialect = adapter.dialect 246 self.source_table = exp.to_table(self.source, dialect=self.dialect) 247 self.target_table = exp.to_table(self.target, dialect=self.dialect) 248 self.where = exp.condition(where, dialect=self.dialect) if where else None 249 self.limit = limit 250 self.model_name = model_name 251 self.model_dialect = model_dialect 252 self.decimals = decimals 253 self.schema_diff_ignore_case = schema_diff_ignore_case 254 255 # Support environment aliases for diff output improvement in certain cases 256 self.source_alias = source_alias 257 self.target_alias = target_alias 258 259 self.skip_columns = { 260 normalize_identifiers( 261 exp.parse_identifier(t.cast(str, col)), 262 dialect=self.model_dialect or self.dialect, 263 ).name 264 for col in ensure_list(skip_columns) 265 } 266 267 self._on = on 268 self._row_diff: t.Optional[RowDiff] = None
278 @cached_property 279 def key_columns(self) -> t.Tuple[t.List[exp.Column], t.List[exp.Column], t.List[str]]: 280 dialect = self.model_dialect or self.dialect 281 282 # If the columns to join on are explicitly specified, then just return them 283 if isinstance(self._on, (list, tuple)): 284 identifiers = [normalize_identifiers(c, dialect=dialect) for c in self._on] 285 s_index = [exp.column(c, "s") for c in identifiers] 286 t_index = [exp.column(c, "t") for c in identifiers] 287 return s_index, t_index, [i.name for i in identifiers] 288 289 # Otherwise, we need to parse them out of the supplied "on" condition 290 index_cols = [] 291 s_index = [] 292 t_index = [] 293 294 normalize_identifiers(self._on, dialect=dialect) 295 for col in self._on.find_all(exp.Column): 296 index_cols.append(col.name) 297 if col.table.lower() == "s": 298 s_index.append(col) 299 elif col.table.lower() == "t": 300 t_index.append(col) 301 302 index_cols = list(dict.fromkeys(index_cols)) 303 s_index = list(dict.fromkeys(s_index)) 304 t_index = list(dict.fromkeys(t_index)) 305 306 return s_index, t_index, index_cols
329 def schema_diff(self) -> SchemaDiff: 330 return SchemaDiff( 331 source=self.source, 332 target=self.target, 333 source_schema=self.source_schema, 334 target_schema=self.target_schema, 335 source_alias=self.source_alias, 336 target_alias=self.target_alias, 337 model_name=self.model_name, 338 ignore_case=self.schema_diff_ignore_case, 339 )
341 def row_diff( 342 self, temp_schema: t.Optional[str] = None, skip_grain_check: bool = False 343 ) -> RowDiff: 344 if self._row_diff is None: 345 source_schema = { 346 c: t for c, t in self.source_schema.items() if c not in self.skip_columns 347 } 348 target_schema = { 349 c: t for c, t in self.target_schema.items() if c not in self.skip_columns 350 } 351 352 s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in source_schema} 353 t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in target_schema} 354 s_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "s").as_( 355 f"s__{SQLMESH_JOIN_KEY_COL}" 356 ) 357 t_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "t").as_( 358 f"t__{SQLMESH_JOIN_KEY_COL}" 359 ) 360 361 matched_columns = {c: t for c, t in source_schema.items() if t == target_schema.get(c)} 362 363 s_index, t_index, index_cols = self.key_columns 364 s_index_names = [c.name for c in s_index] 365 t_index_names = [t.name for t in t_index] 366 367 def _column_expr(name: str, table: str) -> exp.Expr: 368 column_type = matched_columns[name] 369 qualified_column = exp.column(name, table) 370 371 if column_type.is_type(*exp.DataType.REAL_TYPES): 372 return self.adapter._normalize_decimal_value(qualified_column, self.decimals) 373 if column_type.is_type(*exp.DataType.NESTED_TYPES): 374 return self.adapter._normalize_nested_value(qualified_column) 375 376 return qualified_column 377 378 comparisons = [ 379 exp.Case() 380 .when(_column_expr(c, "s").eq(_column_expr(c, "t")), exp.Literal.number(1)) 381 .when( 382 exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()), 383 exp.Literal.number(1), 384 ) 385 .when( 386 exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()), 387 exp.Literal.number(0), 388 ) 389 .else_(exp.Literal.number(0)) 390 .as_(f"{c}_matches") 391 for c, t in matched_columns.items() 392 ] 393 394 source_query = ( 395 exp.select( 396 *(exp.column(c) for c in source_schema), 397 self.source_key_expression.as_(SQLMESH_JOIN_KEY_COL), 398 ) 399 .from_(self.source_table.as_("s")) 400 .where(self.where) 401 ) 402 target_query = ( 403 exp.select( 404 *(exp.column(c) for c in target_schema), 405 self.target_key_expression.as_(SQLMESH_JOIN_KEY_COL), 406 ) 407 .from_(self.target_table.as_("t")) 408 .where(self.where) 409 ) 410 411 # Ensure every column is qualified with the alias in the source and target queries 412 for col in find_all_in_scope(source_query, exp.Column): 413 col.set("table", exp.to_identifier("s")) 414 for col in find_all_in_scope(target_query, exp.Column): 415 col.set("table", exp.to_identifier("t")) 416 417 source_table = exp.table_("__source") 418 target_table = exp.table_("__target") 419 stats_table = exp.table_("__stats") 420 421 stats_query = ( 422 exp.select( 423 *s_selects.values(), 424 *t_selects.values(), 425 exp.func( 426 "IF", exp.column(SQLMESH_JOIN_KEY_COL, "s").is_(exp.Null()).not_(), 1, 0 427 ).as_("s_exists"), 428 exp.func( 429 "IF", exp.column(SQLMESH_JOIN_KEY_COL, "t").is_(exp.Null()).not_(), 1, 0 430 ).as_("t_exists"), 431 exp.func( 432 "IF", 433 exp.column(SQLMESH_JOIN_KEY_COL, "s").eq( 434 exp.column(SQLMESH_JOIN_KEY_COL, "t") 435 ), 436 1, 437 0, 438 ).as_("row_joined"), 439 exp.func( 440 "IF", 441 exp.or_( 442 *( 443 exp.and_( 444 s.is_(exp.Null()), 445 t.is_(exp.Null()), 446 ) 447 for s, t in zip(s_index, t_index) 448 ), 449 ), 450 1, 451 0, 452 ).as_("null_grain"), 453 *comparisons, 454 ) 455 .from_(source_table.as_("s")) 456 .join( 457 target_table.as_("t"), 458 on=exp.column(SQLMESH_JOIN_KEY_COL, "s").eq( 459 exp.column(SQLMESH_JOIN_KEY_COL, "t") 460 ), 461 join_type="FULL", 462 ) 463 ) 464 465 base_query = ( 466 exp.Select() 467 .with_(source_table, source_query) 468 .with_(target_table, target_query) 469 .with_(stats_table, stats_query) 470 .select( 471 "*", 472 exp.Case() 473 .when( 474 exp.and_( 475 *[ 476 exp.column(f"{c}_matches").eq(exp.Literal.number(1)) 477 for c in matched_columns 478 ] 479 ), 480 exp.Literal.number(1), 481 ) 482 .else_(exp.Literal.number(0)) 483 .as_("row_full_match"), 484 ) 485 .from_(stats_table) 486 ) 487 488 query = self.adapter.ensure_nulls_for_unmatched_after_join( 489 quote_identifiers(base_query.copy(), dialect=self.model_dialect or self.dialect) 490 ) 491 492 if not temp_schema: 493 temp_schema = "sqlmesh_temp" 494 495 schema = to_schema(temp_schema, dialect=self.dialect) 496 temp_table = exp.table_("diff", db=schema.db, catalog=schema.catalog, quoted=True) 497 498 temp_table_kwargs: t.Dict[str, t.Any] = {} 499 if isinstance(self.adapter, AthenaEngineAdapter): 500 # Athena has two table formats: Hive (the default) and Iceberg. TableDiff requires that 501 # the formats be the same for the source, target, and temp tables. 502 source_table_type = self.adapter._query_table_type(self.source_table) 503 target_table_type = self.adapter._query_table_type(self.target_table) 504 505 if source_table_type == "iceberg" and target_table_type == "iceberg": 506 temp_table_kwargs["table_format"] = "iceberg" 507 # Sets the temp table's format to Iceberg. 508 # If neither source nor target table is Iceberg, it defaults to Hive (Athena's default). 509 elif source_table_type == "iceberg" or target_table_type == "iceberg": 510 raise SQLMeshError( 511 f"Source table '{self.source}' format '{source_table_type}' and target table '{self.target}' format '{target_table_type}' " 512 f"do not match for Athena. Diffing between different table formats is not supported." 513 ) 514 515 with self.adapter.temp_table( 516 query, name=temp_table, target_columns_to_types=None, **temp_table_kwargs 517 ) as table: 518 summary_sums = [ 519 exp.func("SUM", "s_exists").as_("s_count"), 520 exp.func("SUM", "t_exists").as_("t_count"), 521 exp.func("SUM", "row_joined").as_("join_count"), 522 exp.func("SUM", "null_grain").as_("null_grain_count"), 523 exp.func("SUM", "row_full_match").as_("full_match_count"), 524 *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons), 525 ] 526 527 if not skip_grain_check: 528 summary_sums.extend( 529 [ 530 parse_one(f"COUNT(DISTINCT(s__{SQLMESH_JOIN_KEY_COL}))").as_( 531 "distinct_count_s" 532 ), 533 parse_one(f"COUNT(DISTINCT(t__{SQLMESH_JOIN_KEY_COL}))").as_( 534 "distinct_count_t" 535 ), 536 ] 537 ) 538 539 summary_query = exp.select(*summary_sums).from_(table) 540 541 stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True).fillna(0) 542 stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"] 543 stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"] 544 stats = stats_df.iloc[0].to_dict() 545 546 column_stats_query = ( 547 exp.select( 548 *( 549 exp.func( 550 "ROUND", 551 100 552 * ( 553 exp.cast( 554 exp.func("SUM", name(c)), exp.DataType.build("NUMERIC") 555 ) 556 / exp.func("COUNT", name(c)) 557 ), 558 9, 559 ).as_(c.alias) 560 for c in comparisons 561 ) 562 ) 563 .from_(table) 564 .where(exp.column("row_joined").eq(exp.Literal.number(1))) 565 ) 566 567 column_stats = ( 568 self.adapter.fetchdf(column_stats_query, quote_identifiers=True) 569 .T.rename( 570 columns={0: "pct_match"}, 571 index=lambda x: str(x).replace("_matches", "") if x else "", 572 ) 573 # errors=ignore because all the index_cols might not be present in the DF if the `on` condition was something like "s.id == t.item_id" 574 # because these would not be present in the matching_cols (since they have different names) and thus no summary would be generated 575 .drop(index=index_cols, errors="ignore") 576 ) 577 578 sample = self._fetch_sample( 579 table, s_selects, s_index, t_selects, t_index, self.limit 580 ) 581 582 joined_sample_cols = [f"s__{c}" for c in s_index_names] 583 comparison_cols = [ 584 (f"s__{c}", f"t__{c}") 585 for c in column_stats[column_stats["pct_match"] < 100].index 586 ] 587 588 for cols in comparison_cols: 589 joined_sample_cols.extend(cols) 590 591 joined_renamed_cols = { 592 c: c.split("__")[1] if c.split("__")[1] in index_cols else c 593 for c in joined_sample_cols 594 } 595 596 if ( 597 self.source_alias 598 and self.target_alias 599 and self.source != self.source_alias 600 and self.target != self.target_alias 601 ): 602 joined_renamed_cols = { 603 c: ( 604 n.replace( 605 "s__", 606 f"{self.source_alias.upper()}__", 607 ) 608 if n.startswith("s__") 609 else n 610 ) 611 for c, n in joined_renamed_cols.items() 612 } 613 joined_renamed_cols = { 614 c: ( 615 n.replace( 616 "t__", 617 f"{self.target_alias.upper()}__", 618 ) 619 if n.startswith("t__") 620 else n 621 ) 622 for c, n in joined_renamed_cols.items() 623 } 624 625 joined_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "common_rows"][ 626 joined_sample_cols 627 ] 628 joined_sample.rename( 629 columns=joined_renamed_cols, 630 inplace=True, 631 ) 632 633 s_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "source_only"][ 634 [ 635 *[f"s__{c}" for c in s_index_names], 636 *[f"s__{c}" for c in source_schema if c not in s_index_names], 637 ] 638 ] 639 s_sample.rename( 640 columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True 641 ) 642 643 t_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "target_only"][ 644 [ 645 *[f"t__{c}" for c in t_index_names], 646 *[f"t__{c}" for c in target_schema if c not in t_index_names], 647 ] 648 ] 649 t_sample.rename( 650 columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True 651 ) 652 653 sample.drop( 654 columns=[ 655 f"s__{SQLMESH_JOIN_KEY_COL}", 656 f"t__{SQLMESH_JOIN_KEY_COL}", 657 SQLMESH_SAMPLE_TYPE_COL, 658 ], 659 inplace=True, 660 ) 661 662 self._row_diff = RowDiff( 663 source=self.source, 664 target=self.target, 665 stats=stats, 666 column_stats=column_stats, 667 sample=sample, 668 joined_sample=joined_sample, 669 s_sample=s_sample, 670 t_sample=t_sample, 671 source_alias=self.source_alias, 672 target_alias=self.target_alias, 673 model_name=self.model_name, 674 decimals=self.decimals, 675 ) 676 677 return self._row_diff