sqlmesh.core.table_diff
1from __future__ import annotations 2 3import math 4import typing as t 5 6import pandas as pd 7from sqlglot import exp 8from sqlglot.optimizer.normalize_identifiers import normalize_identifiers 9from sqlglot.optimizer.qualify_columns import quote_identifiers 10 11from sqlmesh.utils.pydantic import PydanticModel 12 13if t.TYPE_CHECKING: 14 from sqlmesh.core._typing import TableName 15 from sqlmesh.core.engine_adapter import EngineAdapter 16 17 18class SchemaDiff(PydanticModel, frozen=True): 19 """An object containing the schema difference between a source and target table.""" 20 21 source: str 22 target: str 23 source_schema: t.Dict[str, exp.DataType] 24 target_schema: t.Dict[str, exp.DataType] 25 source_alias: t.Optional[str] = None 26 target_alias: t.Optional[str] = None 27 model_name: t.Optional[str] = None 28 29 @property 30 def added(self) -> t.List[t.Tuple[str, exp.DataType]]: 31 """Added columns.""" 32 return [(c, t) for c, t in self.target_schema.items() if c not in self.source_schema] 33 34 @property 35 def removed(self) -> t.List[t.Tuple[str, exp.DataType]]: 36 """Removed columns.""" 37 return [(c, t) for c, t in self.source_schema.items() if c not in self.target_schema] 38 39 @property 40 def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]: 41 """Columns with modified types.""" 42 modified = {} 43 for column in self.source_schema.keys() & self.target_schema.keys(): 44 source_type = self.source_schema[column] 45 target_type = self.target_schema[column] 46 47 if source_type != target_type: 48 modified[column] = (source_type, target_type) 49 return modified 50 51 52class RowDiff(PydanticModel, frozen=True): 53 """Summary statistics and a sample dataframe.""" 54 55 source: str 56 target: str 57 stats: t.Dict[str, float] 58 sample: pd.DataFrame 59 joined_sample: pd.DataFrame 60 s_sample: pd.DataFrame 61 t_sample: pd.DataFrame 62 column_stats: pd.DataFrame 63 source_alias: t.Optional[str] = None 64 target_alias: t.Optional[str] = None 65 model_name: t.Optional[str] = None 66 67 @property 68 def source_count(self) -> int: 69 """Count of the source.""" 70 return int(self.stats["s_count"]) 71 72 @property 73 def target_count(self) -> int: 74 """Count of the target.""" 75 return int(self.stats["t_count"]) 76 77 @property 78 def count_pct_change(self) -> float: 79 """The percentage change of the counts.""" 80 if self.source_count == 0: 81 return math.inf 82 return ((self.target_count - self.source_count) / self.source_count) * 100 83 84 @property 85 def join_count(self) -> int: 86 """Count of successfully joined rows.""" 87 return int(self.stats["join_count"]) 88 89 @property 90 def s_only_count(self) -> int: 91 """Count of rows only present in source.""" 92 return int(self.stats["s_only_count"]) 93 94 @property 95 def t_only_count(self) -> int: 96 """Count of rows only present in target.""" 97 return int(self.stats["t_only_count"]) 98 99 100class TableDiff: 101 """Calculates differences between tables, taking into account schema and row level differences.""" 102 103 def __init__( 104 self, 105 adapter: EngineAdapter, 106 source: TableName, 107 target: TableName, 108 on: t.List[str] | exp.Condition, 109 where: t.Optional[str | exp.Condition] = None, 110 limit: int = 20, 111 source_alias: t.Optional[str] = None, 112 target_alias: t.Optional[str] = None, 113 model_name: t.Optional[str] = None, 114 ): 115 self.adapter = adapter 116 self.source = source 117 self.target = target 118 self.dialect = adapter.dialect 119 self.where = exp.condition(where, dialect=self.dialect) if where else None 120 self.limit = limit 121 self.model_name = model_name 122 123 # Support environment aliases for diff output improvement in certain cases 124 self.source_alias = source_alias 125 self.target_alias = target_alias 126 127 if isinstance(on, (list, tuple)): 128 s_table = exp.to_identifier("s", quoted=True) 129 t_table = exp.to_identifier("t", quoted=True) 130 131 self.on: exp.Condition = exp.and_( 132 *( 133 exp.column(c, s_table).eq(exp.column(c, t_table)) 134 | ( 135 exp.column(c, s_table).is_(exp.null()) 136 & exp.column(c, t_table).is_(exp.null()) 137 ) 138 for c in on 139 ) 140 ) 141 else: 142 self.on = on 143 144 normalize_identifiers(self.on, dialect=self.dialect) 145 146 self._source_schema: t.Optional[t.Dict[str, exp.DataType]] = None 147 self._target_schema: t.Optional[t.Dict[str, exp.DataType]] = None 148 self._row_diff: t.Optional[RowDiff] = None 149 150 @property 151 def source_schema(self) -> t.Dict[str, exp.DataType]: 152 if self._source_schema is None: 153 self._source_schema = self.adapter.columns(self.source) 154 return self._source_schema 155 156 @property 157 def target_schema(self) -> t.Dict[str, exp.DataType]: 158 if self._target_schema is None: 159 self._target_schema = self.adapter.columns(self.target) 160 return self._target_schema 161 162 def schema_diff(self) -> SchemaDiff: 163 return SchemaDiff( 164 source=self.source, 165 target=self.target, 166 source_schema=self.source_schema, 167 target_schema=self.target_schema, 168 source_alias=self.source_alias, 169 target_alias=self.target_alias, 170 model_name=self.model_name, 171 ) 172 173 def row_diff(self) -> RowDiff: 174 if self._row_diff is None: 175 s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in self.source_schema} 176 t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in self.target_schema} 177 178 index_cols = [] 179 s_index = [] 180 t_index = [] 181 182 for col in self.on.find_all(exp.Column): 183 index_cols.append(col.name) 184 if col.table == "s": 185 s_index.append(col) 186 elif col.table == "t": 187 t_index.append(col) 188 index_cols = list(dict.fromkeys(index_cols)) 189 190 comparisons = [ 191 exp.Case() 192 .when(exp.column(c, "s").eq(exp.column(c, "t")), exp.Literal.number(1)) 193 .when( 194 exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()), 195 exp.Literal.number(1), 196 ) 197 .when( 198 exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()), 199 exp.Literal.number(0), 200 ) 201 .else_(exp.Literal.number(0)) 202 .as_(f"{c}_matches") 203 for c, t in self.source_schema.items() 204 if t == self.target_schema.get(c) 205 ] 206 207 def name(e: exp.Expression) -> str: 208 return e.args["alias"].sql(identify=True) 209 210 query = ( 211 exp.select( 212 *s_selects.values(), 213 *t_selects.values(), 214 exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in s_index)), 1, 0).as_( 215 "s_exists" 216 ), 217 exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in t_index)), 1, 0).as_( 218 "t_exists" 219 ), 220 exp.func( 221 "IF", 222 exp.and_( 223 *( 224 exp.and_( 225 exp.column(c, "s").eq(exp.column(c, "t")), 226 exp.column(c, "s").not_().is_(exp.Null()), 227 exp.column(c, "t").not_().is_(exp.Null()), 228 ) 229 for c in index_cols 230 ), 231 ), 232 1, 233 0, 234 ).as_("rows_joined"), 235 *comparisons, 236 ) 237 .from_(exp.alias_(self.source, "s")) 238 .join( 239 self.target, 240 on=self.on, 241 join_type="FULL", 242 join_alias="t", 243 ) 244 .where(self.where) 245 ) 246 247 query = quote_identifiers(query, dialect=self.dialect) 248 temp_table = exp.table_("diff", db="sqlmesh_temp", quoted=True) 249 250 with self.adapter.temp_table(query, name=temp_table) as table: 251 summary_query = exp.select( 252 exp.func("SUM", "s_exists").as_("s_count"), 253 exp.func("SUM", "t_exists").as_("t_count"), 254 exp.func("SUM", "rows_joined").as_("join_count"), 255 *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons), 256 ).from_(table) 257 258 stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True) 259 stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"] 260 stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"] 261 stats = stats_df.iloc[0].to_dict() 262 263 column_stats_query = ( 264 exp.select( 265 *( 266 exp.func( 267 "ROUND", 268 100 * (exp.func("SUM", name(c)) / exp.func("COUNT", name(c))), 269 1, 270 ).as_(c.alias) 271 for c in comparisons 272 ) 273 ) 274 .from_(table) 275 .where(exp.column("rows_joined").eq(exp.Literal.number(1))) 276 ) 277 column_stats = ( 278 self.adapter.fetchdf(column_stats_query, quote_identifiers=True) 279 .T.rename( 280 columns={0: "pct_match"}, 281 index=lambda x: str(x).replace("_matches", "") if x else "", 282 ) 283 .drop(index=index_cols) 284 ) 285 286 sample_filter_cols = ["s_exists", "t_exists", "rows_joined"] 287 sample_query = ( 288 exp.select( 289 *(sample_filter_cols), 290 *(name(c) for c in s_selects.values()), 291 *(name(c) for c in t_selects.values()), 292 ) 293 .from_(table) 294 .where(exp.or_(*(exp.column(c.alias).eq(0) for c in comparisons))) 295 .order_by( 296 *(name(s_selects[c.name]) for c in s_index), 297 *(name(t_selects[c.name]) for c in t_index), 298 ) 299 .limit(self.limit) 300 ) 301 sample = self.adapter.fetchdf(sample_query, quote_identifiers=True) 302 303 joined_sample_cols = [f"s__{c}" for c in index_cols] 304 comparison_cols = [ 305 (f"s__{c}", f"t__{c}") 306 for c in column_stats[column_stats["pct_match"] < 100].index 307 ] 308 for cols in comparison_cols: 309 joined_sample_cols.extend(cols) 310 joined_renamed_cols = { 311 c: c.split("__")[1] if c.split("__")[1] in index_cols else c 312 for c in joined_sample_cols 313 } 314 if self.source != self.source_alias and self.target != self.target_alias: 315 joined_renamed_cols = { 316 c: ( 317 n.replace( 318 "s__", f"{self.source_alias.upper() if self.source_alias else ''}__" 319 ) 320 if n.startswith("s__") 321 else n 322 ) 323 for c, n in joined_renamed_cols.items() 324 } 325 joined_renamed_cols = { 326 c: ( 327 n.replace( 328 "t__", f"{self.target_alias.upper() if self.target_alias else ''}__" 329 ) 330 if n.startswith("t__") 331 else n 332 ) 333 for c, n in joined_renamed_cols.items() 334 } 335 joined_sample = sample[sample["rows_joined"] == 1][joined_sample_cols] 336 joined_sample.rename( 337 columns=joined_renamed_cols, 338 inplace=True, 339 ) 340 341 s_sample = sample[(sample["s_exists"] == 1) & (sample["rows_joined"] == 0)][ 342 [ 343 *[f"s__{c}" for c in index_cols], 344 *[f"s__{c}" for c in self.source_schema if c not in index_cols], 345 ] 346 ] 347 s_sample.rename( 348 columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True 349 ) 350 351 t_sample = sample[(sample["t_exists"] == 1) & (sample["rows_joined"] == 0)][ 352 [ 353 *[f"t__{c}" for c in index_cols], 354 *[f"t__{c}" for c in self.target_schema if c not in index_cols], 355 ] 356 ] 357 t_sample.rename( 358 columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True 359 ) 360 361 sample.drop(columns=sample_filter_cols, inplace=True) 362 363 self._row_diff = RowDiff( 364 source=self.source, 365 target=self.target, 366 stats=stats, 367 column_stats=column_stats, 368 sample=sample, 369 joined_sample=joined_sample, 370 s_sample=s_sample, 371 t_sample=t_sample, 372 source_alias=self.source_alias, 373 target_alias=self.target_alias, 374 model_name=self.model_name, 375 ) 376 return self._row_diff
19class SchemaDiff(PydanticModel, frozen=True): 20 """An object containing the schema difference between a source and target table.""" 21 22 source: str 23 target: str 24 source_schema: t.Dict[str, exp.DataType] 25 target_schema: t.Dict[str, exp.DataType] 26 source_alias: t.Optional[str] = None 27 target_alias: t.Optional[str] = None 28 model_name: t.Optional[str] = None 29 30 @property 31 def added(self) -> t.List[t.Tuple[str, exp.DataType]]: 32 """Added columns.""" 33 return [(c, t) for c, t in self.target_schema.items() if c not in self.source_schema] 34 35 @property 36 def removed(self) -> t.List[t.Tuple[str, exp.DataType]]: 37 """Removed columns.""" 38 return [(c, t) for c, t in self.source_schema.items() if c not in self.target_schema] 39 40 @property 41 def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]: 42 """Columns with modified types.""" 43 modified = {} 44 for column in self.source_schema.keys() & self.target_schema.keys(): 45 source_type = self.source_schema[column] 46 target_type = self.target_schema[column] 47 48 if source_type != target_type: 49 modified[column] = (source_type, target_type) 50 return modified
An object containing the schema difference between a source and target table.
modified: Dict[str, Tuple[sqlglot.expressions.DataType, sqlglot.expressions.DataType]]
Columns with modified types.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
53class RowDiff(PydanticModel, frozen=True): 54 """Summary statistics and a sample dataframe.""" 55 56 source: str 57 target: str 58 stats: t.Dict[str, float] 59 sample: pd.DataFrame 60 joined_sample: pd.DataFrame 61 s_sample: pd.DataFrame 62 t_sample: pd.DataFrame 63 column_stats: pd.DataFrame 64 source_alias: t.Optional[str] = None 65 target_alias: t.Optional[str] = None 66 model_name: t.Optional[str] = None 67 68 @property 69 def source_count(self) -> int: 70 """Count of the source.""" 71 return int(self.stats["s_count"]) 72 73 @property 74 def target_count(self) -> int: 75 """Count of the target.""" 76 return int(self.stats["t_count"]) 77 78 @property 79 def count_pct_change(self) -> float: 80 """The percentage change of the counts.""" 81 if self.source_count == 0: 82 return math.inf 83 return ((self.target_count - self.source_count) / self.source_count) * 100 84 85 @property 86 def join_count(self) -> int: 87 """Count of successfully joined rows.""" 88 return int(self.stats["join_count"]) 89 90 @property 91 def s_only_count(self) -> int: 92 """Count of rows only present in source.""" 93 return int(self.stats["s_only_count"]) 94 95 @property 96 def t_only_count(self) -> int: 97 """Count of rows only present in target.""" 98 return int(self.stats["t_only_count"])
Summary statistics and a sample dataframe.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
class
TableDiff:
101class TableDiff: 102 """Calculates differences between tables, taking into account schema and row level differences.""" 103 104 def __init__( 105 self, 106 adapter: EngineAdapter, 107 source: TableName, 108 target: TableName, 109 on: t.List[str] | exp.Condition, 110 where: t.Optional[str | exp.Condition] = None, 111 limit: int = 20, 112 source_alias: t.Optional[str] = None, 113 target_alias: t.Optional[str] = None, 114 model_name: t.Optional[str] = None, 115 ): 116 self.adapter = adapter 117 self.source = source 118 self.target = target 119 self.dialect = adapter.dialect 120 self.where = exp.condition(where, dialect=self.dialect) if where else None 121 self.limit = limit 122 self.model_name = model_name 123 124 # Support environment aliases for diff output improvement in certain cases 125 self.source_alias = source_alias 126 self.target_alias = target_alias 127 128 if isinstance(on, (list, tuple)): 129 s_table = exp.to_identifier("s", quoted=True) 130 t_table = exp.to_identifier("t", quoted=True) 131 132 self.on: exp.Condition = exp.and_( 133 *( 134 exp.column(c, s_table).eq(exp.column(c, t_table)) 135 | ( 136 exp.column(c, s_table).is_(exp.null()) 137 & exp.column(c, t_table).is_(exp.null()) 138 ) 139 for c in on 140 ) 141 ) 142 else: 143 self.on = on 144 145 normalize_identifiers(self.on, dialect=self.dialect) 146 147 self._source_schema: t.Optional[t.Dict[str, exp.DataType]] = None 148 self._target_schema: t.Optional[t.Dict[str, exp.DataType]] = None 149 self._row_diff: t.Optional[RowDiff] = None 150 151 @property 152 def source_schema(self) -> t.Dict[str, exp.DataType]: 153 if self._source_schema is None: 154 self._source_schema = self.adapter.columns(self.source) 155 return self._source_schema 156 157 @property 158 def target_schema(self) -> t.Dict[str, exp.DataType]: 159 if self._target_schema is None: 160 self._target_schema = self.adapter.columns(self.target) 161 return self._target_schema 162 163 def schema_diff(self) -> SchemaDiff: 164 return SchemaDiff( 165 source=self.source, 166 target=self.target, 167 source_schema=self.source_schema, 168 target_schema=self.target_schema, 169 source_alias=self.source_alias, 170 target_alias=self.target_alias, 171 model_name=self.model_name, 172 ) 173 174 def row_diff(self) -> RowDiff: 175 if self._row_diff is None: 176 s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in self.source_schema} 177 t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in self.target_schema} 178 179 index_cols = [] 180 s_index = [] 181 t_index = [] 182 183 for col in self.on.find_all(exp.Column): 184 index_cols.append(col.name) 185 if col.table == "s": 186 s_index.append(col) 187 elif col.table == "t": 188 t_index.append(col) 189 index_cols = list(dict.fromkeys(index_cols)) 190 191 comparisons = [ 192 exp.Case() 193 .when(exp.column(c, "s").eq(exp.column(c, "t")), exp.Literal.number(1)) 194 .when( 195 exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()), 196 exp.Literal.number(1), 197 ) 198 .when( 199 exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()), 200 exp.Literal.number(0), 201 ) 202 .else_(exp.Literal.number(0)) 203 .as_(f"{c}_matches") 204 for c, t in self.source_schema.items() 205 if t == self.target_schema.get(c) 206 ] 207 208 def name(e: exp.Expression) -> str: 209 return e.args["alias"].sql(identify=True) 210 211 query = ( 212 exp.select( 213 *s_selects.values(), 214 *t_selects.values(), 215 exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in s_index)), 1, 0).as_( 216 "s_exists" 217 ), 218 exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in t_index)), 1, 0).as_( 219 "t_exists" 220 ), 221 exp.func( 222 "IF", 223 exp.and_( 224 *( 225 exp.and_( 226 exp.column(c, "s").eq(exp.column(c, "t")), 227 exp.column(c, "s").not_().is_(exp.Null()), 228 exp.column(c, "t").not_().is_(exp.Null()), 229 ) 230 for c in index_cols 231 ), 232 ), 233 1, 234 0, 235 ).as_("rows_joined"), 236 *comparisons, 237 ) 238 .from_(exp.alias_(self.source, "s")) 239 .join( 240 self.target, 241 on=self.on, 242 join_type="FULL", 243 join_alias="t", 244 ) 245 .where(self.where) 246 ) 247 248 query = quote_identifiers(query, dialect=self.dialect) 249 temp_table = exp.table_("diff", db="sqlmesh_temp", quoted=True) 250 251 with self.adapter.temp_table(query, name=temp_table) as table: 252 summary_query = exp.select( 253 exp.func("SUM", "s_exists").as_("s_count"), 254 exp.func("SUM", "t_exists").as_("t_count"), 255 exp.func("SUM", "rows_joined").as_("join_count"), 256 *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons), 257 ).from_(table) 258 259 stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True) 260 stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"] 261 stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"] 262 stats = stats_df.iloc[0].to_dict() 263 264 column_stats_query = ( 265 exp.select( 266 *( 267 exp.func( 268 "ROUND", 269 100 * (exp.func("SUM", name(c)) / exp.func("COUNT", name(c))), 270 1, 271 ).as_(c.alias) 272 for c in comparisons 273 ) 274 ) 275 .from_(table) 276 .where(exp.column("rows_joined").eq(exp.Literal.number(1))) 277 ) 278 column_stats = ( 279 self.adapter.fetchdf(column_stats_query, quote_identifiers=True) 280 .T.rename( 281 columns={0: "pct_match"}, 282 index=lambda x: str(x).replace("_matches", "") if x else "", 283 ) 284 .drop(index=index_cols) 285 ) 286 287 sample_filter_cols = ["s_exists", "t_exists", "rows_joined"] 288 sample_query = ( 289 exp.select( 290 *(sample_filter_cols), 291 *(name(c) for c in s_selects.values()), 292 *(name(c) for c in t_selects.values()), 293 ) 294 .from_(table) 295 .where(exp.or_(*(exp.column(c.alias).eq(0) for c in comparisons))) 296 .order_by( 297 *(name(s_selects[c.name]) for c in s_index), 298 *(name(t_selects[c.name]) for c in t_index), 299 ) 300 .limit(self.limit) 301 ) 302 sample = self.adapter.fetchdf(sample_query, quote_identifiers=True) 303 304 joined_sample_cols = [f"s__{c}" for c in index_cols] 305 comparison_cols = [ 306 (f"s__{c}", f"t__{c}") 307 for c in column_stats[column_stats["pct_match"] < 100].index 308 ] 309 for cols in comparison_cols: 310 joined_sample_cols.extend(cols) 311 joined_renamed_cols = { 312 c: c.split("__")[1] if c.split("__")[1] in index_cols else c 313 for c in joined_sample_cols 314 } 315 if self.source != self.source_alias and self.target != self.target_alias: 316 joined_renamed_cols = { 317 c: ( 318 n.replace( 319 "s__", f"{self.source_alias.upper() if self.source_alias else ''}__" 320 ) 321 if n.startswith("s__") 322 else n 323 ) 324 for c, n in joined_renamed_cols.items() 325 } 326 joined_renamed_cols = { 327 c: ( 328 n.replace( 329 "t__", f"{self.target_alias.upper() if self.target_alias else ''}__" 330 ) 331 if n.startswith("t__") 332 else n 333 ) 334 for c, n in joined_renamed_cols.items() 335 } 336 joined_sample = sample[sample["rows_joined"] == 1][joined_sample_cols] 337 joined_sample.rename( 338 columns=joined_renamed_cols, 339 inplace=True, 340 ) 341 342 s_sample = sample[(sample["s_exists"] == 1) & (sample["rows_joined"] == 0)][ 343 [ 344 *[f"s__{c}" for c in index_cols], 345 *[f"s__{c}" for c in self.source_schema if c not in index_cols], 346 ] 347 ] 348 s_sample.rename( 349 columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True 350 ) 351 352 t_sample = sample[(sample["t_exists"] == 1) & (sample["rows_joined"] == 0)][ 353 [ 354 *[f"t__{c}" for c in index_cols], 355 *[f"t__{c}" for c in self.target_schema if c not in index_cols], 356 ] 357 ] 358 t_sample.rename( 359 columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True 360 ) 361 362 sample.drop(columns=sample_filter_cols, inplace=True) 363 364 self._row_diff = RowDiff( 365 source=self.source, 366 target=self.target, 367 stats=stats, 368 column_stats=column_stats, 369 sample=sample, 370 joined_sample=joined_sample, 371 s_sample=s_sample, 372 t_sample=t_sample, 373 source_alias=self.source_alias, 374 target_alias=self.target_alias, 375 model_name=self.model_name, 376 ) 377 return self._row_diff
Calculates differences between tables, taking into account schema and row level differences.
TableDiff( adapter: <MagicMock id='140338205800528'>, source: <MagicMock id='140338206615488'>, target: <MagicMock id='140338206627632'>, on: 't.List[str] | exp.Condition', where: 't.Optional[str | exp.Condition]' = None, limit: int = 20, source_alias: Union[str, NoneType] = None, target_alias: Union[str, NoneType] = None, model_name: Union[str, NoneType] = None)
104 def __init__( 105 self, 106 adapter: EngineAdapter, 107 source: TableName, 108 target: TableName, 109 on: t.List[str] | exp.Condition, 110 where: t.Optional[str | exp.Condition] = None, 111 limit: int = 20, 112 source_alias: t.Optional[str] = None, 113 target_alias: t.Optional[str] = None, 114 model_name: t.Optional[str] = None, 115 ): 116 self.adapter = adapter 117 self.source = source 118 self.target = target 119 self.dialect = adapter.dialect 120 self.where = exp.condition(where, dialect=self.dialect) if where else None 121 self.limit = limit 122 self.model_name = model_name 123 124 # Support environment aliases for diff output improvement in certain cases 125 self.source_alias = source_alias 126 self.target_alias = target_alias 127 128 if isinstance(on, (list, tuple)): 129 s_table = exp.to_identifier("s", quoted=True) 130 t_table = exp.to_identifier("t", quoted=True) 131 132 self.on: exp.Condition = exp.and_( 133 *( 134 exp.column(c, s_table).eq(exp.column(c, t_table)) 135 | ( 136 exp.column(c, s_table).is_(exp.null()) 137 & exp.column(c, t_table).is_(exp.null()) 138 ) 139 for c in on 140 ) 141 ) 142 else: 143 self.on = on 144 145 normalize_identifiers(self.on, dialect=self.dialect) 146 147 self._source_schema: t.Optional[t.Dict[str, exp.DataType]] = None 148 self._target_schema: t.Optional[t.Dict[str, exp.DataType]] = None 149 self._row_diff: t.Optional[RowDiff] = None
174 def row_diff(self) -> RowDiff: 175 if self._row_diff is None: 176 s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in self.source_schema} 177 t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in self.target_schema} 178 179 index_cols = [] 180 s_index = [] 181 t_index = [] 182 183 for col in self.on.find_all(exp.Column): 184 index_cols.append(col.name) 185 if col.table == "s": 186 s_index.append(col) 187 elif col.table == "t": 188 t_index.append(col) 189 index_cols = list(dict.fromkeys(index_cols)) 190 191 comparisons = [ 192 exp.Case() 193 .when(exp.column(c, "s").eq(exp.column(c, "t")), exp.Literal.number(1)) 194 .when( 195 exp.column(c, "s").is_(exp.Null()) & exp.column(c, "t").is_(exp.Null()), 196 exp.Literal.number(1), 197 ) 198 .when( 199 exp.column(c, "s").is_(exp.Null()) | exp.column(c, "t").is_(exp.Null()), 200 exp.Literal.number(0), 201 ) 202 .else_(exp.Literal.number(0)) 203 .as_(f"{c}_matches") 204 for c, t in self.source_schema.items() 205 if t == self.target_schema.get(c) 206 ] 207 208 def name(e: exp.Expression) -> str: 209 return e.args["alias"].sql(identify=True) 210 211 query = ( 212 exp.select( 213 *s_selects.values(), 214 *t_selects.values(), 215 exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in s_index)), 1, 0).as_( 216 "s_exists" 217 ), 218 exp.func("IF", exp.or_(*(c.not_().is_(exp.Null()) for c in t_index)), 1, 0).as_( 219 "t_exists" 220 ), 221 exp.func( 222 "IF", 223 exp.and_( 224 *( 225 exp.and_( 226 exp.column(c, "s").eq(exp.column(c, "t")), 227 exp.column(c, "s").not_().is_(exp.Null()), 228 exp.column(c, "t").not_().is_(exp.Null()), 229 ) 230 for c in index_cols 231 ), 232 ), 233 1, 234 0, 235 ).as_("rows_joined"), 236 *comparisons, 237 ) 238 .from_(exp.alias_(self.source, "s")) 239 .join( 240 self.target, 241 on=self.on, 242 join_type="FULL", 243 join_alias="t", 244 ) 245 .where(self.where) 246 ) 247 248 query = quote_identifiers(query, dialect=self.dialect) 249 temp_table = exp.table_("diff", db="sqlmesh_temp", quoted=True) 250 251 with self.adapter.temp_table(query, name=temp_table) as table: 252 summary_query = exp.select( 253 exp.func("SUM", "s_exists").as_("s_count"), 254 exp.func("SUM", "t_exists").as_("t_count"), 255 exp.func("SUM", "rows_joined").as_("join_count"), 256 *(exp.func("SUM", name(c)).as_(c.alias) for c in comparisons), 257 ).from_(table) 258 259 stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True) 260 stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"] 261 stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"] 262 stats = stats_df.iloc[0].to_dict() 263 264 column_stats_query = ( 265 exp.select( 266 *( 267 exp.func( 268 "ROUND", 269 100 * (exp.func("SUM", name(c)) / exp.func("COUNT", name(c))), 270 1, 271 ).as_(c.alias) 272 for c in comparisons 273 ) 274 ) 275 .from_(table) 276 .where(exp.column("rows_joined").eq(exp.Literal.number(1))) 277 ) 278 column_stats = ( 279 self.adapter.fetchdf(column_stats_query, quote_identifiers=True) 280 .T.rename( 281 columns={0: "pct_match"}, 282 index=lambda x: str(x).replace("_matches", "") if x else "", 283 ) 284 .drop(index=index_cols) 285 ) 286 287 sample_filter_cols = ["s_exists", "t_exists", "rows_joined"] 288 sample_query = ( 289 exp.select( 290 *(sample_filter_cols), 291 *(name(c) for c in s_selects.values()), 292 *(name(c) for c in t_selects.values()), 293 ) 294 .from_(table) 295 .where(exp.or_(*(exp.column(c.alias).eq(0) for c in comparisons))) 296 .order_by( 297 *(name(s_selects[c.name]) for c in s_index), 298 *(name(t_selects[c.name]) for c in t_index), 299 ) 300 .limit(self.limit) 301 ) 302 sample = self.adapter.fetchdf(sample_query, quote_identifiers=True) 303 304 joined_sample_cols = [f"s__{c}" for c in index_cols] 305 comparison_cols = [ 306 (f"s__{c}", f"t__{c}") 307 for c in column_stats[column_stats["pct_match"] < 100].index 308 ] 309 for cols in comparison_cols: 310 joined_sample_cols.extend(cols) 311 joined_renamed_cols = { 312 c: c.split("__")[1] if c.split("__")[1] in index_cols else c 313 for c in joined_sample_cols 314 } 315 if self.source != self.source_alias and self.target != self.target_alias: 316 joined_renamed_cols = { 317 c: ( 318 n.replace( 319 "s__", f"{self.source_alias.upper() if self.source_alias else ''}__" 320 ) 321 if n.startswith("s__") 322 else n 323 ) 324 for c, n in joined_renamed_cols.items() 325 } 326 joined_renamed_cols = { 327 c: ( 328 n.replace( 329 "t__", f"{self.target_alias.upper() if self.target_alias else ''}__" 330 ) 331 if n.startswith("t__") 332 else n 333 ) 334 for c, n in joined_renamed_cols.items() 335 } 336 joined_sample = sample[sample["rows_joined"] == 1][joined_sample_cols] 337 joined_sample.rename( 338 columns=joined_renamed_cols, 339 inplace=True, 340 ) 341 342 s_sample = sample[(sample["s_exists"] == 1) & (sample["rows_joined"] == 0)][ 343 [ 344 *[f"s__{c}" for c in index_cols], 345 *[f"s__{c}" for c in self.source_schema if c not in index_cols], 346 ] 347 ] 348 s_sample.rename( 349 columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True 350 ) 351 352 t_sample = sample[(sample["t_exists"] == 1) & (sample["rows_joined"] == 0)][ 353 [ 354 *[f"t__{c}" for c in index_cols], 355 *[f"t__{c}" for c in self.target_schema if c not in index_cols], 356 ] 357 ] 358 t_sample.rename( 359 columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True 360 ) 361 362 sample.drop(columns=sample_filter_cols, inplace=True) 363 364 self._row_diff = RowDiff( 365 source=self.source, 366 target=self.target, 367 stats=stats, 368 column_stats=column_stats, 369 sample=sample, 370 joined_sample=joined_sample, 371 s_sample=s_sample, 372 t_sample=t_sample, 373 source_alias=self.source_alias, 374 target_alias=self.target_alias, 375 model_name=self.model_name, 376 ) 377 return self._row_diff