Edit on GitHub

This script's goal is to warn users if there is both a metadata and non-metadata reference in the python environment of a model. Additionally, it warns them if there's a macro referenced in a used audit's query, in the argument list of the audits and signals properties, or in an on_virtual_update statement.

Context:

The metadata status for macros and signals is now transitive, i.e. every dependency of a metadata macro or signal is also metadata, unless it is referenced by a non-metadata object.

This means that global references of metadata objects may now be excluded from the data hash calculation because of their new metadata status, which would lead to a diff.

Additionally, we now implicitly treat macro refs in the aforementioned statements as "metadata-only", even though they may not be marked as such by a user. This may also lead to a diff.

  1"""
  2This script's goal is to warn users if there is both a metadata and non-metadata reference in
  3the python environment of a model. Additionally, it warns them if there's a macro referenced
  4in a used audit's query, in the argument list of the audits and signals properties, or in an
  5on_virtual_update statement.
  6
  7Context:
  8
  9The metadata status for macros and signals is now transitive, i.e. every dependency of a
 10metadata macro or signal is also metadata, unless it is referenced by a non-metadata object.
 11
 12This means that global references of metadata objects may now be excluded from the data hash
 13calculation because of their new metadata status, which would lead to a diff.
 14
 15Additionally, we now implicitly treat macro refs in the aforementioned statements as "metadata-only",
 16even though they may not be marked as such by a user. This may also lead to a diff.
 17"""
 18
 19import json
 20
 21from sqlglot import exp
 22
 23import sqlmesh.core.dialect as d
 24from sqlmesh.core.console import get_console
 25
 26
 27def migrate_schemas(engine_adapter, schema, **kwargs):  # type: ignore
 28    pass
 29
 30
 31def migrate_rows(engine_adapter, schema, **kwargs):  # type: ignore
 32    snapshots_table = "_snapshots"
 33    if schema:
 34        snapshots_table = f"{schema}.{snapshots_table}"
 35
 36    warning = (
 37        "SQLMesh detected that it may not be able to fully migrate the state database. This should not impact "
 38        "the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` "
 39        "command. Please run `sqlmesh diff prod` after the migration has completed, before making any new "
 40        "changes. If any unexpected changes are reported, consider running a forward-only plan to apply these "
 41        "changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. "
 42        "See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n"
 43    )
 44
 45    for (snapshot,) in engine_adapter.fetchall(
 46        exp.select("snapshot").from_(snapshots_table), quote_identifiers=True
 47    ):
 48        parsed_snapshot = json.loads(snapshot)
 49        node = parsed_snapshot["node"]
 50
 51        # Standalone audits don't have a data hash, so they're unaffected
 52        if node.get("source_type") == "audit":
 53            continue
 54
 55        python_env = node.get("python_env") or {}
 56
 57        has_metadata = False
 58        has_non_metadata = False
 59
 60        for k, v in python_env.items():
 61            if v.get("is_metadata"):
 62                has_metadata = True
 63            else:
 64                has_non_metadata = True
 65
 66            if has_metadata and has_non_metadata:
 67                get_console().log_warning(warning)
 68                return
 69
 70        dialect = node.get("dialect")
 71        metadata_hash_statements = []
 72
 73        # We use try-except here as a conservative measure to avoid any unexpected exceptions
 74        try:
 75            if on_virtual_update := node.get("on_virtual_update"):
 76                metadata_hash_statements.extend(parse_expression(on_virtual_update, dialect))
 77
 78            for _, audit_args in func_call_validator(node.get("audits") or []):
 79                metadata_hash_statements.extend(audit_args.values())
 80
 81            for signal_name, signal_args in func_call_validator(
 82                node.get("signals") or [], is_signal=True
 83            ):
 84                metadata_hash_statements.extend(signal_args.values())
 85
 86            if audit_definitions := node.get("audit_definitions"):
 87                audit_queries = [
 88                    parse_expression(audit["query"], audit["dialect"])
 89                    for audit in audit_definitions.values()
 90                ]
 91                metadata_hash_statements.extend(audit_queries)
 92
 93            for macro_name in extract_used_macros(metadata_hash_statements):
 94                serialized_macro = python_env.get(macro_name)
 95                if isinstance(serialized_macro, dict) and not serialized_macro.get("is_metadata"):
 96                    get_console().log_warning(warning)
 97                    return
 98        except Exception:
 99            pass
100
101
102def extract_used_macros(expressions):
103    used_macros = set()
104    for expression in expressions:
105        if isinstance(expression, d.Jinja):
106            continue
107
108        for macro_func in expression.find_all(d.MacroFunc):
109            if macro_func.__class__ is d.MacroFunc:
110                used_macros.add(macro_func.this.name.lower())
111
112    return used_macros
113
114
115def func_call_validator(v, is_signal=False):
116    assert isinstance(v, list)
117
118    audits = []
119    for entry in v:
120        if isinstance(entry, dict):
121            args = entry
122            name = "" if is_signal else entry.pop("name")
123        else:
124            assert isinstance(entry, (tuple, list))
125            name, args = entry
126
127        parsed_audit = {
128            key: d.parse_one(value) if isinstance(value, str) else value
129            for key, value in args.items()
130        }
131        audits.append((name.lower(), parsed_audit))
132
133    return audits
134
135
136def parse_expression(v, dialect):
137    if v is None:
138        return None
139
140    if isinstance(v, list):
141        return [d.parse_one(e, dialect=dialect) for e in v]
142
143    assert isinstance(v, str)
144    return d.parse_one(v, dialect=dialect)
def migrate_schemas(engine_adapter, schema, **kwargs):
28def migrate_schemas(engine_adapter, schema, **kwargs):  # type: ignore
29    pass
def migrate_rows(engine_adapter, schema, **kwargs):
 32def migrate_rows(engine_adapter, schema, **kwargs):  # type: ignore
 33    snapshots_table = "_snapshots"
 34    if schema:
 35        snapshots_table = f"{schema}.{snapshots_table}"
 36
 37    warning = (
 38        "SQLMesh detected that it may not be able to fully migrate the state database. This should not impact "
 39        "the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` "
 40        "command. Please run `sqlmesh diff prod` after the migration has completed, before making any new "
 41        "changes. If any unexpected changes are reported, consider running a forward-only plan to apply these "
 42        "changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. "
 43        "See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n"
 44    )
 45
 46    for (snapshot,) in engine_adapter.fetchall(
 47        exp.select("snapshot").from_(snapshots_table), quote_identifiers=True
 48    ):
 49        parsed_snapshot = json.loads(snapshot)
 50        node = parsed_snapshot["node"]
 51
 52        # Standalone audits don't have a data hash, so they're unaffected
 53        if node.get("source_type") == "audit":
 54            continue
 55
 56        python_env = node.get("python_env") or {}
 57
 58        has_metadata = False
 59        has_non_metadata = False
 60
 61        for k, v in python_env.items():
 62            if v.get("is_metadata"):
 63                has_metadata = True
 64            else:
 65                has_non_metadata = True
 66
 67            if has_metadata and has_non_metadata:
 68                get_console().log_warning(warning)
 69                return
 70
 71        dialect = node.get("dialect")
 72        metadata_hash_statements = []
 73
 74        # We use try-except here as a conservative measure to avoid any unexpected exceptions
 75        try:
 76            if on_virtual_update := node.get("on_virtual_update"):
 77                metadata_hash_statements.extend(parse_expression(on_virtual_update, dialect))
 78
 79            for _, audit_args in func_call_validator(node.get("audits") or []):
 80                metadata_hash_statements.extend(audit_args.values())
 81
 82            for signal_name, signal_args in func_call_validator(
 83                node.get("signals") or [], is_signal=True
 84            ):
 85                metadata_hash_statements.extend(signal_args.values())
 86
 87            if audit_definitions := node.get("audit_definitions"):
 88                audit_queries = [
 89                    parse_expression(audit["query"], audit["dialect"])
 90                    for audit in audit_definitions.values()
 91                ]
 92                metadata_hash_statements.extend(audit_queries)
 93
 94            for macro_name in extract_used_macros(metadata_hash_statements):
 95                serialized_macro = python_env.get(macro_name)
 96                if isinstance(serialized_macro, dict) and not serialized_macro.get("is_metadata"):
 97                    get_console().log_warning(warning)
 98                    return
 99        except Exception:
100            pass
def extract_used_macros(expressions):
103def extract_used_macros(expressions):
104    used_macros = set()
105    for expression in expressions:
106        if isinstance(expression, d.Jinja):
107            continue
108
109        for macro_func in expression.find_all(d.MacroFunc):
110            if macro_func.__class__ is d.MacroFunc:
111                used_macros.add(macro_func.this.name.lower())
112
113    return used_macros
def func_call_validator(v, is_signal=False):
116def func_call_validator(v, is_signal=False):
117    assert isinstance(v, list)
118
119    audits = []
120    for entry in v:
121        if isinstance(entry, dict):
122            args = entry
123            name = "" if is_signal else entry.pop("name")
124        else:
125            assert isinstance(entry, (tuple, list))
126            name, args = entry
127
128        parsed_audit = {
129            key: d.parse_one(value) if isinstance(value, str) else value
130            for key, value in args.items()
131        }
132        audits.append((name.lower(), parsed_audit))
133
134    return audits
def parse_expression(v, dialect):
137def parse_expression(v, dialect):
138    if v is None:
139        return None
140
141    if isinstance(v, list):
142        return [d.parse_one(e, dialect=dialect) for e in v]
143
144    assert isinstance(v, str)
145    return d.parse_one(v, dialect=dialect)