Edit on GitHub

Change serialization of signals to allow for function calls.

  1"""Change serialization of signals to allow for function calls."""
  2
  3import json
  4
  5from sqlglot import exp, parse_one
  6
  7from sqlmesh.utils.migration import index_text_type, blob_text_type
  8
  9
 10def migrate_schemas(engine_adapter, schema, **kwargs):  # type: ignore
 11    pass
 12
 13
 14def migrate_rows(engine_adapter, schema, **kwargs):  # type: ignore
 15    import pandas as pd
 16
 17    snapshots_table = "_snapshots"
 18    index_type = index_text_type(engine_adapter.dialect)
 19    if schema:
 20        snapshots_table = f"{schema}.{snapshots_table}"
 21
 22    new_snapshots = []
 23
 24    signal_change = False
 25    for (
 26        name,
 27        identifier,
 28        version,
 29        snapshot,
 30        kind_name,
 31        updated_ts,
 32        unpaused_ts,
 33        ttl_ms,
 34        unrestorable,
 35    ) in engine_adapter.fetchall(
 36        exp.select(
 37            "name",
 38            "identifier",
 39            "version",
 40            "snapshot",
 41            "kind_name",
 42            "updated_ts",
 43            "unpaused_ts",
 44            "ttl_ms",
 45            "unrestorable",
 46        ).from_(snapshots_table),
 47        quote_identifiers=True,
 48    ):
 49        parsed_snapshot = json.loads(snapshot)
 50        node = parsed_snapshot["node"]
 51        signals = node.get("signals")
 52
 53        if signals:
 54            signal_change = True
 55            node["signals"] = []
 56
 57            for signal in signals:
 58                node["signals"].append(
 59                    (
 60                        "",
 61                        {
 62                            eq.left.name: eq.right.sql()
 63                            for eq in parse_one(signal, into=exp.Tuple).expressions
 64                        },
 65                    )
 66                )
 67
 68        new_snapshots.append(
 69            {
 70                "name": name,
 71                "identifier": identifier,
 72                "version": version,
 73                "snapshot": json.dumps(parsed_snapshot),
 74                "kind_name": kind_name,
 75                "updated_ts": updated_ts,
 76                "unpaused_ts": unpaused_ts,
 77                "ttl_ms": ttl_ms,
 78                "unrestorable": unrestorable,
 79            }
 80        )
 81
 82    if signal_change and new_snapshots:
 83        engine_adapter.delete_from(snapshots_table, "TRUE")
 84        blob_type = blob_text_type(engine_adapter.dialect)
 85
 86        engine_adapter.insert_append(
 87            snapshots_table,
 88            pd.DataFrame(new_snapshots),
 89            target_columns_to_types={
 90                "name": exp.DataType.build(index_type),
 91                "identifier": exp.DataType.build(index_type),
 92                "version": exp.DataType.build(index_type),
 93                "snapshot": exp.DataType.build(blob_type),
 94                "kind_name": exp.DataType.build(index_type),
 95                "updated_ts": exp.DataType.build("bigint"),
 96                "unpaused_ts": exp.DataType.build("bigint"),
 97                "ttl_ms": exp.DataType.build("bigint"),
 98                "unrestorable": exp.DataType.build("boolean"),
 99            },
100        )
def migrate_schemas(engine_adapter, schema, **kwargs):
11def migrate_schemas(engine_adapter, schema, **kwargs):  # type: ignore
12    pass
def migrate_rows(engine_adapter, schema, **kwargs):
 15def migrate_rows(engine_adapter, schema, **kwargs):  # type: ignore
 16    import pandas as pd
 17
 18    snapshots_table = "_snapshots"
 19    index_type = index_text_type(engine_adapter.dialect)
 20    if schema:
 21        snapshots_table = f"{schema}.{snapshots_table}"
 22
 23    new_snapshots = []
 24
 25    signal_change = False
 26    for (
 27        name,
 28        identifier,
 29        version,
 30        snapshot,
 31        kind_name,
 32        updated_ts,
 33        unpaused_ts,
 34        ttl_ms,
 35        unrestorable,
 36    ) in engine_adapter.fetchall(
 37        exp.select(
 38            "name",
 39            "identifier",
 40            "version",
 41            "snapshot",
 42            "kind_name",
 43            "updated_ts",
 44            "unpaused_ts",
 45            "ttl_ms",
 46            "unrestorable",
 47        ).from_(snapshots_table),
 48        quote_identifiers=True,
 49    ):
 50        parsed_snapshot = json.loads(snapshot)
 51        node = parsed_snapshot["node"]
 52        signals = node.get("signals")
 53
 54        if signals:
 55            signal_change = True
 56            node["signals"] = []
 57
 58            for signal in signals:
 59                node["signals"].append(
 60                    (
 61                        "",
 62                        {
 63                            eq.left.name: eq.right.sql()
 64                            for eq in parse_one(signal, into=exp.Tuple).expressions
 65                        },
 66                    )
 67                )
 68
 69        new_snapshots.append(
 70            {
 71                "name": name,
 72                "identifier": identifier,
 73                "version": version,
 74                "snapshot": json.dumps(parsed_snapshot),
 75                "kind_name": kind_name,
 76                "updated_ts": updated_ts,
 77                "unpaused_ts": unpaused_ts,
 78                "ttl_ms": ttl_ms,
 79                "unrestorable": unrestorable,
 80            }
 81        )
 82
 83    if signal_change and new_snapshots:
 84        engine_adapter.delete_from(snapshots_table, "TRUE")
 85        blob_type = blob_text_type(engine_adapter.dialect)
 86
 87        engine_adapter.insert_append(
 88            snapshots_table,
 89            pd.DataFrame(new_snapshots),
 90            target_columns_to_types={
 91                "name": exp.DataType.build(index_type),
 92                "identifier": exp.DataType.build(index_type),
 93                "version": exp.DataType.build(index_type),
 94                "snapshot": exp.DataType.build(blob_type),
 95                "kind_name": exp.DataType.build(index_type),
 96                "updated_ts": exp.DataType.build("bigint"),
 97                "unpaused_ts": exp.DataType.build("bigint"),
 98                "ttl_ms": exp.DataType.build("bigint"),
 99                "unrestorable": exp.DataType.build("boolean"),
100            },
101        )