sqlmesh.core.lineage
1from __future__ import annotations 2 3import typing as t 4from collections import defaultdict 5 6from sqlglot import exp 7from sqlglot.helper import first 8from sqlglot.lineage import Node 9from sqlglot.lineage import lineage as sqlglot_lineage 10from sqlglot.optimizer import Scope, build_scope, qualify 11 12from sqlmesh.core.dialect import normalize_mapping_schema, normalize_model_name 13 14if t.TYPE_CHECKING: 15 from sqlmesh.core.context import Context 16 from sqlmesh.core.model import Model 17 18 19CACHE: t.Dict[str, t.Tuple[int, exp.Expression, Scope]] = {} 20 21 22def lineage( 23 column: str | exp.Column, 24 model: Model, 25 trim_selects: bool = True, 26 **kwargs: t.Any, 27) -> Node: 28 query = None 29 scope = None 30 31 if model.name in CACHE: 32 obj_id, query, scope = CACHE[model.name] 33 34 if obj_id != id(model): 35 query = None 36 scope = None 37 38 if not query or not scope: 39 query = t.cast(exp.Query, model.render_query_or_raise().copy()) 40 41 if model.managed_columns: 42 query = query.select( 43 *( 44 exp.alias_(exp.cast(exp.Null(), to=col_type), col) 45 for col, col_type in model.managed_columns.items() 46 if col not in query.named_selects 47 ), 48 copy=False, 49 ) 50 51 query = qualify.qualify( 52 query, 53 dialect=model.dialect, 54 schema=normalize_mapping_schema(model.mapping_schema, dialect=model.dialect), 55 **{"validate_qualify_columns": False, "infer_schema": True, **kwargs}, 56 ) 57 58 scope = build_scope(query) 59 60 if scope: 61 CACHE[model.name] = (id(model), query, scope) 62 63 return sqlglot_lineage( 64 column, 65 sql=query, 66 scope=scope, 67 trim_selects=trim_selects, 68 dialect=model.dialect, 69 ) 70 71 72def column_dependencies(context: Context, model_name: str, column: str) -> t.Dict[str, t.Set[str]]: 73 model = context.get_model(model_name) 74 parents = defaultdict(set) 75 76 for node in lineage(column, model, trim_selects=False).walk(): 77 if node.downstream: 78 continue 79 80 table = node.expression.find(exp.Table) 81 if table: 82 name = normalize_model_name( 83 table, default_catalog=context.default_catalog, dialect=model.dialect 84 ) 85 parents[name].add(exp.to_column(node.name).name) 86 return dict(parents) 87 88 89def column_description(context: Context, model_name: str, column: str) -> t.Optional[str]: 90 """Returns a column's description, inferring if needed.""" 91 model = context.get_model(model_name) 92 93 if column in model.column_descriptions: 94 return model.column_descriptions[column] 95 96 dependencies = column_dependencies(context, model_name, column) 97 98 if len(dependencies) != 1: 99 return None 100 101 parent, columns = first(dependencies.items()) 102 103 if len(columns) != 1: 104 return None 105 106 return column_description(context, parent, first(columns))
def
lineage( column: 'str | exp.Column', model: Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel], trim_selects: bool = True, **kwargs: Any) -> sqlglot.lineage.Node:
23def lineage( 24 column: str | exp.Column, 25 model: Model, 26 trim_selects: bool = True, 27 **kwargs: t.Any, 28) -> Node: 29 query = None 30 scope = None 31 32 if model.name in CACHE: 33 obj_id, query, scope = CACHE[model.name] 34 35 if obj_id != id(model): 36 query = None 37 scope = None 38 39 if not query or not scope: 40 query = t.cast(exp.Query, model.render_query_or_raise().copy()) 41 42 if model.managed_columns: 43 query = query.select( 44 *( 45 exp.alias_(exp.cast(exp.Null(), to=col_type), col) 46 for col, col_type in model.managed_columns.items() 47 if col not in query.named_selects 48 ), 49 copy=False, 50 ) 51 52 query = qualify.qualify( 53 query, 54 dialect=model.dialect, 55 schema=normalize_mapping_schema(model.mapping_schema, dialect=model.dialect), 56 **{"validate_qualify_columns": False, "infer_schema": True, **kwargs}, 57 ) 58 59 scope = build_scope(query) 60 61 if scope: 62 CACHE[model.name] = (id(model), query, scope) 63 64 return sqlglot_lineage( 65 column, 66 sql=query, 67 scope=scope, 68 trim_selects=trim_selects, 69 dialect=model.dialect, 70 )
def
column_dependencies( context: sqlmesh.core.context.Context, model_name: str, column: str) -> Dict[str, Set[str]]:
73def column_dependencies(context: Context, model_name: str, column: str) -> t.Dict[str, t.Set[str]]: 74 model = context.get_model(model_name) 75 parents = defaultdict(set) 76 77 for node in lineage(column, model, trim_selects=False).walk(): 78 if node.downstream: 79 continue 80 81 table = node.expression.find(exp.Table) 82 if table: 83 name = normalize_model_name( 84 table, default_catalog=context.default_catalog, dialect=model.dialect 85 ) 86 parents[name].add(exp.to_column(node.name).name) 87 return dict(parents)
def
column_description( context: sqlmesh.core.context.Context, model_name: str, column: str) -> Union[str, NoneType]:
90def column_description(context: Context, model_name: str, column: str) -> t.Optional[str]: 91 """Returns a column's description, inferring if needed.""" 92 model = context.get_model(model_name) 93 94 if column in model.column_descriptions: 95 return model.column_descriptions[column] 96 97 dependencies = column_dependencies(context, model_name, column) 98 99 if len(dependencies) != 1: 100 return None 101 102 parent, columns = first(dependencies.items()) 103 104 if len(columns) != 1: 105 return None 106 107 return column_description(context, parent, first(columns))
Returns a column's description, inferring if needed.