Fix expressions that contain jinja.
1"""Fix expressions that contain jinja.""" 2 3import json 4import typing as t 5 6import pandas as pd 7from sqlglot import exp 8 9from sqlmesh.utils.jinja import has_jinja 10from sqlmesh.utils.migration import index_text_type 11 12 13def migrate(state_sync, **kwargs): # type: ignore 14 engine_adapter = state_sync.engine_adapter 15 schema = state_sync.schema 16 snapshots_table = "_snapshots" 17 if schema: 18 snapshots_table = f"{schema}.{snapshots_table}" 19 20 new_snapshots = [] 21 22 for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( 23 exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), 24 quote_identifiers=True, 25 ): 26 parsed_snapshot = json.loads(snapshot) 27 audits = parsed_snapshot.get("audits", []) 28 model = parsed_snapshot["model"] 29 30 if "query" in model and has_jinja(model["query"]): 31 model["query"] = _wrap_query(model["query"]) 32 33 _wrap_statements(model, "pre_statements") 34 _wrap_statements(model, "post_statements") 35 36 for audit in audits: 37 if has_jinja(audit["query"]): 38 audit["query"] = _wrap_query(audit["query"]) 39 _wrap_statements(audit, "expressions") 40 41 new_snapshots.append( 42 { 43 "name": name, 44 "identifier": identifier, 45 "version": version, 46 "snapshot": json.dumps(parsed_snapshot), 47 "kind_name": kind_name, 48 } 49 ) 50 51 if new_snapshots: 52 engine_adapter.delete_from(snapshots_table, "TRUE") 53 54 index_type = index_text_type(engine_adapter.dialect) 55 56 engine_adapter.insert_append( 57 snapshots_table, 58 pd.DataFrame(new_snapshots), 59 columns_to_types={ 60 "name": exp.DataType.build(index_type), 61 "identifier": exp.DataType.build(index_type), 62 "version": exp.DataType.build(index_type), 63 "snapshot": exp.DataType.build("text"), 64 "kind_name": exp.DataType.build(index_type), 65 }, 66 ) 67 68 69def _wrap_statements(obj: t.Dict, key: str) -> None: 70 updated_statements = [] 71 for statement in obj.get(key, []): 72 if has_jinja(statement): 73 statement = _wrap_statement(statement) 74 updated_statements.append(statement) 75 76 if updated_statements: 77 obj[key] = updated_statements 78 79 80def _wrap_query(sql: str) -> str: 81 return f"JINJA_QUERY_BEGIN;\n{sql}\nJINJA_END;" 82 83 84def _wrap_statement(sql: str) -> str: 85 return f"JINJA_STATEMENT_BEGIN;\n{sql}\nJINJA_END;"
def
migrate(state_sync, **kwargs):
14def migrate(state_sync, **kwargs): # type: ignore 15 engine_adapter = state_sync.engine_adapter 16 schema = state_sync.schema 17 snapshots_table = "_snapshots" 18 if schema: 19 snapshots_table = f"{schema}.{snapshots_table}" 20 21 new_snapshots = [] 22 23 for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( 24 exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), 25 quote_identifiers=True, 26 ): 27 parsed_snapshot = json.loads(snapshot) 28 audits = parsed_snapshot.get("audits", []) 29 model = parsed_snapshot["model"] 30 31 if "query" in model and has_jinja(model["query"]): 32 model["query"] = _wrap_query(model["query"]) 33 34 _wrap_statements(model, "pre_statements") 35 _wrap_statements(model, "post_statements") 36 37 for audit in audits: 38 if has_jinja(audit["query"]): 39 audit["query"] = _wrap_query(audit["query"]) 40 _wrap_statements(audit, "expressions") 41 42 new_snapshots.append( 43 { 44 "name": name, 45 "identifier": identifier, 46 "version": version, 47 "snapshot": json.dumps(parsed_snapshot), 48 "kind_name": kind_name, 49 } 50 ) 51 52 if new_snapshots: 53 engine_adapter.delete_from(snapshots_table, "TRUE") 54 55 index_type = index_text_type(engine_adapter.dialect) 56 57 engine_adapter.insert_append( 58 snapshots_table, 59 pd.DataFrame(new_snapshots), 60 columns_to_types={ 61 "name": exp.DataType.build(index_type), 62 "identifier": exp.DataType.build(index_type), 63 "version": exp.DataType.build(index_type), 64 "snapshot": exp.DataType.build("text"), 65 "kind_name": exp.DataType.build(index_type), 66 }, 67 )