Edit on GitHub

Serialize SQL using the dialect of each model.

 1"""Serialize SQL using the dialect of each model."""
 2
 3import json
 4import typing as t
 5
 6import pandas as pd
 7from sqlglot import exp, parse_one
 8
 9from sqlmesh.utils.jinja import has_jinja
10from sqlmesh.utils.migration import index_text_type
11
12
13def migrate(state_sync, **kwargs):  # type: ignore
14    engine_adapter = state_sync.engine_adapter
15    schema = state_sync.schema
16    snapshots_table = "_snapshots"
17    if schema:
18        snapshots_table = f"{schema}.{snapshots_table}"
19
20    new_snapshots = []
21
22    for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall(
23        exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table),
24        quote_identifiers=True,
25    ):
26        parsed_snapshot = json.loads(snapshot)
27        model = parsed_snapshot["model"]
28        dialect = model["dialect"]
29
30        _update_expression(model, "query", dialect)
31        _update_expression_list(model, "pre_statements", dialect)
32        _update_expression_list(model, "post_statements", dialect)
33
34        for audit in parsed_snapshot.get("audits", []):
35            dialect = audit["dialect"]
36            _update_expression(audit, "query", dialect)
37            _update_expression_list(audit, "expressions", dialect)
38
39        new_snapshots.append(
40            {
41                "name": name,
42                "identifier": identifier,
43                "version": version,
44                "snapshot": json.dumps(parsed_snapshot),
45                "kind_name": kind_name,
46            }
47        )
48
49    if new_snapshots:
50        engine_adapter.delete_from(snapshots_table, "TRUE")
51
52        index_type = index_text_type(engine_adapter.dialect)
53
54        engine_adapter.insert_append(
55            snapshots_table,
56            pd.DataFrame(new_snapshots),
57            columns_to_types={
58                "name": exp.DataType.build(index_type),
59                "identifier": exp.DataType.build(index_type),
60                "version": exp.DataType.build(index_type),
61                "snapshot": exp.DataType.build("text"),
62                "kind_name": exp.DataType.build(index_type),
63            },
64        )
65
66
67# Note: previously we used to do serde using the SQLGlot dialect, so we need to parse the
68# stored queries using that dialect and then write them back using the correct dialect.
69
70
71def _update_expression(obj: t.Dict, key: str, dialect: str) -> None:
72    if key in obj and not has_jinja(obj[key]):
73        obj[key] = parse_one(obj[key]).sql(dialect=dialect)
74
75
76def _update_expression_list(obj: t.Dict, key: str, dialect: str) -> None:
77    if key in obj:
78        obj[key] = [
79            (
80                parse_one(expression).sql(dialect=dialect)
81                if not has_jinja(expression)
82                else expression
83            )
84            for expression in obj[key]
85            if expression
86        ]
def migrate(state_sync, **kwargs):
14def migrate(state_sync, **kwargs):  # type: ignore
15    engine_adapter = state_sync.engine_adapter
16    schema = state_sync.schema
17    snapshots_table = "_snapshots"
18    if schema:
19        snapshots_table = f"{schema}.{snapshots_table}"
20
21    new_snapshots = []
22
23    for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall(
24        exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table),
25        quote_identifiers=True,
26    ):
27        parsed_snapshot = json.loads(snapshot)
28        model = parsed_snapshot["model"]
29        dialect = model["dialect"]
30
31        _update_expression(model, "query", dialect)
32        _update_expression_list(model, "pre_statements", dialect)
33        _update_expression_list(model, "post_statements", dialect)
34
35        for audit in parsed_snapshot.get("audits", []):
36            dialect = audit["dialect"]
37            _update_expression(audit, "query", dialect)
38            _update_expression_list(audit, "expressions", dialect)
39
40        new_snapshots.append(
41            {
42                "name": name,
43                "identifier": identifier,
44                "version": version,
45                "snapshot": json.dumps(parsed_snapshot),
46                "kind_name": kind_name,
47            }
48        )
49
50    if new_snapshots:
51        engine_adapter.delete_from(snapshots_table, "TRUE")
52
53        index_type = index_text_type(engine_adapter.dialect)
54
55        engine_adapter.insert_append(
56            snapshots_table,
57            pd.DataFrame(new_snapshots),
58            columns_to_types={
59                "name": exp.DataType.build(index_type),
60                "identifier": exp.DataType.build(index_type),
61                "version": exp.DataType.build(index_type),
62                "snapshot": exp.DataType.build("text"),
63                "kind_name": exp.DataType.build(index_type),
64            },
65        )