Edit on GitHub

sqlmesh.core.model.schema

  1from __future__ import annotations
  2
  3import typing as t
  4from concurrent.futures import as_completed
  5from pathlib import Path
  6
  7from sqlglot.errors import SchemaError
  8from sqlglot.schema import MappingSchema
  9
 10from sqlmesh.core.model.cache import (
 11    load_optimized_query_and_mapping,
 12    optimized_query_cache_pool,
 13    OptimizedQueryCache,
 14)
 15
 16if t.TYPE_CHECKING:
 17    from sqlmesh.core.model.definition import Model
 18    from sqlmesh.utils import UniqueKeyDict
 19    from sqlmesh.utils.dag import DAG
 20
 21
 22def update_model_schemas(
 23    dag: DAG[str],
 24    models: UniqueKeyDict[str, Model],
 25    cache_dir: Path,
 26) -> None:
 27    schema = MappingSchema(normalize=False)
 28    optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(cache_dir)
 29
 30    _update_model_schemas(dag, models, schema, optimized_query_cache)
 31
 32
 33def _update_schema_with_model(schema: MappingSchema, model: Model) -> None:
 34    columns_to_types = model.columns_to_types
 35    if columns_to_types:
 36        try:
 37            schema.add_table(model.fqn, columns_to_types, dialect=model.dialect)
 38        except SchemaError as e:
 39            if "nesting level:" in str(e):
 40                from sqlmesh.core.console import get_console
 41
 42                get_console().log_error(
 43                    "SQLMesh requires all model names and references to have the same level of nesting."
 44                )
 45            raise
 46
 47
 48def _update_model_schemas(
 49    dag: DAG[str],
 50    models: UniqueKeyDict[str, Model],
 51    schema: MappingSchema,
 52    optimized_query_cache: OptimizedQueryCache,
 53) -> None:
 54    futures = set()
 55    graph = {
 56        model: {dep for dep in deps if dep in models}
 57        for model, deps in dag._dag.items()
 58        if model in models
 59    }
 60
 61    def process_models(completed_model: t.Optional[Model] = None) -> None:
 62        for name in list(graph):
 63            deps = graph[name]
 64
 65            if completed_model:
 66                deps.discard(completed_model.fqn)
 67
 68            if not deps:
 69                del graph[name]
 70                model = models[name]
 71                futures.add(
 72                    executor.submit(
 73                        load_optimized_query_and_mapping,
 74                        model,
 75                        mapping={
 76                            parent: models[parent].columns_to_types
 77                            for parent in model.depends_on
 78                            if parent in models
 79                        },
 80                    )
 81                )
 82
 83    with optimized_query_cache_pool(optimized_query_cache) as executor:
 84        process_models()
 85
 86        while futures:
 87            for future in as_completed(futures):
 88                try:
 89                    futures.remove(future)
 90                    fqn, entry_name, data_hash, metadata_hash, mapping_schema = future.result()
 91                    model = models[fqn]
 92                    model._data_hash = data_hash
 93                    model._metadata_hash = metadata_hash
 94                    if model.mapping_schema != mapping_schema:
 95                        model.set_mapping_schema(mapping_schema)
 96                    optimized_query_cache.with_optimized_query(model, entry_name)
 97                    _update_schema_with_model(schema, model)
 98                    process_models(completed_model=model)
 99                except Exception as ex:
100                    raise SchemaError(f"Failed to update model schemas\n\n{ex}")
23def update_model_schemas(
24    dag: DAG[str],
25    models: UniqueKeyDict[str, Model],
26    cache_dir: Path,
27) -> None:
28    schema = MappingSchema(normalize=False)
29    optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(cache_dir)
30
31    _update_model_schemas(dag, models, schema, optimized_query_cache)