sqlmesh.core.metric.rewriter
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot.dialects.dialect import DialectType 7from sqlglot.optimizer import Scope, find_all_in_scope, optimize 8from sqlglot.optimizer.optimize_joins import optimize_joins 9from sqlglot.optimizer.qualify import qualify 10 11from sqlmesh.core import dialect as d 12from sqlmesh.core.metric.definition import Metric, remove_namespace 13 14if t.TYPE_CHECKING: 15 from sqlmesh.core.reference import ReferenceGraph 16 17 18SourceAggsAndJoins = t.Dict[str, t.Tuple[t.Set[exp.AggFunc], t.Dict[str, t.Optional[exp.Join]]]] 19 20 21class Rewriter: 22 def __init__( 23 self, 24 graph: ReferenceGraph, 25 metrics: t.Dict[str, Metric], 26 dialect: DialectType = "", 27 join_type: str = "FULL", 28 semantic_schema: str = "__semantic", 29 semantic_table: str = "__table", 30 ): 31 self.graph = graph 32 self.metrics = metrics 33 self.dialect = dialect 34 self.join_type = join_type 35 self.semantic_name = f"{semantic_schema}.{semantic_table}" 36 37 def rewrite(self, expression: exp.Expr) -> exp.Expr: 38 for select in list(expression.find_all(exp.Select)): 39 self._expand(select) 40 41 return expression 42 43 def _build_sources(self, projections: t.List[exp.Expr]) -> SourceAggsAndJoins: 44 sources: SourceAggsAndJoins = {} 45 46 for projection in projections: 47 for ref in find_all_in_scope(projection, d.MetricAgg): 48 metric = self.metrics[ref.this.name] 49 ref.replace(metric.formula.this) 50 51 for agg, (measure, dims) in metric.aggs.items(): 52 aggs, joins = sources.setdefault(measure, (set(), dict())) 53 aggs.add(agg) 54 for dim in dims: 55 joins[dim] = None 56 57 return sources 58 59 def _expand(self, select: exp.Select) -> None: 60 base = select.args["from_"].this.find(exp.Table) 61 base_alias = base.alias_or_name 62 base_name = exp.table_name(base) 63 64 sources: SourceAggsAndJoins = ( 65 {} if base_name == self.semantic_name else {base_name: (set(), {})} 66 ) 67 sources.update(self._build_sources(select.selects)) 68 69 group = select.args.pop("group", None) 70 group_by = group.expressions if group else [] 71 72 mapping = { 73 remove_namespace(exp.table_name(source.assert_is(exp.Table))): name 74 for name, source in Scope(select).references 75 if name != base_alias 76 } 77 78 explicit_joins = {exp.table_name(join.this): join for join in select.args.pop("joins", [])} 79 80 for i, (name, (aggs, joins)) in enumerate(sources.items()): 81 source: exp.Expr = exp.to_table(name) 82 table_name = remove_namespace(name) 83 84 if not isinstance(source, exp.Select): 85 source = exp.Select().from_( 86 exp.alias_(source, table_name, table=True, copy=False), copy=False 87 ) 88 89 joins.update(explicit_joins) 90 query = self._add_joins(source, name, joins, group_by, mapping).select( 91 *sorted(aggs, key=str), copy=False 92 ) 93 94 if not query.selects: 95 query.select("*", copy=False) 96 97 if i == 0: 98 where = select.args.pop("where", None) 99 100 if where: 101 query.where(_replace_table(where.this, table_name, base_alias), copy=False) 102 103 select.from_(query.subquery(base_alias, copy=False), copy=False) 104 else: 105 select.join( 106 query, 107 on=[e.eq(_replace_table(e.copy(), table_name, base_alias)) for e in group_by], # type: ignore 108 join_type=self.join_type, 109 join_alias=table_name, 110 copy=False, 111 ) 112 113 for node in find_all_in_scope(query, exp.Column, exp.TableAlias): # type: ignore[arg-type,var-annotated] 114 if isinstance(node, exp.Column): 115 if node.table in mapping: 116 node.set("table", exp.to_identifier(mapping[node.table])) 117 else: 118 if node.name in mapping: 119 node.set("this", exp.to_identifier(mapping[node.name])) 120 121 def _add_joins( 122 self, 123 source: exp.Select, 124 name: str, 125 joins: t.Dict[str, t.Optional[exp.Join]], 126 group_by: t.List[exp.Expr], 127 mapping: t.Dict[str, str], 128 ) -> exp.Select: 129 grain = [e.copy() for e in group_by] 130 table_name = remove_namespace(name) 131 mapping = {v: k for k, v in mapping.items()} 132 133 for expr in grain: 134 for node in expr.walk(): 135 if isinstance(node, exp.Column): 136 models = self.graph.models_for_column(name, node.name) 137 138 if name in models: 139 node.args["table"] = exp.to_identifier(table_name) 140 elif models: 141 t = mapping.get(node.table) 142 model = next( 143 (model for model in models if remove_namespace(model) == t), 144 models[0], 145 ) 146 node.args["table"] = exp.to_identifier(t or remove_namespace(model)) 147 if model not in joins: 148 joins[model] = None 149 150 for target, join in joins.items(): 151 path = self.graph.find_path(name, target) 152 for i in range(len(path) - 1): 153 a_ref = path[i] 154 b_ref = path[i + 1] 155 a_model_alias = remove_namespace(a_ref.model_name) 156 b_model_alias = remove_namespace(b_ref.model_name) 157 158 a = a_ref.expression.copy() 159 a.set("table", exp.to_identifier(a_model_alias)) 160 b = b_ref.expression.copy() 161 b.set("table", exp.to_identifier(b_model_alias)) 162 on = a.eq(b) 163 164 if join: 165 join.set("on", on) 166 source.append("joins", join) 167 else: 168 source.join( 169 b_ref.model_name, 170 on=on, 171 join_type="LEFT", 172 join_alias=b_model_alias, 173 dialect=self.dialect, 174 copy=False, 175 ) 176 177 return source.select(*grain, copy=False).group_by(*grain, copy=False) 178 179 180def _replace_table(node: exp.Expr, table: str, base_alias: str) -> exp.Expr: 181 for column in find_all_in_scope(node, exp.Column): 182 if column.table == base_alias: 183 column.args["table"] = exp.to_identifier(table) 184 return node 185 186 187def rewrite( 188 sql: str | exp.Expr, 189 graph: ReferenceGraph, 190 metrics: t.Dict[str, Metric], 191 dialect: t.Optional[str] = "", 192) -> exp.Expr: 193 rewriter = Rewriter(graph=graph, metrics=metrics, dialect=dialect) 194 195 return optimize( 196 d.parse_one(sql, dialect=dialect) if isinstance(sql, str) else sql, 197 dialect=dialect, 198 quote_identifiers=False, 199 rules=( 200 qualify, 201 rewriter.rewrite, 202 optimize_joins, 203 ), 204 )
SourceAggsAndJoins =
typing.Dict[str, typing.Tuple[typing.Set[sqlglot.expressions.core.AggFunc], typing.Dict[str, typing.Optional[sqlglot.expressions.query.Join]]]]
class
Rewriter:
22class Rewriter: 23 def __init__( 24 self, 25 graph: ReferenceGraph, 26 metrics: t.Dict[str, Metric], 27 dialect: DialectType = "", 28 join_type: str = "FULL", 29 semantic_schema: str = "__semantic", 30 semantic_table: str = "__table", 31 ): 32 self.graph = graph 33 self.metrics = metrics 34 self.dialect = dialect 35 self.join_type = join_type 36 self.semantic_name = f"{semantic_schema}.{semantic_table}" 37 38 def rewrite(self, expression: exp.Expr) -> exp.Expr: 39 for select in list(expression.find_all(exp.Select)): 40 self._expand(select) 41 42 return expression 43 44 def _build_sources(self, projections: t.List[exp.Expr]) -> SourceAggsAndJoins: 45 sources: SourceAggsAndJoins = {} 46 47 for projection in projections: 48 for ref in find_all_in_scope(projection, d.MetricAgg): 49 metric = self.metrics[ref.this.name] 50 ref.replace(metric.formula.this) 51 52 for agg, (measure, dims) in metric.aggs.items(): 53 aggs, joins = sources.setdefault(measure, (set(), dict())) 54 aggs.add(agg) 55 for dim in dims: 56 joins[dim] = None 57 58 return sources 59 60 def _expand(self, select: exp.Select) -> None: 61 base = select.args["from_"].this.find(exp.Table) 62 base_alias = base.alias_or_name 63 base_name = exp.table_name(base) 64 65 sources: SourceAggsAndJoins = ( 66 {} if base_name == self.semantic_name else {base_name: (set(), {})} 67 ) 68 sources.update(self._build_sources(select.selects)) 69 70 group = select.args.pop("group", None) 71 group_by = group.expressions if group else [] 72 73 mapping = { 74 remove_namespace(exp.table_name(source.assert_is(exp.Table))): name 75 for name, source in Scope(select).references 76 if name != base_alias 77 } 78 79 explicit_joins = {exp.table_name(join.this): join for join in select.args.pop("joins", [])} 80 81 for i, (name, (aggs, joins)) in enumerate(sources.items()): 82 source: exp.Expr = exp.to_table(name) 83 table_name = remove_namespace(name) 84 85 if not isinstance(source, exp.Select): 86 source = exp.Select().from_( 87 exp.alias_(source, table_name, table=True, copy=False), copy=False 88 ) 89 90 joins.update(explicit_joins) 91 query = self._add_joins(source, name, joins, group_by, mapping).select( 92 *sorted(aggs, key=str), copy=False 93 ) 94 95 if not query.selects: 96 query.select("*", copy=False) 97 98 if i == 0: 99 where = select.args.pop("where", None) 100 101 if where: 102 query.where(_replace_table(where.this, table_name, base_alias), copy=False) 103 104 select.from_(query.subquery(base_alias, copy=False), copy=False) 105 else: 106 select.join( 107 query, 108 on=[e.eq(_replace_table(e.copy(), table_name, base_alias)) for e in group_by], # type: ignore 109 join_type=self.join_type, 110 join_alias=table_name, 111 copy=False, 112 ) 113 114 for node in find_all_in_scope(query, exp.Column, exp.TableAlias): # type: ignore[arg-type,var-annotated] 115 if isinstance(node, exp.Column): 116 if node.table in mapping: 117 node.set("table", exp.to_identifier(mapping[node.table])) 118 else: 119 if node.name in mapping: 120 node.set("this", exp.to_identifier(mapping[node.name])) 121 122 def _add_joins( 123 self, 124 source: exp.Select, 125 name: str, 126 joins: t.Dict[str, t.Optional[exp.Join]], 127 group_by: t.List[exp.Expr], 128 mapping: t.Dict[str, str], 129 ) -> exp.Select: 130 grain = [e.copy() for e in group_by] 131 table_name = remove_namespace(name) 132 mapping = {v: k for k, v in mapping.items()} 133 134 for expr in grain: 135 for node in expr.walk(): 136 if isinstance(node, exp.Column): 137 models = self.graph.models_for_column(name, node.name) 138 139 if name in models: 140 node.args["table"] = exp.to_identifier(table_name) 141 elif models: 142 t = mapping.get(node.table) 143 model = next( 144 (model for model in models if remove_namespace(model) == t), 145 models[0], 146 ) 147 node.args["table"] = exp.to_identifier(t or remove_namespace(model)) 148 if model not in joins: 149 joins[model] = None 150 151 for target, join in joins.items(): 152 path = self.graph.find_path(name, target) 153 for i in range(len(path) - 1): 154 a_ref = path[i] 155 b_ref = path[i + 1] 156 a_model_alias = remove_namespace(a_ref.model_name) 157 b_model_alias = remove_namespace(b_ref.model_name) 158 159 a = a_ref.expression.copy() 160 a.set("table", exp.to_identifier(a_model_alias)) 161 b = b_ref.expression.copy() 162 b.set("table", exp.to_identifier(b_model_alias)) 163 on = a.eq(b) 164 165 if join: 166 join.set("on", on) 167 source.append("joins", join) 168 else: 169 source.join( 170 b_ref.model_name, 171 on=on, 172 join_type="LEFT", 173 join_alias=b_model_alias, 174 dialect=self.dialect, 175 copy=False, 176 ) 177 178 return source.select(*grain, copy=False).group_by(*grain, copy=False)
Rewriter( graph: sqlmesh.core.reference.ReferenceGraph, metrics: Dict[str, sqlmesh.core.metric.definition.Metric], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = '', join_type: str = 'FULL', semantic_schema: str = '__semantic', semantic_table: str = '__table')
23 def __init__( 24 self, 25 graph: ReferenceGraph, 26 metrics: t.Dict[str, Metric], 27 dialect: DialectType = "", 28 join_type: str = "FULL", 29 semantic_schema: str = "__semantic", 30 semantic_table: str = "__table", 31 ): 32 self.graph = graph 33 self.metrics = metrics 34 self.dialect = dialect 35 self.join_type = join_type 36 self.semantic_name = f"{semantic_schema}.{semantic_table}"
def
rewrite( sql: str | sqlglot.expressions.core.Expr, graph: sqlmesh.core.reference.ReferenceGraph, metrics: Dict[str, sqlmesh.core.metric.definition.Metric], dialect: Optional[str] = '') -> sqlglot.expressions.core.Expr:
188def rewrite( 189 sql: str | exp.Expr, 190 graph: ReferenceGraph, 191 metrics: t.Dict[str, Metric], 192 dialect: t.Optional[str] = "", 193) -> exp.Expr: 194 rewriter = Rewriter(graph=graph, metrics=metrics, dialect=dialect) 195 196 return optimize( 197 d.parse_one(sql, dialect=dialect) if isinstance(sql, str) else sql, 198 dialect=dialect, 199 quote_identifiers=False, 200 rules=( 201 qualify, 202 rewriter.rewrite, 203 optimize_joins, 204 ), 205 )