sqlmesh.core.test.definition
1from __future__ import annotations 2 3import sys 4 5import datetime 6import threading 7import typing as t 8import unittest 9from collections import Counter 10from contextlib import nullcontext, contextmanager, AbstractContextManager 11from itertools import chain 12from pathlib import Path 13from unittest.mock import patch 14 15 16from io import StringIO 17from sqlglot import Dialect, exp 18from sqlglot.optimizer.annotate_types import annotate_types 19from sqlglot.optimizer.normalize_identifiers import normalize_identifiers 20 21from sqlmesh.core import constants as c 22from sqlmesh.core.dialect import normalize_model_name, schema_ 23from sqlmesh.core.engine_adapter import EngineAdapter 24from sqlmesh.core.macros import RuntimeStage 25from sqlmesh.core.model import Model, PythonModel, SqlModel 26from sqlmesh.utils import UniqueKeyDict, random_id, type_is_known, yaml 27from sqlmesh.utils.date import date_dict, pandas_timestamp_to_pydatetime, to_datetime 28from sqlmesh.utils.errors import ConfigError, TestError 29from sqlmesh.utils.yaml import load as yaml_load 30from sqlmesh.utils import Verbosity 31from sqlmesh.utils.rich import df_to_table 32 33if t.TYPE_CHECKING: 34 import pandas as pd 35 36 from sqlglot.dialects.dialect import DialectType 37 38 Row = t.Dict[str, t.Any] 39 40 41TIME_KWARG_KEYS = { 42 "start", 43 "end", 44 "execution_time", 45 "latest", 46 # all built-in datetime macro var names 47 *date_dict(execution_time="1970-01-01", start="1970-01-01", end="1970-01-01").keys(), 48} 49 50 51class ModelTest(unittest.TestCase): 52 __test__ = False 53 54 CONCURRENT_RENDER_LOCK = threading.Lock() 55 56 def __init__( 57 self, 58 body: t.Dict[str, t.Any], 59 test_name: str, 60 model: Model, 61 models: UniqueKeyDict[str, Model], 62 engine_adapter: EngineAdapter, 63 dialect: str | None = None, 64 path: Path | None = None, 65 preserve_fixtures: bool = False, 66 default_catalog: str | None = None, 67 concurrency: bool = False, 68 verbosity: Verbosity = Verbosity.DEFAULT, 69 ) -> None: 70 """ModelTest encapsulates a unit test for a model. 71 72 Args: 73 body: A dictionary that contains test metadata like inputs and outputs. 74 test_name: The name of the test. 75 model: The model that is being tested. 76 models: All models to use for expansion and mapping of physical locations. 77 engine_adapter: The engine adapter to use. 78 dialect: The models' dialect, used for normalization purposes. 79 path: An optional path to the test definition yaml file. 80 preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. 81 """ 82 self.body = body 83 self.test_name = test_name 84 self.model = model 85 self.models = models 86 self.engine_adapter = engine_adapter 87 self.path = path 88 self.preserve_fixtures = preserve_fixtures 89 self.default_catalog = default_catalog 90 self.dialect = dialect 91 self.concurrency = concurrency 92 self.verbosity = verbosity 93 94 self._fixture_table_cache: t.Dict[str, exp.Table] = {} 95 self._normalized_column_name_cache: t.Dict[str, str] = {} 96 self._normalized_model_name_cache: t.Dict[t.Tuple[str, bool], str] = {} 97 98 self._test_adapter_dialect = Dialect.get_or_raise(self.engine_adapter.dialect) 99 100 self._validate_and_normalize_test() 101 102 if self.engine_adapter.default_catalog: 103 self._fixture_catalog: t.Optional[exp.Identifier] = normalize_identifiers( 104 exp.parse_identifier( 105 self.engine_adapter.default_catalog, dialect=self._test_adapter_dialect 106 ), 107 dialect=self._test_adapter_dialect, 108 ) 109 else: 110 self._fixture_catalog = None 111 112 # The test schema name is randomized to avoid concurrency issues, 113 # unless a schema is provided in the unit tests's body 114 self._fixture_schema = exp.parse_identifier( 115 self.body.get("schema") or f"sqlmesh_test_{random_id(short=True)}" 116 ) 117 self._qualified_fixture_schema = schema_(self._fixture_schema, self._fixture_catalog) 118 119 self._transforms = self._test_adapter_dialect.generator_class.TRANSFORMS 120 self._execution_time = str(self.body.get("vars", {}).get("execution_time") or "") 121 122 if self._execution_time: 123 # Normalizes the execution time by converting it into UTC timezone 124 self._execution_time = str(to_datetime(self._execution_time)) 125 126 # When execution_time is set, we mock the CURRENT_* SQL expressions so they always return it 127 if self._execution_time: 128 exec_time = exp.Literal.string(self._execution_time) 129 self._transforms = { 130 **self._transforms, 131 exp.CurrentDate: lambda self, _: self.sql( 132 exp.cast(exec_time, "date", dialect=dialect) 133 ), 134 exp.CurrentDatetime: lambda self, _: self.sql( 135 exp.cast(exec_time, "datetime", dialect=dialect) 136 ), 137 exp.CurrentTime: lambda self, _: self.sql( 138 exp.cast(exec_time, "time", dialect=dialect) 139 ), 140 exp.CurrentTimestamp: lambda self, _: self.sql( 141 exp.cast(exec_time, "timestamp", dialect=dialect) 142 ), 143 } 144 145 super().__init__() 146 147 def defaultTestResult(self) -> unittest.TestResult: 148 from sqlmesh.core.test.result import ModelTextTestResult 149 150 return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity) 151 152 def shortDescription(self) -> t.Optional[str]: 153 return self.body.get("description") 154 155 def setUp(self) -> None: 156 """Load all input tables""" 157 import pandas as pd 158 import numpy as np 159 160 self.engine_adapter.create_schema(self._qualified_fixture_schema) 161 162 for name, values in self.body.get("inputs", {}).items(): 163 all_types_are_known = False 164 columns_to_known_types: t.Dict[str, exp.DataType] = {} 165 166 model = self.models.get(name) 167 if model: 168 inferred_columns_to_types = model.columns_to_types or {} 169 columns_to_known_types = { 170 c: t for c, t in inferred_columns_to_types.items() if type_is_known(t) 171 } 172 all_types_are_known = bool(inferred_columns_to_types) and ( 173 len(columns_to_known_types) == len(inferred_columns_to_types) 174 ) 175 176 # Types specified in the test will override the corresponding inferred ones 177 columns_to_known_types.update(values.get("columns", {})) 178 179 rows = values.get("rows") 180 if not all_types_are_known and rows: 181 for col, value in rows[0].items(): 182 if col not in columns_to_known_types: 183 v_type = annotate_types(exp.convert(value)).type or type(value).__name__ 184 v_type = exp.maybe_parse( 185 v_type, into=exp.DataType, dialect=self._test_adapter_dialect 186 ) 187 188 if not type_is_known(v_type): 189 _raise_error( 190 f"Failed to infer the data type of column '{col}' for '{name}'. This issue can be " 191 "mitigated by casting the column in the model definition, setting its type in " 192 "external_models.yaml if it's an external model, setting the model's 'columns' property, " 193 "or setting its 'columns' mapping in the test itself", 194 self.path, 195 ) 196 197 columns_to_known_types[col] = v_type 198 199 if rows is None: 200 query_or_df: exp.Query | pd.DataFrame = self._add_missing_columns( 201 values["query"], columns_to_known_types 202 ) 203 if columns_to_known_types: 204 columns_to_known_types = { 205 col: columns_to_known_types[col] for col in query_or_df.named_selects 206 } 207 else: 208 query_or_df = self._create_df(values, columns=columns_to_known_types) 209 210 # Convert NaN/NaT values to None if DataFrame 211 if isinstance(query_or_df, pd.DataFrame): 212 query_or_df = query_or_df.replace({np.nan: None}) 213 214 self.engine_adapter.create_view( 215 self._test_fixture_table(name), query_or_df, columns_to_known_types 216 ) 217 218 def tearDown(self) -> None: 219 """Drop all fixture tables.""" 220 if not self.preserve_fixtures: 221 self.engine_adapter.drop_schema(self._qualified_fixture_schema, cascade=True) 222 223 def assert_equal( 224 self, 225 expected: pd.DataFrame, 226 actual: pd.DataFrame, 227 sort: bool, 228 partial: t.Optional[bool] = False, 229 ) -> None: 230 """Compare two DataFrames""" 231 import numpy as np 232 import pandas as pd 233 from pandas.api.types import is_object_dtype 234 235 if partial: 236 intersection = actual[actual.columns.intersection(expected.columns)] 237 if len(intersection.columns) > 0: 238 actual = intersection 239 240 # Two astypes are necessary, pandas converts strings to times as NS, 241 # but if the actual is US, it doesn't take effect until the 2nd try! 242 actual_types = actual.dtypes.to_dict() 243 expected = expected.astype(actual_types, errors="ignore").astype( 244 actual_types, errors="ignore" 245 ) 246 247 # The `actual` df's dtypes will almost always be pd.Timestamp for datetime values, 248 # but in some scenarios (e.g., DuckDB >=0.10.2) it will be a pandas `object` type 249 # containing python `datetime.xxx` values. 250 # 251 # Pandas `object` columns result in a noop for the `astype` call above. Because any 252 # quoted YAML value is a string, we must manually convert the `expected` df string 253 # values to the correct `datetime.xxx` type. 254 # 255 # We determine the type from a single sentinel value, but since the `actual` df is 256 # coming from a database query, it is safe to assume that the column contains only 257 # a single type. 258 object_sentinel_values = { 259 col: actual[col][0] 260 for col in actual_types 261 if is_object_dtype(actual_types[col]) and len(actual[col]) != 0 262 } 263 for col, value in object_sentinel_values.items(): 264 try: 265 # can't use `isinstance()` here - https://stackoverflow.com/a/68743663/1707525 266 if type(value) is datetime.date: 267 expected[col] = pd.to_datetime(expected[col]).dt.date 268 elif type(value) is datetime.time: 269 expected[col] = pd.to_datetime(expected[col]).dt.time 270 elif type(value) is datetime.datetime: 271 expected[col] = pd.to_datetime(expected[col]).dt.to_pydatetime() 272 except Exception as e: 273 from sqlmesh.core.console import get_console 274 275 get_console().log_warning( 276 f"Failed to convert expected value for {col} into `datetime` " 277 f"for unit test '{str(self)}'. {str(e)}." 278 ) 279 280 actual = actual.replace({np.nan: None}) 281 expected = expected.replace({np.nan: None}) 282 283 # We define this here to avoid a top-level import of numpy and pandas 284 DATETIME_TYPES = ( 285 datetime.datetime, 286 datetime.date, 287 datetime.time, 288 np.datetime64, 289 pd.Timestamp, 290 ) 291 292 def _to_hashable(x: t.Any) -> t.Any: 293 if isinstance(x, (list, np.ndarray)): 294 return tuple(_to_hashable(v) for v in x) 295 if isinstance(x, dict): 296 return tuple((k, _to_hashable(v)) for k, v in x.items()) 297 return str(x) if isinstance(x, DATETIME_TYPES) or not isinstance(x, t.Hashable) else x 298 299 actual = actual.apply(lambda col: col.map(_to_hashable)) 300 expected = expected.apply(lambda col: col.map(_to_hashable)) 301 302 if sort: 303 actual = actual.sort_values(by=actual.columns.to_list()).reset_index(drop=True) 304 expected = expected.sort_values(by=expected.columns.to_list()).reset_index(drop=True) 305 306 try: 307 pd.testing.assert_frame_equal( 308 expected, 309 actual, 310 check_dtype=False, 311 check_datetimelike_compat=True, 312 check_like=True, # Ignore column order 313 ) 314 except AssertionError as e: 315 # There are 2 concepts at play here: 316 # 1. The Exception args will contain the error message plus the diff dataframe table stringified 317 # (backwards compatibility with existing tests, possible to serialize/send over network etc) 318 # 2. Each test will also transform these diff dataframes into Rich tables, which will be the ones that'll 319 # be surfaced to the user through Console for better UX (versus stringified dataframes) 320 # 321 # This is a bit of a hack, but it's a way to get the best of both worlds. 322 args: t.List[t.Any] = [] 323 324 failed_subtest = "" 325 326 if subtest := getattr(self, "_subtest", None): 327 if cte := subtest.params.get("cte"): 328 failed_subtest = f" (CTE {cte})" 329 330 if expected.shape != actual.shape: 331 _raise_if_unexpected_columns(expected.columns, actual.columns) 332 333 args.append("Data mismatch (rows are different)") 334 335 missing_rows = _row_difference(expected, actual) 336 if not missing_rows.empty: 337 args[0] += f"\n\nMissing rows:\n\n{missing_rows}" 338 args.append(df_to_table(f"Missing rows{failed_subtest}", missing_rows)) 339 340 unexpected_rows = _row_difference(actual, expected) 341 342 if not unexpected_rows.empty: 343 args[0] += f"\n\nUnexpected rows:\n\n{unexpected_rows}" 344 args.append(df_to_table(f"Unexpected rows{failed_subtest}", unexpected_rows)) 345 346 else: 347 diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"}) 348 349 args.append(f"Data mismatch (exp: expected, act: actual)\n\n{diff}") 350 351 diff.rename(columns={"exp": "Expected", "act": "Actual"}, inplace=True) 352 if self.verbosity == Verbosity.DEFAULT: 353 args.extend( 354 df_to_table(f"Data mismatch{failed_subtest}", df) 355 for df in _split_df_by_column_pairs(diff) 356 ) 357 else: 358 from pandas import DataFrame, MultiIndex 359 360 levels = t.cast(MultiIndex, diff.columns).levels[0] 361 for col in levels: 362 # diff[col] returns a DataFrame when columns is a MultiIndex 363 col_diff = t.cast(DataFrame, diff[col]) 364 if not col_diff.empty: 365 table = df_to_table( 366 f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]", 367 col_diff, 368 ) 369 args.append(table) 370 371 e.args = (*args,) 372 373 raise e 374 375 def runTest(self) -> None: 376 raise NotImplementedError 377 378 def path_relative_to(self, other: Path) -> Path | None: 379 """Compute a version of this test's path relative to the `other` path""" 380 return self.path.relative_to(other) if self.path else None 381 382 @staticmethod 383 def create_test( 384 body: t.Dict[str, t.Any], 385 test_name: str, 386 models: UniqueKeyDict[str, Model], 387 engine_adapter: EngineAdapter, 388 dialect: str | None, 389 path: Path | None, 390 preserve_fixtures: bool = False, 391 default_catalog: str | None = None, 392 concurrency: bool = False, 393 verbosity: Verbosity = Verbosity.DEFAULT, 394 ) -> t.Optional[ModelTest]: 395 """Create a SqlModelTest or a PythonModelTest. 396 397 Args: 398 body: A dictionary that contains test metadata like inputs and outputs. 399 test_name: The name of the test. 400 models: All models to use for expansion and mapping of physical locations. 401 engine_adapter: The engine adapter to use. 402 dialect: The models' dialect, used for normalization purposes. 403 path: An optional path to the test definition yaml file. 404 preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. 405 """ 406 name = body.get("model") 407 if name is None: 408 _raise_error("Missing required 'model' field", path) 409 410 name = normalize_model_name(name, default_catalog=default_catalog, dialect=dialect) 411 model = models.get(name) 412 if not model: 413 from sqlmesh.core.console import get_console 414 415 get_console().log_warning( 416 f"Model '{name}' was not found{' at ' + str(path) if path else ''}" 417 ) 418 return None 419 420 if isinstance(model, SqlModel): 421 test_type: t.Type[ModelTest] = SqlModelTest 422 elif isinstance(model, PythonModel): 423 test_type = PythonModelTest 424 else: 425 _raise_error(f"Model '{name}' is an unsupported model type for testing", path) 426 427 try: 428 return test_type( 429 body, 430 test_name, 431 t.cast(Model, model), 432 models, 433 engine_adapter, 434 dialect, 435 path, 436 preserve_fixtures, 437 default_catalog, 438 concurrency, 439 verbosity, 440 ) 441 except Exception as e: 442 raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}") 443 444 def __str__(self) -> str: 445 return f"{self.test_name} ({self.path})" 446 447 def _validate_and_normalize_test(self) -> None: 448 inputs = self.body.get("inputs") 449 outputs = self.body.get("outputs", {}) 450 451 if not outputs: 452 _raise_error("Incomplete test, missing outputs", self.path) 453 454 ctes = outputs.get("ctes") 455 query = outputs.get("query") 456 partial = outputs.pop("partial", None) 457 458 if ctes is None and query is None: 459 _raise_error("Incomplete test, outputs must contain 'query' or 'ctes'", self.path) 460 461 def _normalize_rows( 462 values: t.List[Row] | t.Dict, 463 name: str, 464 partial: bool = False, 465 dialect: DialectType = None, 466 ) -> t.Dict: 467 import pandas as pd 468 469 if not isinstance(values, dict): 470 values = {"rows": values} 471 472 rows = values.get("rows") 473 query = values.get("query") 474 475 fmt = values.get("format") 476 path = values.get("path") 477 if fmt == "csv": 478 csv_settings = values.get("csv_settings") or {} 479 rows = pd.read_csv(path or StringIO(rows), **csv_settings).to_dict(orient="records") 480 elif fmt in (None, "yaml"): 481 if path: 482 input_rows = yaml_load(Path(path)) 483 rows = input_rows.get("rows") if isinstance(input_rows, dict) else input_rows 484 else: 485 _raise_error(f"Unsupported data format '{fmt}' for '{name}'", self.path) 486 487 if query is not None: 488 if rows is not None: 489 _raise_error( 490 f"Invalid test, cannot set both 'query' and 'rows' for '{name}'", self.path 491 ) 492 493 # We parse the user-supplied query using the testing adapter dialect, but we 494 # normalize its identifiers according to the model's dialect, so that, e.g., 495 # the projection names match those in its `columns_to_types` field 496 values["query"] = normalize_identifiers( 497 exp.maybe_parse(query, dialect=self._test_adapter_dialect), dialect=dialect 498 ) 499 return values 500 501 if rows is None: 502 _raise_error(f"Incomplete test, missing row data for '{name}'", self.path) 503 504 assert isinstance(rows, list) 505 values["rows"] = [ 506 {self._normalize_column_name(column): value for column, value in row.items()} 507 for row in rows 508 ] 509 if partial: 510 values["partial"] = True 511 512 return values 513 514 def _normalize_sources( 515 sources: t.Dict, partial: bool = False, with_default_catalog: bool = True 516 ) -> t.Dict: 517 normalized_sources = {} 518 for name, values in sources.items(): 519 normalized_name = self._normalize_model_name( 520 name, with_default_catalog=with_default_catalog 521 ) 522 model = self.models.get(normalized_name) 523 dialect = model.dialect if model else self.dialect 524 525 normalized_sources[normalized_name] = _normalize_rows( 526 values, name, partial=partial, dialect=dialect 527 ) 528 529 return normalized_sources 530 531 normalized_model_name = self._normalize_model_name(self.body["model"]) 532 self.body["model"] = normalized_model_name 533 534 if inputs: 535 inputs = _normalize_sources(inputs) 536 for name, values in inputs.items(): 537 columns = values.get("columns") 538 if columns is None: 539 continue 540 541 if not isinstance(columns, dict): 542 _raise_error( 543 f"Invalid 'columns' value for model '{name}', expected a mapping name -> type", 544 self.path, 545 ) 546 547 values["columns"] = { 548 self._normalize_column_name(c): exp.DataType.build( 549 t, dialect=self._test_adapter_dialect 550 ) 551 for c, t in columns.items() 552 } 553 554 for depends_on in self.model.depends_on: 555 if depends_on not in inputs: 556 _raise_error(f"Incomplete test, missing input model '{depends_on}'", self.path) 557 558 if self.model.depends_on_self and normalized_model_name not in inputs: 559 inputs[normalized_model_name] = {"rows": []} 560 561 self.body["inputs"] = inputs 562 563 if ctes: 564 outputs["ctes"] = _normalize_sources(ctes, partial=partial, with_default_catalog=False) 565 566 if query or query == []: 567 outputs["query"] = _normalize_rows( 568 query, self.model.name, partial=partial, dialect=self.model.dialect 569 ) 570 571 def _test_fixture_table(self, name: str) -> exp.Table: 572 table = self._fixture_table_cache.get(name) 573 if not table: 574 table = exp.to_table(name, dialect=self._test_adapter_dialect) 575 576 # We change the table path below, so this ensures there are no name clashes 577 table.this.set("this", "__".join(part.name for part in table.parts)) 578 579 table.set("db", self._fixture_schema.copy()) 580 if self._fixture_catalog: 581 table.set("catalog", self._fixture_catalog.copy()) 582 583 self._fixture_table_cache[name] = table 584 585 return table 586 587 def _normalize_model_name(self, name: str, with_default_catalog: bool = True) -> str: 588 normalized_name = self._normalized_model_name_cache.get((name, with_default_catalog)) 589 if normalized_name is None: 590 default_catalog = self.default_catalog if with_default_catalog else None 591 normalized_name = normalize_model_name( 592 name, default_catalog=default_catalog, dialect=self.dialect 593 ) 594 self._normalized_model_name_cache[(name, with_default_catalog)] = normalized_name 595 596 return normalized_name 597 598 def _normalize_column_name(self, name: str) -> str: 599 normalized_name = self._normalized_column_name_cache.get(name) 600 if normalized_name is None: 601 normalized_name = normalize_identifiers(name, dialect=self.dialect).name 602 self._normalized_column_name_cache[name] = normalized_name 603 604 return normalized_name 605 606 @contextmanager 607 def _concurrent_render_context(self) -> t.Iterator[None]: 608 """ 609 Context manager that ensures that the tests are executed safely in a concurrent environment. 610 This is needed in case `execution_time` is set, as we'd then have to: 611 - Freeze time through `time_machine` (not thread safe) 612 - Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation 613 """ 614 import time_machine 615 616 lock_ctx: AbstractContextManager = ( 617 self.CONCURRENT_RENDER_LOCK if self.concurrency else nullcontext() 618 ) 619 time_ctx: AbstractContextManager = nullcontext() 620 dialect_patch_ctx: AbstractContextManager = nullcontext() 621 622 if self._execution_time: 623 time_ctx = time_machine.travel(self._execution_time, tick=False) 624 dialect_patch_ctx = patch.dict( 625 self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms 626 ) 627 628 with lock_ctx, time_ctx, dialect_patch_ctx: 629 yield 630 631 def _execute(self, query: exp.Query | str) -> pd.DataFrame: 632 """Executes the given query using the testing engine adapter and returns a DataFrame.""" 633 return self.engine_adapter.fetchdf(query) 634 635 def _create_df( 636 self, 637 values: t.Dict[str, t.Any], 638 columns: t.Optional[t.Collection] = None, 639 partial: t.Optional[bool] = False, 640 ) -> pd.DataFrame: 641 import pandas as pd 642 643 query = values.get("query") 644 if query: 645 if not partial: 646 query = self._add_missing_columns(query, columns) 647 648 return self._execute(query) 649 650 rows = values["rows"] 651 columns_str: t.Optional[t.List[str]] = None 652 if columns: 653 columns_str = [str(c) for c in columns] 654 referenced_columns = list(dict.fromkeys(col for row in rows for col in row)) 655 _raise_if_unexpected_columns(columns, referenced_columns) 656 657 if partial: 658 columns_str = [c for c in columns_str if c in referenced_columns] 659 660 return pd.DataFrame.from_records(rows, columns=columns_str) 661 662 def _add_missing_columns( 663 self, query: exp.Query, all_columns: t.Optional[t.Collection[str]] = None 664 ) -> exp.Query: 665 if not all_columns or query.is_star: 666 return query 667 668 query_columns = set(query.named_selects) 669 missing_columns = [col for col in all_columns if col not in query_columns] 670 if missing_columns: 671 query.select(*[exp.null().as_(col) for col in missing_columns], copy=False) 672 673 return query 674 675 676class SqlModelTest(ModelTest): 677 def test_ctes(self, ctes: t.Dict[str, exp.Expr], recursive: bool = False) -> None: 678 """Run CTE queries and compare output to expected output""" 679 for cte_name, values in self.body["outputs"].get("ctes", {}).items(): 680 with self.subTest(cte=cte_name): 681 if cte_name not in ctes: 682 _raise_error( 683 f"No CTE named {cte_name} found in model {self.model.name}", self.path 684 ) 685 686 cte_query = ctes[cte_name].this 687 688 sort = cte_query.args.get("order") is None 689 partial = values.get("partial") 690 691 cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_name) 692 for alias, cte in ctes.items(): 693 cte_query = cte_query.with_(alias, cte.this, recursive=recursive) 694 695 with self._concurrent_render_context(): 696 # Similar to the model's query, we render the CTE query under the locked context 697 # so that the execution (fetchdf) can continue concurrently between the threads 698 sql = cte_query.sql( 699 self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql 700 ) 701 702 actual = self._execute(sql) 703 expected = self._create_df(values, columns=cte_query.named_selects, partial=partial) 704 705 self.assert_equal(expected, actual, sort=sort, partial=partial) 706 707 def runTest(self) -> None: 708 with self._concurrent_render_context(): 709 # Render the model's query and generate the SQL under the locked context so that 710 # execution (fetchdf) can continue concurrently between the threads 711 query = self._render_model_query() 712 sql = query.sql(self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql) 713 714 with_clause = query.args.get("with_") 715 716 if with_clause: 717 self.test_ctes( 718 { 719 self._normalize_model_name(cte.alias, with_default_catalog=False): cte 720 for cte in query.ctes 721 }, 722 recursive=with_clause.recursive, 723 ) 724 725 values = self.body["outputs"].get("query") 726 if values is not None: 727 partial = values.get("partial") 728 sort = query.args.get("order") is None 729 730 actual = self._execute(sql) 731 expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial) 732 733 self.assert_equal(expected, actual, sort=sort, partial=partial) 734 735 def _render_model_query(self) -> exp.Query: 736 variables = self.body.get("vars", {}).copy() 737 time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables} 738 739 query = self.model.render_query_or_raise( 740 **time_kwargs, 741 variables=variables, 742 engine_adapter=self.engine_adapter, 743 table_mapping={ 744 name: self._test_fixture_table(name).sql() for name in self.body.get("inputs", {}) 745 }, 746 runtime_stage=RuntimeStage.TESTING, 747 ) 748 return query 749 750 751class PythonModelTest(ModelTest): 752 def __init__( 753 self, 754 body: t.Dict[str, t.Any], 755 test_name: str, 756 model: Model, 757 models: UniqueKeyDict[str, Model], 758 engine_adapter: EngineAdapter, 759 dialect: str | None = None, 760 path: Path | None = None, 761 preserve_fixtures: bool = False, 762 default_catalog: str | None = None, 763 concurrency: bool = False, 764 verbosity: Verbosity = Verbosity.DEFAULT, 765 ) -> None: 766 """PythonModelTest encapsulates a unit test for a Python model. 767 768 Args: 769 body: A dictionary that contains test metadata like inputs and outputs. 770 test_name: The name of the test. 771 model: The Python model that is being tested. 772 models: All models to use for expansion and mapping of physical locations. 773 engine_adapter: The engine adapter to use. 774 dialect: The models' dialect, used for normalization purposes. 775 path: An optional path to the test definition yaml file. 776 preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. 777 """ 778 from sqlmesh.core.test.context import TestExecutionContext 779 780 super().__init__( 781 body, 782 test_name, 783 model, 784 models, 785 engine_adapter, 786 dialect, 787 path, 788 preserve_fixtures, 789 default_catalog, 790 concurrency, 791 verbosity, 792 ) 793 794 self.context = TestExecutionContext( 795 engine_adapter=engine_adapter, 796 models=models, 797 test=self, 798 default_dialect=dialect, 799 default_catalog=default_catalog, 800 ) 801 802 def runTest(self) -> None: 803 values = self.body["outputs"].get("query") 804 if values is not None: 805 partial = values.get("partial") 806 807 actual_df = self._execute_model() 808 actual_df.reset_index(drop=True, inplace=True) 809 expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial) 810 811 self.assert_equal(expected, actual_df, sort=True, partial=partial) 812 813 def _execute_model(self) -> pd.DataFrame: 814 """Executes the python model and returns a DataFrame.""" 815 import pandas as pd 816 817 with self._concurrent_render_context(): 818 variables = self.body.get("vars", {}).copy() 819 time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables} 820 df = next(self.model.render(context=self.context, variables=variables, **time_kwargs)) 821 822 assert not isinstance(df, exp.Expr) 823 return df if isinstance(df, pd.DataFrame) else df.toPandas() 824 825 826def generate_test( 827 model: Model, 828 input_queries: t.Dict[str, str], 829 models: UniqueKeyDict[str, Model], 830 engine_adapter: EngineAdapter, 831 test_engine_adapter: EngineAdapter, 832 project_path: Path, 833 overwrite: bool = False, 834 variables: t.Optional[t.Dict[str, str]] = None, 835 path: t.Optional[str] = None, 836 name: t.Optional[str] = None, 837 include_ctes: bool = False, 838) -> None: 839 """Generate a unit test fixture for a given model. 840 841 Args: 842 model: The model to test. 843 input_queries: Mapping of model names to queries. Each model included in this mapping 844 will be populated in the test based on the results of the corresponding query. 845 models: The context's models. 846 engine_adapter: The target engine adapter. 847 test_engine_adapter: The test engine adapter. 848 project_path: The path pointing to the project's root directory. 849 overwrite: Whether to overwrite the existing test in case of a file path collision. 850 When set to False, an error will be raised if there is such a collision. 851 variables: Key-value pairs that will define variables needed by the model. 852 path: The file path corresponding to the fixture, relative to the test directory. 853 By default, the fixture will be created under the test directory and the file name 854 will be inferred from the test's name. 855 name: The name of the test. This is inferred from the model name by default. 856 include_ctes: When true, CTE fixtures will also be generated. 857 """ 858 import numpy as np 859 860 test_name = name or f"test_{model.view_name}" 861 path = path or f"{test_name}.yaml" 862 863 extension = path.split(".")[-1].lower() 864 if extension not in ("yaml", "yml"): 865 path = f"{path}.yaml" 866 867 fixture_path = project_path / c.TESTS / path 868 if not overwrite and fixture_path.exists(): 869 raise ConfigError( 870 f"Fixture '{fixture_path}' already exists, make sure to set --overwrite if it can be safely overwritten." 871 ) 872 873 # ruamel.yaml does not support pandas Timestamps, so we must convert them to python 874 # datetime or datetime.date objects based on column type 875 inputs = { 876 dep: pandas_timestamp_to_pydatetime( 877 engine_adapter.fetchdf(query).apply(lambda col: col.map(_normalize_df_value)), 878 models[dep].columns_to_types, 879 ) 880 .replace({np.nan: None}) 881 .to_dict(orient="records") 882 for dep, query in input_queries.items() 883 } 884 outputs: t.Dict[str, t.Any] = {"query": {}} 885 variables = variables or {} 886 test_body = {"model": model.fqn, "inputs": inputs, "outputs": outputs} 887 888 if variables: 889 test_body["vars"] = variables 890 891 test = ModelTest.create_test( 892 body=test_body.copy(), 893 test_name=test_name, 894 models=models, 895 engine_adapter=test_engine_adapter, 896 dialect=model.dialect, 897 path=fixture_path, 898 default_catalog=model.default_catalog, 899 ) 900 if not test: 901 return 902 903 test.setUp() 904 905 if isinstance(model, SqlModel): 906 assert isinstance(test, SqlModelTest) 907 model_query = test._render_model_query() 908 with_clause = model_query.args.get("with_") 909 910 if with_clause and include_ctes: 911 ctes = {} 912 recursive = with_clause.recursive 913 previous_ctes: t.List[exp.CTE] = [] 914 915 for cte in model_query.ctes: 916 cte_query = cte.this 917 cte_identifier = cte.args["alias"].this 918 919 cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_identifier) 920 921 for prev in chain(previous_ctes, [cte]): 922 cte_query = cte_query.with_( 923 prev.args["alias"].this, prev.this, recursive=recursive 924 ) 925 926 cte_output = test._execute(cte_query) 927 ctes[cte.alias] = ( 928 pandas_timestamp_to_pydatetime( 929 df=cte_output.apply(lambda col: col.map(_normalize_df_value)), 930 ) 931 .replace({np.nan: None}) 932 .to_dict(orient="records") 933 ) 934 935 previous_ctes.append(cte) 936 937 if ctes: 938 outputs["ctes"] = ctes 939 940 output = test._execute(model_query) 941 else: 942 output = t.cast(PythonModelTest, test)._execute_model() 943 944 outputs["query"] = ( 945 pandas_timestamp_to_pydatetime( 946 output.apply(lambda col: col.map(_normalize_df_value)), model.columns_to_types 947 ) 948 .replace({np.nan: None}) 949 .to_dict(orient="records") 950 ) 951 952 test.tearDown() 953 954 fixture_path.parent.mkdir(exist_ok=True, parents=True) 955 with open(fixture_path, "w", encoding="utf-8") as file: 956 yaml.dump({test_name: test_body}, file) 957 958 959def _projection_identifiers(query: exp.Query) -> t.List[str | exp.Identifier]: 960 identifiers: t.List[str | exp.Identifier] = [] 961 for select in query.selects: 962 if isinstance(select, exp.Alias): 963 identifiers.append(select.args["alias"]) 964 elif isinstance(select, exp.Column): 965 identifiers.append(select.this) 966 else: 967 identifiers.append(select.output_name) 968 969 return identifiers 970 971 972def _raise_if_unexpected_columns( 973 expected_cols: t.Collection[str], actual_cols: t.Collection[str] 974) -> None: 975 unique_expected_cols = set(expected_cols) 976 unknown_cols = [col for col in actual_cols if col not in unique_expected_cols] 977 978 if unknown_cols: 979 expected = f"Expected column(s): {', '.join(list(expected_cols))}\n" 980 unknown = f"Unknown column(s): {', '.join(unknown_cols)}" 981 _raise_error(f"Detected unknown column(s)\n\n{expected}{unknown}") 982 983 984def _row_difference(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: 985 """Returns all rows in `left` that don't appear in `right`.""" 986 import numpy as np 987 import pandas as pd 988 989 rows_missing_from_right = [] 990 991 # `None` replaces `np.nan` because `np.nan != np.nan` and this would affect the mapping lookup 992 right_row_count: t.MutableMapping[t.Tuple, int] = Counter( 993 right.replace({np.nan: None}).itertuples(index=False, name=None) 994 ) 995 for left_row in left.replace({np.nan: None}).itertuples(index=False): 996 left_row_tuple = tuple(left_row) 997 if right_row_count[left_row_tuple] <= 0: 998 rows_missing_from_right.append(left_row) 999 else: 1000 right_row_count[left_row_tuple] -= 1 1001 1002 return pd.DataFrame(rows_missing_from_right) 1003 1004 1005def _raise_error(msg: str, path: Path | None = None) -> None: 1006 if path: 1007 raise TestError(f"Failed to run test at {path}:\n{msg}") 1008 raise TestError(f"Failed to run test:\n{msg}") 1009 1010 1011def _normalize_df_value(value: t.Any) -> t.Any: 1012 """Normalize data in a pandas dataframe so ruamel and sqlglot can deal with it.""" 1013 import numpy as np 1014 1015 if isinstance(value, (list, np.ndarray)): 1016 return [_normalize_df_value(v) for v in value] 1017 if isinstance(value, dict): 1018 if "key" in value and "value" in value: 1019 # Maps returned by DuckDB look like: {'key': ['key1', 'key2'], 'value': [10, 20]} 1020 # so we convert to {'key1': 10, 'key2': 20} (TODO: handle more dialects here) 1021 return {k: _normalize_df_value(v) for k, v in zip(value["key"], value["value"])} 1022 return {k: _normalize_df_value(v) for k, v in value.items()} 1023 return value 1024 1025 1026def _split_df_by_column_pairs(df: pd.DataFrame, pairs_per_chunk: int = 4) -> t.List[pd.DataFrame]: 1027 """Split a dataframe into chunks of column pairs. 1028 1029 Args: 1030 df: The dataframe to split 1031 pairs_per_chunk: Number of column pairs per chunk (default: 4) 1032 1033 Returns: 1034 List of dataframes, each containing an even number of columns 1035 """ 1036 total_columns = len(df.columns) 1037 1038 # If we have fewer columns than pairs_per_chunk * 2, return the original df 1039 if total_columns <= pairs_per_chunk * 2: 1040 return [df] 1041 1042 # Calculate number of chunks needed to split columns evenly 1043 num_chunks = (total_columns + (pairs_per_chunk * 2 - 1)) // (pairs_per_chunk * 2) 1044 1045 # Calculate columns per chunk to ensure equal distribution 1046 # We round down to nearest even number to ensure each chunk has even columns 1047 columns_per_chunk = (total_columns // num_chunks) & ~1 # Round down to nearest even number 1048 remainder = total_columns - (columns_per_chunk * num_chunks) 1049 1050 chunks = [] 1051 start_idx = 0 1052 1053 # Distribute columns evenly across chunks 1054 for i in range(num_chunks): 1055 # Add 2 columns to early chunks if there's a remainder 1056 # This ensures we always add pairs of columns 1057 extra = 2 if i < remainder // 2 else 0 1058 end_idx = start_idx + columns_per_chunk + extra 1059 chunk = df.iloc[:, start_idx:end_idx] 1060 chunks.append(chunk) 1061 start_idx = end_idx 1062 1063 return chunks
52class ModelTest(unittest.TestCase): 53 __test__ = False 54 55 CONCURRENT_RENDER_LOCK = threading.Lock() 56 57 def __init__( 58 self, 59 body: t.Dict[str, t.Any], 60 test_name: str, 61 model: Model, 62 models: UniqueKeyDict[str, Model], 63 engine_adapter: EngineAdapter, 64 dialect: str | None = None, 65 path: Path | None = None, 66 preserve_fixtures: bool = False, 67 default_catalog: str | None = None, 68 concurrency: bool = False, 69 verbosity: Verbosity = Verbosity.DEFAULT, 70 ) -> None: 71 """ModelTest encapsulates a unit test for a model. 72 73 Args: 74 body: A dictionary that contains test metadata like inputs and outputs. 75 test_name: The name of the test. 76 model: The model that is being tested. 77 models: All models to use for expansion and mapping of physical locations. 78 engine_adapter: The engine adapter to use. 79 dialect: The models' dialect, used for normalization purposes. 80 path: An optional path to the test definition yaml file. 81 preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. 82 """ 83 self.body = body 84 self.test_name = test_name 85 self.model = model 86 self.models = models 87 self.engine_adapter = engine_adapter 88 self.path = path 89 self.preserve_fixtures = preserve_fixtures 90 self.default_catalog = default_catalog 91 self.dialect = dialect 92 self.concurrency = concurrency 93 self.verbosity = verbosity 94 95 self._fixture_table_cache: t.Dict[str, exp.Table] = {} 96 self._normalized_column_name_cache: t.Dict[str, str] = {} 97 self._normalized_model_name_cache: t.Dict[t.Tuple[str, bool], str] = {} 98 99 self._test_adapter_dialect = Dialect.get_or_raise(self.engine_adapter.dialect) 100 101 self._validate_and_normalize_test() 102 103 if self.engine_adapter.default_catalog: 104 self._fixture_catalog: t.Optional[exp.Identifier] = normalize_identifiers( 105 exp.parse_identifier( 106 self.engine_adapter.default_catalog, dialect=self._test_adapter_dialect 107 ), 108 dialect=self._test_adapter_dialect, 109 ) 110 else: 111 self._fixture_catalog = None 112 113 # The test schema name is randomized to avoid concurrency issues, 114 # unless a schema is provided in the unit tests's body 115 self._fixture_schema = exp.parse_identifier( 116 self.body.get("schema") or f"sqlmesh_test_{random_id(short=True)}" 117 ) 118 self._qualified_fixture_schema = schema_(self._fixture_schema, self._fixture_catalog) 119 120 self._transforms = self._test_adapter_dialect.generator_class.TRANSFORMS 121 self._execution_time = str(self.body.get("vars", {}).get("execution_time") or "") 122 123 if self._execution_time: 124 # Normalizes the execution time by converting it into UTC timezone 125 self._execution_time = str(to_datetime(self._execution_time)) 126 127 # When execution_time is set, we mock the CURRENT_* SQL expressions so they always return it 128 if self._execution_time: 129 exec_time = exp.Literal.string(self._execution_time) 130 self._transforms = { 131 **self._transforms, 132 exp.CurrentDate: lambda self, _: self.sql( 133 exp.cast(exec_time, "date", dialect=dialect) 134 ), 135 exp.CurrentDatetime: lambda self, _: self.sql( 136 exp.cast(exec_time, "datetime", dialect=dialect) 137 ), 138 exp.CurrentTime: lambda self, _: self.sql( 139 exp.cast(exec_time, "time", dialect=dialect) 140 ), 141 exp.CurrentTimestamp: lambda self, _: self.sql( 142 exp.cast(exec_time, "timestamp", dialect=dialect) 143 ), 144 } 145 146 super().__init__() 147 148 def defaultTestResult(self) -> unittest.TestResult: 149 from sqlmesh.core.test.result import ModelTextTestResult 150 151 return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity) 152 153 def shortDescription(self) -> t.Optional[str]: 154 return self.body.get("description") 155 156 def setUp(self) -> None: 157 """Load all input tables""" 158 import pandas as pd 159 import numpy as np 160 161 self.engine_adapter.create_schema(self._qualified_fixture_schema) 162 163 for name, values in self.body.get("inputs", {}).items(): 164 all_types_are_known = False 165 columns_to_known_types: t.Dict[str, exp.DataType] = {} 166 167 model = self.models.get(name) 168 if model: 169 inferred_columns_to_types = model.columns_to_types or {} 170 columns_to_known_types = { 171 c: t for c, t in inferred_columns_to_types.items() if type_is_known(t) 172 } 173 all_types_are_known = bool(inferred_columns_to_types) and ( 174 len(columns_to_known_types) == len(inferred_columns_to_types) 175 ) 176 177 # Types specified in the test will override the corresponding inferred ones 178 columns_to_known_types.update(values.get("columns", {})) 179 180 rows = values.get("rows") 181 if not all_types_are_known and rows: 182 for col, value in rows[0].items(): 183 if col not in columns_to_known_types: 184 v_type = annotate_types(exp.convert(value)).type or type(value).__name__ 185 v_type = exp.maybe_parse( 186 v_type, into=exp.DataType, dialect=self._test_adapter_dialect 187 ) 188 189 if not type_is_known(v_type): 190 _raise_error( 191 f"Failed to infer the data type of column '{col}' for '{name}'. This issue can be " 192 "mitigated by casting the column in the model definition, setting its type in " 193 "external_models.yaml if it's an external model, setting the model's 'columns' property, " 194 "or setting its 'columns' mapping in the test itself", 195 self.path, 196 ) 197 198 columns_to_known_types[col] = v_type 199 200 if rows is None: 201 query_or_df: exp.Query | pd.DataFrame = self._add_missing_columns( 202 values["query"], columns_to_known_types 203 ) 204 if columns_to_known_types: 205 columns_to_known_types = { 206 col: columns_to_known_types[col] for col in query_or_df.named_selects 207 } 208 else: 209 query_or_df = self._create_df(values, columns=columns_to_known_types) 210 211 # Convert NaN/NaT values to None if DataFrame 212 if isinstance(query_or_df, pd.DataFrame): 213 query_or_df = query_or_df.replace({np.nan: None}) 214 215 self.engine_adapter.create_view( 216 self._test_fixture_table(name), query_or_df, columns_to_known_types 217 ) 218 219 def tearDown(self) -> None: 220 """Drop all fixture tables.""" 221 if not self.preserve_fixtures: 222 self.engine_adapter.drop_schema(self._qualified_fixture_schema, cascade=True) 223 224 def assert_equal( 225 self, 226 expected: pd.DataFrame, 227 actual: pd.DataFrame, 228 sort: bool, 229 partial: t.Optional[bool] = False, 230 ) -> None: 231 """Compare two DataFrames""" 232 import numpy as np 233 import pandas as pd 234 from pandas.api.types import is_object_dtype 235 236 if partial: 237 intersection = actual[actual.columns.intersection(expected.columns)] 238 if len(intersection.columns) > 0: 239 actual = intersection 240 241 # Two astypes are necessary, pandas converts strings to times as NS, 242 # but if the actual is US, it doesn't take effect until the 2nd try! 243 actual_types = actual.dtypes.to_dict() 244 expected = expected.astype(actual_types, errors="ignore").astype( 245 actual_types, errors="ignore" 246 ) 247 248 # The `actual` df's dtypes will almost always be pd.Timestamp for datetime values, 249 # but in some scenarios (e.g., DuckDB >=0.10.2) it will be a pandas `object` type 250 # containing python `datetime.xxx` values. 251 # 252 # Pandas `object` columns result in a noop for the `astype` call above. Because any 253 # quoted YAML value is a string, we must manually convert the `expected` df string 254 # values to the correct `datetime.xxx` type. 255 # 256 # We determine the type from a single sentinel value, but since the `actual` df is 257 # coming from a database query, it is safe to assume that the column contains only 258 # a single type. 259 object_sentinel_values = { 260 col: actual[col][0] 261 for col in actual_types 262 if is_object_dtype(actual_types[col]) and len(actual[col]) != 0 263 } 264 for col, value in object_sentinel_values.items(): 265 try: 266 # can't use `isinstance()` here - https://stackoverflow.com/a/68743663/1707525 267 if type(value) is datetime.date: 268 expected[col] = pd.to_datetime(expected[col]).dt.date 269 elif type(value) is datetime.time: 270 expected[col] = pd.to_datetime(expected[col]).dt.time 271 elif type(value) is datetime.datetime: 272 expected[col] = pd.to_datetime(expected[col]).dt.to_pydatetime() 273 except Exception as e: 274 from sqlmesh.core.console import get_console 275 276 get_console().log_warning( 277 f"Failed to convert expected value for {col} into `datetime` " 278 f"for unit test '{str(self)}'. {str(e)}." 279 ) 280 281 actual = actual.replace({np.nan: None}) 282 expected = expected.replace({np.nan: None}) 283 284 # We define this here to avoid a top-level import of numpy and pandas 285 DATETIME_TYPES = ( 286 datetime.datetime, 287 datetime.date, 288 datetime.time, 289 np.datetime64, 290 pd.Timestamp, 291 ) 292 293 def _to_hashable(x: t.Any) -> t.Any: 294 if isinstance(x, (list, np.ndarray)): 295 return tuple(_to_hashable(v) for v in x) 296 if isinstance(x, dict): 297 return tuple((k, _to_hashable(v)) for k, v in x.items()) 298 return str(x) if isinstance(x, DATETIME_TYPES) or not isinstance(x, t.Hashable) else x 299 300 actual = actual.apply(lambda col: col.map(_to_hashable)) 301 expected = expected.apply(lambda col: col.map(_to_hashable)) 302 303 if sort: 304 actual = actual.sort_values(by=actual.columns.to_list()).reset_index(drop=True) 305 expected = expected.sort_values(by=expected.columns.to_list()).reset_index(drop=True) 306 307 try: 308 pd.testing.assert_frame_equal( 309 expected, 310 actual, 311 check_dtype=False, 312 check_datetimelike_compat=True, 313 check_like=True, # Ignore column order 314 ) 315 except AssertionError as e: 316 # There are 2 concepts at play here: 317 # 1. The Exception args will contain the error message plus the diff dataframe table stringified 318 # (backwards compatibility with existing tests, possible to serialize/send over network etc) 319 # 2. Each test will also transform these diff dataframes into Rich tables, which will be the ones that'll 320 # be surfaced to the user through Console for better UX (versus stringified dataframes) 321 # 322 # This is a bit of a hack, but it's a way to get the best of both worlds. 323 args: t.List[t.Any] = [] 324 325 failed_subtest = "" 326 327 if subtest := getattr(self, "_subtest", None): 328 if cte := subtest.params.get("cte"): 329 failed_subtest = f" (CTE {cte})" 330 331 if expected.shape != actual.shape: 332 _raise_if_unexpected_columns(expected.columns, actual.columns) 333 334 args.append("Data mismatch (rows are different)") 335 336 missing_rows = _row_difference(expected, actual) 337 if not missing_rows.empty: 338 args[0] += f"\n\nMissing rows:\n\n{missing_rows}" 339 args.append(df_to_table(f"Missing rows{failed_subtest}", missing_rows)) 340 341 unexpected_rows = _row_difference(actual, expected) 342 343 if not unexpected_rows.empty: 344 args[0] += f"\n\nUnexpected rows:\n\n{unexpected_rows}" 345 args.append(df_to_table(f"Unexpected rows{failed_subtest}", unexpected_rows)) 346 347 else: 348 diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"}) 349 350 args.append(f"Data mismatch (exp: expected, act: actual)\n\n{diff}") 351 352 diff.rename(columns={"exp": "Expected", "act": "Actual"}, inplace=True) 353 if self.verbosity == Verbosity.DEFAULT: 354 args.extend( 355 df_to_table(f"Data mismatch{failed_subtest}", df) 356 for df in _split_df_by_column_pairs(diff) 357 ) 358 else: 359 from pandas import DataFrame, MultiIndex 360 361 levels = t.cast(MultiIndex, diff.columns).levels[0] 362 for col in levels: 363 # diff[col] returns a DataFrame when columns is a MultiIndex 364 col_diff = t.cast(DataFrame, diff[col]) 365 if not col_diff.empty: 366 table = df_to_table( 367 f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]", 368 col_diff, 369 ) 370 args.append(table) 371 372 e.args = (*args,) 373 374 raise e 375 376 def runTest(self) -> None: 377 raise NotImplementedError 378 379 def path_relative_to(self, other: Path) -> Path | None: 380 """Compute a version of this test's path relative to the `other` path""" 381 return self.path.relative_to(other) if self.path else None 382 383 @staticmethod 384 def create_test( 385 body: t.Dict[str, t.Any], 386 test_name: str, 387 models: UniqueKeyDict[str, Model], 388 engine_adapter: EngineAdapter, 389 dialect: str | None, 390 path: Path | None, 391 preserve_fixtures: bool = False, 392 default_catalog: str | None = None, 393 concurrency: bool = False, 394 verbosity: Verbosity = Verbosity.DEFAULT, 395 ) -> t.Optional[ModelTest]: 396 """Create a SqlModelTest or a PythonModelTest. 397 398 Args: 399 body: A dictionary that contains test metadata like inputs and outputs. 400 test_name: The name of the test. 401 models: All models to use for expansion and mapping of physical locations. 402 engine_adapter: The engine adapter to use. 403 dialect: The models' dialect, used for normalization purposes. 404 path: An optional path to the test definition yaml file. 405 preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. 406 """ 407 name = body.get("model") 408 if name is None: 409 _raise_error("Missing required 'model' field", path) 410 411 name = normalize_model_name(name, default_catalog=default_catalog, dialect=dialect) 412 model = models.get(name) 413 if not model: 414 from sqlmesh.core.console import get_console 415 416 get_console().log_warning( 417 f"Model '{name}' was not found{' at ' + str(path) if path else ''}" 418 ) 419 return None 420 421 if isinstance(model, SqlModel): 422 test_type: t.Type[ModelTest] = SqlModelTest 423 elif isinstance(model, PythonModel): 424 test_type = PythonModelTest 425 else: 426 _raise_error(f"Model '{name}' is an unsupported model type for testing", path) 427 428 try: 429 return test_type( 430 body, 431 test_name, 432 t.cast(Model, model), 433 models, 434 engine_adapter, 435 dialect, 436 path, 437 preserve_fixtures, 438 default_catalog, 439 concurrency, 440 verbosity, 441 ) 442 except Exception as e: 443 raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}") 444 445 def __str__(self) -> str: 446 return f"{self.test_name} ({self.path})" 447 448 def _validate_and_normalize_test(self) -> None: 449 inputs = self.body.get("inputs") 450 outputs = self.body.get("outputs", {}) 451 452 if not outputs: 453 _raise_error("Incomplete test, missing outputs", self.path) 454 455 ctes = outputs.get("ctes") 456 query = outputs.get("query") 457 partial = outputs.pop("partial", None) 458 459 if ctes is None and query is None: 460 _raise_error("Incomplete test, outputs must contain 'query' or 'ctes'", self.path) 461 462 def _normalize_rows( 463 values: t.List[Row] | t.Dict, 464 name: str, 465 partial: bool = False, 466 dialect: DialectType = None, 467 ) -> t.Dict: 468 import pandas as pd 469 470 if not isinstance(values, dict): 471 values = {"rows": values} 472 473 rows = values.get("rows") 474 query = values.get("query") 475 476 fmt = values.get("format") 477 path = values.get("path") 478 if fmt == "csv": 479 csv_settings = values.get("csv_settings") or {} 480 rows = pd.read_csv(path or StringIO(rows), **csv_settings).to_dict(orient="records") 481 elif fmt in (None, "yaml"): 482 if path: 483 input_rows = yaml_load(Path(path)) 484 rows = input_rows.get("rows") if isinstance(input_rows, dict) else input_rows 485 else: 486 _raise_error(f"Unsupported data format '{fmt}' for '{name}'", self.path) 487 488 if query is not None: 489 if rows is not None: 490 _raise_error( 491 f"Invalid test, cannot set both 'query' and 'rows' for '{name}'", self.path 492 ) 493 494 # We parse the user-supplied query using the testing adapter dialect, but we 495 # normalize its identifiers according to the model's dialect, so that, e.g., 496 # the projection names match those in its `columns_to_types` field 497 values["query"] = normalize_identifiers( 498 exp.maybe_parse(query, dialect=self._test_adapter_dialect), dialect=dialect 499 ) 500 return values 501 502 if rows is None: 503 _raise_error(f"Incomplete test, missing row data for '{name}'", self.path) 504 505 assert isinstance(rows, list) 506 values["rows"] = [ 507 {self._normalize_column_name(column): value for column, value in row.items()} 508 for row in rows 509 ] 510 if partial: 511 values["partial"] = True 512 513 return values 514 515 def _normalize_sources( 516 sources: t.Dict, partial: bool = False, with_default_catalog: bool = True 517 ) -> t.Dict: 518 normalized_sources = {} 519 for name, values in sources.items(): 520 normalized_name = self._normalize_model_name( 521 name, with_default_catalog=with_default_catalog 522 ) 523 model = self.models.get(normalized_name) 524 dialect = model.dialect if model else self.dialect 525 526 normalized_sources[normalized_name] = _normalize_rows( 527 values, name, partial=partial, dialect=dialect 528 ) 529 530 return normalized_sources 531 532 normalized_model_name = self._normalize_model_name(self.body["model"]) 533 self.body["model"] = normalized_model_name 534 535 if inputs: 536 inputs = _normalize_sources(inputs) 537 for name, values in inputs.items(): 538 columns = values.get("columns") 539 if columns is None: 540 continue 541 542 if not isinstance(columns, dict): 543 _raise_error( 544 f"Invalid 'columns' value for model '{name}', expected a mapping name -> type", 545 self.path, 546 ) 547 548 values["columns"] = { 549 self._normalize_column_name(c): exp.DataType.build( 550 t, dialect=self._test_adapter_dialect 551 ) 552 for c, t in columns.items() 553 } 554 555 for depends_on in self.model.depends_on: 556 if depends_on not in inputs: 557 _raise_error(f"Incomplete test, missing input model '{depends_on}'", self.path) 558 559 if self.model.depends_on_self and normalized_model_name not in inputs: 560 inputs[normalized_model_name] = {"rows": []} 561 562 self.body["inputs"] = inputs 563 564 if ctes: 565 outputs["ctes"] = _normalize_sources(ctes, partial=partial, with_default_catalog=False) 566 567 if query or query == []: 568 outputs["query"] = _normalize_rows( 569 query, self.model.name, partial=partial, dialect=self.model.dialect 570 ) 571 572 def _test_fixture_table(self, name: str) -> exp.Table: 573 table = self._fixture_table_cache.get(name) 574 if not table: 575 table = exp.to_table(name, dialect=self._test_adapter_dialect) 576 577 # We change the table path below, so this ensures there are no name clashes 578 table.this.set("this", "__".join(part.name for part in table.parts)) 579 580 table.set("db", self._fixture_schema.copy()) 581 if self._fixture_catalog: 582 table.set("catalog", self._fixture_catalog.copy()) 583 584 self._fixture_table_cache[name] = table 585 586 return table 587 588 def _normalize_model_name(self, name: str, with_default_catalog: bool = True) -> str: 589 normalized_name = self._normalized_model_name_cache.get((name, with_default_catalog)) 590 if normalized_name is None: 591 default_catalog = self.default_catalog if with_default_catalog else None 592 normalized_name = normalize_model_name( 593 name, default_catalog=default_catalog, dialect=self.dialect 594 ) 595 self._normalized_model_name_cache[(name, with_default_catalog)] = normalized_name 596 597 return normalized_name 598 599 def _normalize_column_name(self, name: str) -> str: 600 normalized_name = self._normalized_column_name_cache.get(name) 601 if normalized_name is None: 602 normalized_name = normalize_identifiers(name, dialect=self.dialect).name 603 self._normalized_column_name_cache[name] = normalized_name 604 605 return normalized_name 606 607 @contextmanager 608 def _concurrent_render_context(self) -> t.Iterator[None]: 609 """ 610 Context manager that ensures that the tests are executed safely in a concurrent environment. 611 This is needed in case `execution_time` is set, as we'd then have to: 612 - Freeze time through `time_machine` (not thread safe) 613 - Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation 614 """ 615 import time_machine 616 617 lock_ctx: AbstractContextManager = ( 618 self.CONCURRENT_RENDER_LOCK if self.concurrency else nullcontext() 619 ) 620 time_ctx: AbstractContextManager = nullcontext() 621 dialect_patch_ctx: AbstractContextManager = nullcontext() 622 623 if self._execution_time: 624 time_ctx = time_machine.travel(self._execution_time, tick=False) 625 dialect_patch_ctx = patch.dict( 626 self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms 627 ) 628 629 with lock_ctx, time_ctx, dialect_patch_ctx: 630 yield 631 632 def _execute(self, query: exp.Query | str) -> pd.DataFrame: 633 """Executes the given query using the testing engine adapter and returns a DataFrame.""" 634 return self.engine_adapter.fetchdf(query) 635 636 def _create_df( 637 self, 638 values: t.Dict[str, t.Any], 639 columns: t.Optional[t.Collection] = None, 640 partial: t.Optional[bool] = False, 641 ) -> pd.DataFrame: 642 import pandas as pd 643 644 query = values.get("query") 645 if query: 646 if not partial: 647 query = self._add_missing_columns(query, columns) 648 649 return self._execute(query) 650 651 rows = values["rows"] 652 columns_str: t.Optional[t.List[str]] = None 653 if columns: 654 columns_str = [str(c) for c in columns] 655 referenced_columns = list(dict.fromkeys(col for row in rows for col in row)) 656 _raise_if_unexpected_columns(columns, referenced_columns) 657 658 if partial: 659 columns_str = [c for c in columns_str if c in referenced_columns] 660 661 return pd.DataFrame.from_records(rows, columns=columns_str) 662 663 def _add_missing_columns( 664 self, query: exp.Query, all_columns: t.Optional[t.Collection[str]] = None 665 ) -> exp.Query: 666 if not all_columns or query.is_star: 667 return query 668 669 query_columns = set(query.named_selects) 670 missing_columns = [col for col in all_columns if col not in query_columns] 671 if missing_columns: 672 query.select(*[exp.null().as_(col) for col in missing_columns], copy=False) 673 674 return query
A class whose instances are single test cases.
By default, the test code itself should be placed in a method named 'runTest'.
If the fixture may be used for many test cases, create as many test methods as are needed. When instantiating such a TestCase subclass, specify in the constructor arguments the name of the test method that the instance is to execute.
Test authors should subclass TestCase for their own tests. Construction and deconstruction of the test's environment ('fixture') can be implemented by overriding the 'setUp' and 'tearDown' methods respectively.
If it is necessary to override the __init__ method, the base class __init__ method must always be called. It is important that subclasses should not change the signature of their __init__ method, since instances of the classes are instantiated automatically by parts of the framework in order to be run.
When subclassing TestCase, you can set these attributes:
- failureException: determines which exception will be raised when the instance's assertion methods fail; test methods raising this exception will be deemed to have 'failed' rather than 'errored'.
- longMessage: determines whether long messages (including repr of objects used in assert methods) will be printed on failure in addition to any explicit message passed.
- maxDiff: sets the maximum length of a diff in failure messages by assert methods using difflib. It is looked up as an instance attribute so can be configured by individual tests if required.
57 def __init__( 58 self, 59 body: t.Dict[str, t.Any], 60 test_name: str, 61 model: Model, 62 models: UniqueKeyDict[str, Model], 63 engine_adapter: EngineAdapter, 64 dialect: str | None = None, 65 path: Path | None = None, 66 preserve_fixtures: bool = False, 67 default_catalog: str | None = None, 68 concurrency: bool = False, 69 verbosity: Verbosity = Verbosity.DEFAULT, 70 ) -> None: 71 """ModelTest encapsulates a unit test for a model. 72 73 Args: 74 body: A dictionary that contains test metadata like inputs and outputs. 75 test_name: The name of the test. 76 model: The model that is being tested. 77 models: All models to use for expansion and mapping of physical locations. 78 engine_adapter: The engine adapter to use. 79 dialect: The models' dialect, used for normalization purposes. 80 path: An optional path to the test definition yaml file. 81 preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. 82 """ 83 self.body = body 84 self.test_name = test_name 85 self.model = model 86 self.models = models 87 self.engine_adapter = engine_adapter 88 self.path = path 89 self.preserve_fixtures = preserve_fixtures 90 self.default_catalog = default_catalog 91 self.dialect = dialect 92 self.concurrency = concurrency 93 self.verbosity = verbosity 94 95 self._fixture_table_cache: t.Dict[str, exp.Table] = {} 96 self._normalized_column_name_cache: t.Dict[str, str] = {} 97 self._normalized_model_name_cache: t.Dict[t.Tuple[str, bool], str] = {} 98 99 self._test_adapter_dialect = Dialect.get_or_raise(self.engine_adapter.dialect) 100 101 self._validate_and_normalize_test() 102 103 if self.engine_adapter.default_catalog: 104 self._fixture_catalog: t.Optional[exp.Identifier] = normalize_identifiers( 105 exp.parse_identifier( 106 self.engine_adapter.default_catalog, dialect=self._test_adapter_dialect 107 ), 108 dialect=self._test_adapter_dialect, 109 ) 110 else: 111 self._fixture_catalog = None 112 113 # The test schema name is randomized to avoid concurrency issues, 114 # unless a schema is provided in the unit tests's body 115 self._fixture_schema = exp.parse_identifier( 116 self.body.get("schema") or f"sqlmesh_test_{random_id(short=True)}" 117 ) 118 self._qualified_fixture_schema = schema_(self._fixture_schema, self._fixture_catalog) 119 120 self._transforms = self._test_adapter_dialect.generator_class.TRANSFORMS 121 self._execution_time = str(self.body.get("vars", {}).get("execution_time") or "") 122 123 if self._execution_time: 124 # Normalizes the execution time by converting it into UTC timezone 125 self._execution_time = str(to_datetime(self._execution_time)) 126 127 # When execution_time is set, we mock the CURRENT_* SQL expressions so they always return it 128 if self._execution_time: 129 exec_time = exp.Literal.string(self._execution_time) 130 self._transforms = { 131 **self._transforms, 132 exp.CurrentDate: lambda self, _: self.sql( 133 exp.cast(exec_time, "date", dialect=dialect) 134 ), 135 exp.CurrentDatetime: lambda self, _: self.sql( 136 exp.cast(exec_time, "datetime", dialect=dialect) 137 ), 138 exp.CurrentTime: lambda self, _: self.sql( 139 exp.cast(exec_time, "time", dialect=dialect) 140 ), 141 exp.CurrentTimestamp: lambda self, _: self.sql( 142 exp.cast(exec_time, "timestamp", dialect=dialect) 143 ), 144 } 145 146 super().__init__()
ModelTest encapsulates a unit test for a model.
Arguments:
- body: A dictionary that contains test metadata like inputs and outputs.
- test_name: The name of the test.
- model: The model that is being tested.
- models: All models to use for expansion and mapping of physical locations.
- engine_adapter: The engine adapter to use.
- dialect: The models' dialect, used for normalization purposes.
- path: An optional path to the test definition yaml file.
- preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
Returns a one-line description of the test, or None if no description has been provided.
The default implementation of this method returns the first line of the specified test method's docstring.
156 def setUp(self) -> None: 157 """Load all input tables""" 158 import pandas as pd 159 import numpy as np 160 161 self.engine_adapter.create_schema(self._qualified_fixture_schema) 162 163 for name, values in self.body.get("inputs", {}).items(): 164 all_types_are_known = False 165 columns_to_known_types: t.Dict[str, exp.DataType] = {} 166 167 model = self.models.get(name) 168 if model: 169 inferred_columns_to_types = model.columns_to_types or {} 170 columns_to_known_types = { 171 c: t for c, t in inferred_columns_to_types.items() if type_is_known(t) 172 } 173 all_types_are_known = bool(inferred_columns_to_types) and ( 174 len(columns_to_known_types) == len(inferred_columns_to_types) 175 ) 176 177 # Types specified in the test will override the corresponding inferred ones 178 columns_to_known_types.update(values.get("columns", {})) 179 180 rows = values.get("rows") 181 if not all_types_are_known and rows: 182 for col, value in rows[0].items(): 183 if col not in columns_to_known_types: 184 v_type = annotate_types(exp.convert(value)).type or type(value).__name__ 185 v_type = exp.maybe_parse( 186 v_type, into=exp.DataType, dialect=self._test_adapter_dialect 187 ) 188 189 if not type_is_known(v_type): 190 _raise_error( 191 f"Failed to infer the data type of column '{col}' for '{name}'. This issue can be " 192 "mitigated by casting the column in the model definition, setting its type in " 193 "external_models.yaml if it's an external model, setting the model's 'columns' property, " 194 "or setting its 'columns' mapping in the test itself", 195 self.path, 196 ) 197 198 columns_to_known_types[col] = v_type 199 200 if rows is None: 201 query_or_df: exp.Query | pd.DataFrame = self._add_missing_columns( 202 values["query"], columns_to_known_types 203 ) 204 if columns_to_known_types: 205 columns_to_known_types = { 206 col: columns_to_known_types[col] for col in query_or_df.named_selects 207 } 208 else: 209 query_or_df = self._create_df(values, columns=columns_to_known_types) 210 211 # Convert NaN/NaT values to None if DataFrame 212 if isinstance(query_or_df, pd.DataFrame): 213 query_or_df = query_or_df.replace({np.nan: None}) 214 215 self.engine_adapter.create_view( 216 self._test_fixture_table(name), query_or_df, columns_to_known_types 217 )
Load all input tables
219 def tearDown(self) -> None: 220 """Drop all fixture tables.""" 221 if not self.preserve_fixtures: 222 self.engine_adapter.drop_schema(self._qualified_fixture_schema, cascade=True)
Drop all fixture tables.
224 def assert_equal( 225 self, 226 expected: pd.DataFrame, 227 actual: pd.DataFrame, 228 sort: bool, 229 partial: t.Optional[bool] = False, 230 ) -> None: 231 """Compare two DataFrames""" 232 import numpy as np 233 import pandas as pd 234 from pandas.api.types import is_object_dtype 235 236 if partial: 237 intersection = actual[actual.columns.intersection(expected.columns)] 238 if len(intersection.columns) > 0: 239 actual = intersection 240 241 # Two astypes are necessary, pandas converts strings to times as NS, 242 # but if the actual is US, it doesn't take effect until the 2nd try! 243 actual_types = actual.dtypes.to_dict() 244 expected = expected.astype(actual_types, errors="ignore").astype( 245 actual_types, errors="ignore" 246 ) 247 248 # The `actual` df's dtypes will almost always be pd.Timestamp for datetime values, 249 # but in some scenarios (e.g., DuckDB >=0.10.2) it will be a pandas `object` type 250 # containing python `datetime.xxx` values. 251 # 252 # Pandas `object` columns result in a noop for the `astype` call above. Because any 253 # quoted YAML value is a string, we must manually convert the `expected` df string 254 # values to the correct `datetime.xxx` type. 255 # 256 # We determine the type from a single sentinel value, but since the `actual` df is 257 # coming from a database query, it is safe to assume that the column contains only 258 # a single type. 259 object_sentinel_values = { 260 col: actual[col][0] 261 for col in actual_types 262 if is_object_dtype(actual_types[col]) and len(actual[col]) != 0 263 } 264 for col, value in object_sentinel_values.items(): 265 try: 266 # can't use `isinstance()` here - https://stackoverflow.com/a/68743663/1707525 267 if type(value) is datetime.date: 268 expected[col] = pd.to_datetime(expected[col]).dt.date 269 elif type(value) is datetime.time: 270 expected[col] = pd.to_datetime(expected[col]).dt.time 271 elif type(value) is datetime.datetime: 272 expected[col] = pd.to_datetime(expected[col]).dt.to_pydatetime() 273 except Exception as e: 274 from sqlmesh.core.console import get_console 275 276 get_console().log_warning( 277 f"Failed to convert expected value for {col} into `datetime` " 278 f"for unit test '{str(self)}'. {str(e)}." 279 ) 280 281 actual = actual.replace({np.nan: None}) 282 expected = expected.replace({np.nan: None}) 283 284 # We define this here to avoid a top-level import of numpy and pandas 285 DATETIME_TYPES = ( 286 datetime.datetime, 287 datetime.date, 288 datetime.time, 289 np.datetime64, 290 pd.Timestamp, 291 ) 292 293 def _to_hashable(x: t.Any) -> t.Any: 294 if isinstance(x, (list, np.ndarray)): 295 return tuple(_to_hashable(v) for v in x) 296 if isinstance(x, dict): 297 return tuple((k, _to_hashable(v)) for k, v in x.items()) 298 return str(x) if isinstance(x, DATETIME_TYPES) or not isinstance(x, t.Hashable) else x 299 300 actual = actual.apply(lambda col: col.map(_to_hashable)) 301 expected = expected.apply(lambda col: col.map(_to_hashable)) 302 303 if sort: 304 actual = actual.sort_values(by=actual.columns.to_list()).reset_index(drop=True) 305 expected = expected.sort_values(by=expected.columns.to_list()).reset_index(drop=True) 306 307 try: 308 pd.testing.assert_frame_equal( 309 expected, 310 actual, 311 check_dtype=False, 312 check_datetimelike_compat=True, 313 check_like=True, # Ignore column order 314 ) 315 except AssertionError as e: 316 # There are 2 concepts at play here: 317 # 1. The Exception args will contain the error message plus the diff dataframe table stringified 318 # (backwards compatibility with existing tests, possible to serialize/send over network etc) 319 # 2. Each test will also transform these diff dataframes into Rich tables, which will be the ones that'll 320 # be surfaced to the user through Console for better UX (versus stringified dataframes) 321 # 322 # This is a bit of a hack, but it's a way to get the best of both worlds. 323 args: t.List[t.Any] = [] 324 325 failed_subtest = "" 326 327 if subtest := getattr(self, "_subtest", None): 328 if cte := subtest.params.get("cte"): 329 failed_subtest = f" (CTE {cte})" 330 331 if expected.shape != actual.shape: 332 _raise_if_unexpected_columns(expected.columns, actual.columns) 333 334 args.append("Data mismatch (rows are different)") 335 336 missing_rows = _row_difference(expected, actual) 337 if not missing_rows.empty: 338 args[0] += f"\n\nMissing rows:\n\n{missing_rows}" 339 args.append(df_to_table(f"Missing rows{failed_subtest}", missing_rows)) 340 341 unexpected_rows = _row_difference(actual, expected) 342 343 if not unexpected_rows.empty: 344 args[0] += f"\n\nUnexpected rows:\n\n{unexpected_rows}" 345 args.append(df_to_table(f"Unexpected rows{failed_subtest}", unexpected_rows)) 346 347 else: 348 diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"}) 349 350 args.append(f"Data mismatch (exp: expected, act: actual)\n\n{diff}") 351 352 diff.rename(columns={"exp": "Expected", "act": "Actual"}, inplace=True) 353 if self.verbosity == Verbosity.DEFAULT: 354 args.extend( 355 df_to_table(f"Data mismatch{failed_subtest}", df) 356 for df in _split_df_by_column_pairs(diff) 357 ) 358 else: 359 from pandas import DataFrame, MultiIndex 360 361 levels = t.cast(MultiIndex, diff.columns).levels[0] 362 for col in levels: 363 # diff[col] returns a DataFrame when columns is a MultiIndex 364 col_diff = t.cast(DataFrame, diff[col]) 365 if not col_diff.empty: 366 table = df_to_table( 367 f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]", 368 col_diff, 369 ) 370 args.append(table) 371 372 e.args = (*args,) 373 374 raise e
Compare two DataFrames
379 def path_relative_to(self, other: Path) -> Path | None: 380 """Compute a version of this test's path relative to the `other` path""" 381 return self.path.relative_to(other) if self.path else None
Compute a version of this test's path relative to the other path
383 @staticmethod 384 def create_test( 385 body: t.Dict[str, t.Any], 386 test_name: str, 387 models: UniqueKeyDict[str, Model], 388 engine_adapter: EngineAdapter, 389 dialect: str | None, 390 path: Path | None, 391 preserve_fixtures: bool = False, 392 default_catalog: str | None = None, 393 concurrency: bool = False, 394 verbosity: Verbosity = Verbosity.DEFAULT, 395 ) -> t.Optional[ModelTest]: 396 """Create a SqlModelTest or a PythonModelTest. 397 398 Args: 399 body: A dictionary that contains test metadata like inputs and outputs. 400 test_name: The name of the test. 401 models: All models to use for expansion and mapping of physical locations. 402 engine_adapter: The engine adapter to use. 403 dialect: The models' dialect, used for normalization purposes. 404 path: An optional path to the test definition yaml file. 405 preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. 406 """ 407 name = body.get("model") 408 if name is None: 409 _raise_error("Missing required 'model' field", path) 410 411 name = normalize_model_name(name, default_catalog=default_catalog, dialect=dialect) 412 model = models.get(name) 413 if not model: 414 from sqlmesh.core.console import get_console 415 416 get_console().log_warning( 417 f"Model '{name}' was not found{' at ' + str(path) if path else ''}" 418 ) 419 return None 420 421 if isinstance(model, SqlModel): 422 test_type: t.Type[ModelTest] = SqlModelTest 423 elif isinstance(model, PythonModel): 424 test_type = PythonModelTest 425 else: 426 _raise_error(f"Model '{name}' is an unsupported model type for testing", path) 427 428 try: 429 return test_type( 430 body, 431 test_name, 432 t.cast(Model, model), 433 models, 434 engine_adapter, 435 dialect, 436 path, 437 preserve_fixtures, 438 default_catalog, 439 concurrency, 440 verbosity, 441 ) 442 except Exception as e: 443 raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}")
Create a SqlModelTest or a PythonModelTest.
Arguments:
- body: A dictionary that contains test metadata like inputs and outputs.
- test_name: The name of the test.
- models: All models to use for expansion and mapping of physical locations.
- engine_adapter: The engine adapter to use.
- dialect: The models' dialect, used for normalization purposes.
- path: An optional path to the test definition yaml file.
- preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
Inherited Members
- unittest.case.TestCase
- failureException
- longMessage
- maxDiff
- addTypeEqualityFunc
- addCleanup
- addClassCleanup
- setUpClass
- tearDownClass
- countTestCases
- id
- subTest
- run
- doCleanups
- doClassCleanups
- debug
- skipTest
- fail
- assertFalse
- assertTrue
- assertRaises
- assertWarns
- assertLogs
- assertNoLogs
- assertEqual
- assertNotEqual
- assertAlmostEqual
- assertNotAlmostEqual
- assertSequenceEqual
- assertListEqual
- assertTupleEqual
- assertSetEqual
- assertIn
- assertNotIn
- assertIs
- assertIsNot
- assertDictEqual
- assertDictContainsSubset
- assertCountEqual
- assertMultiLineEqual
- assertLess
- assertLessEqual
- assertGreater
- assertGreaterEqual
- assertIsNone
- assertIsNotNone
- assertIsInstance
- assertNotIsInstance
- assertRaisesRegex
- assertWarnsRegex
- assertRegex
- assertNotRegex
- failUnlessRaises
- failIf
- assertRaisesRegexp
- assertRegexpMatches
- assertNotRegexpMatches
- failUnlessEqual
- assertEquals
- failIfEqual
- assertNotEquals
- failUnlessAlmostEqual
- assertAlmostEquals
- failIfAlmostEqual
- assertNotAlmostEquals
- failUnless
- assert_
677class SqlModelTest(ModelTest): 678 def test_ctes(self, ctes: t.Dict[str, exp.Expr], recursive: bool = False) -> None: 679 """Run CTE queries and compare output to expected output""" 680 for cte_name, values in self.body["outputs"].get("ctes", {}).items(): 681 with self.subTest(cte=cte_name): 682 if cte_name not in ctes: 683 _raise_error( 684 f"No CTE named {cte_name} found in model {self.model.name}", self.path 685 ) 686 687 cte_query = ctes[cte_name].this 688 689 sort = cte_query.args.get("order") is None 690 partial = values.get("partial") 691 692 cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_name) 693 for alias, cte in ctes.items(): 694 cte_query = cte_query.with_(alias, cte.this, recursive=recursive) 695 696 with self._concurrent_render_context(): 697 # Similar to the model's query, we render the CTE query under the locked context 698 # so that the execution (fetchdf) can continue concurrently between the threads 699 sql = cte_query.sql( 700 self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql 701 ) 702 703 actual = self._execute(sql) 704 expected = self._create_df(values, columns=cte_query.named_selects, partial=partial) 705 706 self.assert_equal(expected, actual, sort=sort, partial=partial) 707 708 def runTest(self) -> None: 709 with self._concurrent_render_context(): 710 # Render the model's query and generate the SQL under the locked context so that 711 # execution (fetchdf) can continue concurrently between the threads 712 query = self._render_model_query() 713 sql = query.sql(self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql) 714 715 with_clause = query.args.get("with_") 716 717 if with_clause: 718 self.test_ctes( 719 { 720 self._normalize_model_name(cte.alias, with_default_catalog=False): cte 721 for cte in query.ctes 722 }, 723 recursive=with_clause.recursive, 724 ) 725 726 values = self.body["outputs"].get("query") 727 if values is not None: 728 partial = values.get("partial") 729 sort = query.args.get("order") is None 730 731 actual = self._execute(sql) 732 expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial) 733 734 self.assert_equal(expected, actual, sort=sort, partial=partial) 735 736 def _render_model_query(self) -> exp.Query: 737 variables = self.body.get("vars", {}).copy() 738 time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables} 739 740 query = self.model.render_query_or_raise( 741 **time_kwargs, 742 variables=variables, 743 engine_adapter=self.engine_adapter, 744 table_mapping={ 745 name: self._test_fixture_table(name).sql() for name in self.body.get("inputs", {}) 746 }, 747 runtime_stage=RuntimeStage.TESTING, 748 ) 749 return query
A class whose instances are single test cases.
By default, the test code itself should be placed in a method named 'runTest'.
If the fixture may be used for many test cases, create as many test methods as are needed. When instantiating such a TestCase subclass, specify in the constructor arguments the name of the test method that the instance is to execute.
Test authors should subclass TestCase for their own tests. Construction and deconstruction of the test's environment ('fixture') can be implemented by overriding the 'setUp' and 'tearDown' methods respectively.
If it is necessary to override the __init__ method, the base class __init__ method must always be called. It is important that subclasses should not change the signature of their __init__ method, since instances of the classes are instantiated automatically by parts of the framework in order to be run.
When subclassing TestCase, you can set these attributes:
- failureException: determines which exception will be raised when the instance's assertion methods fail; test methods raising this exception will be deemed to have 'failed' rather than 'errored'.
- longMessage: determines whether long messages (including repr of objects used in assert methods) will be printed on failure in addition to any explicit message passed.
- maxDiff: sets the maximum length of a diff in failure messages by assert methods using difflib. It is looked up as an instance attribute so can be configured by individual tests if required.
678 def test_ctes(self, ctes: t.Dict[str, exp.Expr], recursive: bool = False) -> None: 679 """Run CTE queries and compare output to expected output""" 680 for cte_name, values in self.body["outputs"].get("ctes", {}).items(): 681 with self.subTest(cte=cte_name): 682 if cte_name not in ctes: 683 _raise_error( 684 f"No CTE named {cte_name} found in model {self.model.name}", self.path 685 ) 686 687 cte_query = ctes[cte_name].this 688 689 sort = cte_query.args.get("order") is None 690 partial = values.get("partial") 691 692 cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_name) 693 for alias, cte in ctes.items(): 694 cte_query = cte_query.with_(alias, cte.this, recursive=recursive) 695 696 with self._concurrent_render_context(): 697 # Similar to the model's query, we render the CTE query under the locked context 698 # so that the execution (fetchdf) can continue concurrently between the threads 699 sql = cte_query.sql( 700 self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql 701 ) 702 703 actual = self._execute(sql) 704 expected = self._create_df(values, columns=cte_query.named_selects, partial=partial) 705 706 self.assert_equal(expected, actual, sort=sort, partial=partial)
Run CTE queries and compare output to expected output
708 def runTest(self) -> None: 709 with self._concurrent_render_context(): 710 # Render the model's query and generate the SQL under the locked context so that 711 # execution (fetchdf) can continue concurrently between the threads 712 query = self._render_model_query() 713 sql = query.sql(self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql) 714 715 with_clause = query.args.get("with_") 716 717 if with_clause: 718 self.test_ctes( 719 { 720 self._normalize_model_name(cte.alias, with_default_catalog=False): cte 721 for cte in query.ctes 722 }, 723 recursive=with_clause.recursive, 724 ) 725 726 values = self.body["outputs"].get("query") 727 if values is not None: 728 partial = values.get("partial") 729 sort = query.args.get("order") is None 730 731 actual = self._execute(sql) 732 expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial) 733 734 self.assert_equal(expected, actual, sort=sort, partial=partial)
Inherited Members
- ModelTest
- ModelTest
- CONCURRENT_RENDER_LOCK
- body
- test_name
- model
- models
- engine_adapter
- path
- preserve_fixtures
- default_catalog
- dialect
- concurrency
- verbosity
- defaultTestResult
- shortDescription
- setUp
- tearDown
- assert_equal
- path_relative_to
- create_test
- unittest.case.TestCase
- failureException
- longMessage
- maxDiff
- addTypeEqualityFunc
- addCleanup
- addClassCleanup
- setUpClass
- tearDownClass
- countTestCases
- id
- subTest
- run
- doCleanups
- doClassCleanups
- debug
- skipTest
- fail
- assertFalse
- assertTrue
- assertRaises
- assertWarns
- assertLogs
- assertNoLogs
- assertEqual
- assertNotEqual
- assertAlmostEqual
- assertNotAlmostEqual
- assertSequenceEqual
- assertListEqual
- assertTupleEqual
- assertSetEqual
- assertIn
- assertNotIn
- assertIs
- assertIsNot
- assertDictEqual
- assertDictContainsSubset
- assertCountEqual
- assertMultiLineEqual
- assertLess
- assertLessEqual
- assertGreater
- assertGreaterEqual
- assertIsNone
- assertIsNotNone
- assertIsInstance
- assertNotIsInstance
- assertRaisesRegex
- assertWarnsRegex
- assertRegex
- assertNotRegex
- failUnlessRaises
- failIf
- assertRaisesRegexp
- assertRegexpMatches
- assertNotRegexpMatches
- failUnlessEqual
- assertEquals
- failIfEqual
- assertNotEquals
- failUnlessAlmostEqual
- assertAlmostEquals
- failIfAlmostEqual
- assertNotAlmostEquals
- failUnless
- assert_
752class PythonModelTest(ModelTest): 753 def __init__( 754 self, 755 body: t.Dict[str, t.Any], 756 test_name: str, 757 model: Model, 758 models: UniqueKeyDict[str, Model], 759 engine_adapter: EngineAdapter, 760 dialect: str | None = None, 761 path: Path | None = None, 762 preserve_fixtures: bool = False, 763 default_catalog: str | None = None, 764 concurrency: bool = False, 765 verbosity: Verbosity = Verbosity.DEFAULT, 766 ) -> None: 767 """PythonModelTest encapsulates a unit test for a Python model. 768 769 Args: 770 body: A dictionary that contains test metadata like inputs and outputs. 771 test_name: The name of the test. 772 model: The Python model that is being tested. 773 models: All models to use for expansion and mapping of physical locations. 774 engine_adapter: The engine adapter to use. 775 dialect: The models' dialect, used for normalization purposes. 776 path: An optional path to the test definition yaml file. 777 preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. 778 """ 779 from sqlmesh.core.test.context import TestExecutionContext 780 781 super().__init__( 782 body, 783 test_name, 784 model, 785 models, 786 engine_adapter, 787 dialect, 788 path, 789 preserve_fixtures, 790 default_catalog, 791 concurrency, 792 verbosity, 793 ) 794 795 self.context = TestExecutionContext( 796 engine_adapter=engine_adapter, 797 models=models, 798 test=self, 799 default_dialect=dialect, 800 default_catalog=default_catalog, 801 ) 802 803 def runTest(self) -> None: 804 values = self.body["outputs"].get("query") 805 if values is not None: 806 partial = values.get("partial") 807 808 actual_df = self._execute_model() 809 actual_df.reset_index(drop=True, inplace=True) 810 expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial) 811 812 self.assert_equal(expected, actual_df, sort=True, partial=partial) 813 814 def _execute_model(self) -> pd.DataFrame: 815 """Executes the python model and returns a DataFrame.""" 816 import pandas as pd 817 818 with self._concurrent_render_context(): 819 variables = self.body.get("vars", {}).copy() 820 time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables} 821 df = next(self.model.render(context=self.context, variables=variables, **time_kwargs)) 822 823 assert not isinstance(df, exp.Expr) 824 return df if isinstance(df, pd.DataFrame) else df.toPandas()
A class whose instances are single test cases.
By default, the test code itself should be placed in a method named 'runTest'.
If the fixture may be used for many test cases, create as many test methods as are needed. When instantiating such a TestCase subclass, specify in the constructor arguments the name of the test method that the instance is to execute.
Test authors should subclass TestCase for their own tests. Construction and deconstruction of the test's environment ('fixture') can be implemented by overriding the 'setUp' and 'tearDown' methods respectively.
If it is necessary to override the __init__ method, the base class __init__ method must always be called. It is important that subclasses should not change the signature of their __init__ method, since instances of the classes are instantiated automatically by parts of the framework in order to be run.
When subclassing TestCase, you can set these attributes:
- failureException: determines which exception will be raised when the instance's assertion methods fail; test methods raising this exception will be deemed to have 'failed' rather than 'errored'.
- longMessage: determines whether long messages (including repr of objects used in assert methods) will be printed on failure in addition to any explicit message passed.
- maxDiff: sets the maximum length of a diff in failure messages by assert methods using difflib. It is looked up as an instance attribute so can be configured by individual tests if required.
753 def __init__( 754 self, 755 body: t.Dict[str, t.Any], 756 test_name: str, 757 model: Model, 758 models: UniqueKeyDict[str, Model], 759 engine_adapter: EngineAdapter, 760 dialect: str | None = None, 761 path: Path | None = None, 762 preserve_fixtures: bool = False, 763 default_catalog: str | None = None, 764 concurrency: bool = False, 765 verbosity: Verbosity = Verbosity.DEFAULT, 766 ) -> None: 767 """PythonModelTest encapsulates a unit test for a Python model. 768 769 Args: 770 body: A dictionary that contains test metadata like inputs and outputs. 771 test_name: The name of the test. 772 model: The Python model that is being tested. 773 models: All models to use for expansion and mapping of physical locations. 774 engine_adapter: The engine adapter to use. 775 dialect: The models' dialect, used for normalization purposes. 776 path: An optional path to the test definition yaml file. 777 preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. 778 """ 779 from sqlmesh.core.test.context import TestExecutionContext 780 781 super().__init__( 782 body, 783 test_name, 784 model, 785 models, 786 engine_adapter, 787 dialect, 788 path, 789 preserve_fixtures, 790 default_catalog, 791 concurrency, 792 verbosity, 793 ) 794 795 self.context = TestExecutionContext( 796 engine_adapter=engine_adapter, 797 models=models, 798 test=self, 799 default_dialect=dialect, 800 default_catalog=default_catalog, 801 )
PythonModelTest encapsulates a unit test for a Python model.
Arguments:
- body: A dictionary that contains test metadata like inputs and outputs.
- test_name: The name of the test.
- model: The Python model that is being tested.
- models: All models to use for expansion and mapping of physical locations.
- engine_adapter: The engine adapter to use.
- dialect: The models' dialect, used for normalization purposes.
- path: An optional path to the test definition yaml file.
- preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
803 def runTest(self) -> None: 804 values = self.body["outputs"].get("query") 805 if values is not None: 806 partial = values.get("partial") 807 808 actual_df = self._execute_model() 809 actual_df.reset_index(drop=True, inplace=True) 810 expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial) 811 812 self.assert_equal(expected, actual_df, sort=True, partial=partial)
Inherited Members
- ModelTest
- CONCURRENT_RENDER_LOCK
- body
- test_name
- model
- models
- engine_adapter
- path
- preserve_fixtures
- default_catalog
- dialect
- concurrency
- verbosity
- defaultTestResult
- shortDescription
- setUp
- tearDown
- assert_equal
- path_relative_to
- create_test
- unittest.case.TestCase
- failureException
- longMessage
- maxDiff
- addTypeEqualityFunc
- addCleanup
- addClassCleanup
- setUpClass
- tearDownClass
- countTestCases
- id
- subTest
- run
- doCleanups
- doClassCleanups
- debug
- skipTest
- fail
- assertFalse
- assertTrue
- assertRaises
- assertWarns
- assertLogs
- assertNoLogs
- assertEqual
- assertNotEqual
- assertAlmostEqual
- assertNotAlmostEqual
- assertSequenceEqual
- assertListEqual
- assertTupleEqual
- assertSetEqual
- assertIn
- assertNotIn
- assertIs
- assertIsNot
- assertDictEqual
- assertDictContainsSubset
- assertCountEqual
- assertMultiLineEqual
- assertLess
- assertLessEqual
- assertGreater
- assertGreaterEqual
- assertIsNone
- assertIsNotNone
- assertIsInstance
- assertNotIsInstance
- assertRaisesRegex
- assertWarnsRegex
- assertRegex
- assertNotRegex
- failUnlessRaises
- failIf
- assertRaisesRegexp
- assertRegexpMatches
- assertNotRegexpMatches
- failUnlessEqual
- assertEquals
- failIfEqual
- assertNotEquals
- failUnlessAlmostEqual
- assertAlmostEquals
- failIfAlmostEqual
- assertNotAlmostEquals
- failUnless
- assert_
827def generate_test( 828 model: Model, 829 input_queries: t.Dict[str, str], 830 models: UniqueKeyDict[str, Model], 831 engine_adapter: EngineAdapter, 832 test_engine_adapter: EngineAdapter, 833 project_path: Path, 834 overwrite: bool = False, 835 variables: t.Optional[t.Dict[str, str]] = None, 836 path: t.Optional[str] = None, 837 name: t.Optional[str] = None, 838 include_ctes: bool = False, 839) -> None: 840 """Generate a unit test fixture for a given model. 841 842 Args: 843 model: The model to test. 844 input_queries: Mapping of model names to queries. Each model included in this mapping 845 will be populated in the test based on the results of the corresponding query. 846 models: The context's models. 847 engine_adapter: The target engine adapter. 848 test_engine_adapter: The test engine adapter. 849 project_path: The path pointing to the project's root directory. 850 overwrite: Whether to overwrite the existing test in case of a file path collision. 851 When set to False, an error will be raised if there is such a collision. 852 variables: Key-value pairs that will define variables needed by the model. 853 path: The file path corresponding to the fixture, relative to the test directory. 854 By default, the fixture will be created under the test directory and the file name 855 will be inferred from the test's name. 856 name: The name of the test. This is inferred from the model name by default. 857 include_ctes: When true, CTE fixtures will also be generated. 858 """ 859 import numpy as np 860 861 test_name = name or f"test_{model.view_name}" 862 path = path or f"{test_name}.yaml" 863 864 extension = path.split(".")[-1].lower() 865 if extension not in ("yaml", "yml"): 866 path = f"{path}.yaml" 867 868 fixture_path = project_path / c.TESTS / path 869 if not overwrite and fixture_path.exists(): 870 raise ConfigError( 871 f"Fixture '{fixture_path}' already exists, make sure to set --overwrite if it can be safely overwritten." 872 ) 873 874 # ruamel.yaml does not support pandas Timestamps, so we must convert them to python 875 # datetime or datetime.date objects based on column type 876 inputs = { 877 dep: pandas_timestamp_to_pydatetime( 878 engine_adapter.fetchdf(query).apply(lambda col: col.map(_normalize_df_value)), 879 models[dep].columns_to_types, 880 ) 881 .replace({np.nan: None}) 882 .to_dict(orient="records") 883 for dep, query in input_queries.items() 884 } 885 outputs: t.Dict[str, t.Any] = {"query": {}} 886 variables = variables or {} 887 test_body = {"model": model.fqn, "inputs": inputs, "outputs": outputs} 888 889 if variables: 890 test_body["vars"] = variables 891 892 test = ModelTest.create_test( 893 body=test_body.copy(), 894 test_name=test_name, 895 models=models, 896 engine_adapter=test_engine_adapter, 897 dialect=model.dialect, 898 path=fixture_path, 899 default_catalog=model.default_catalog, 900 ) 901 if not test: 902 return 903 904 test.setUp() 905 906 if isinstance(model, SqlModel): 907 assert isinstance(test, SqlModelTest) 908 model_query = test._render_model_query() 909 with_clause = model_query.args.get("with_") 910 911 if with_clause and include_ctes: 912 ctes = {} 913 recursive = with_clause.recursive 914 previous_ctes: t.List[exp.CTE] = [] 915 916 for cte in model_query.ctes: 917 cte_query = cte.this 918 cte_identifier = cte.args["alias"].this 919 920 cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_identifier) 921 922 for prev in chain(previous_ctes, [cte]): 923 cte_query = cte_query.with_( 924 prev.args["alias"].this, prev.this, recursive=recursive 925 ) 926 927 cte_output = test._execute(cte_query) 928 ctes[cte.alias] = ( 929 pandas_timestamp_to_pydatetime( 930 df=cte_output.apply(lambda col: col.map(_normalize_df_value)), 931 ) 932 .replace({np.nan: None}) 933 .to_dict(orient="records") 934 ) 935 936 previous_ctes.append(cte) 937 938 if ctes: 939 outputs["ctes"] = ctes 940 941 output = test._execute(model_query) 942 else: 943 output = t.cast(PythonModelTest, test)._execute_model() 944 945 outputs["query"] = ( 946 pandas_timestamp_to_pydatetime( 947 output.apply(lambda col: col.map(_normalize_df_value)), model.columns_to_types 948 ) 949 .replace({np.nan: None}) 950 .to_dict(orient="records") 951 ) 952 953 test.tearDown() 954 955 fixture_path.parent.mkdir(exist_ok=True, parents=True) 956 with open(fixture_path, "w", encoding="utf-8") as file: 957 yaml.dump({test_name: test_body}, file)
Generate a unit test fixture for a given model.
Arguments:
- model: The model to test.
- input_queries: Mapping of model names to queries. Each model included in this mapping will be populated in the test based on the results of the corresponding query.
- models: The context's models.
- engine_adapter: The target engine adapter.
- test_engine_adapter: The test engine adapter.
- project_path: The path pointing to the project's root directory.
- overwrite: Whether to overwrite the existing test in case of a file path collision. When set to False, an error will be raised if there is such a collision.
- variables: Key-value pairs that will define variables needed by the model.
- path: The file path corresponding to the fixture, relative to the test directory. By default, the fixture will be created under the test directory and the file name will be inferred from the test's name.
- name: The name of the test. This is inferred from the model name by default.
- include_ctes: When true, CTE fixtures will also be generated.