Edit on GitHub

sqlmesh.core.metric.rewriter

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import exp
  6from sqlglot.dialects.dialect import DialectType
  7from sqlglot.optimizer import Scope, find_all_in_scope, optimize
  8from sqlglot.optimizer.optimize_joins import optimize_joins
  9from sqlglot.optimizer.qualify import qualify
 10
 11from sqlmesh.core import dialect as d
 12from sqlmesh.core.metric.definition import Metric, remove_namespace
 13
 14if t.TYPE_CHECKING:
 15    from sqlmesh.core.reference import ReferenceGraph
 16
 17
 18SourceAggsAndJoins = t.Dict[str, t.Tuple[t.Set[exp.AggFunc], t.Dict[str, t.Optional[exp.Join]]]]
 19
 20
 21class Rewriter:
 22    def __init__(
 23        self,
 24        graph: ReferenceGraph,
 25        metrics: t.Dict[str, Metric],
 26        dialect: DialectType = "",
 27        join_type: str = "FULL",
 28        semantic_schema: str = "__semantic",
 29        semantic_table: str = "__table",
 30    ):
 31        self.graph = graph
 32        self.metrics = metrics
 33        self.dialect = dialect
 34        self.join_type = join_type
 35        self.semantic_name = f"{semantic_schema}.{semantic_table}"
 36
 37    def rewrite(self, expression: exp.Expr) -> exp.Expr:
 38        for select in list(expression.find_all(exp.Select)):
 39            self._expand(select)
 40
 41        return expression
 42
 43    def _build_sources(self, projections: t.List[exp.Expr]) -> SourceAggsAndJoins:
 44        sources: SourceAggsAndJoins = {}
 45
 46        for projection in projections:
 47            for ref in find_all_in_scope(projection, d.MetricAgg):
 48                metric = self.metrics[ref.this.name]
 49                ref.replace(metric.formula.this)
 50
 51                for agg, (measure, dims) in metric.aggs.items():
 52                    aggs, joins = sources.setdefault(measure, (set(), dict()))
 53                    aggs.add(agg)
 54                    for dim in dims:
 55                        joins[dim] = None
 56
 57        return sources
 58
 59    def _expand(self, select: exp.Select) -> None:
 60        base = select.args["from_"].this.find(exp.Table)
 61        base_alias = base.alias_or_name
 62        base_name = exp.table_name(base)
 63
 64        sources: SourceAggsAndJoins = (
 65            {} if base_name == self.semantic_name else {base_name: (set(), {})}
 66        )
 67        sources.update(self._build_sources(select.selects))
 68
 69        group = select.args.pop("group", None)
 70        group_by = group.expressions if group else []
 71
 72        mapping = {
 73            remove_namespace(exp.table_name(source.assert_is(exp.Table))): name
 74            for name, source in Scope(select).references
 75            if name != base_alias
 76        }
 77
 78        explicit_joins = {exp.table_name(join.this): join for join in select.args.pop("joins", [])}
 79
 80        for i, (name, (aggs, joins)) in enumerate(sources.items()):
 81            source: exp.Expr = exp.to_table(name)
 82            table_name = remove_namespace(name)
 83
 84            if not isinstance(source, exp.Select):
 85                source = exp.Select().from_(
 86                    exp.alias_(source, table_name, table=True, copy=False), copy=False
 87                )
 88
 89            joins.update(explicit_joins)
 90            query = self._add_joins(source, name, joins, group_by, mapping).select(
 91                *sorted(aggs, key=str), copy=False
 92            )
 93
 94            if not query.selects:
 95                query.select("*", copy=False)
 96
 97            if i == 0:
 98                where = select.args.pop("where", None)
 99
100                if where:
101                    query.where(_replace_table(where.this, table_name, base_alias), copy=False)
102
103                select.from_(query.subquery(base_alias, copy=False), copy=False)
104            else:
105                select.join(
106                    query,
107                    on=[e.eq(_replace_table(e.copy(), table_name, base_alias)) for e in group_by],  # type: ignore
108                    join_type=self.join_type,
109                    join_alias=table_name,
110                    copy=False,
111                )
112
113            for node in find_all_in_scope(query, exp.Column, exp.TableAlias):  # type: ignore[arg-type,var-annotated]
114                if isinstance(node, exp.Column):
115                    if node.table in mapping:
116                        node.set("table", exp.to_identifier(mapping[node.table]))
117                else:
118                    if node.name in mapping:
119                        node.set("this", exp.to_identifier(mapping[node.name]))
120
121    def _add_joins(
122        self,
123        source: exp.Select,
124        name: str,
125        joins: t.Dict[str, t.Optional[exp.Join]],
126        group_by: t.List[exp.Expr],
127        mapping: t.Dict[str, str],
128    ) -> exp.Select:
129        grain = [e.copy() for e in group_by]
130        table_name = remove_namespace(name)
131        mapping = {v: k for k, v in mapping.items()}
132
133        for expr in grain:
134            for node in expr.walk():
135                if isinstance(node, exp.Column):
136                    models = self.graph.models_for_column(name, node.name)
137
138                    if name in models:
139                        node.args["table"] = exp.to_identifier(table_name)
140                    elif models:
141                        t = mapping.get(node.table)
142                        model = next(
143                            (model for model in models if remove_namespace(model) == t),
144                            models[0],
145                        )
146                        node.args["table"] = exp.to_identifier(t or remove_namespace(model))
147                        if model not in joins:
148                            joins[model] = None
149
150        for target, join in joins.items():
151            path = self.graph.find_path(name, target)
152            for i in range(len(path) - 1):
153                a_ref = path[i]
154                b_ref = path[i + 1]
155                a_model_alias = remove_namespace(a_ref.model_name)
156                b_model_alias = remove_namespace(b_ref.model_name)
157
158                a = a_ref.expression.copy()
159                a.set("table", exp.to_identifier(a_model_alias))
160                b = b_ref.expression.copy()
161                b.set("table", exp.to_identifier(b_model_alias))
162                on = a.eq(b)
163
164                if join:
165                    join.set("on", on)
166                    source.append("joins", join)
167                else:
168                    source.join(
169                        b_ref.model_name,
170                        on=on,
171                        join_type="LEFT",
172                        join_alias=b_model_alias,
173                        dialect=self.dialect,
174                        copy=False,
175                    )
176
177        return source.select(*grain, copy=False).group_by(*grain, copy=False)
178
179
180def _replace_table(node: exp.Expr, table: str, base_alias: str) -> exp.Expr:
181    for column in find_all_in_scope(node, exp.Column):
182        if column.table == base_alias:
183            column.args["table"] = exp.to_identifier(table)
184    return node
185
186
187def rewrite(
188    sql: str | exp.Expr,
189    graph: ReferenceGraph,
190    metrics: t.Dict[str, Metric],
191    dialect: t.Optional[str] = "",
192) -> exp.Expr:
193    rewriter = Rewriter(graph=graph, metrics=metrics, dialect=dialect)
194
195    return optimize(
196        d.parse_one(sql, dialect=dialect) if isinstance(sql, str) else sql,
197        dialect=dialect,
198        quote_identifiers=False,
199        rules=(
200            qualify,
201            rewriter.rewrite,
202            optimize_joins,
203        ),
204    )
SourceAggsAndJoins = typing.Dict[str, typing.Tuple[typing.Set[sqlglot.expressions.core.AggFunc], typing.Dict[str, typing.Optional[sqlglot.expressions.query.Join]]]]
class Rewriter:
 22class Rewriter:
 23    def __init__(
 24        self,
 25        graph: ReferenceGraph,
 26        metrics: t.Dict[str, Metric],
 27        dialect: DialectType = "",
 28        join_type: str = "FULL",
 29        semantic_schema: str = "__semantic",
 30        semantic_table: str = "__table",
 31    ):
 32        self.graph = graph
 33        self.metrics = metrics
 34        self.dialect = dialect
 35        self.join_type = join_type
 36        self.semantic_name = f"{semantic_schema}.{semantic_table}"
 37
 38    def rewrite(self, expression: exp.Expr) -> exp.Expr:
 39        for select in list(expression.find_all(exp.Select)):
 40            self._expand(select)
 41
 42        return expression
 43
 44    def _build_sources(self, projections: t.List[exp.Expr]) -> SourceAggsAndJoins:
 45        sources: SourceAggsAndJoins = {}
 46
 47        for projection in projections:
 48            for ref in find_all_in_scope(projection, d.MetricAgg):
 49                metric = self.metrics[ref.this.name]
 50                ref.replace(metric.formula.this)
 51
 52                for agg, (measure, dims) in metric.aggs.items():
 53                    aggs, joins = sources.setdefault(measure, (set(), dict()))
 54                    aggs.add(agg)
 55                    for dim in dims:
 56                        joins[dim] = None
 57
 58        return sources
 59
 60    def _expand(self, select: exp.Select) -> None:
 61        base = select.args["from_"].this.find(exp.Table)
 62        base_alias = base.alias_or_name
 63        base_name = exp.table_name(base)
 64
 65        sources: SourceAggsAndJoins = (
 66            {} if base_name == self.semantic_name else {base_name: (set(), {})}
 67        )
 68        sources.update(self._build_sources(select.selects))
 69
 70        group = select.args.pop("group", None)
 71        group_by = group.expressions if group else []
 72
 73        mapping = {
 74            remove_namespace(exp.table_name(source.assert_is(exp.Table))): name
 75            for name, source in Scope(select).references
 76            if name != base_alias
 77        }
 78
 79        explicit_joins = {exp.table_name(join.this): join for join in select.args.pop("joins", [])}
 80
 81        for i, (name, (aggs, joins)) in enumerate(sources.items()):
 82            source: exp.Expr = exp.to_table(name)
 83            table_name = remove_namespace(name)
 84
 85            if not isinstance(source, exp.Select):
 86                source = exp.Select().from_(
 87                    exp.alias_(source, table_name, table=True, copy=False), copy=False
 88                )
 89
 90            joins.update(explicit_joins)
 91            query = self._add_joins(source, name, joins, group_by, mapping).select(
 92                *sorted(aggs, key=str), copy=False
 93            )
 94
 95            if not query.selects:
 96                query.select("*", copy=False)
 97
 98            if i == 0:
 99                where = select.args.pop("where", None)
100
101                if where:
102                    query.where(_replace_table(where.this, table_name, base_alias), copy=False)
103
104                select.from_(query.subquery(base_alias, copy=False), copy=False)
105            else:
106                select.join(
107                    query,
108                    on=[e.eq(_replace_table(e.copy(), table_name, base_alias)) for e in group_by],  # type: ignore
109                    join_type=self.join_type,
110                    join_alias=table_name,
111                    copy=False,
112                )
113
114            for node in find_all_in_scope(query, exp.Column, exp.TableAlias):  # type: ignore[arg-type,var-annotated]
115                if isinstance(node, exp.Column):
116                    if node.table in mapping:
117                        node.set("table", exp.to_identifier(mapping[node.table]))
118                else:
119                    if node.name in mapping:
120                        node.set("this", exp.to_identifier(mapping[node.name]))
121
122    def _add_joins(
123        self,
124        source: exp.Select,
125        name: str,
126        joins: t.Dict[str, t.Optional[exp.Join]],
127        group_by: t.List[exp.Expr],
128        mapping: t.Dict[str, str],
129    ) -> exp.Select:
130        grain = [e.copy() for e in group_by]
131        table_name = remove_namespace(name)
132        mapping = {v: k for k, v in mapping.items()}
133
134        for expr in grain:
135            for node in expr.walk():
136                if isinstance(node, exp.Column):
137                    models = self.graph.models_for_column(name, node.name)
138
139                    if name in models:
140                        node.args["table"] = exp.to_identifier(table_name)
141                    elif models:
142                        t = mapping.get(node.table)
143                        model = next(
144                            (model for model in models if remove_namespace(model) == t),
145                            models[0],
146                        )
147                        node.args["table"] = exp.to_identifier(t or remove_namespace(model))
148                        if model not in joins:
149                            joins[model] = None
150
151        for target, join in joins.items():
152            path = self.graph.find_path(name, target)
153            for i in range(len(path) - 1):
154                a_ref = path[i]
155                b_ref = path[i + 1]
156                a_model_alias = remove_namespace(a_ref.model_name)
157                b_model_alias = remove_namespace(b_ref.model_name)
158
159                a = a_ref.expression.copy()
160                a.set("table", exp.to_identifier(a_model_alias))
161                b = b_ref.expression.copy()
162                b.set("table", exp.to_identifier(b_model_alias))
163                on = a.eq(b)
164
165                if join:
166                    join.set("on", on)
167                    source.append("joins", join)
168                else:
169                    source.join(
170                        b_ref.model_name,
171                        on=on,
172                        join_type="LEFT",
173                        join_alias=b_model_alias,
174                        dialect=self.dialect,
175                        copy=False,
176                    )
177
178        return source.select(*grain, copy=False).group_by(*grain, copy=False)
Rewriter( graph: sqlmesh.core.reference.ReferenceGraph, metrics: Dict[str, sqlmesh.core.metric.definition.Metric], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = '', join_type: str = 'FULL', semantic_schema: str = '__semantic', semantic_table: str = '__table')
23    def __init__(
24        self,
25        graph: ReferenceGraph,
26        metrics: t.Dict[str, Metric],
27        dialect: DialectType = "",
28        join_type: str = "FULL",
29        semantic_schema: str = "__semantic",
30        semantic_table: str = "__table",
31    ):
32        self.graph = graph
33        self.metrics = metrics
34        self.dialect = dialect
35        self.join_type = join_type
36        self.semantic_name = f"{semantic_schema}.{semantic_table}"
graph
metrics
dialect
join_type
semantic_name
def rewrite( self, expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
38    def rewrite(self, expression: exp.Expr) -> exp.Expr:
39        for select in list(expression.find_all(exp.Select)):
40            self._expand(select)
41
42        return expression
def rewrite( sql: str | sqlglot.expressions.core.Expr, graph: sqlmesh.core.reference.ReferenceGraph, metrics: Dict[str, sqlmesh.core.metric.definition.Metric], dialect: Optional[str] = '') -> sqlglot.expressions.core.Expr:
188def rewrite(
189    sql: str | exp.Expr,
190    graph: ReferenceGraph,
191    metrics: t.Dict[str, Metric],
192    dialect: t.Optional[str] = "",
193) -> exp.Expr:
194    rewriter = Rewriter(graph=graph, metrics=metrics, dialect=dialect)
195
196    return optimize(
197        d.parse_one(sql, dialect=dialect) if isinstance(sql, str) else sql,
198        dialect=dialect,
199        quote_identifiers=False,
200        rules=(
201            qualify,
202            rewriter.rewrite,
203            optimize_joins,
204        ),
205    )