This script's goal is to warn users if there is both a metadata and non-metadata reference in the python environment of a model. Additionally, it warns them if there's a macro referenced in a used audit's query, in the argument list of the audits and signals properties, or in an on_virtual_update statement.
Context:
The metadata status for macros and signals is now transitive, i.e. every dependency of a metadata macro or signal is also metadata, unless it is referenced by a non-metadata object.
This means that global references of metadata objects may now be excluded from the data hash calculation because of their new metadata status, which would lead to a diff.
Additionally, we now implicitly treat macro refs in the aforementioned statements as "metadata-only", even though they may not be marked as such by a user. This may also lead to a diff.
1""" 2This script's goal is to warn users if there is both a metadata and non-metadata reference in 3the python environment of a model. Additionally, it warns them if there's a macro referenced 4in a used audit's query, in the argument list of the audits and signals properties, or in an 5on_virtual_update statement. 6 7Context: 8 9The metadata status for macros and signals is now transitive, i.e. every dependency of a 10metadata macro or signal is also metadata, unless it is referenced by a non-metadata object. 11 12This means that global references of metadata objects may now be excluded from the data hash 13calculation because of their new metadata status, which would lead to a diff. 14 15Additionally, we now implicitly treat macro refs in the aforementioned statements as "metadata-only", 16even though they may not be marked as such by a user. This may also lead to a diff. 17""" 18 19import json 20 21from sqlglot import exp 22 23import sqlmesh.core.dialect as d 24from sqlmesh.core.console import get_console 25 26 27def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore 28 pass 29 30 31def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore 32 snapshots_table = "_snapshots" 33 if schema: 34 snapshots_table = f"{schema}.{snapshots_table}" 35 36 warning = ( 37 "SQLMesh detected that it may not be able to fully migrate the state database. This should not impact " 38 "the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` " 39 "command. Please run `sqlmesh diff prod` after the migration has completed, before making any new " 40 "changes. If any unexpected changes are reported, consider running a forward-only plan to apply these " 41 "changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. " 42 "See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n" 43 ) 44 45 for (snapshot,) in engine_adapter.fetchall( 46 exp.select("snapshot").from_(snapshots_table), quote_identifiers=True 47 ): 48 parsed_snapshot = json.loads(snapshot) 49 node = parsed_snapshot["node"] 50 51 # Standalone audits don't have a data hash, so they're unaffected 52 if node.get("source_type") == "audit": 53 continue 54 55 python_env = node.get("python_env") or {} 56 57 has_metadata = False 58 has_non_metadata = False 59 60 for k, v in python_env.items(): 61 if v.get("is_metadata"): 62 has_metadata = True 63 else: 64 has_non_metadata = True 65 66 if has_metadata and has_non_metadata: 67 get_console().log_warning(warning) 68 return 69 70 dialect = node.get("dialect") 71 metadata_hash_statements = [] 72 73 # We use try-except here as a conservative measure to avoid any unexpected exceptions 74 try: 75 if on_virtual_update := node.get("on_virtual_update"): 76 metadata_hash_statements.extend(parse_expression(on_virtual_update, dialect)) 77 78 for _, audit_args in func_call_validator(node.get("audits") or []): 79 metadata_hash_statements.extend(audit_args.values()) 80 81 for signal_name, signal_args in func_call_validator( 82 node.get("signals") or [], is_signal=True 83 ): 84 metadata_hash_statements.extend(signal_args.values()) 85 86 if audit_definitions := node.get("audit_definitions"): 87 audit_queries = [ 88 parse_expression(audit["query"], audit["dialect"]) 89 for audit in audit_definitions.values() 90 ] 91 metadata_hash_statements.extend(audit_queries) 92 93 for macro_name in extract_used_macros(metadata_hash_statements): 94 serialized_macro = python_env.get(macro_name) 95 if isinstance(serialized_macro, dict) and not serialized_macro.get("is_metadata"): 96 get_console().log_warning(warning) 97 return 98 except Exception: 99 pass 100 101 102def extract_used_macros(expressions): 103 used_macros = set() 104 for expression in expressions: 105 if isinstance(expression, d.Jinja): 106 continue 107 108 for macro_func in expression.find_all(d.MacroFunc): 109 if macro_func.__class__ is d.MacroFunc: 110 used_macros.add(macro_func.this.name.lower()) 111 112 return used_macros 113 114 115def func_call_validator(v, is_signal=False): 116 assert isinstance(v, list) 117 118 audits = [] 119 for entry in v: 120 if isinstance(entry, dict): 121 args = entry 122 name = "" if is_signal else entry.pop("name") 123 else: 124 assert isinstance(entry, (tuple, list)) 125 name, args = entry 126 127 parsed_audit = { 128 key: d.parse_one(value) if isinstance(value, str) else value 129 for key, value in args.items() 130 } 131 audits.append((name.lower(), parsed_audit)) 132 133 return audits 134 135 136def parse_expression(v, dialect): 137 if v is None: 138 return None 139 140 if isinstance(v, list): 141 return [d.parse_one(e, dialect=dialect) for e in v] 142 143 assert isinstance(v, str) 144 return d.parse_one(v, dialect=dialect)
32def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore 33 snapshots_table = "_snapshots" 34 if schema: 35 snapshots_table = f"{schema}.{snapshots_table}" 36 37 warning = ( 38 "SQLMesh detected that it may not be able to fully migrate the state database. This should not impact " 39 "the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` " 40 "command. Please run `sqlmesh diff prod` after the migration has completed, before making any new " 41 "changes. If any unexpected changes are reported, consider running a forward-only plan to apply these " 42 "changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. " 43 "See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n" 44 ) 45 46 for (snapshot,) in engine_adapter.fetchall( 47 exp.select("snapshot").from_(snapshots_table), quote_identifiers=True 48 ): 49 parsed_snapshot = json.loads(snapshot) 50 node = parsed_snapshot["node"] 51 52 # Standalone audits don't have a data hash, so they're unaffected 53 if node.get("source_type") == "audit": 54 continue 55 56 python_env = node.get("python_env") or {} 57 58 has_metadata = False 59 has_non_metadata = False 60 61 for k, v in python_env.items(): 62 if v.get("is_metadata"): 63 has_metadata = True 64 else: 65 has_non_metadata = True 66 67 if has_metadata and has_non_metadata: 68 get_console().log_warning(warning) 69 return 70 71 dialect = node.get("dialect") 72 metadata_hash_statements = [] 73 74 # We use try-except here as a conservative measure to avoid any unexpected exceptions 75 try: 76 if on_virtual_update := node.get("on_virtual_update"): 77 metadata_hash_statements.extend(parse_expression(on_virtual_update, dialect)) 78 79 for _, audit_args in func_call_validator(node.get("audits") or []): 80 metadata_hash_statements.extend(audit_args.values()) 81 82 for signal_name, signal_args in func_call_validator( 83 node.get("signals") or [], is_signal=True 84 ): 85 metadata_hash_statements.extend(signal_args.values()) 86 87 if audit_definitions := node.get("audit_definitions"): 88 audit_queries = [ 89 parse_expression(audit["query"], audit["dialect"]) 90 for audit in audit_definitions.values() 91 ] 92 metadata_hash_statements.extend(audit_queries) 93 94 for macro_name in extract_used_macros(metadata_hash_statements): 95 serialized_macro = python_env.get(macro_name) 96 if isinstance(serialized_macro, dict) and not serialized_macro.get("is_metadata"): 97 get_console().log_warning(warning) 98 return 99 except Exception: 100 pass
103def extract_used_macros(expressions): 104 used_macros = set() 105 for expression in expressions: 106 if isinstance(expression, d.Jinja): 107 continue 108 109 for macro_func in expression.find_all(d.MacroFunc): 110 if macro_func.__class__ is d.MacroFunc: 111 used_macros.add(macro_func.this.name.lower()) 112 113 return used_macros
116def func_call_validator(v, is_signal=False): 117 assert isinstance(v, list) 118 119 audits = [] 120 for entry in v: 121 if isinstance(entry, dict): 122 args = entry 123 name = "" if is_signal else entry.pop("name") 124 else: 125 assert isinstance(entry, (tuple, list)) 126 name, args = entry 127 128 parsed_audit = { 129 key: d.parse_one(value) if isinstance(value, str) else value 130 for key, value in args.items() 131 } 132 audits.append((name.lower(), parsed_audit)) 133 134 return audits