Edit on GitHub

sqlmesh.core.lineage

  1from __future__ import annotations
  2
  3import typing as t
  4from collections import defaultdict
  5
  6from sqlglot import exp
  7from sqlglot.helper import first
  8from sqlglot.lineage import Node
  9from sqlglot.lineage import lineage as sqlglot_lineage
 10from sqlglot.optimizer import Scope, build_scope, qualify
 11
 12from sqlmesh.core.dialect import normalize_mapping_schema, normalize_model_name
 13
 14if t.TYPE_CHECKING:
 15    from sqlmesh.core.context import Context
 16    from sqlmesh.core.model import Model
 17
 18
 19CACHE: t.Dict[str, t.Tuple[int, exp.Expr, Scope]] = {}
 20
 21
 22def lineage(
 23    column: str | exp.Column,
 24    model: Model,
 25    trim_selects: bool = True,
 26    **kwargs: t.Any,
 27) -> Node:
 28    query: t.Optional[exp.Expr] = None
 29    scope: t.Optional[Scope] = None
 30
 31    if model.name in CACHE:
 32        obj_id, query, scope = CACHE[model.name]
 33
 34        if obj_id != id(model):
 35            query = None
 36            scope = None
 37
 38    if not query or not scope:
 39        query = t.cast(exp.Query, model.render_query_or_raise().copy())
 40
 41        if model.managed_columns:
 42            query = query.select(
 43                *(
 44                    exp.alias_(exp.cast(exp.Null(), to=col_type), col)
 45                    for col, col_type in model.managed_columns.items()
 46                    if col not in query.named_selects
 47                ),
 48                copy=False,
 49            )
 50
 51        query = qualify.qualify(
 52            query,
 53            dialect=model.dialect,
 54            schema=normalize_mapping_schema(model.mapping_schema, dialect=model.dialect),
 55            **{"validate_qualify_columns": False, "infer_schema": True, **kwargs},
 56        )
 57
 58        scope = build_scope(query)
 59
 60        if scope:
 61            CACHE[model.name] = (id(model), query, scope)
 62
 63    return sqlglot_lineage(
 64        column,
 65        sql=query,
 66        scope=scope,
 67        trim_selects=trim_selects,
 68        dialect=model.dialect,
 69        copy=False,
 70    )
 71
 72
 73def column_dependencies(
 74    context: Context, model_name: str, column: str | exp.Column
 75) -> t.Dict[str, t.Set[str]]:
 76    model = context.get_model(model_name)
 77    parents = defaultdict(set)
 78
 79    for node in lineage(column, model, trim_selects=False).walk():
 80        if node.downstream:
 81            continue
 82
 83        table = node.expression.find(exp.Table)
 84        if table:
 85            name = normalize_model_name(
 86                table, default_catalog=context.default_catalog, dialect=model.dialect
 87            )
 88            parents[name].add(exp.to_column(node.name).name)
 89    return dict(parents)
 90
 91
 92def column_description(
 93    context: Context, model_name: str, column: str, quote_column: bool = False
 94) -> t.Optional[str]:
 95    """Returns a column's description, inferring if needed."""
 96    model = context.get_model(model_name)
 97
 98    if not model:
 99        return None
100
101    if column in model.column_descriptions:
102        return model.column_descriptions[column]
103
104    dependencies = column_dependencies(context, model_name, exp.column(column, quoted=quote_column))
105
106    if len(dependencies) != 1:
107        return None
108
109    parent, columns = first(dependencies.items())
110
111    if len(columns) != 1:
112        return None
113
114    return column_description(context, parent, first(columns))
CACHE: Dict[str, Tuple[int, sqlglot.expressions.core.Expr, sqlglot.optimizer.scope.Scope]] = {}
def lineage( column: str | sqlglot.expressions.core.Column, model: Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel], trim_selects: bool = True, **kwargs: Any) -> sqlglot.lineage.Node:
23def lineage(
24    column: str | exp.Column,
25    model: Model,
26    trim_selects: bool = True,
27    **kwargs: t.Any,
28) -> Node:
29    query: t.Optional[exp.Expr] = None
30    scope: t.Optional[Scope] = None
31
32    if model.name in CACHE:
33        obj_id, query, scope = CACHE[model.name]
34
35        if obj_id != id(model):
36            query = None
37            scope = None
38
39    if not query or not scope:
40        query = t.cast(exp.Query, model.render_query_or_raise().copy())
41
42        if model.managed_columns:
43            query = query.select(
44                *(
45                    exp.alias_(exp.cast(exp.Null(), to=col_type), col)
46                    for col, col_type in model.managed_columns.items()
47                    if col not in query.named_selects
48                ),
49                copy=False,
50            )
51
52        query = qualify.qualify(
53            query,
54            dialect=model.dialect,
55            schema=normalize_mapping_schema(model.mapping_schema, dialect=model.dialect),
56            **{"validate_qualify_columns": False, "infer_schema": True, **kwargs},
57        )
58
59        scope = build_scope(query)
60
61        if scope:
62            CACHE[model.name] = (id(model), query, scope)
63
64    return sqlglot_lineage(
65        column,
66        sql=query,
67        scope=scope,
68        trim_selects=trim_selects,
69        dialect=model.dialect,
70        copy=False,
71    )
def column_dependencies( context: sqlmesh.core.context.Context, model_name: str, column: str | sqlglot.expressions.core.Column) -> Dict[str, Set[str]]:
74def column_dependencies(
75    context: Context, model_name: str, column: str | exp.Column
76) -> t.Dict[str, t.Set[str]]:
77    model = context.get_model(model_name)
78    parents = defaultdict(set)
79
80    for node in lineage(column, model, trim_selects=False).walk():
81        if node.downstream:
82            continue
83
84        table = node.expression.find(exp.Table)
85        if table:
86            name = normalize_model_name(
87                table, default_catalog=context.default_catalog, dialect=model.dialect
88            )
89            parents[name].add(exp.to_column(node.name).name)
90    return dict(parents)
def column_description( context: sqlmesh.core.context.Context, model_name: str, column: str, quote_column: bool = False) -> Optional[str]:
 93def column_description(
 94    context: Context, model_name: str, column: str, quote_column: bool = False
 95) -> t.Optional[str]:
 96    """Returns a column's description, inferring if needed."""
 97    model = context.get_model(model_name)
 98
 99    if not model:
100        return None
101
102    if column in model.column_descriptions:
103        return model.column_descriptions[column]
104
105    dependencies = column_dependencies(context, model_name, exp.column(column, quoted=quote_column))
106
107    if len(dependencies) != 1:
108        return None
109
110    parent, columns = first(dependencies.items())
111
112    if len(columns) != 1:
113        return None
114
115    return column_description(context, parent, first(columns))

Returns a column's description, inferring if needed.