sqlmesh.core.lineage
1from __future__ import annotations 2 3import typing as t 4from collections import defaultdict 5 6from sqlglot import exp 7from sqlglot.helper import first 8from sqlglot.lineage import Node 9from sqlglot.lineage import lineage as sqlglot_lineage 10from sqlglot.optimizer import Scope, build_scope, qualify 11 12from sqlmesh.core.dialect import normalize_mapping_schema, normalize_model_name 13 14if t.TYPE_CHECKING: 15 from sqlmesh.core.context import Context 16 from sqlmesh.core.model import Model 17 18 19CACHE: t.Dict[str, t.Tuple[int, exp.Expr, Scope]] = {} 20 21 22def lineage( 23 column: str | exp.Column, 24 model: Model, 25 trim_selects: bool = True, 26 **kwargs: t.Any, 27) -> Node: 28 query: t.Optional[exp.Expr] = None 29 scope: t.Optional[Scope] = None 30 31 if model.name in CACHE: 32 obj_id, query, scope = CACHE[model.name] 33 34 if obj_id != id(model): 35 query = None 36 scope = None 37 38 if not query or not scope: 39 query = t.cast(exp.Query, model.render_query_or_raise().copy()) 40 41 if model.managed_columns: 42 query = query.select( 43 *( 44 exp.alias_(exp.cast(exp.Null(), to=col_type), col) 45 for col, col_type in model.managed_columns.items() 46 if col not in query.named_selects 47 ), 48 copy=False, 49 ) 50 51 query = qualify.qualify( 52 query, 53 dialect=model.dialect, 54 schema=normalize_mapping_schema(model.mapping_schema, dialect=model.dialect), 55 **{"validate_qualify_columns": False, "infer_schema": True, **kwargs}, 56 ) 57 58 scope = build_scope(query) 59 60 if scope: 61 CACHE[model.name] = (id(model), query, scope) 62 63 return sqlglot_lineage( 64 column, 65 sql=query, 66 scope=scope, 67 trim_selects=trim_selects, 68 dialect=model.dialect, 69 copy=False, 70 ) 71 72 73def column_dependencies( 74 context: Context, model_name: str, column: str | exp.Column 75) -> t.Dict[str, t.Set[str]]: 76 model = context.get_model(model_name) 77 parents = defaultdict(set) 78 79 for node in lineage(column, model, trim_selects=False).walk(): 80 if node.downstream: 81 continue 82 83 table = node.expression.find(exp.Table) 84 if table: 85 name = normalize_model_name( 86 table, default_catalog=context.default_catalog, dialect=model.dialect 87 ) 88 parents[name].add(exp.to_column(node.name).name) 89 return dict(parents) 90 91 92def column_description( 93 context: Context, model_name: str, column: str, quote_column: bool = False 94) -> t.Optional[str]: 95 """Returns a column's description, inferring if needed.""" 96 model = context.get_model(model_name) 97 98 if not model: 99 return None 100 101 if column in model.column_descriptions: 102 return model.column_descriptions[column] 103 104 dependencies = column_dependencies(context, model_name, exp.column(column, quoted=quote_column)) 105 106 if len(dependencies) != 1: 107 return None 108 109 parent, columns = first(dependencies.items()) 110 111 if len(columns) != 1: 112 return None 113 114 return column_description(context, parent, first(columns))
CACHE: Dict[str, Tuple[int, sqlglot.expressions.core.Expr, sqlglot.optimizer.scope.Scope]] =
{}
def
lineage( column: str | sqlglot.expressions.core.Column, model: Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel], trim_selects: bool = True, **kwargs: Any) -> sqlglot.lineage.Node:
23def lineage( 24 column: str | exp.Column, 25 model: Model, 26 trim_selects: bool = True, 27 **kwargs: t.Any, 28) -> Node: 29 query: t.Optional[exp.Expr] = None 30 scope: t.Optional[Scope] = None 31 32 if model.name in CACHE: 33 obj_id, query, scope = CACHE[model.name] 34 35 if obj_id != id(model): 36 query = None 37 scope = None 38 39 if not query or not scope: 40 query = t.cast(exp.Query, model.render_query_or_raise().copy()) 41 42 if model.managed_columns: 43 query = query.select( 44 *( 45 exp.alias_(exp.cast(exp.Null(), to=col_type), col) 46 for col, col_type in model.managed_columns.items() 47 if col not in query.named_selects 48 ), 49 copy=False, 50 ) 51 52 query = qualify.qualify( 53 query, 54 dialect=model.dialect, 55 schema=normalize_mapping_schema(model.mapping_schema, dialect=model.dialect), 56 **{"validate_qualify_columns": False, "infer_schema": True, **kwargs}, 57 ) 58 59 scope = build_scope(query) 60 61 if scope: 62 CACHE[model.name] = (id(model), query, scope) 63 64 return sqlglot_lineage( 65 column, 66 sql=query, 67 scope=scope, 68 trim_selects=trim_selects, 69 dialect=model.dialect, 70 copy=False, 71 )
def
column_dependencies( context: sqlmesh.core.context.Context, model_name: str, column: str | sqlglot.expressions.core.Column) -> Dict[str, Set[str]]:
74def column_dependencies( 75 context: Context, model_name: str, column: str | exp.Column 76) -> t.Dict[str, t.Set[str]]: 77 model = context.get_model(model_name) 78 parents = defaultdict(set) 79 80 for node in lineage(column, model, trim_selects=False).walk(): 81 if node.downstream: 82 continue 83 84 table = node.expression.find(exp.Table) 85 if table: 86 name = normalize_model_name( 87 table, default_catalog=context.default_catalog, dialect=model.dialect 88 ) 89 parents[name].add(exp.to_column(node.name).name) 90 return dict(parents)
def
column_description( context: sqlmesh.core.context.Context, model_name: str, column: str, quote_column: bool = False) -> Optional[str]:
93def column_description( 94 context: Context, model_name: str, column: str, quote_column: bool = False 95) -> t.Optional[str]: 96 """Returns a column's description, inferring if needed.""" 97 model = context.get_model(model_name) 98 99 if not model: 100 return None 101 102 if column in model.column_descriptions: 103 return model.column_descriptions[column] 104 105 dependencies = column_dependencies(context, model_name, exp.column(column, quoted=quote_column)) 106 107 if len(dependencies) != 1: 108 return None 109 110 parent, columns = first(dependencies.items()) 111 112 if len(columns) != 1: 113 return None 114 115 return column_description(context, parent, first(columns))
Returns a column's description, inferring if needed.