Edit on GitHub

sqlmesh.core.lineage

  1from __future__ import annotations
  2
  3import typing as t
  4from collections import defaultdict
  5
  6from sqlglot import exp
  7from sqlglot.helper import first
  8from sqlglot.lineage import Node
  9from sqlglot.lineage import lineage as sqlglot_lineage
 10from sqlglot.optimizer import Scope, build_scope, qualify
 11
 12from sqlmesh.core.dialect import normalize_mapping_schema, normalize_model_name
 13
 14if t.TYPE_CHECKING:
 15    from sqlmesh.core.context import Context
 16    from sqlmesh.core.model import Model
 17
 18
 19CACHE: t.Dict[str, t.Tuple[int, exp.Expression, Scope]] = {}
 20
 21
 22def lineage(
 23    column: str | exp.Column,
 24    model: Model,
 25    trim_selects: bool = True,
 26    **kwargs: t.Any,
 27) -> Node:
 28    query = None
 29    scope = None
 30
 31    if model.name in CACHE:
 32        obj_id, query, scope = CACHE[model.name]
 33
 34        if obj_id != id(model):
 35            query = None
 36            scope = None
 37
 38    if not query or not scope:
 39        query = t.cast(exp.Query, model.render_query_or_raise().copy())
 40
 41        if model.managed_columns:
 42            query = query.select(
 43                *(
 44                    exp.alias_(exp.cast(exp.Null(), to=col_type), col)
 45                    for col, col_type in model.managed_columns.items()
 46                    if col not in query.named_selects
 47                ),
 48                copy=False,
 49            )
 50
 51        query = qualify.qualify(
 52            query,
 53            dialect=model.dialect,
 54            schema=normalize_mapping_schema(model.mapping_schema, dialect=model.dialect),
 55            **{"validate_qualify_columns": False, "infer_schema": True, **kwargs},
 56        )
 57
 58        scope = build_scope(query)
 59
 60        if scope:
 61            CACHE[model.name] = (id(model), query, scope)
 62
 63    return sqlglot_lineage(
 64        column,
 65        sql=query,
 66        scope=scope,
 67        trim_selects=trim_selects,
 68        dialect=model.dialect,
 69    )
 70
 71
 72def column_dependencies(context: Context, model_name: str, column: str) -> t.Dict[str, t.Set[str]]:
 73    model = context.get_model(model_name)
 74    parents = defaultdict(set)
 75
 76    for node in lineage(column, model, trim_selects=False).walk():
 77        if node.downstream:
 78            continue
 79
 80        table = node.expression.find(exp.Table)
 81        if table:
 82            name = normalize_model_name(
 83                table, default_catalog=context.default_catalog, dialect=model.dialect
 84            )
 85            parents[name].add(exp.to_column(node.name).name)
 86    return dict(parents)
 87
 88
 89def column_description(context: Context, model_name: str, column: str) -> t.Optional[str]:
 90    """Returns a column's description, inferring if needed."""
 91    model = context.get_model(model_name)
 92
 93    if column in model.column_descriptions:
 94        return model.column_descriptions[column]
 95
 96    dependencies = column_dependencies(context, model_name, column)
 97
 98    if len(dependencies) != 1:
 99        return None
100
101    parent, columns = first(dependencies.items())
102
103    if len(columns) != 1:
104        return None
105
106    return column_description(context, parent, first(columns))
def lineage( column: 'str | exp.Column', model: Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel], trim_selects: bool = True, **kwargs: Any) -> sqlglot.lineage.Node:
23def lineage(
24    column: str | exp.Column,
25    model: Model,
26    trim_selects: bool = True,
27    **kwargs: t.Any,
28) -> Node:
29    query = None
30    scope = None
31
32    if model.name in CACHE:
33        obj_id, query, scope = CACHE[model.name]
34
35        if obj_id != id(model):
36            query = None
37            scope = None
38
39    if not query or not scope:
40        query = t.cast(exp.Query, model.render_query_or_raise().copy())
41
42        if model.managed_columns:
43            query = query.select(
44                *(
45                    exp.alias_(exp.cast(exp.Null(), to=col_type), col)
46                    for col, col_type in model.managed_columns.items()
47                    if col not in query.named_selects
48                ),
49                copy=False,
50            )
51
52        query = qualify.qualify(
53            query,
54            dialect=model.dialect,
55            schema=normalize_mapping_schema(model.mapping_schema, dialect=model.dialect),
56            **{"validate_qualify_columns": False, "infer_schema": True, **kwargs},
57        )
58
59        scope = build_scope(query)
60
61        if scope:
62            CACHE[model.name] = (id(model), query, scope)
63
64    return sqlglot_lineage(
65        column,
66        sql=query,
67        scope=scope,
68        trim_selects=trim_selects,
69        dialect=model.dialect,
70    )
def column_dependencies( context: sqlmesh.core.context.Context, model_name: str, column: str) -> Dict[str, Set[str]]:
73def column_dependencies(context: Context, model_name: str, column: str) -> t.Dict[str, t.Set[str]]:
74    model = context.get_model(model_name)
75    parents = defaultdict(set)
76
77    for node in lineage(column, model, trim_selects=False).walk():
78        if node.downstream:
79            continue
80
81        table = node.expression.find(exp.Table)
82        if table:
83            name = normalize_model_name(
84                table, default_catalog=context.default_catalog, dialect=model.dialect
85            )
86            parents[name].add(exp.to_column(node.name).name)
87    return dict(parents)
def column_description( context: sqlmesh.core.context.Context, model_name: str, column: str) -> Union[str, NoneType]:
 90def column_description(context: Context, model_name: str, column: str) -> t.Optional[str]:
 91    """Returns a column's description, inferring if needed."""
 92    model = context.get_model(model_name)
 93
 94    if column in model.column_descriptions:
 95        return model.column_descriptions[column]
 96
 97    dependencies = column_dependencies(context, model_name, column)
 98
 99    if len(dependencies) != 1:
100        return None
101
102    parent, columns = first(dependencies.items())
103
104    if len(columns) != 1:
105        return None
106
107    return column_description(context, parent, first(columns))

Returns a column's description, inferring if needed.