Edit on GitHub

Fix expressions that contain jinja.

 1"""Fix expressions that contain jinja."""
 2
 3import json
 4import typing as t
 5
 6import pandas as pd
 7from sqlglot import exp
 8
 9from sqlmesh.utils.jinja import has_jinja
10from sqlmesh.utils.migration import index_text_type
11
12
13def migrate(state_sync, **kwargs):  # type: ignore
14    engine_adapter = state_sync.engine_adapter
15    schema = state_sync.schema
16    snapshots_table = "_snapshots"
17    if schema:
18        snapshots_table = f"{schema}.{snapshots_table}"
19
20    new_snapshots = []
21
22    for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall(
23        exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table),
24        quote_identifiers=True,
25    ):
26        parsed_snapshot = json.loads(snapshot)
27        audits = parsed_snapshot.get("audits", [])
28        model = parsed_snapshot["model"]
29
30        if "query" in model and has_jinja(model["query"]):
31            model["query"] = _wrap_query(model["query"])
32
33        _wrap_statements(model, "pre_statements")
34        _wrap_statements(model, "post_statements")
35
36        for audit in audits:
37            if has_jinja(audit["query"]):
38                audit["query"] = _wrap_query(audit["query"])
39            _wrap_statements(audit, "expressions")
40
41        new_snapshots.append(
42            {
43                "name": name,
44                "identifier": identifier,
45                "version": version,
46                "snapshot": json.dumps(parsed_snapshot),
47                "kind_name": kind_name,
48            }
49        )
50
51    if new_snapshots:
52        engine_adapter.delete_from(snapshots_table, "TRUE")
53
54        index_type = index_text_type(engine_adapter.dialect)
55
56        engine_adapter.insert_append(
57            snapshots_table,
58            pd.DataFrame(new_snapshots),
59            columns_to_types={
60                "name": exp.DataType.build(index_type),
61                "identifier": exp.DataType.build(index_type),
62                "version": exp.DataType.build(index_type),
63                "snapshot": exp.DataType.build("text"),
64                "kind_name": exp.DataType.build(index_type),
65            },
66        )
67
68
69def _wrap_statements(obj: t.Dict, key: str) -> None:
70    updated_statements = []
71    for statement in obj.get(key, []):
72        if has_jinja(statement):
73            statement = _wrap_statement(statement)
74        updated_statements.append(statement)
75
76    if updated_statements:
77        obj[key] = updated_statements
78
79
80def _wrap_query(sql: str) -> str:
81    return f"JINJA_QUERY_BEGIN;\n{sql}\nJINJA_END;"
82
83
84def _wrap_statement(sql: str) -> str:
85    return f"JINJA_STATEMENT_BEGIN;\n{sql}\nJINJA_END;"
def migrate(state_sync, **kwargs):
14def migrate(state_sync, **kwargs):  # type: ignore
15    engine_adapter = state_sync.engine_adapter
16    schema = state_sync.schema
17    snapshots_table = "_snapshots"
18    if schema:
19        snapshots_table = f"{schema}.{snapshots_table}"
20
21    new_snapshots = []
22
23    for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall(
24        exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table),
25        quote_identifiers=True,
26    ):
27        parsed_snapshot = json.loads(snapshot)
28        audits = parsed_snapshot.get("audits", [])
29        model = parsed_snapshot["model"]
30
31        if "query" in model and has_jinja(model["query"]):
32            model["query"] = _wrap_query(model["query"])
33
34        _wrap_statements(model, "pre_statements")
35        _wrap_statements(model, "post_statements")
36
37        for audit in audits:
38            if has_jinja(audit["query"]):
39                audit["query"] = _wrap_query(audit["query"])
40            _wrap_statements(audit, "expressions")
41
42        new_snapshots.append(
43            {
44                "name": name,
45                "identifier": identifier,
46                "version": version,
47                "snapshot": json.dumps(parsed_snapshot),
48                "kind_name": kind_name,
49            }
50        )
51
52    if new_snapshots:
53        engine_adapter.delete_from(snapshots_table, "TRUE")
54
55        index_type = index_text_type(engine_adapter.dialect)
56
57        engine_adapter.insert_append(
58            snapshots_table,
59            pd.DataFrame(new_snapshots),
60            columns_to_types={
61                "name": exp.DataType.build(index_type),
62                "identifier": exp.DataType.build(index_type),
63                "version": exp.DataType.build(index_type),
64                "snapshot": exp.DataType.build("text"),
65                "kind_name": exp.DataType.build(index_type),
66            },
67        )