Serialize SQL using the dialect of each model.
1"""Serialize SQL using the dialect of each model.""" 2 3import json 4import typing as t 5 6import pandas as pd 7from sqlglot import exp, parse_one 8 9from sqlmesh.utils.jinja import has_jinja 10from sqlmesh.utils.migration import index_text_type 11 12 13def migrate(state_sync, **kwargs): # type: ignore 14 engine_adapter = state_sync.engine_adapter 15 schema = state_sync.schema 16 snapshots_table = "_snapshots" 17 if schema: 18 snapshots_table = f"{schema}.{snapshots_table}" 19 20 new_snapshots = [] 21 22 for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( 23 exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), 24 quote_identifiers=True, 25 ): 26 parsed_snapshot = json.loads(snapshot) 27 model = parsed_snapshot["model"] 28 dialect = model["dialect"] 29 30 _update_expression(model, "query", dialect) 31 _update_expression_list(model, "pre_statements", dialect) 32 _update_expression_list(model, "post_statements", dialect) 33 34 for audit in parsed_snapshot.get("audits", []): 35 dialect = audit["dialect"] 36 _update_expression(audit, "query", dialect) 37 _update_expression_list(audit, "expressions", dialect) 38 39 new_snapshots.append( 40 { 41 "name": name, 42 "identifier": identifier, 43 "version": version, 44 "snapshot": json.dumps(parsed_snapshot), 45 "kind_name": kind_name, 46 } 47 ) 48 49 if new_snapshots: 50 engine_adapter.delete_from(snapshots_table, "TRUE") 51 52 index_type = index_text_type(engine_adapter.dialect) 53 54 engine_adapter.insert_append( 55 snapshots_table, 56 pd.DataFrame(new_snapshots), 57 columns_to_types={ 58 "name": exp.DataType.build(index_type), 59 "identifier": exp.DataType.build(index_type), 60 "version": exp.DataType.build(index_type), 61 "snapshot": exp.DataType.build("text"), 62 "kind_name": exp.DataType.build(index_type), 63 }, 64 ) 65 66 67# Note: previously we used to do serde using the SQLGlot dialect, so we need to parse the 68# stored queries using that dialect and then write them back using the correct dialect. 69 70 71def _update_expression(obj: t.Dict, key: str, dialect: str) -> None: 72 if key in obj and not has_jinja(obj[key]): 73 obj[key] = parse_one(obj[key]).sql(dialect=dialect) 74 75 76def _update_expression_list(obj: t.Dict, key: str, dialect: str) -> None: 77 if key in obj: 78 obj[key] = [ 79 ( 80 parse_one(expression).sql(dialect=dialect) 81 if not has_jinja(expression) 82 else expression 83 ) 84 for expression in obj[key] 85 if expression 86 ]
def
migrate(state_sync, **kwargs):
14def migrate(state_sync, **kwargs): # type: ignore 15 engine_adapter = state_sync.engine_adapter 16 schema = state_sync.schema 17 snapshots_table = "_snapshots" 18 if schema: 19 snapshots_table = f"{schema}.{snapshots_table}" 20 21 new_snapshots = [] 22 23 for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( 24 exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), 25 quote_identifiers=True, 26 ): 27 parsed_snapshot = json.loads(snapshot) 28 model = parsed_snapshot["model"] 29 dialect = model["dialect"] 30 31 _update_expression(model, "query", dialect) 32 _update_expression_list(model, "pre_statements", dialect) 33 _update_expression_list(model, "post_statements", dialect) 34 35 for audit in parsed_snapshot.get("audits", []): 36 dialect = audit["dialect"] 37 _update_expression(audit, "query", dialect) 38 _update_expression_list(audit, "expressions", dialect) 39 40 new_snapshots.append( 41 { 42 "name": name, 43 "identifier": identifier, 44 "version": version, 45 "snapshot": json.dumps(parsed_snapshot), 46 "kind_name": kind_name, 47 } 48 ) 49 50 if new_snapshots: 51 engine_adapter.delete_from(snapshots_table, "TRUE") 52 53 index_type = index_text_type(engine_adapter.dialect) 54 55 engine_adapter.insert_append( 56 snapshots_table, 57 pd.DataFrame(new_snapshots), 58 columns_to_types={ 59 "name": exp.DataType.build(index_type), 60 "identifier": exp.DataType.build(index_type), 61 "version": exp.DataType.build(index_type), 62 "snapshot": exp.DataType.build("text"), 63 "kind_name": exp.DataType.build(index_type), 64 }, 65 )