Edit on GitHub

sqlmesh.core.test.definition

   1from __future__ import annotations
   2
   3import sys
   4
   5import datetime
   6import threading
   7import typing as t
   8import unittest
   9from collections import Counter
  10from contextlib import nullcontext, contextmanager, AbstractContextManager
  11from itertools import chain
  12from pathlib import Path
  13from unittest.mock import patch
  14
  15
  16from io import StringIO
  17from sqlglot import Dialect, exp
  18from sqlglot.optimizer.annotate_types import annotate_types
  19from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
  20
  21from sqlmesh.core import constants as c
  22from sqlmesh.core.dialect import normalize_model_name, schema_
  23from sqlmesh.core.engine_adapter import EngineAdapter
  24from sqlmesh.core.macros import RuntimeStage
  25from sqlmesh.core.model import Model, PythonModel, SqlModel
  26from sqlmesh.utils import UniqueKeyDict, random_id, type_is_known, yaml
  27from sqlmesh.utils.date import date_dict, pandas_timestamp_to_pydatetime, to_datetime
  28from sqlmesh.utils.errors import ConfigError, TestError
  29from sqlmesh.utils.yaml import load as yaml_load
  30from sqlmesh.utils import Verbosity
  31from sqlmesh.utils.rich import df_to_table
  32
  33if t.TYPE_CHECKING:
  34    import pandas as pd
  35
  36    from sqlglot.dialects.dialect import DialectType
  37
  38    Row = t.Dict[str, t.Any]
  39
  40
  41TIME_KWARG_KEYS = {
  42    "start",
  43    "end",
  44    "execution_time",
  45    "latest",
  46    # all built-in datetime macro var names
  47    *date_dict(execution_time="1970-01-01", start="1970-01-01", end="1970-01-01").keys(),
  48}
  49
  50
  51class ModelTest(unittest.TestCase):
  52    __test__ = False
  53
  54    CONCURRENT_RENDER_LOCK = threading.Lock()
  55
  56    def __init__(
  57        self,
  58        body: t.Dict[str, t.Any],
  59        test_name: str,
  60        model: Model,
  61        models: UniqueKeyDict[str, Model],
  62        engine_adapter: EngineAdapter,
  63        dialect: str | None = None,
  64        path: Path | None = None,
  65        preserve_fixtures: bool = False,
  66        default_catalog: str | None = None,
  67        concurrency: bool = False,
  68        verbosity: Verbosity = Verbosity.DEFAULT,
  69    ) -> None:
  70        """ModelTest encapsulates a unit test for a model.
  71
  72        Args:
  73            body: A dictionary that contains test metadata like inputs and outputs.
  74            test_name: The name of the test.
  75            model: The model that is being tested.
  76            models: All models to use for expansion and mapping of physical locations.
  77            engine_adapter: The engine adapter to use.
  78            dialect: The models' dialect, used for normalization purposes.
  79            path: An optional path to the test definition yaml file.
  80            preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
  81        """
  82        self.body = body
  83        self.test_name = test_name
  84        self.model = model
  85        self.models = models
  86        self.engine_adapter = engine_adapter
  87        self.path = path
  88        self.preserve_fixtures = preserve_fixtures
  89        self.default_catalog = default_catalog
  90        self.dialect = dialect
  91        self.concurrency = concurrency
  92        self.verbosity = verbosity
  93
  94        self._fixture_table_cache: t.Dict[str, exp.Table] = {}
  95        self._normalized_column_name_cache: t.Dict[str, str] = {}
  96        self._normalized_model_name_cache: t.Dict[t.Tuple[str, bool], str] = {}
  97
  98        self._test_adapter_dialect = Dialect.get_or_raise(self.engine_adapter.dialect)
  99
 100        self._validate_and_normalize_test()
 101
 102        if self.engine_adapter.default_catalog:
 103            self._fixture_catalog: t.Optional[exp.Identifier] = normalize_identifiers(
 104                exp.parse_identifier(
 105                    self.engine_adapter.default_catalog, dialect=self._test_adapter_dialect
 106                ),
 107                dialect=self._test_adapter_dialect,
 108            )
 109        else:
 110            self._fixture_catalog = None
 111
 112        # The test schema name is randomized to avoid concurrency issues,
 113        # unless a schema is provided in the unit tests's body
 114        self._fixture_schema = exp.parse_identifier(
 115            self.body.get("schema") or f"sqlmesh_test_{random_id(short=True)}"
 116        )
 117        self._qualified_fixture_schema = schema_(self._fixture_schema, self._fixture_catalog)
 118
 119        self._transforms = self._test_adapter_dialect.generator_class.TRANSFORMS
 120        self._execution_time = str(self.body.get("vars", {}).get("execution_time") or "")
 121
 122        if self._execution_time:
 123            # Normalizes the execution time by converting it into UTC timezone
 124            self._execution_time = str(to_datetime(self._execution_time))
 125
 126        # When execution_time is set, we mock the CURRENT_* SQL expressions so they always return it
 127        if self._execution_time:
 128            exec_time = exp.Literal.string(self._execution_time)
 129            self._transforms = {
 130                **self._transforms,
 131                exp.CurrentDate: lambda self, _: self.sql(
 132                    exp.cast(exec_time, "date", dialect=dialect)
 133                ),
 134                exp.CurrentDatetime: lambda self, _: self.sql(
 135                    exp.cast(exec_time, "datetime", dialect=dialect)
 136                ),
 137                exp.CurrentTime: lambda self, _: self.sql(
 138                    exp.cast(exec_time, "time", dialect=dialect)
 139                ),
 140                exp.CurrentTimestamp: lambda self, _: self.sql(
 141                    exp.cast(exec_time, "timestamp", dialect=dialect)
 142                ),
 143            }
 144
 145        super().__init__()
 146
 147    def defaultTestResult(self) -> unittest.TestResult:
 148        from sqlmesh.core.test.result import ModelTextTestResult
 149
 150        return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity)
 151
 152    def shortDescription(self) -> t.Optional[str]:
 153        return self.body.get("description")
 154
 155    def setUp(self) -> None:
 156        """Load all input tables"""
 157        import pandas as pd
 158        import numpy as np
 159
 160        self.engine_adapter.create_schema(self._qualified_fixture_schema)
 161
 162        for name, values in self.body.get("inputs", {}).items():
 163            all_types_are_known = False
 164            columns_to_known_types: t.Dict[str, exp.DataType] = {}
 165
 166            model = self.models.get(name)
 167            if model:
 168                inferred_columns_to_types = model.columns_to_types or {}
 169                columns_to_known_types = {
 170                    c: t for c, t in inferred_columns_to_types.items() if type_is_known(t)
 171                }
 172                all_types_are_known = bool(inferred_columns_to_types) and (
 173                    len(columns_to_known_types) == len(inferred_columns_to_types)
 174                )
 175
 176            # Types specified in the test will override the corresponding inferred ones
 177            columns_to_known_types.update(values.get("columns", {}))
 178
 179            rows = values.get("rows")
 180            if not all_types_are_known and rows:
 181                for col, value in rows[0].items():
 182                    if col not in columns_to_known_types:
 183                        v_type = annotate_types(exp.convert(value)).type or type(value).__name__
 184                        v_type = exp.maybe_parse(
 185                            v_type, into=exp.DataType, dialect=self._test_adapter_dialect
 186                        )
 187
 188                        if not type_is_known(v_type):
 189                            _raise_error(
 190                                f"Failed to infer the data type of column '{col}' for '{name}'. This issue can be "
 191                                "mitigated by casting the column in the model definition, setting its type in "
 192                                "external_models.yaml if it's an external model, setting the model's 'columns' property, "
 193                                "or setting its 'columns' mapping in the test itself",
 194                                self.path,
 195                            )
 196
 197                        columns_to_known_types[col] = v_type
 198
 199            if rows is None:
 200                query_or_df: exp.Query | pd.DataFrame = self._add_missing_columns(
 201                    values["query"], columns_to_known_types
 202                )
 203                if columns_to_known_types:
 204                    columns_to_known_types = {
 205                        col: columns_to_known_types[col] for col in query_or_df.named_selects
 206                    }
 207            else:
 208                query_or_df = self._create_df(values, columns=columns_to_known_types)
 209
 210            # Convert NaN/NaT values to None if DataFrame
 211            if isinstance(query_or_df, pd.DataFrame):
 212                query_or_df = query_or_df.replace({np.nan: None})
 213
 214            self.engine_adapter.create_view(
 215                self._test_fixture_table(name), query_or_df, columns_to_known_types
 216            )
 217
 218    def tearDown(self) -> None:
 219        """Drop all fixture tables."""
 220        if not self.preserve_fixtures:
 221            self.engine_adapter.drop_schema(self._qualified_fixture_schema, cascade=True)
 222
 223    def assert_equal(
 224        self,
 225        expected: pd.DataFrame,
 226        actual: pd.DataFrame,
 227        sort: bool,
 228        partial: t.Optional[bool] = False,
 229    ) -> None:
 230        """Compare two DataFrames"""
 231        import numpy as np
 232        import pandas as pd
 233        from pandas.api.types import is_object_dtype
 234
 235        if partial:
 236            intersection = actual[actual.columns.intersection(expected.columns)]
 237            if len(intersection.columns) > 0:
 238                actual = intersection
 239
 240        # Two astypes are necessary, pandas converts strings to times as NS,
 241        # but if the actual is US, it doesn't take effect until the 2nd try!
 242        actual_types = actual.dtypes.to_dict()
 243        expected = expected.astype(actual_types, errors="ignore").astype(
 244            actual_types, errors="ignore"
 245        )
 246
 247        # The `actual` df's dtypes will almost always be pd.Timestamp for datetime values,
 248        # but in some scenarios (e.g., DuckDB >=0.10.2) it will be a pandas `object` type
 249        # containing python `datetime.xxx` values.
 250        #
 251        # Pandas `object` columns result in a noop for the `astype` call above. Because any
 252        # quoted YAML value is a string, we must manually convert the `expected` df string
 253        # values to the correct `datetime.xxx` type.
 254        #
 255        # We determine the type from a single sentinel value, but since the `actual` df is
 256        # coming from a database query, it is safe to assume that the column contains only
 257        # a single type.
 258        object_sentinel_values = {
 259            col: actual[col][0]
 260            for col in actual_types
 261            if is_object_dtype(actual_types[col]) and len(actual[col]) != 0
 262        }
 263        for col, value in object_sentinel_values.items():
 264            try:
 265                # can't use `isinstance()` here - https://stackoverflow.com/a/68743663/1707525
 266                if type(value) is datetime.date:
 267                    expected[col] = pd.to_datetime(expected[col]).dt.date
 268                elif type(value) is datetime.time:
 269                    expected[col] = pd.to_datetime(expected[col]).dt.time
 270                elif type(value) is datetime.datetime:
 271                    expected[col] = pd.to_datetime(expected[col]).dt.to_pydatetime()
 272            except Exception as e:
 273                from sqlmesh.core.console import get_console
 274
 275                get_console().log_warning(
 276                    f"Failed to convert expected value for {col} into `datetime` "
 277                    f"for unit test '{str(self)}'. {str(e)}."
 278                )
 279
 280        actual = actual.replace({np.nan: None})
 281        expected = expected.replace({np.nan: None})
 282
 283        # We define this here to avoid a top-level import of numpy and pandas
 284        DATETIME_TYPES = (
 285            datetime.datetime,
 286            datetime.date,
 287            datetime.time,
 288            np.datetime64,
 289            pd.Timestamp,
 290        )
 291
 292        def _to_hashable(x: t.Any) -> t.Any:
 293            if isinstance(x, (list, np.ndarray)):
 294                return tuple(_to_hashable(v) for v in x)
 295            if isinstance(x, dict):
 296                return tuple((k, _to_hashable(v)) for k, v in x.items())
 297            return str(x) if isinstance(x, DATETIME_TYPES) or not isinstance(x, t.Hashable) else x
 298
 299        actual = actual.apply(lambda col: col.map(_to_hashable))
 300        expected = expected.apply(lambda col: col.map(_to_hashable))
 301
 302        if sort:
 303            actual = actual.sort_values(by=actual.columns.to_list()).reset_index(drop=True)
 304            expected = expected.sort_values(by=expected.columns.to_list()).reset_index(drop=True)
 305
 306        try:
 307            pd.testing.assert_frame_equal(
 308                expected,
 309                actual,
 310                check_dtype=False,
 311                check_datetimelike_compat=True,
 312                check_like=True,  # Ignore column order
 313            )
 314        except AssertionError as e:
 315            # There are 2 concepts at play here:
 316            # 1. The Exception args will contain the error message plus the diff dataframe table stringified
 317            #    (backwards compatibility with existing tests, possible to serialize/send over network etc)
 318            # 2. Each test will also transform these diff dataframes into Rich tables, which will be the ones that'll
 319            #    be surfaced to the user through Console for better UX (versus stringified dataframes)
 320            #
 321            # This is a bit of a hack, but it's a way to get the best of both worlds.
 322            args: t.List[t.Any] = []
 323
 324            failed_subtest = ""
 325
 326            if subtest := getattr(self, "_subtest", None):
 327                if cte := subtest.params.get("cte"):
 328                    failed_subtest = f" (CTE {cte})"
 329
 330            if expected.shape != actual.shape:
 331                _raise_if_unexpected_columns(expected.columns, actual.columns)
 332
 333                args.append("Data mismatch (rows are different)")
 334
 335                missing_rows = _row_difference(expected, actual)
 336                if not missing_rows.empty:
 337                    args[0] += f"\n\nMissing rows:\n\n{missing_rows}"
 338                    args.append(df_to_table(f"Missing rows{failed_subtest}", missing_rows))
 339
 340                unexpected_rows = _row_difference(actual, expected)
 341
 342                if not unexpected_rows.empty:
 343                    args[0] += f"\n\nUnexpected rows:\n\n{unexpected_rows}"
 344                    args.append(df_to_table(f"Unexpected rows{failed_subtest}", unexpected_rows))
 345
 346            else:
 347                diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"})
 348
 349                args.append(f"Data mismatch (exp: expected, act: actual)\n\n{diff}")
 350
 351                diff.rename(columns={"exp": "Expected", "act": "Actual"}, inplace=True)
 352                if self.verbosity == Verbosity.DEFAULT:
 353                    args.extend(
 354                        df_to_table(f"Data mismatch{failed_subtest}", df)
 355                        for df in _split_df_by_column_pairs(diff)
 356                    )
 357                else:
 358                    from pandas import DataFrame, MultiIndex
 359
 360                    levels = t.cast(MultiIndex, diff.columns).levels[0]
 361                    for col in levels:
 362                        # diff[col] returns a DataFrame when columns is a MultiIndex
 363                        col_diff = t.cast(DataFrame, diff[col])
 364                        if not col_diff.empty:
 365                            table = df_to_table(
 366                                f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]",
 367                                col_diff,
 368                            )
 369                            args.append(table)
 370
 371            e.args = (*args,)
 372
 373            raise e
 374
 375    def runTest(self) -> None:
 376        raise NotImplementedError
 377
 378    def path_relative_to(self, other: Path) -> Path | None:
 379        """Compute a version of this test's path relative to the `other` path"""
 380        return self.path.relative_to(other) if self.path else None
 381
 382    @staticmethod
 383    def create_test(
 384        body: t.Dict[str, t.Any],
 385        test_name: str,
 386        models: UniqueKeyDict[str, Model],
 387        engine_adapter: EngineAdapter,
 388        dialect: str | None,
 389        path: Path | None,
 390        preserve_fixtures: bool = False,
 391        default_catalog: str | None = None,
 392        concurrency: bool = False,
 393        verbosity: Verbosity = Verbosity.DEFAULT,
 394    ) -> t.Optional[ModelTest]:
 395        """Create a SqlModelTest or a PythonModelTest.
 396
 397        Args:
 398            body: A dictionary that contains test metadata like inputs and outputs.
 399            test_name: The name of the test.
 400            models: All models to use for expansion and mapping of physical locations.
 401            engine_adapter: The engine adapter to use.
 402            dialect: The models' dialect, used for normalization purposes.
 403            path: An optional path to the test definition yaml file.
 404            preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
 405        """
 406        name = body.get("model")
 407        if name is None:
 408            _raise_error("Missing required 'model' field", path)
 409
 410        name = normalize_model_name(name, default_catalog=default_catalog, dialect=dialect)
 411        model = models.get(name)
 412        if not model:
 413            from sqlmesh.core.console import get_console
 414
 415            get_console().log_warning(
 416                f"Model '{name}' was not found{' at ' + str(path) if path else ''}"
 417            )
 418            return None
 419
 420        if isinstance(model, SqlModel):
 421            test_type: t.Type[ModelTest] = SqlModelTest
 422        elif isinstance(model, PythonModel):
 423            test_type = PythonModelTest
 424        else:
 425            _raise_error(f"Model '{name}' is an unsupported model type for testing", path)
 426
 427        try:
 428            return test_type(
 429                body,
 430                test_name,
 431                t.cast(Model, model),
 432                models,
 433                engine_adapter,
 434                dialect,
 435                path,
 436                preserve_fixtures,
 437                default_catalog,
 438                concurrency,
 439                verbosity,
 440            )
 441        except Exception as e:
 442            raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}")
 443
 444    def __str__(self) -> str:
 445        return f"{self.test_name} ({self.path})"
 446
 447    def _validate_and_normalize_test(self) -> None:
 448        inputs = self.body.get("inputs")
 449        outputs = self.body.get("outputs", {})
 450
 451        if not outputs:
 452            _raise_error("Incomplete test, missing outputs", self.path)
 453
 454        ctes = outputs.get("ctes")
 455        query = outputs.get("query")
 456        partial = outputs.pop("partial", None)
 457
 458        if ctes is None and query is None:
 459            _raise_error("Incomplete test, outputs must contain 'query' or 'ctes'", self.path)
 460
 461        def _normalize_rows(
 462            values: t.List[Row] | t.Dict,
 463            name: str,
 464            partial: bool = False,
 465            dialect: DialectType = None,
 466        ) -> t.Dict:
 467            import pandas as pd
 468
 469            if not isinstance(values, dict):
 470                values = {"rows": values}
 471
 472            rows = values.get("rows")
 473            query = values.get("query")
 474
 475            fmt = values.get("format")
 476            path = values.get("path")
 477            if fmt == "csv":
 478                csv_settings = values.get("csv_settings") or {}
 479                rows = pd.read_csv(path or StringIO(rows), **csv_settings).to_dict(orient="records")
 480            elif fmt in (None, "yaml"):
 481                if path:
 482                    input_rows = yaml_load(Path(path))
 483                    rows = input_rows.get("rows") if isinstance(input_rows, dict) else input_rows
 484            else:
 485                _raise_error(f"Unsupported data format '{fmt}' for '{name}'", self.path)
 486
 487            if query is not None:
 488                if rows is not None:
 489                    _raise_error(
 490                        f"Invalid test, cannot set both 'query' and 'rows' for '{name}'", self.path
 491                    )
 492
 493                # We parse the user-supplied query using the testing adapter dialect, but we
 494                # normalize its identifiers according to the model's dialect, so that, e.g.,
 495                # the projection names match those in its `columns_to_types` field
 496                values["query"] = normalize_identifiers(
 497                    exp.maybe_parse(query, dialect=self._test_adapter_dialect), dialect=dialect
 498                )
 499                return values
 500
 501            if rows is None:
 502                _raise_error(f"Incomplete test, missing row data for '{name}'", self.path)
 503
 504            assert isinstance(rows, list)
 505            values["rows"] = [
 506                {self._normalize_column_name(column): value for column, value in row.items()}
 507                for row in rows
 508            ]
 509            if partial:
 510                values["partial"] = True
 511
 512            return values
 513
 514        def _normalize_sources(
 515            sources: t.Dict, partial: bool = False, with_default_catalog: bool = True
 516        ) -> t.Dict:
 517            normalized_sources = {}
 518            for name, values in sources.items():
 519                normalized_name = self._normalize_model_name(
 520                    name, with_default_catalog=with_default_catalog
 521                )
 522                model = self.models.get(normalized_name)
 523                dialect = model.dialect if model else self.dialect
 524
 525                normalized_sources[normalized_name] = _normalize_rows(
 526                    values, name, partial=partial, dialect=dialect
 527                )
 528
 529            return normalized_sources
 530
 531        normalized_model_name = self._normalize_model_name(self.body["model"])
 532        self.body["model"] = normalized_model_name
 533
 534        if inputs:
 535            inputs = _normalize_sources(inputs)
 536            for name, values in inputs.items():
 537                columns = values.get("columns")
 538                if columns is None:
 539                    continue
 540
 541                if not isinstance(columns, dict):
 542                    _raise_error(
 543                        f"Invalid 'columns' value for model '{name}', expected a mapping name -> type",
 544                        self.path,
 545                    )
 546
 547                values["columns"] = {
 548                    self._normalize_column_name(c): exp.DataType.build(
 549                        t, dialect=self._test_adapter_dialect
 550                    )
 551                    for c, t in columns.items()
 552                }
 553
 554            for depends_on in self.model.depends_on:
 555                if depends_on not in inputs:
 556                    _raise_error(f"Incomplete test, missing input model '{depends_on}'", self.path)
 557
 558            if self.model.depends_on_self and normalized_model_name not in inputs:
 559                inputs[normalized_model_name] = {"rows": []}
 560
 561            self.body["inputs"] = inputs
 562
 563        if ctes:
 564            outputs["ctes"] = _normalize_sources(ctes, partial=partial, with_default_catalog=False)
 565
 566        if query or query == []:
 567            outputs["query"] = _normalize_rows(
 568                query, self.model.name, partial=partial, dialect=self.model.dialect
 569            )
 570
 571    def _test_fixture_table(self, name: str) -> exp.Table:
 572        table = self._fixture_table_cache.get(name)
 573        if not table:
 574            table = exp.to_table(name, dialect=self._test_adapter_dialect)
 575
 576            # We change the table path below, so this ensures there are no name clashes
 577            table.this.set("this", "__".join(part.name for part in table.parts))
 578
 579            table.set("db", self._fixture_schema.copy())
 580            if self._fixture_catalog:
 581                table.set("catalog", self._fixture_catalog.copy())
 582
 583            self._fixture_table_cache[name] = table
 584
 585        return table
 586
 587    def _normalize_model_name(self, name: str, with_default_catalog: bool = True) -> str:
 588        normalized_name = self._normalized_model_name_cache.get((name, with_default_catalog))
 589        if normalized_name is None:
 590            default_catalog = self.default_catalog if with_default_catalog else None
 591            normalized_name = normalize_model_name(
 592                name, default_catalog=default_catalog, dialect=self.dialect
 593            )
 594            self._normalized_model_name_cache[(name, with_default_catalog)] = normalized_name
 595
 596        return normalized_name
 597
 598    def _normalize_column_name(self, name: str) -> str:
 599        normalized_name = self._normalized_column_name_cache.get(name)
 600        if normalized_name is None:
 601            normalized_name = normalize_identifiers(name, dialect=self.dialect).name
 602            self._normalized_column_name_cache[name] = normalized_name
 603
 604        return normalized_name
 605
 606    @contextmanager
 607    def _concurrent_render_context(self) -> t.Iterator[None]:
 608        """
 609        Context manager that ensures that the tests are executed safely in a concurrent environment.
 610        This is needed in case `execution_time` is set, as we'd then have to:
 611        - Freeze time through `time_machine` (not thread safe)
 612        - Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation
 613        """
 614        import time_machine
 615
 616        lock_ctx: AbstractContextManager = (
 617            self.CONCURRENT_RENDER_LOCK if self.concurrency else nullcontext()
 618        )
 619        time_ctx: AbstractContextManager = nullcontext()
 620        dialect_patch_ctx: AbstractContextManager = nullcontext()
 621
 622        if self._execution_time:
 623            time_ctx = time_machine.travel(self._execution_time, tick=False)
 624            dialect_patch_ctx = patch.dict(
 625                self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms
 626            )
 627
 628        with lock_ctx, time_ctx, dialect_patch_ctx:
 629            yield
 630
 631    def _execute(self, query: exp.Query | str) -> pd.DataFrame:
 632        """Executes the given query using the testing engine adapter and returns a DataFrame."""
 633        return self.engine_adapter.fetchdf(query)
 634
 635    def _create_df(
 636        self,
 637        values: t.Dict[str, t.Any],
 638        columns: t.Optional[t.Collection] = None,
 639        partial: t.Optional[bool] = False,
 640    ) -> pd.DataFrame:
 641        import pandas as pd
 642
 643        query = values.get("query")
 644        if query:
 645            if not partial:
 646                query = self._add_missing_columns(query, columns)
 647
 648            return self._execute(query)
 649
 650        rows = values["rows"]
 651        columns_str: t.Optional[t.List[str]] = None
 652        if columns:
 653            columns_str = [str(c) for c in columns]
 654            referenced_columns = list(dict.fromkeys(col for row in rows for col in row))
 655            _raise_if_unexpected_columns(columns, referenced_columns)
 656
 657            if partial:
 658                columns_str = [c for c in columns_str if c in referenced_columns]
 659
 660        return pd.DataFrame.from_records(rows, columns=columns_str)
 661
 662    def _add_missing_columns(
 663        self, query: exp.Query, all_columns: t.Optional[t.Collection[str]] = None
 664    ) -> exp.Query:
 665        if not all_columns or query.is_star:
 666            return query
 667
 668        query_columns = set(query.named_selects)
 669        missing_columns = [col for col in all_columns if col not in query_columns]
 670        if missing_columns:
 671            query.select(*[exp.null().as_(col) for col in missing_columns], copy=False)
 672
 673        return query
 674
 675
 676class SqlModelTest(ModelTest):
 677    def test_ctes(self, ctes: t.Dict[str, exp.Expr], recursive: bool = False) -> None:
 678        """Run CTE queries and compare output to expected output"""
 679        for cte_name, values in self.body["outputs"].get("ctes", {}).items():
 680            with self.subTest(cte=cte_name):
 681                if cte_name not in ctes:
 682                    _raise_error(
 683                        f"No CTE named {cte_name} found in model {self.model.name}", self.path
 684                    )
 685
 686                cte_query = ctes[cte_name].this
 687
 688                sort = cte_query.args.get("order") is None
 689                partial = values.get("partial")
 690
 691                cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_name)
 692                for alias, cte in ctes.items():
 693                    cte_query = cte_query.with_(alias, cte.this, recursive=recursive)
 694
 695                with self._concurrent_render_context():
 696                    # Similar to the model's query, we render the CTE query under the locked context
 697                    # so that the execution (fetchdf) can continue concurrently between the threads
 698                    sql = cte_query.sql(
 699                        self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql
 700                    )
 701
 702                actual = self._execute(sql)
 703                expected = self._create_df(values, columns=cte_query.named_selects, partial=partial)
 704
 705                self.assert_equal(expected, actual, sort=sort, partial=partial)
 706
 707    def runTest(self) -> None:
 708        with self._concurrent_render_context():
 709            # Render the model's query and generate the SQL under the locked context so that
 710            # execution (fetchdf) can continue concurrently between the threads
 711            query = self._render_model_query()
 712            sql = query.sql(self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql)
 713
 714        with_clause = query.args.get("with_")
 715
 716        if with_clause:
 717            self.test_ctes(
 718                {
 719                    self._normalize_model_name(cte.alias, with_default_catalog=False): cte
 720                    for cte in query.ctes
 721                },
 722                recursive=with_clause.recursive,
 723            )
 724
 725        values = self.body["outputs"].get("query")
 726        if values is not None:
 727            partial = values.get("partial")
 728            sort = query.args.get("order") is None
 729
 730            actual = self._execute(sql)
 731            expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
 732
 733            self.assert_equal(expected, actual, sort=sort, partial=partial)
 734
 735    def _render_model_query(self) -> exp.Query:
 736        variables = self.body.get("vars", {}).copy()
 737        time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables}
 738
 739        query = self.model.render_query_or_raise(
 740            **time_kwargs,
 741            variables=variables,
 742            engine_adapter=self.engine_adapter,
 743            table_mapping={
 744                name: self._test_fixture_table(name).sql() for name in self.body.get("inputs", {})
 745            },
 746            runtime_stage=RuntimeStage.TESTING,
 747        )
 748        return query
 749
 750
 751class PythonModelTest(ModelTest):
 752    def __init__(
 753        self,
 754        body: t.Dict[str, t.Any],
 755        test_name: str,
 756        model: Model,
 757        models: UniqueKeyDict[str, Model],
 758        engine_adapter: EngineAdapter,
 759        dialect: str | None = None,
 760        path: Path | None = None,
 761        preserve_fixtures: bool = False,
 762        default_catalog: str | None = None,
 763        concurrency: bool = False,
 764        verbosity: Verbosity = Verbosity.DEFAULT,
 765    ) -> None:
 766        """PythonModelTest encapsulates a unit test for a Python model.
 767
 768        Args:
 769            body: A dictionary that contains test metadata like inputs and outputs.
 770            test_name: The name of the test.
 771            model: The Python model that is being tested.
 772            models: All models to use for expansion and mapping of physical locations.
 773            engine_adapter: The engine adapter to use.
 774            dialect: The models' dialect, used for normalization purposes.
 775            path: An optional path to the test definition yaml file.
 776            preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
 777        """
 778        from sqlmesh.core.test.context import TestExecutionContext
 779
 780        super().__init__(
 781            body,
 782            test_name,
 783            model,
 784            models,
 785            engine_adapter,
 786            dialect,
 787            path,
 788            preserve_fixtures,
 789            default_catalog,
 790            concurrency,
 791            verbosity,
 792        )
 793
 794        self.context = TestExecutionContext(
 795            engine_adapter=engine_adapter,
 796            models=models,
 797            test=self,
 798            default_dialect=dialect,
 799            default_catalog=default_catalog,
 800        )
 801
 802    def runTest(self) -> None:
 803        values = self.body["outputs"].get("query")
 804        if values is not None:
 805            partial = values.get("partial")
 806
 807            actual_df = self._execute_model()
 808            actual_df.reset_index(drop=True, inplace=True)
 809            expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
 810
 811            self.assert_equal(expected, actual_df, sort=True, partial=partial)
 812
 813    def _execute_model(self) -> pd.DataFrame:
 814        """Executes the python model and returns a DataFrame."""
 815        import pandas as pd
 816
 817        with self._concurrent_render_context():
 818            variables = self.body.get("vars", {}).copy()
 819            time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables}
 820            df = next(self.model.render(context=self.context, variables=variables, **time_kwargs))
 821
 822        assert not isinstance(df, exp.Expr)
 823        return df if isinstance(df, pd.DataFrame) else df.toPandas()
 824
 825
 826def generate_test(
 827    model: Model,
 828    input_queries: t.Dict[str, str],
 829    models: UniqueKeyDict[str, Model],
 830    engine_adapter: EngineAdapter,
 831    test_engine_adapter: EngineAdapter,
 832    project_path: Path,
 833    overwrite: bool = False,
 834    variables: t.Optional[t.Dict[str, str]] = None,
 835    path: t.Optional[str] = None,
 836    name: t.Optional[str] = None,
 837    include_ctes: bool = False,
 838) -> None:
 839    """Generate a unit test fixture for a given model.
 840
 841    Args:
 842        model: The model to test.
 843        input_queries: Mapping of model names to queries. Each model included in this mapping
 844            will be populated in the test based on the results of the corresponding query.
 845        models: The context's models.
 846        engine_adapter: The target engine adapter.
 847        test_engine_adapter: The test engine adapter.
 848        project_path: The path pointing to the project's root directory.
 849        overwrite: Whether to overwrite the existing test in case of a file path collision.
 850            When set to False, an error will be raised if there is such a collision.
 851        variables: Key-value pairs that will define variables needed by the model.
 852        path: The file path corresponding to the fixture, relative to the test directory.
 853            By default, the fixture will be created under the test directory and the file name
 854            will be inferred from the test's name.
 855        name: The name of the test. This is inferred from the model name by default.
 856        include_ctes: When true, CTE fixtures will also be generated.
 857    """
 858    import numpy as np
 859
 860    test_name = name or f"test_{model.view_name}"
 861    path = path or f"{test_name}.yaml"
 862
 863    extension = path.split(".")[-1].lower()
 864    if extension not in ("yaml", "yml"):
 865        path = f"{path}.yaml"
 866
 867    fixture_path = project_path / c.TESTS / path
 868    if not overwrite and fixture_path.exists():
 869        raise ConfigError(
 870            f"Fixture '{fixture_path}' already exists, make sure to set --overwrite if it can be safely overwritten."
 871        )
 872
 873    # ruamel.yaml does not support pandas Timestamps, so we must convert them to python
 874    # datetime or datetime.date objects based on column type
 875    inputs = {
 876        dep: pandas_timestamp_to_pydatetime(
 877            engine_adapter.fetchdf(query).apply(lambda col: col.map(_normalize_df_value)),
 878            models[dep].columns_to_types,
 879        )
 880        .replace({np.nan: None})
 881        .to_dict(orient="records")
 882        for dep, query in input_queries.items()
 883    }
 884    outputs: t.Dict[str, t.Any] = {"query": {}}
 885    variables = variables or {}
 886    test_body = {"model": model.fqn, "inputs": inputs, "outputs": outputs}
 887
 888    if variables:
 889        test_body["vars"] = variables
 890
 891    test = ModelTest.create_test(
 892        body=test_body.copy(),
 893        test_name=test_name,
 894        models=models,
 895        engine_adapter=test_engine_adapter,
 896        dialect=model.dialect,
 897        path=fixture_path,
 898        default_catalog=model.default_catalog,
 899    )
 900    if not test:
 901        return
 902
 903    test.setUp()
 904
 905    if isinstance(model, SqlModel):
 906        assert isinstance(test, SqlModelTest)
 907        model_query = test._render_model_query()
 908        with_clause = model_query.args.get("with_")
 909
 910        if with_clause and include_ctes:
 911            ctes = {}
 912            recursive = with_clause.recursive
 913            previous_ctes: t.List[exp.CTE] = []
 914
 915            for cte in model_query.ctes:
 916                cte_query = cte.this
 917                cte_identifier = cte.args["alias"].this
 918
 919                cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_identifier)
 920
 921                for prev in chain(previous_ctes, [cte]):
 922                    cte_query = cte_query.with_(
 923                        prev.args["alias"].this, prev.this, recursive=recursive
 924                    )
 925
 926                cte_output = test._execute(cte_query)
 927                ctes[cte.alias] = (
 928                    pandas_timestamp_to_pydatetime(
 929                        df=cte_output.apply(lambda col: col.map(_normalize_df_value)),
 930                    )
 931                    .replace({np.nan: None})
 932                    .to_dict(orient="records")
 933                )
 934
 935                previous_ctes.append(cte)
 936
 937            if ctes:
 938                outputs["ctes"] = ctes
 939
 940        output = test._execute(model_query)
 941    else:
 942        output = t.cast(PythonModelTest, test)._execute_model()
 943
 944    outputs["query"] = (
 945        pandas_timestamp_to_pydatetime(
 946            output.apply(lambda col: col.map(_normalize_df_value)), model.columns_to_types
 947        )
 948        .replace({np.nan: None})
 949        .to_dict(orient="records")
 950    )
 951
 952    test.tearDown()
 953
 954    fixture_path.parent.mkdir(exist_ok=True, parents=True)
 955    with open(fixture_path, "w", encoding="utf-8") as file:
 956        yaml.dump({test_name: test_body}, file)
 957
 958
 959def _projection_identifiers(query: exp.Query) -> t.List[str | exp.Identifier]:
 960    identifiers: t.List[str | exp.Identifier] = []
 961    for select in query.selects:
 962        if isinstance(select, exp.Alias):
 963            identifiers.append(select.args["alias"])
 964        elif isinstance(select, exp.Column):
 965            identifiers.append(select.this)
 966        else:
 967            identifiers.append(select.output_name)
 968
 969    return identifiers
 970
 971
 972def _raise_if_unexpected_columns(
 973    expected_cols: t.Collection[str], actual_cols: t.Collection[str]
 974) -> None:
 975    unique_expected_cols = set(expected_cols)
 976    unknown_cols = [col for col in actual_cols if col not in unique_expected_cols]
 977
 978    if unknown_cols:
 979        expected = f"Expected column(s): {', '.join(list(expected_cols))}\n"
 980        unknown = f"Unknown column(s): {', '.join(unknown_cols)}"
 981        _raise_error(f"Detected unknown column(s)\n\n{expected}{unknown}")
 982
 983
 984def _row_difference(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
 985    """Returns all rows in `left` that don't appear in `right`."""
 986    import numpy as np
 987    import pandas as pd
 988
 989    rows_missing_from_right = []
 990
 991    # `None` replaces `np.nan` because `np.nan != np.nan` and this would affect the mapping lookup
 992    right_row_count: t.MutableMapping[t.Tuple, int] = Counter(
 993        right.replace({np.nan: None}).itertuples(index=False, name=None)
 994    )
 995    for left_row in left.replace({np.nan: None}).itertuples(index=False):
 996        left_row_tuple = tuple(left_row)
 997        if right_row_count[left_row_tuple] <= 0:
 998            rows_missing_from_right.append(left_row)
 999        else:
1000            right_row_count[left_row_tuple] -= 1
1001
1002    return pd.DataFrame(rows_missing_from_right)
1003
1004
1005def _raise_error(msg: str, path: Path | None = None) -> None:
1006    if path:
1007        raise TestError(f"Failed to run test at {path}:\n{msg}")
1008    raise TestError(f"Failed to run test:\n{msg}")
1009
1010
1011def _normalize_df_value(value: t.Any) -> t.Any:
1012    """Normalize data in a pandas dataframe so ruamel and sqlglot can deal with it."""
1013    import numpy as np
1014
1015    if isinstance(value, (list, np.ndarray)):
1016        return [_normalize_df_value(v) for v in value]
1017    if isinstance(value, dict):
1018        if "key" in value and "value" in value:
1019            # Maps returned by DuckDB look like: {'key': ['key1', 'key2'], 'value': [10, 20]}
1020            # so we convert to {'key1': 10, 'key2': 20} (TODO: handle more dialects here)
1021            return {k: _normalize_df_value(v) for k, v in zip(value["key"], value["value"])}
1022        return {k: _normalize_df_value(v) for k, v in value.items()}
1023    return value
1024
1025
1026def _split_df_by_column_pairs(df: pd.DataFrame, pairs_per_chunk: int = 4) -> t.List[pd.DataFrame]:
1027    """Split a dataframe into chunks of column pairs.
1028
1029    Args:
1030        df: The dataframe to split
1031        pairs_per_chunk: Number of column pairs per chunk (default: 4)
1032
1033    Returns:
1034        List of dataframes, each containing an even number of columns
1035    """
1036    total_columns = len(df.columns)
1037
1038    # If we have fewer columns than pairs_per_chunk * 2, return the original df
1039    if total_columns <= pairs_per_chunk * 2:
1040        return [df]
1041
1042    # Calculate number of chunks needed to split columns evenly
1043    num_chunks = (total_columns + (pairs_per_chunk * 2 - 1)) // (pairs_per_chunk * 2)
1044
1045    # Calculate columns per chunk to ensure equal distribution
1046    # We round down to nearest even number to ensure each chunk has even columns
1047    columns_per_chunk = (total_columns // num_chunks) & ~1  # Round down to nearest even number
1048    remainder = total_columns - (columns_per_chunk * num_chunks)
1049
1050    chunks = []
1051    start_idx = 0
1052
1053    # Distribute columns evenly across chunks
1054    for i in range(num_chunks):
1055        # Add 2 columns to early chunks if there's a remainder
1056        # This ensures we always add pairs of columns
1057        extra = 2 if i < remainder // 2 else 0
1058        end_idx = start_idx + columns_per_chunk + extra
1059        chunk = df.iloc[:, start_idx:end_idx]
1060        chunks.append(chunk)
1061        start_idx = end_idx
1062
1063    return chunks
TIME_KWARG_KEYS = {'start', 'latest_dt', 'latest_epoch', 'latest_hour', 'latest_tstz', 'start_epoch', 'start_date', 'start_hour', 'latest_dtntz', 'execution_tstz', 'end_date', 'latest_date', 'end_dt', 'end_tstz', 'end_ds', 'execution_dt', 'latest_ts', 'execution_time', 'latest_millis', 'execution_hour', 'end_epoch', 'start_dtntz', 'end_millis', 'execution_millis', 'latest_ds', 'start_tstz', 'latest', 'execution_epoch', 'execution_ds', 'execution_ts', 'end_ts', 'end_dtntz', 'end', 'execution_dtntz', 'start_ds', 'start_millis', 'execution_date', 'start_ts', 'start_dt', 'end_hour'}
class ModelTest(unittest.case.TestCase):
 52class ModelTest(unittest.TestCase):
 53    __test__ = False
 54
 55    CONCURRENT_RENDER_LOCK = threading.Lock()
 56
 57    def __init__(
 58        self,
 59        body: t.Dict[str, t.Any],
 60        test_name: str,
 61        model: Model,
 62        models: UniqueKeyDict[str, Model],
 63        engine_adapter: EngineAdapter,
 64        dialect: str | None = None,
 65        path: Path | None = None,
 66        preserve_fixtures: bool = False,
 67        default_catalog: str | None = None,
 68        concurrency: bool = False,
 69        verbosity: Verbosity = Verbosity.DEFAULT,
 70    ) -> None:
 71        """ModelTest encapsulates a unit test for a model.
 72
 73        Args:
 74            body: A dictionary that contains test metadata like inputs and outputs.
 75            test_name: The name of the test.
 76            model: The model that is being tested.
 77            models: All models to use for expansion and mapping of physical locations.
 78            engine_adapter: The engine adapter to use.
 79            dialect: The models' dialect, used for normalization purposes.
 80            path: An optional path to the test definition yaml file.
 81            preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
 82        """
 83        self.body = body
 84        self.test_name = test_name
 85        self.model = model
 86        self.models = models
 87        self.engine_adapter = engine_adapter
 88        self.path = path
 89        self.preserve_fixtures = preserve_fixtures
 90        self.default_catalog = default_catalog
 91        self.dialect = dialect
 92        self.concurrency = concurrency
 93        self.verbosity = verbosity
 94
 95        self._fixture_table_cache: t.Dict[str, exp.Table] = {}
 96        self._normalized_column_name_cache: t.Dict[str, str] = {}
 97        self._normalized_model_name_cache: t.Dict[t.Tuple[str, bool], str] = {}
 98
 99        self._test_adapter_dialect = Dialect.get_or_raise(self.engine_adapter.dialect)
100
101        self._validate_and_normalize_test()
102
103        if self.engine_adapter.default_catalog:
104            self._fixture_catalog: t.Optional[exp.Identifier] = normalize_identifiers(
105                exp.parse_identifier(
106                    self.engine_adapter.default_catalog, dialect=self._test_adapter_dialect
107                ),
108                dialect=self._test_adapter_dialect,
109            )
110        else:
111            self._fixture_catalog = None
112
113        # The test schema name is randomized to avoid concurrency issues,
114        # unless a schema is provided in the unit tests's body
115        self._fixture_schema = exp.parse_identifier(
116            self.body.get("schema") or f"sqlmesh_test_{random_id(short=True)}"
117        )
118        self._qualified_fixture_schema = schema_(self._fixture_schema, self._fixture_catalog)
119
120        self._transforms = self._test_adapter_dialect.generator_class.TRANSFORMS
121        self._execution_time = str(self.body.get("vars", {}).get("execution_time") or "")
122
123        if self._execution_time:
124            # Normalizes the execution time by converting it into UTC timezone
125            self._execution_time = str(to_datetime(self._execution_time))
126
127        # When execution_time is set, we mock the CURRENT_* SQL expressions so they always return it
128        if self._execution_time:
129            exec_time = exp.Literal.string(self._execution_time)
130            self._transforms = {
131                **self._transforms,
132                exp.CurrentDate: lambda self, _: self.sql(
133                    exp.cast(exec_time, "date", dialect=dialect)
134                ),
135                exp.CurrentDatetime: lambda self, _: self.sql(
136                    exp.cast(exec_time, "datetime", dialect=dialect)
137                ),
138                exp.CurrentTime: lambda self, _: self.sql(
139                    exp.cast(exec_time, "time", dialect=dialect)
140                ),
141                exp.CurrentTimestamp: lambda self, _: self.sql(
142                    exp.cast(exec_time, "timestamp", dialect=dialect)
143                ),
144            }
145
146        super().__init__()
147
148    def defaultTestResult(self) -> unittest.TestResult:
149        from sqlmesh.core.test.result import ModelTextTestResult
150
151        return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity)
152
153    def shortDescription(self) -> t.Optional[str]:
154        return self.body.get("description")
155
156    def setUp(self) -> None:
157        """Load all input tables"""
158        import pandas as pd
159        import numpy as np
160
161        self.engine_adapter.create_schema(self._qualified_fixture_schema)
162
163        for name, values in self.body.get("inputs", {}).items():
164            all_types_are_known = False
165            columns_to_known_types: t.Dict[str, exp.DataType] = {}
166
167            model = self.models.get(name)
168            if model:
169                inferred_columns_to_types = model.columns_to_types or {}
170                columns_to_known_types = {
171                    c: t for c, t in inferred_columns_to_types.items() if type_is_known(t)
172                }
173                all_types_are_known = bool(inferred_columns_to_types) and (
174                    len(columns_to_known_types) == len(inferred_columns_to_types)
175                )
176
177            # Types specified in the test will override the corresponding inferred ones
178            columns_to_known_types.update(values.get("columns", {}))
179
180            rows = values.get("rows")
181            if not all_types_are_known and rows:
182                for col, value in rows[0].items():
183                    if col not in columns_to_known_types:
184                        v_type = annotate_types(exp.convert(value)).type or type(value).__name__
185                        v_type = exp.maybe_parse(
186                            v_type, into=exp.DataType, dialect=self._test_adapter_dialect
187                        )
188
189                        if not type_is_known(v_type):
190                            _raise_error(
191                                f"Failed to infer the data type of column '{col}' for '{name}'. This issue can be "
192                                "mitigated by casting the column in the model definition, setting its type in "
193                                "external_models.yaml if it's an external model, setting the model's 'columns' property, "
194                                "or setting its 'columns' mapping in the test itself",
195                                self.path,
196                            )
197
198                        columns_to_known_types[col] = v_type
199
200            if rows is None:
201                query_or_df: exp.Query | pd.DataFrame = self._add_missing_columns(
202                    values["query"], columns_to_known_types
203                )
204                if columns_to_known_types:
205                    columns_to_known_types = {
206                        col: columns_to_known_types[col] for col in query_or_df.named_selects
207                    }
208            else:
209                query_or_df = self._create_df(values, columns=columns_to_known_types)
210
211            # Convert NaN/NaT values to None if DataFrame
212            if isinstance(query_or_df, pd.DataFrame):
213                query_or_df = query_or_df.replace({np.nan: None})
214
215            self.engine_adapter.create_view(
216                self._test_fixture_table(name), query_or_df, columns_to_known_types
217            )
218
219    def tearDown(self) -> None:
220        """Drop all fixture tables."""
221        if not self.preserve_fixtures:
222            self.engine_adapter.drop_schema(self._qualified_fixture_schema, cascade=True)
223
224    def assert_equal(
225        self,
226        expected: pd.DataFrame,
227        actual: pd.DataFrame,
228        sort: bool,
229        partial: t.Optional[bool] = False,
230    ) -> None:
231        """Compare two DataFrames"""
232        import numpy as np
233        import pandas as pd
234        from pandas.api.types import is_object_dtype
235
236        if partial:
237            intersection = actual[actual.columns.intersection(expected.columns)]
238            if len(intersection.columns) > 0:
239                actual = intersection
240
241        # Two astypes are necessary, pandas converts strings to times as NS,
242        # but if the actual is US, it doesn't take effect until the 2nd try!
243        actual_types = actual.dtypes.to_dict()
244        expected = expected.astype(actual_types, errors="ignore").astype(
245            actual_types, errors="ignore"
246        )
247
248        # The `actual` df's dtypes will almost always be pd.Timestamp for datetime values,
249        # but in some scenarios (e.g., DuckDB >=0.10.2) it will be a pandas `object` type
250        # containing python `datetime.xxx` values.
251        #
252        # Pandas `object` columns result in a noop for the `astype` call above. Because any
253        # quoted YAML value is a string, we must manually convert the `expected` df string
254        # values to the correct `datetime.xxx` type.
255        #
256        # We determine the type from a single sentinel value, but since the `actual` df is
257        # coming from a database query, it is safe to assume that the column contains only
258        # a single type.
259        object_sentinel_values = {
260            col: actual[col][0]
261            for col in actual_types
262            if is_object_dtype(actual_types[col]) and len(actual[col]) != 0
263        }
264        for col, value in object_sentinel_values.items():
265            try:
266                # can't use `isinstance()` here - https://stackoverflow.com/a/68743663/1707525
267                if type(value) is datetime.date:
268                    expected[col] = pd.to_datetime(expected[col]).dt.date
269                elif type(value) is datetime.time:
270                    expected[col] = pd.to_datetime(expected[col]).dt.time
271                elif type(value) is datetime.datetime:
272                    expected[col] = pd.to_datetime(expected[col]).dt.to_pydatetime()
273            except Exception as e:
274                from sqlmesh.core.console import get_console
275
276                get_console().log_warning(
277                    f"Failed to convert expected value for {col} into `datetime` "
278                    f"for unit test '{str(self)}'. {str(e)}."
279                )
280
281        actual = actual.replace({np.nan: None})
282        expected = expected.replace({np.nan: None})
283
284        # We define this here to avoid a top-level import of numpy and pandas
285        DATETIME_TYPES = (
286            datetime.datetime,
287            datetime.date,
288            datetime.time,
289            np.datetime64,
290            pd.Timestamp,
291        )
292
293        def _to_hashable(x: t.Any) -> t.Any:
294            if isinstance(x, (list, np.ndarray)):
295                return tuple(_to_hashable(v) for v in x)
296            if isinstance(x, dict):
297                return tuple((k, _to_hashable(v)) for k, v in x.items())
298            return str(x) if isinstance(x, DATETIME_TYPES) or not isinstance(x, t.Hashable) else x
299
300        actual = actual.apply(lambda col: col.map(_to_hashable))
301        expected = expected.apply(lambda col: col.map(_to_hashable))
302
303        if sort:
304            actual = actual.sort_values(by=actual.columns.to_list()).reset_index(drop=True)
305            expected = expected.sort_values(by=expected.columns.to_list()).reset_index(drop=True)
306
307        try:
308            pd.testing.assert_frame_equal(
309                expected,
310                actual,
311                check_dtype=False,
312                check_datetimelike_compat=True,
313                check_like=True,  # Ignore column order
314            )
315        except AssertionError as e:
316            # There are 2 concepts at play here:
317            # 1. The Exception args will contain the error message plus the diff dataframe table stringified
318            #    (backwards compatibility with existing tests, possible to serialize/send over network etc)
319            # 2. Each test will also transform these diff dataframes into Rich tables, which will be the ones that'll
320            #    be surfaced to the user through Console for better UX (versus stringified dataframes)
321            #
322            # This is a bit of a hack, but it's a way to get the best of both worlds.
323            args: t.List[t.Any] = []
324
325            failed_subtest = ""
326
327            if subtest := getattr(self, "_subtest", None):
328                if cte := subtest.params.get("cte"):
329                    failed_subtest = f" (CTE {cte})"
330
331            if expected.shape != actual.shape:
332                _raise_if_unexpected_columns(expected.columns, actual.columns)
333
334                args.append("Data mismatch (rows are different)")
335
336                missing_rows = _row_difference(expected, actual)
337                if not missing_rows.empty:
338                    args[0] += f"\n\nMissing rows:\n\n{missing_rows}"
339                    args.append(df_to_table(f"Missing rows{failed_subtest}", missing_rows))
340
341                unexpected_rows = _row_difference(actual, expected)
342
343                if not unexpected_rows.empty:
344                    args[0] += f"\n\nUnexpected rows:\n\n{unexpected_rows}"
345                    args.append(df_to_table(f"Unexpected rows{failed_subtest}", unexpected_rows))
346
347            else:
348                diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"})
349
350                args.append(f"Data mismatch (exp: expected, act: actual)\n\n{diff}")
351
352                diff.rename(columns={"exp": "Expected", "act": "Actual"}, inplace=True)
353                if self.verbosity == Verbosity.DEFAULT:
354                    args.extend(
355                        df_to_table(f"Data mismatch{failed_subtest}", df)
356                        for df in _split_df_by_column_pairs(diff)
357                    )
358                else:
359                    from pandas import DataFrame, MultiIndex
360
361                    levels = t.cast(MultiIndex, diff.columns).levels[0]
362                    for col in levels:
363                        # diff[col] returns a DataFrame when columns is a MultiIndex
364                        col_diff = t.cast(DataFrame, diff[col])
365                        if not col_diff.empty:
366                            table = df_to_table(
367                                f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]",
368                                col_diff,
369                            )
370                            args.append(table)
371
372            e.args = (*args,)
373
374            raise e
375
376    def runTest(self) -> None:
377        raise NotImplementedError
378
379    def path_relative_to(self, other: Path) -> Path | None:
380        """Compute a version of this test's path relative to the `other` path"""
381        return self.path.relative_to(other) if self.path else None
382
383    @staticmethod
384    def create_test(
385        body: t.Dict[str, t.Any],
386        test_name: str,
387        models: UniqueKeyDict[str, Model],
388        engine_adapter: EngineAdapter,
389        dialect: str | None,
390        path: Path | None,
391        preserve_fixtures: bool = False,
392        default_catalog: str | None = None,
393        concurrency: bool = False,
394        verbosity: Verbosity = Verbosity.DEFAULT,
395    ) -> t.Optional[ModelTest]:
396        """Create a SqlModelTest or a PythonModelTest.
397
398        Args:
399            body: A dictionary that contains test metadata like inputs and outputs.
400            test_name: The name of the test.
401            models: All models to use for expansion and mapping of physical locations.
402            engine_adapter: The engine adapter to use.
403            dialect: The models' dialect, used for normalization purposes.
404            path: An optional path to the test definition yaml file.
405            preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
406        """
407        name = body.get("model")
408        if name is None:
409            _raise_error("Missing required 'model' field", path)
410
411        name = normalize_model_name(name, default_catalog=default_catalog, dialect=dialect)
412        model = models.get(name)
413        if not model:
414            from sqlmesh.core.console import get_console
415
416            get_console().log_warning(
417                f"Model '{name}' was not found{' at ' + str(path) if path else ''}"
418            )
419            return None
420
421        if isinstance(model, SqlModel):
422            test_type: t.Type[ModelTest] = SqlModelTest
423        elif isinstance(model, PythonModel):
424            test_type = PythonModelTest
425        else:
426            _raise_error(f"Model '{name}' is an unsupported model type for testing", path)
427
428        try:
429            return test_type(
430                body,
431                test_name,
432                t.cast(Model, model),
433                models,
434                engine_adapter,
435                dialect,
436                path,
437                preserve_fixtures,
438                default_catalog,
439                concurrency,
440                verbosity,
441            )
442        except Exception as e:
443            raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}")
444
445    def __str__(self) -> str:
446        return f"{self.test_name} ({self.path})"
447
448    def _validate_and_normalize_test(self) -> None:
449        inputs = self.body.get("inputs")
450        outputs = self.body.get("outputs", {})
451
452        if not outputs:
453            _raise_error("Incomplete test, missing outputs", self.path)
454
455        ctes = outputs.get("ctes")
456        query = outputs.get("query")
457        partial = outputs.pop("partial", None)
458
459        if ctes is None and query is None:
460            _raise_error("Incomplete test, outputs must contain 'query' or 'ctes'", self.path)
461
462        def _normalize_rows(
463            values: t.List[Row] | t.Dict,
464            name: str,
465            partial: bool = False,
466            dialect: DialectType = None,
467        ) -> t.Dict:
468            import pandas as pd
469
470            if not isinstance(values, dict):
471                values = {"rows": values}
472
473            rows = values.get("rows")
474            query = values.get("query")
475
476            fmt = values.get("format")
477            path = values.get("path")
478            if fmt == "csv":
479                csv_settings = values.get("csv_settings") or {}
480                rows = pd.read_csv(path or StringIO(rows), **csv_settings).to_dict(orient="records")
481            elif fmt in (None, "yaml"):
482                if path:
483                    input_rows = yaml_load(Path(path))
484                    rows = input_rows.get("rows") if isinstance(input_rows, dict) else input_rows
485            else:
486                _raise_error(f"Unsupported data format '{fmt}' for '{name}'", self.path)
487
488            if query is not None:
489                if rows is not None:
490                    _raise_error(
491                        f"Invalid test, cannot set both 'query' and 'rows' for '{name}'", self.path
492                    )
493
494                # We parse the user-supplied query using the testing adapter dialect, but we
495                # normalize its identifiers according to the model's dialect, so that, e.g.,
496                # the projection names match those in its `columns_to_types` field
497                values["query"] = normalize_identifiers(
498                    exp.maybe_parse(query, dialect=self._test_adapter_dialect), dialect=dialect
499                )
500                return values
501
502            if rows is None:
503                _raise_error(f"Incomplete test, missing row data for '{name}'", self.path)
504
505            assert isinstance(rows, list)
506            values["rows"] = [
507                {self._normalize_column_name(column): value for column, value in row.items()}
508                for row in rows
509            ]
510            if partial:
511                values["partial"] = True
512
513            return values
514
515        def _normalize_sources(
516            sources: t.Dict, partial: bool = False, with_default_catalog: bool = True
517        ) -> t.Dict:
518            normalized_sources = {}
519            for name, values in sources.items():
520                normalized_name = self._normalize_model_name(
521                    name, with_default_catalog=with_default_catalog
522                )
523                model = self.models.get(normalized_name)
524                dialect = model.dialect if model else self.dialect
525
526                normalized_sources[normalized_name] = _normalize_rows(
527                    values, name, partial=partial, dialect=dialect
528                )
529
530            return normalized_sources
531
532        normalized_model_name = self._normalize_model_name(self.body["model"])
533        self.body["model"] = normalized_model_name
534
535        if inputs:
536            inputs = _normalize_sources(inputs)
537            for name, values in inputs.items():
538                columns = values.get("columns")
539                if columns is None:
540                    continue
541
542                if not isinstance(columns, dict):
543                    _raise_error(
544                        f"Invalid 'columns' value for model '{name}', expected a mapping name -> type",
545                        self.path,
546                    )
547
548                values["columns"] = {
549                    self._normalize_column_name(c): exp.DataType.build(
550                        t, dialect=self._test_adapter_dialect
551                    )
552                    for c, t in columns.items()
553                }
554
555            for depends_on in self.model.depends_on:
556                if depends_on not in inputs:
557                    _raise_error(f"Incomplete test, missing input model '{depends_on}'", self.path)
558
559            if self.model.depends_on_self and normalized_model_name not in inputs:
560                inputs[normalized_model_name] = {"rows": []}
561
562            self.body["inputs"] = inputs
563
564        if ctes:
565            outputs["ctes"] = _normalize_sources(ctes, partial=partial, with_default_catalog=False)
566
567        if query or query == []:
568            outputs["query"] = _normalize_rows(
569                query, self.model.name, partial=partial, dialect=self.model.dialect
570            )
571
572    def _test_fixture_table(self, name: str) -> exp.Table:
573        table = self._fixture_table_cache.get(name)
574        if not table:
575            table = exp.to_table(name, dialect=self._test_adapter_dialect)
576
577            # We change the table path below, so this ensures there are no name clashes
578            table.this.set("this", "__".join(part.name for part in table.parts))
579
580            table.set("db", self._fixture_schema.copy())
581            if self._fixture_catalog:
582                table.set("catalog", self._fixture_catalog.copy())
583
584            self._fixture_table_cache[name] = table
585
586        return table
587
588    def _normalize_model_name(self, name: str, with_default_catalog: bool = True) -> str:
589        normalized_name = self._normalized_model_name_cache.get((name, with_default_catalog))
590        if normalized_name is None:
591            default_catalog = self.default_catalog if with_default_catalog else None
592            normalized_name = normalize_model_name(
593                name, default_catalog=default_catalog, dialect=self.dialect
594            )
595            self._normalized_model_name_cache[(name, with_default_catalog)] = normalized_name
596
597        return normalized_name
598
599    def _normalize_column_name(self, name: str) -> str:
600        normalized_name = self._normalized_column_name_cache.get(name)
601        if normalized_name is None:
602            normalized_name = normalize_identifiers(name, dialect=self.dialect).name
603            self._normalized_column_name_cache[name] = normalized_name
604
605        return normalized_name
606
607    @contextmanager
608    def _concurrent_render_context(self) -> t.Iterator[None]:
609        """
610        Context manager that ensures that the tests are executed safely in a concurrent environment.
611        This is needed in case `execution_time` is set, as we'd then have to:
612        - Freeze time through `time_machine` (not thread safe)
613        - Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation
614        """
615        import time_machine
616
617        lock_ctx: AbstractContextManager = (
618            self.CONCURRENT_RENDER_LOCK if self.concurrency else nullcontext()
619        )
620        time_ctx: AbstractContextManager = nullcontext()
621        dialect_patch_ctx: AbstractContextManager = nullcontext()
622
623        if self._execution_time:
624            time_ctx = time_machine.travel(self._execution_time, tick=False)
625            dialect_patch_ctx = patch.dict(
626                self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms
627            )
628
629        with lock_ctx, time_ctx, dialect_patch_ctx:
630            yield
631
632    def _execute(self, query: exp.Query | str) -> pd.DataFrame:
633        """Executes the given query using the testing engine adapter and returns a DataFrame."""
634        return self.engine_adapter.fetchdf(query)
635
636    def _create_df(
637        self,
638        values: t.Dict[str, t.Any],
639        columns: t.Optional[t.Collection] = None,
640        partial: t.Optional[bool] = False,
641    ) -> pd.DataFrame:
642        import pandas as pd
643
644        query = values.get("query")
645        if query:
646            if not partial:
647                query = self._add_missing_columns(query, columns)
648
649            return self._execute(query)
650
651        rows = values["rows"]
652        columns_str: t.Optional[t.List[str]] = None
653        if columns:
654            columns_str = [str(c) for c in columns]
655            referenced_columns = list(dict.fromkeys(col for row in rows for col in row))
656            _raise_if_unexpected_columns(columns, referenced_columns)
657
658            if partial:
659                columns_str = [c for c in columns_str if c in referenced_columns]
660
661        return pd.DataFrame.from_records(rows, columns=columns_str)
662
663    def _add_missing_columns(
664        self, query: exp.Query, all_columns: t.Optional[t.Collection[str]] = None
665    ) -> exp.Query:
666        if not all_columns or query.is_star:
667            return query
668
669        query_columns = set(query.named_selects)
670        missing_columns = [col for col in all_columns if col not in query_columns]
671        if missing_columns:
672            query.select(*[exp.null().as_(col) for col in missing_columns], copy=False)
673
674        return query

A class whose instances are single test cases.

By default, the test code itself should be placed in a method named 'runTest'.

If the fixture may be used for many test cases, create as many test methods as are needed. When instantiating such a TestCase subclass, specify in the constructor arguments the name of the test method that the instance is to execute.

Test authors should subclass TestCase for their own tests. Construction and deconstruction of the test's environment ('fixture') can be implemented by overriding the 'setUp' and 'tearDown' methods respectively.

If it is necessary to override the __init__ method, the base class __init__ method must always be called. It is important that subclasses should not change the signature of their __init__ method, since instances of the classes are instantiated automatically by parts of the framework in order to be run.

When subclassing TestCase, you can set these attributes:

  • failureException: determines which exception will be raised when the instance's assertion methods fail; test methods raising this exception will be deemed to have 'failed' rather than 'errored'.
  • longMessage: determines whether long messages (including repr of objects used in assert methods) will be printed on failure in addition to any explicit message passed.
  • maxDiff: sets the maximum length of a diff in failure messages by assert methods using difflib. It is looked up as an instance attribute so can be configured by individual tests if required.
ModelTest( body: Dict[str, Any], test_name: str, model: Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel], models: sqlmesh.utils.UniqueKeyDict[str, typing.Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel]], engine_adapter: sqlmesh.core.engine_adapter.base.EngineAdapter, dialect: str | None = None, path: pathlib.Path | None = None, preserve_fixtures: bool = False, default_catalog: str | None = None, concurrency: bool = False, verbosity: sqlmesh.utils.Verbosity = <Verbosity.DEFAULT: 0>)
 57    def __init__(
 58        self,
 59        body: t.Dict[str, t.Any],
 60        test_name: str,
 61        model: Model,
 62        models: UniqueKeyDict[str, Model],
 63        engine_adapter: EngineAdapter,
 64        dialect: str | None = None,
 65        path: Path | None = None,
 66        preserve_fixtures: bool = False,
 67        default_catalog: str | None = None,
 68        concurrency: bool = False,
 69        verbosity: Verbosity = Verbosity.DEFAULT,
 70    ) -> None:
 71        """ModelTest encapsulates a unit test for a model.
 72
 73        Args:
 74            body: A dictionary that contains test metadata like inputs and outputs.
 75            test_name: The name of the test.
 76            model: The model that is being tested.
 77            models: All models to use for expansion and mapping of physical locations.
 78            engine_adapter: The engine adapter to use.
 79            dialect: The models' dialect, used for normalization purposes.
 80            path: An optional path to the test definition yaml file.
 81            preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
 82        """
 83        self.body = body
 84        self.test_name = test_name
 85        self.model = model
 86        self.models = models
 87        self.engine_adapter = engine_adapter
 88        self.path = path
 89        self.preserve_fixtures = preserve_fixtures
 90        self.default_catalog = default_catalog
 91        self.dialect = dialect
 92        self.concurrency = concurrency
 93        self.verbosity = verbosity
 94
 95        self._fixture_table_cache: t.Dict[str, exp.Table] = {}
 96        self._normalized_column_name_cache: t.Dict[str, str] = {}
 97        self._normalized_model_name_cache: t.Dict[t.Tuple[str, bool], str] = {}
 98
 99        self._test_adapter_dialect = Dialect.get_or_raise(self.engine_adapter.dialect)
100
101        self._validate_and_normalize_test()
102
103        if self.engine_adapter.default_catalog:
104            self._fixture_catalog: t.Optional[exp.Identifier] = normalize_identifiers(
105                exp.parse_identifier(
106                    self.engine_adapter.default_catalog, dialect=self._test_adapter_dialect
107                ),
108                dialect=self._test_adapter_dialect,
109            )
110        else:
111            self._fixture_catalog = None
112
113        # The test schema name is randomized to avoid concurrency issues,
114        # unless a schema is provided in the unit tests's body
115        self._fixture_schema = exp.parse_identifier(
116            self.body.get("schema") or f"sqlmesh_test_{random_id(short=True)}"
117        )
118        self._qualified_fixture_schema = schema_(self._fixture_schema, self._fixture_catalog)
119
120        self._transforms = self._test_adapter_dialect.generator_class.TRANSFORMS
121        self._execution_time = str(self.body.get("vars", {}).get("execution_time") or "")
122
123        if self._execution_time:
124            # Normalizes the execution time by converting it into UTC timezone
125            self._execution_time = str(to_datetime(self._execution_time))
126
127        # When execution_time is set, we mock the CURRENT_* SQL expressions so they always return it
128        if self._execution_time:
129            exec_time = exp.Literal.string(self._execution_time)
130            self._transforms = {
131                **self._transforms,
132                exp.CurrentDate: lambda self, _: self.sql(
133                    exp.cast(exec_time, "date", dialect=dialect)
134                ),
135                exp.CurrentDatetime: lambda self, _: self.sql(
136                    exp.cast(exec_time, "datetime", dialect=dialect)
137                ),
138                exp.CurrentTime: lambda self, _: self.sql(
139                    exp.cast(exec_time, "time", dialect=dialect)
140                ),
141                exp.CurrentTimestamp: lambda self, _: self.sql(
142                    exp.cast(exec_time, "timestamp", dialect=dialect)
143                ),
144            }
145
146        super().__init__()

ModelTest encapsulates a unit test for a model.

Arguments:
  • body: A dictionary that contains test metadata like inputs and outputs.
  • test_name: The name of the test.
  • model: The model that is being tested.
  • models: All models to use for expansion and mapping of physical locations.
  • engine_adapter: The engine adapter to use.
  • dialect: The models' dialect, used for normalization purposes.
  • path: An optional path to the test definition yaml file.
  • preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
CONCURRENT_RENDER_LOCK = <unlocked _thread.lock object>
body
test_name
model
models
engine_adapter
path
preserve_fixtures
default_catalog
dialect
concurrency
verbosity
def defaultTestResult(self) -> unittest.result.TestResult:
148    def defaultTestResult(self) -> unittest.TestResult:
149        from sqlmesh.core.test.result import ModelTextTestResult
150
151        return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity)
def shortDescription(self) -> Optional[str]:
153    def shortDescription(self) -> t.Optional[str]:
154        return self.body.get("description")

Returns a one-line description of the test, or None if no description has been provided.

The default implementation of this method returns the first line of the specified test method's docstring.

def setUp(self) -> None:
156    def setUp(self) -> None:
157        """Load all input tables"""
158        import pandas as pd
159        import numpy as np
160
161        self.engine_adapter.create_schema(self._qualified_fixture_schema)
162
163        for name, values in self.body.get("inputs", {}).items():
164            all_types_are_known = False
165            columns_to_known_types: t.Dict[str, exp.DataType] = {}
166
167            model = self.models.get(name)
168            if model:
169                inferred_columns_to_types = model.columns_to_types or {}
170                columns_to_known_types = {
171                    c: t for c, t in inferred_columns_to_types.items() if type_is_known(t)
172                }
173                all_types_are_known = bool(inferred_columns_to_types) and (
174                    len(columns_to_known_types) == len(inferred_columns_to_types)
175                )
176
177            # Types specified in the test will override the corresponding inferred ones
178            columns_to_known_types.update(values.get("columns", {}))
179
180            rows = values.get("rows")
181            if not all_types_are_known and rows:
182                for col, value in rows[0].items():
183                    if col not in columns_to_known_types:
184                        v_type = annotate_types(exp.convert(value)).type or type(value).__name__
185                        v_type = exp.maybe_parse(
186                            v_type, into=exp.DataType, dialect=self._test_adapter_dialect
187                        )
188
189                        if not type_is_known(v_type):
190                            _raise_error(
191                                f"Failed to infer the data type of column '{col}' for '{name}'. This issue can be "
192                                "mitigated by casting the column in the model definition, setting its type in "
193                                "external_models.yaml if it's an external model, setting the model's 'columns' property, "
194                                "or setting its 'columns' mapping in the test itself",
195                                self.path,
196                            )
197
198                        columns_to_known_types[col] = v_type
199
200            if rows is None:
201                query_or_df: exp.Query | pd.DataFrame = self._add_missing_columns(
202                    values["query"], columns_to_known_types
203                )
204                if columns_to_known_types:
205                    columns_to_known_types = {
206                        col: columns_to_known_types[col] for col in query_or_df.named_selects
207                    }
208            else:
209                query_or_df = self._create_df(values, columns=columns_to_known_types)
210
211            # Convert NaN/NaT values to None if DataFrame
212            if isinstance(query_or_df, pd.DataFrame):
213                query_or_df = query_or_df.replace({np.nan: None})
214
215            self.engine_adapter.create_view(
216                self._test_fixture_table(name), query_or_df, columns_to_known_types
217            )

Load all input tables

def tearDown(self) -> None:
219    def tearDown(self) -> None:
220        """Drop all fixture tables."""
221        if not self.preserve_fixtures:
222            self.engine_adapter.drop_schema(self._qualified_fixture_schema, cascade=True)

Drop all fixture tables.

def assert_equal( self, expected: pandas.core.frame.DataFrame, actual: pandas.core.frame.DataFrame, sort: bool, partial: Optional[bool] = False) -> None:
224    def assert_equal(
225        self,
226        expected: pd.DataFrame,
227        actual: pd.DataFrame,
228        sort: bool,
229        partial: t.Optional[bool] = False,
230    ) -> None:
231        """Compare two DataFrames"""
232        import numpy as np
233        import pandas as pd
234        from pandas.api.types import is_object_dtype
235
236        if partial:
237            intersection = actual[actual.columns.intersection(expected.columns)]
238            if len(intersection.columns) > 0:
239                actual = intersection
240
241        # Two astypes are necessary, pandas converts strings to times as NS,
242        # but if the actual is US, it doesn't take effect until the 2nd try!
243        actual_types = actual.dtypes.to_dict()
244        expected = expected.astype(actual_types, errors="ignore").astype(
245            actual_types, errors="ignore"
246        )
247
248        # The `actual` df's dtypes will almost always be pd.Timestamp for datetime values,
249        # but in some scenarios (e.g., DuckDB >=0.10.2) it will be a pandas `object` type
250        # containing python `datetime.xxx` values.
251        #
252        # Pandas `object` columns result in a noop for the `astype` call above. Because any
253        # quoted YAML value is a string, we must manually convert the `expected` df string
254        # values to the correct `datetime.xxx` type.
255        #
256        # We determine the type from a single sentinel value, but since the `actual` df is
257        # coming from a database query, it is safe to assume that the column contains only
258        # a single type.
259        object_sentinel_values = {
260            col: actual[col][0]
261            for col in actual_types
262            if is_object_dtype(actual_types[col]) and len(actual[col]) != 0
263        }
264        for col, value in object_sentinel_values.items():
265            try:
266                # can't use `isinstance()` here - https://stackoverflow.com/a/68743663/1707525
267                if type(value) is datetime.date:
268                    expected[col] = pd.to_datetime(expected[col]).dt.date
269                elif type(value) is datetime.time:
270                    expected[col] = pd.to_datetime(expected[col]).dt.time
271                elif type(value) is datetime.datetime:
272                    expected[col] = pd.to_datetime(expected[col]).dt.to_pydatetime()
273            except Exception as e:
274                from sqlmesh.core.console import get_console
275
276                get_console().log_warning(
277                    f"Failed to convert expected value for {col} into `datetime` "
278                    f"for unit test '{str(self)}'. {str(e)}."
279                )
280
281        actual = actual.replace({np.nan: None})
282        expected = expected.replace({np.nan: None})
283
284        # We define this here to avoid a top-level import of numpy and pandas
285        DATETIME_TYPES = (
286            datetime.datetime,
287            datetime.date,
288            datetime.time,
289            np.datetime64,
290            pd.Timestamp,
291        )
292
293        def _to_hashable(x: t.Any) -> t.Any:
294            if isinstance(x, (list, np.ndarray)):
295                return tuple(_to_hashable(v) for v in x)
296            if isinstance(x, dict):
297                return tuple((k, _to_hashable(v)) for k, v in x.items())
298            return str(x) if isinstance(x, DATETIME_TYPES) or not isinstance(x, t.Hashable) else x
299
300        actual = actual.apply(lambda col: col.map(_to_hashable))
301        expected = expected.apply(lambda col: col.map(_to_hashable))
302
303        if sort:
304            actual = actual.sort_values(by=actual.columns.to_list()).reset_index(drop=True)
305            expected = expected.sort_values(by=expected.columns.to_list()).reset_index(drop=True)
306
307        try:
308            pd.testing.assert_frame_equal(
309                expected,
310                actual,
311                check_dtype=False,
312                check_datetimelike_compat=True,
313                check_like=True,  # Ignore column order
314            )
315        except AssertionError as e:
316            # There are 2 concepts at play here:
317            # 1. The Exception args will contain the error message plus the diff dataframe table stringified
318            #    (backwards compatibility with existing tests, possible to serialize/send over network etc)
319            # 2. Each test will also transform these diff dataframes into Rich tables, which will be the ones that'll
320            #    be surfaced to the user through Console for better UX (versus stringified dataframes)
321            #
322            # This is a bit of a hack, but it's a way to get the best of both worlds.
323            args: t.List[t.Any] = []
324
325            failed_subtest = ""
326
327            if subtest := getattr(self, "_subtest", None):
328                if cte := subtest.params.get("cte"):
329                    failed_subtest = f" (CTE {cte})"
330
331            if expected.shape != actual.shape:
332                _raise_if_unexpected_columns(expected.columns, actual.columns)
333
334                args.append("Data mismatch (rows are different)")
335
336                missing_rows = _row_difference(expected, actual)
337                if not missing_rows.empty:
338                    args[0] += f"\n\nMissing rows:\n\n{missing_rows}"
339                    args.append(df_to_table(f"Missing rows{failed_subtest}", missing_rows))
340
341                unexpected_rows = _row_difference(actual, expected)
342
343                if not unexpected_rows.empty:
344                    args[0] += f"\n\nUnexpected rows:\n\n{unexpected_rows}"
345                    args.append(df_to_table(f"Unexpected rows{failed_subtest}", unexpected_rows))
346
347            else:
348                diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"})
349
350                args.append(f"Data mismatch (exp: expected, act: actual)\n\n{diff}")
351
352                diff.rename(columns={"exp": "Expected", "act": "Actual"}, inplace=True)
353                if self.verbosity == Verbosity.DEFAULT:
354                    args.extend(
355                        df_to_table(f"Data mismatch{failed_subtest}", df)
356                        for df in _split_df_by_column_pairs(diff)
357                    )
358                else:
359                    from pandas import DataFrame, MultiIndex
360
361                    levels = t.cast(MultiIndex, diff.columns).levels[0]
362                    for col in levels:
363                        # diff[col] returns a DataFrame when columns is a MultiIndex
364                        col_diff = t.cast(DataFrame, diff[col])
365                        if not col_diff.empty:
366                            table = df_to_table(
367                                f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]",
368                                col_diff,
369                            )
370                            args.append(table)
371
372            e.args = (*args,)
373
374            raise e

Compare two DataFrames

def runTest(self) -> None:
376    def runTest(self) -> None:
377        raise NotImplementedError
def path_relative_to(self, other: pathlib.Path) -> pathlib.Path | None:
379    def path_relative_to(self, other: Path) -> Path | None:
380        """Compute a version of this test's path relative to the `other` path"""
381        return self.path.relative_to(other) if self.path else None

Compute a version of this test's path relative to the other path

@staticmethod
def create_test( body: Dict[str, Any], test_name: str, models: sqlmesh.utils.UniqueKeyDict[str, typing.Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel]], engine_adapter: sqlmesh.core.engine_adapter.base.EngineAdapter, dialect: str | None, path: pathlib.Path | None, preserve_fixtures: bool = False, default_catalog: str | None = None, concurrency: bool = False, verbosity: sqlmesh.utils.Verbosity = <Verbosity.DEFAULT: 0>) -> Optional[ModelTest]:
383    @staticmethod
384    def create_test(
385        body: t.Dict[str, t.Any],
386        test_name: str,
387        models: UniqueKeyDict[str, Model],
388        engine_adapter: EngineAdapter,
389        dialect: str | None,
390        path: Path | None,
391        preserve_fixtures: bool = False,
392        default_catalog: str | None = None,
393        concurrency: bool = False,
394        verbosity: Verbosity = Verbosity.DEFAULT,
395    ) -> t.Optional[ModelTest]:
396        """Create a SqlModelTest or a PythonModelTest.
397
398        Args:
399            body: A dictionary that contains test metadata like inputs and outputs.
400            test_name: The name of the test.
401            models: All models to use for expansion and mapping of physical locations.
402            engine_adapter: The engine adapter to use.
403            dialect: The models' dialect, used for normalization purposes.
404            path: An optional path to the test definition yaml file.
405            preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
406        """
407        name = body.get("model")
408        if name is None:
409            _raise_error("Missing required 'model' field", path)
410
411        name = normalize_model_name(name, default_catalog=default_catalog, dialect=dialect)
412        model = models.get(name)
413        if not model:
414            from sqlmesh.core.console import get_console
415
416            get_console().log_warning(
417                f"Model '{name}' was not found{' at ' + str(path) if path else ''}"
418            )
419            return None
420
421        if isinstance(model, SqlModel):
422            test_type: t.Type[ModelTest] = SqlModelTest
423        elif isinstance(model, PythonModel):
424            test_type = PythonModelTest
425        else:
426            _raise_error(f"Model '{name}' is an unsupported model type for testing", path)
427
428        try:
429            return test_type(
430                body,
431                test_name,
432                t.cast(Model, model),
433                models,
434                engine_adapter,
435                dialect,
436                path,
437                preserve_fixtures,
438                default_catalog,
439                concurrency,
440                verbosity,
441            )
442        except Exception as e:
443            raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}")

Create a SqlModelTest or a PythonModelTest.

Arguments:
  • body: A dictionary that contains test metadata like inputs and outputs.
  • test_name: The name of the test.
  • models: All models to use for expansion and mapping of physical locations.
  • engine_adapter: The engine adapter to use.
  • dialect: The models' dialect, used for normalization purposes.
  • path: An optional path to the test definition yaml file.
  • preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
Inherited Members
unittest.case.TestCase
failureException
longMessage
maxDiff
addTypeEqualityFunc
addCleanup
addClassCleanup
setUpClass
tearDownClass
countTestCases
id
subTest
run
doCleanups
doClassCleanups
debug
skipTest
fail
assertFalse
assertTrue
assertRaises
assertWarns
assertLogs
assertNoLogs
assertEqual
assertNotEqual
assertAlmostEqual
assertNotAlmostEqual
assertSequenceEqual
assertListEqual
assertTupleEqual
assertSetEqual
assertIn
assertNotIn
assertIs
assertIsNot
assertDictEqual
assertDictContainsSubset
assertCountEqual
assertMultiLineEqual
assertLess
assertLessEqual
assertGreater
assertGreaterEqual
assertIsNone
assertIsNotNone
assertIsInstance
assertNotIsInstance
assertRaisesRegex
assertWarnsRegex
assertRegex
assertNotRegex
failUnlessRaises
failIf
assertRaisesRegexp
assertRegexpMatches
assertNotRegexpMatches
failUnlessEqual
assertEquals
failIfEqual
assertNotEquals
failUnlessAlmostEqual
assertAlmostEquals
failIfAlmostEqual
assertNotAlmostEquals
failUnless
assert_
class SqlModelTest(ModelTest):
677class SqlModelTest(ModelTest):
678    def test_ctes(self, ctes: t.Dict[str, exp.Expr], recursive: bool = False) -> None:
679        """Run CTE queries and compare output to expected output"""
680        for cte_name, values in self.body["outputs"].get("ctes", {}).items():
681            with self.subTest(cte=cte_name):
682                if cte_name not in ctes:
683                    _raise_error(
684                        f"No CTE named {cte_name} found in model {self.model.name}", self.path
685                    )
686
687                cte_query = ctes[cte_name].this
688
689                sort = cte_query.args.get("order") is None
690                partial = values.get("partial")
691
692                cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_name)
693                for alias, cte in ctes.items():
694                    cte_query = cte_query.with_(alias, cte.this, recursive=recursive)
695
696                with self._concurrent_render_context():
697                    # Similar to the model's query, we render the CTE query under the locked context
698                    # so that the execution (fetchdf) can continue concurrently between the threads
699                    sql = cte_query.sql(
700                        self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql
701                    )
702
703                actual = self._execute(sql)
704                expected = self._create_df(values, columns=cte_query.named_selects, partial=partial)
705
706                self.assert_equal(expected, actual, sort=sort, partial=partial)
707
708    def runTest(self) -> None:
709        with self._concurrent_render_context():
710            # Render the model's query and generate the SQL under the locked context so that
711            # execution (fetchdf) can continue concurrently between the threads
712            query = self._render_model_query()
713            sql = query.sql(self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql)
714
715        with_clause = query.args.get("with_")
716
717        if with_clause:
718            self.test_ctes(
719                {
720                    self._normalize_model_name(cte.alias, with_default_catalog=False): cte
721                    for cte in query.ctes
722                },
723                recursive=with_clause.recursive,
724            )
725
726        values = self.body["outputs"].get("query")
727        if values is not None:
728            partial = values.get("partial")
729            sort = query.args.get("order") is None
730
731            actual = self._execute(sql)
732            expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
733
734            self.assert_equal(expected, actual, sort=sort, partial=partial)
735
736    def _render_model_query(self) -> exp.Query:
737        variables = self.body.get("vars", {}).copy()
738        time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables}
739
740        query = self.model.render_query_or_raise(
741            **time_kwargs,
742            variables=variables,
743            engine_adapter=self.engine_adapter,
744            table_mapping={
745                name: self._test_fixture_table(name).sql() for name in self.body.get("inputs", {})
746            },
747            runtime_stage=RuntimeStage.TESTING,
748        )
749        return query

A class whose instances are single test cases.

By default, the test code itself should be placed in a method named 'runTest'.

If the fixture may be used for many test cases, create as many test methods as are needed. When instantiating such a TestCase subclass, specify in the constructor arguments the name of the test method that the instance is to execute.

Test authors should subclass TestCase for their own tests. Construction and deconstruction of the test's environment ('fixture') can be implemented by overriding the 'setUp' and 'tearDown' methods respectively.

If it is necessary to override the __init__ method, the base class __init__ method must always be called. It is important that subclasses should not change the signature of their __init__ method, since instances of the classes are instantiated automatically by parts of the framework in order to be run.

When subclassing TestCase, you can set these attributes:

  • failureException: determines which exception will be raised when the instance's assertion methods fail; test methods raising this exception will be deemed to have 'failed' rather than 'errored'.
  • longMessage: determines whether long messages (including repr of objects used in assert methods) will be printed on failure in addition to any explicit message passed.
  • maxDiff: sets the maximum length of a diff in failure messages by assert methods using difflib. It is looked up as an instance attribute so can be configured by individual tests if required.
def test_ctes( self, ctes: Dict[str, sqlglot.expressions.core.Expr], recursive: bool = False) -> None:
678    def test_ctes(self, ctes: t.Dict[str, exp.Expr], recursive: bool = False) -> None:
679        """Run CTE queries and compare output to expected output"""
680        for cte_name, values in self.body["outputs"].get("ctes", {}).items():
681            with self.subTest(cte=cte_name):
682                if cte_name not in ctes:
683                    _raise_error(
684                        f"No CTE named {cte_name} found in model {self.model.name}", self.path
685                    )
686
687                cte_query = ctes[cte_name].this
688
689                sort = cte_query.args.get("order") is None
690                partial = values.get("partial")
691
692                cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_name)
693                for alias, cte in ctes.items():
694                    cte_query = cte_query.with_(alias, cte.this, recursive=recursive)
695
696                with self._concurrent_render_context():
697                    # Similar to the model's query, we render the CTE query under the locked context
698                    # so that the execution (fetchdf) can continue concurrently between the threads
699                    sql = cte_query.sql(
700                        self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql
701                    )
702
703                actual = self._execute(sql)
704                expected = self._create_df(values, columns=cte_query.named_selects, partial=partial)
705
706                self.assert_equal(expected, actual, sort=sort, partial=partial)

Run CTE queries and compare output to expected output

def runTest(self) -> None:
708    def runTest(self) -> None:
709        with self._concurrent_render_context():
710            # Render the model's query and generate the SQL under the locked context so that
711            # execution (fetchdf) can continue concurrently between the threads
712            query = self._render_model_query()
713            sql = query.sql(self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql)
714
715        with_clause = query.args.get("with_")
716
717        if with_clause:
718            self.test_ctes(
719                {
720                    self._normalize_model_name(cte.alias, with_default_catalog=False): cte
721                    for cte in query.ctes
722                },
723                recursive=with_clause.recursive,
724            )
725
726        values = self.body["outputs"].get("query")
727        if values is not None:
728            partial = values.get("partial")
729            sort = query.args.get("order") is None
730
731            actual = self._execute(sql)
732            expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
733
734            self.assert_equal(expected, actual, sort=sort, partial=partial)
Inherited Members
ModelTest
ModelTest
CONCURRENT_RENDER_LOCK
body
test_name
model
models
engine_adapter
path
preserve_fixtures
default_catalog
dialect
concurrency
verbosity
defaultTestResult
shortDescription
setUp
tearDown
assert_equal
path_relative_to
create_test
unittest.case.TestCase
failureException
longMessage
maxDiff
addTypeEqualityFunc
addCleanup
addClassCleanup
setUpClass
tearDownClass
countTestCases
id
subTest
run
doCleanups
doClassCleanups
debug
skipTest
fail
assertFalse
assertTrue
assertRaises
assertWarns
assertLogs
assertNoLogs
assertEqual
assertNotEqual
assertAlmostEqual
assertNotAlmostEqual
assertSequenceEqual
assertListEqual
assertTupleEqual
assertSetEqual
assertIn
assertNotIn
assertIs
assertIsNot
assertDictEqual
assertDictContainsSubset
assertCountEqual
assertMultiLineEqual
assertLess
assertLessEqual
assertGreater
assertGreaterEqual
assertIsNone
assertIsNotNone
assertIsInstance
assertNotIsInstance
assertRaisesRegex
assertWarnsRegex
assertRegex
assertNotRegex
failUnlessRaises
failIf
assertRaisesRegexp
assertRegexpMatches
assertNotRegexpMatches
failUnlessEqual
assertEquals
failIfEqual
assertNotEquals
failUnlessAlmostEqual
assertAlmostEquals
failIfAlmostEqual
assertNotAlmostEquals
failUnless
assert_
class PythonModelTest(ModelTest):
752class PythonModelTest(ModelTest):
753    def __init__(
754        self,
755        body: t.Dict[str, t.Any],
756        test_name: str,
757        model: Model,
758        models: UniqueKeyDict[str, Model],
759        engine_adapter: EngineAdapter,
760        dialect: str | None = None,
761        path: Path | None = None,
762        preserve_fixtures: bool = False,
763        default_catalog: str | None = None,
764        concurrency: bool = False,
765        verbosity: Verbosity = Verbosity.DEFAULT,
766    ) -> None:
767        """PythonModelTest encapsulates a unit test for a Python model.
768
769        Args:
770            body: A dictionary that contains test metadata like inputs and outputs.
771            test_name: The name of the test.
772            model: The Python model that is being tested.
773            models: All models to use for expansion and mapping of physical locations.
774            engine_adapter: The engine adapter to use.
775            dialect: The models' dialect, used for normalization purposes.
776            path: An optional path to the test definition yaml file.
777            preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
778        """
779        from sqlmesh.core.test.context import TestExecutionContext
780
781        super().__init__(
782            body,
783            test_name,
784            model,
785            models,
786            engine_adapter,
787            dialect,
788            path,
789            preserve_fixtures,
790            default_catalog,
791            concurrency,
792            verbosity,
793        )
794
795        self.context = TestExecutionContext(
796            engine_adapter=engine_adapter,
797            models=models,
798            test=self,
799            default_dialect=dialect,
800            default_catalog=default_catalog,
801        )
802
803    def runTest(self) -> None:
804        values = self.body["outputs"].get("query")
805        if values is not None:
806            partial = values.get("partial")
807
808            actual_df = self._execute_model()
809            actual_df.reset_index(drop=True, inplace=True)
810            expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
811
812            self.assert_equal(expected, actual_df, sort=True, partial=partial)
813
814    def _execute_model(self) -> pd.DataFrame:
815        """Executes the python model and returns a DataFrame."""
816        import pandas as pd
817
818        with self._concurrent_render_context():
819            variables = self.body.get("vars", {}).copy()
820            time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables}
821            df = next(self.model.render(context=self.context, variables=variables, **time_kwargs))
822
823        assert not isinstance(df, exp.Expr)
824        return df if isinstance(df, pd.DataFrame) else df.toPandas()

A class whose instances are single test cases.

By default, the test code itself should be placed in a method named 'runTest'.

If the fixture may be used for many test cases, create as many test methods as are needed. When instantiating such a TestCase subclass, specify in the constructor arguments the name of the test method that the instance is to execute.

Test authors should subclass TestCase for their own tests. Construction and deconstruction of the test's environment ('fixture') can be implemented by overriding the 'setUp' and 'tearDown' methods respectively.

If it is necessary to override the __init__ method, the base class __init__ method must always be called. It is important that subclasses should not change the signature of their __init__ method, since instances of the classes are instantiated automatically by parts of the framework in order to be run.

When subclassing TestCase, you can set these attributes:

  • failureException: determines which exception will be raised when the instance's assertion methods fail; test methods raising this exception will be deemed to have 'failed' rather than 'errored'.
  • longMessage: determines whether long messages (including repr of objects used in assert methods) will be printed on failure in addition to any explicit message passed.
  • maxDiff: sets the maximum length of a diff in failure messages by assert methods using difflib. It is looked up as an instance attribute so can be configured by individual tests if required.
PythonModelTest( body: Dict[str, Any], test_name: str, model: Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel], models: sqlmesh.utils.UniqueKeyDict[str, typing.Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel]], engine_adapter: sqlmesh.core.engine_adapter.base.EngineAdapter, dialect: str | None = None, path: pathlib.Path | None = None, preserve_fixtures: bool = False, default_catalog: str | None = None, concurrency: bool = False, verbosity: sqlmesh.utils.Verbosity = <Verbosity.DEFAULT: 0>)
753    def __init__(
754        self,
755        body: t.Dict[str, t.Any],
756        test_name: str,
757        model: Model,
758        models: UniqueKeyDict[str, Model],
759        engine_adapter: EngineAdapter,
760        dialect: str | None = None,
761        path: Path | None = None,
762        preserve_fixtures: bool = False,
763        default_catalog: str | None = None,
764        concurrency: bool = False,
765        verbosity: Verbosity = Verbosity.DEFAULT,
766    ) -> None:
767        """PythonModelTest encapsulates a unit test for a Python model.
768
769        Args:
770            body: A dictionary that contains test metadata like inputs and outputs.
771            test_name: The name of the test.
772            model: The Python model that is being tested.
773            models: All models to use for expansion and mapping of physical locations.
774            engine_adapter: The engine adapter to use.
775            dialect: The models' dialect, used for normalization purposes.
776            path: An optional path to the test definition yaml file.
777            preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
778        """
779        from sqlmesh.core.test.context import TestExecutionContext
780
781        super().__init__(
782            body,
783            test_name,
784            model,
785            models,
786            engine_adapter,
787            dialect,
788            path,
789            preserve_fixtures,
790            default_catalog,
791            concurrency,
792            verbosity,
793        )
794
795        self.context = TestExecutionContext(
796            engine_adapter=engine_adapter,
797            models=models,
798            test=self,
799            default_dialect=dialect,
800            default_catalog=default_catalog,
801        )

PythonModelTest encapsulates a unit test for a Python model.

Arguments:
  • body: A dictionary that contains test metadata like inputs and outputs.
  • test_name: The name of the test.
  • model: The Python model that is being tested.
  • models: All models to use for expansion and mapping of physical locations.
  • engine_adapter: The engine adapter to use.
  • dialect: The models' dialect, used for normalization purposes.
  • path: An optional path to the test definition yaml file.
  • preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
context
def runTest(self) -> None:
803    def runTest(self) -> None:
804        values = self.body["outputs"].get("query")
805        if values is not None:
806            partial = values.get("partial")
807
808            actual_df = self._execute_model()
809            actual_df.reset_index(drop=True, inplace=True)
810            expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
811
812            self.assert_equal(expected, actual_df, sort=True, partial=partial)
Inherited Members
ModelTest
CONCURRENT_RENDER_LOCK
body
test_name
model
models
engine_adapter
path
preserve_fixtures
default_catalog
dialect
concurrency
verbosity
defaultTestResult
shortDescription
setUp
tearDown
assert_equal
path_relative_to
create_test
unittest.case.TestCase
failureException
longMessage
maxDiff
addTypeEqualityFunc
addCleanup
addClassCleanup
setUpClass
tearDownClass
countTestCases
id
subTest
run
doCleanups
doClassCleanups
debug
skipTest
fail
assertFalse
assertTrue
assertRaises
assertWarns
assertLogs
assertNoLogs
assertEqual
assertNotEqual
assertAlmostEqual
assertNotAlmostEqual
assertSequenceEqual
assertListEqual
assertTupleEqual
assertSetEqual
assertIn
assertNotIn
assertIs
assertIsNot
assertDictEqual
assertDictContainsSubset
assertCountEqual
assertMultiLineEqual
assertLess
assertLessEqual
assertGreater
assertGreaterEqual
assertIsNone
assertIsNotNone
assertIsInstance
assertNotIsInstance
assertRaisesRegex
assertWarnsRegex
assertRegex
assertNotRegex
failUnlessRaises
failIf
assertRaisesRegexp
assertRegexpMatches
assertNotRegexpMatches
failUnlessEqual
assertEquals
failIfEqual
assertNotEquals
failUnlessAlmostEqual
assertAlmostEquals
failIfAlmostEqual
assertNotAlmostEquals
failUnless
assert_
def generate_test( model: Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel], input_queries: Dict[str, str], models: sqlmesh.utils.UniqueKeyDict[str, typing.Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel]], engine_adapter: sqlmesh.core.engine_adapter.base.EngineAdapter, test_engine_adapter: sqlmesh.core.engine_adapter.base.EngineAdapter, project_path: pathlib.Path, overwrite: bool = False, variables: Optional[Dict[str, str]] = None, path: Optional[str] = None, name: Optional[str] = None, include_ctes: bool = False) -> None:
827def generate_test(
828    model: Model,
829    input_queries: t.Dict[str, str],
830    models: UniqueKeyDict[str, Model],
831    engine_adapter: EngineAdapter,
832    test_engine_adapter: EngineAdapter,
833    project_path: Path,
834    overwrite: bool = False,
835    variables: t.Optional[t.Dict[str, str]] = None,
836    path: t.Optional[str] = None,
837    name: t.Optional[str] = None,
838    include_ctes: bool = False,
839) -> None:
840    """Generate a unit test fixture for a given model.
841
842    Args:
843        model: The model to test.
844        input_queries: Mapping of model names to queries. Each model included in this mapping
845            will be populated in the test based on the results of the corresponding query.
846        models: The context's models.
847        engine_adapter: The target engine adapter.
848        test_engine_adapter: The test engine adapter.
849        project_path: The path pointing to the project's root directory.
850        overwrite: Whether to overwrite the existing test in case of a file path collision.
851            When set to False, an error will be raised if there is such a collision.
852        variables: Key-value pairs that will define variables needed by the model.
853        path: The file path corresponding to the fixture, relative to the test directory.
854            By default, the fixture will be created under the test directory and the file name
855            will be inferred from the test's name.
856        name: The name of the test. This is inferred from the model name by default.
857        include_ctes: When true, CTE fixtures will also be generated.
858    """
859    import numpy as np
860
861    test_name = name or f"test_{model.view_name}"
862    path = path or f"{test_name}.yaml"
863
864    extension = path.split(".")[-1].lower()
865    if extension not in ("yaml", "yml"):
866        path = f"{path}.yaml"
867
868    fixture_path = project_path / c.TESTS / path
869    if not overwrite and fixture_path.exists():
870        raise ConfigError(
871            f"Fixture '{fixture_path}' already exists, make sure to set --overwrite if it can be safely overwritten."
872        )
873
874    # ruamel.yaml does not support pandas Timestamps, so we must convert them to python
875    # datetime or datetime.date objects based on column type
876    inputs = {
877        dep: pandas_timestamp_to_pydatetime(
878            engine_adapter.fetchdf(query).apply(lambda col: col.map(_normalize_df_value)),
879            models[dep].columns_to_types,
880        )
881        .replace({np.nan: None})
882        .to_dict(orient="records")
883        for dep, query in input_queries.items()
884    }
885    outputs: t.Dict[str, t.Any] = {"query": {}}
886    variables = variables or {}
887    test_body = {"model": model.fqn, "inputs": inputs, "outputs": outputs}
888
889    if variables:
890        test_body["vars"] = variables
891
892    test = ModelTest.create_test(
893        body=test_body.copy(),
894        test_name=test_name,
895        models=models,
896        engine_adapter=test_engine_adapter,
897        dialect=model.dialect,
898        path=fixture_path,
899        default_catalog=model.default_catalog,
900    )
901    if not test:
902        return
903
904    test.setUp()
905
906    if isinstance(model, SqlModel):
907        assert isinstance(test, SqlModelTest)
908        model_query = test._render_model_query()
909        with_clause = model_query.args.get("with_")
910
911        if with_clause and include_ctes:
912            ctes = {}
913            recursive = with_clause.recursive
914            previous_ctes: t.List[exp.CTE] = []
915
916            for cte in model_query.ctes:
917                cte_query = cte.this
918                cte_identifier = cte.args["alias"].this
919
920                cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_identifier)
921
922                for prev in chain(previous_ctes, [cte]):
923                    cte_query = cte_query.with_(
924                        prev.args["alias"].this, prev.this, recursive=recursive
925                    )
926
927                cte_output = test._execute(cte_query)
928                ctes[cte.alias] = (
929                    pandas_timestamp_to_pydatetime(
930                        df=cte_output.apply(lambda col: col.map(_normalize_df_value)),
931                    )
932                    .replace({np.nan: None})
933                    .to_dict(orient="records")
934                )
935
936                previous_ctes.append(cte)
937
938            if ctes:
939                outputs["ctes"] = ctes
940
941        output = test._execute(model_query)
942    else:
943        output = t.cast(PythonModelTest, test)._execute_model()
944
945    outputs["query"] = (
946        pandas_timestamp_to_pydatetime(
947            output.apply(lambda col: col.map(_normalize_df_value)), model.columns_to_types
948        )
949        .replace({np.nan: None})
950        .to_dict(orient="records")
951    )
952
953    test.tearDown()
954
955    fixture_path.parent.mkdir(exist_ok=True, parents=True)
956    with open(fixture_path, "w", encoding="utf-8") as file:
957        yaml.dump({test_name: test_body}, file)

Generate a unit test fixture for a given model.

Arguments:
  • model: The model to test.
  • input_queries: Mapping of model names to queries. Each model included in this mapping will be populated in the test based on the results of the corresponding query.
  • models: The context's models.
  • engine_adapter: The target engine adapter.
  • test_engine_adapter: The test engine adapter.
  • project_path: The path pointing to the project's root directory.
  • overwrite: Whether to overwrite the existing test in case of a file path collision. When set to False, an error will be raised if there is such a collision.
  • variables: Key-value pairs that will define variables needed by the model.
  • path: The file path corresponding to the fixture, relative to the test directory. By default, the fixture will be created under the test directory and the file name will be inferred from the test's name.
  • name: The name of the test. This is inferred from the model name by default.
  • include_ctes: When true, CTE fixtures will also be generated.