Add default catalog to snapshots and update names to match new normalization rules.
1"""Add default catalog to snapshots and update names to match new normalization rules.""" 2 3from __future__ import annotations 4 5import json 6import typing as t 7 8import pandas as pd 9from sqlglot import exp 10from sqlglot.dialects.dialect import DialectType 11from sqlglot.helper import dict_depth, seq_get 12from sqlglot.optimizer.normalize_identifiers import normalize_identifiers 13 14from sqlmesh.utils.migration import index_text_type 15 16 17def set_default_catalog( 18 table: exp.Table, 19 default_catalog: t.Optional[str], 20) -> exp.Table: 21 if default_catalog and not table.catalog and table.db: 22 table.set("catalog", exp.parse_identifier(default_catalog)) 23 24 return table 25 26 27def normalize_model_name( 28 table: str | exp.Table, 29 default_catalog: t.Optional[str], 30 dialect: DialectType = None, 31) -> str: 32 table = exp.to_table(table, dialect=dialect) 33 34 table = set_default_catalog(table, default_catalog) 35 return exp.table_name(normalize_identifiers(table, dialect=dialect), identify=True) 36 37 38def normalize_mapping_schema(mapping_schema: t.Dict, dialect: str) -> t.Dict: 39 # Example input: {'"catalog"': {'schema': {'table': {'column': 'INT'}}}} 40 # Example output: {'"catalog"': {'"schema"': {'"table"': {'column': 'INT'}}}} 41 normalized_mapping_schema = {} 42 for key, value in mapping_schema.items(): 43 if isinstance(value, dict): 44 normalized_mapping_schema[normalize_model_name(key, None, dialect)] = ( 45 normalize_mapping_schema(value, dialect) 46 ) 47 else: 48 normalized_mapping_schema[key] = value 49 return normalized_mapping_schema 50 51 52def update_dbt_relations( 53 source: t.Optional[t.Dict], keys: t.List[str], default_catalog: t.Optional[str] 54) -> None: 55 if not default_catalog or not source: 56 return 57 for key in keys: 58 relations = source.get(key) 59 if relations: 60 relations = [relations] if "database" in relations else relations.values() 61 for relation in relations: 62 if not relation["database"]: 63 relation["database"] = default_catalog 64 65 66def migrate(state_sync, default_catalog: t.Optional[str], **kwargs): # type: ignore 67 engine_adapter = state_sync.engine_adapter 68 schema = state_sync.schema 69 snapshots_table = "_snapshots" 70 environments_table = "_environments" 71 intervals_table = "_intervals" 72 seeds_table = "_seeds" 73 74 if schema: 75 snapshots_table = f"{schema}.{snapshots_table}" 76 environments_table = f"{schema}.{environments_table}" 77 intervals_table = f"{schema}.{intervals_table}" 78 seeds_table = f"{schema}.{seeds_table}" 79 80 new_snapshots = [] 81 snapshot_to_dialect = {} 82 index_type = index_text_type(engine_adapter.dialect) 83 84 for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( 85 exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), 86 quote_identifiers=True, 87 ): 88 parsed_snapshot = json.loads(snapshot) 89 # This is here in the case where the user originally had catalog in this model name, and therefore 90 # we would have before created the table with the catalog in the name. New logic removes the catalog, 91 # and therefore we need to make sure the table name is the same as the original table name, so we include 92 # this override 93 parsed_snapshot["base_table_name_override"] = parsed_snapshot["name"] 94 node = parsed_snapshot["node"] 95 dialect = node.get("dialect") 96 normalized_name = ( 97 normalize_model_name(name, default_catalog=default_catalog, dialect=dialect) 98 if node["source_type"] != "audit" 99 else name 100 ) 101 parsed_snapshot["name"] = normalized_name 102 # At the time of migration all nodes had default catalog, so we don't have to check type 103 node["default_catalog"] = default_catalog 104 snapshot_to_dialect[name] = dialect 105 mapping_schema = node.get("mapping_schema", {}) 106 if mapping_schema: 107 normalized_default_catalog = ( 108 normalize_model_name(default_catalog, default_catalog=None, dialect=dialect) 109 if default_catalog 110 else None 111 ) 112 mapping_schema_depth = dict_depth(mapping_schema) 113 if mapping_schema_depth == 3 and normalized_default_catalog: 114 mapping_schema = {normalized_default_catalog: mapping_schema} 115 node["mapping_schema"] = normalize_mapping_schema(mapping_schema, dialect) 116 depends_on = node.get("depends_on", []) 117 if depends_on: 118 node["depends_on"] = [ 119 normalize_model_name(dep, default_catalog, dialect) for dep in depends_on 120 ] 121 if parsed_snapshot["parents"]: 122 parsed_snapshot["parents"] = [ 123 { 124 "name": normalize_model_name(parent["name"], default_catalog, dialect), 125 "identifier": parent["identifier"], 126 } 127 for parent in parsed_snapshot["parents"] 128 ] 129 if parsed_snapshot["indirect_versions"]: 130 parsed_snapshot["indirect_versions"] = { 131 normalize_model_name(name, default_catalog, dialect): snapshot_data_versions 132 for name, snapshot_data_versions in parsed_snapshot["indirect_versions"].items() 133 } 134 # dbt specific migration 135 jinja_macros = node.get("jinja_macros") 136 if ( 137 default_catalog 138 and jinja_macros 139 and jinja_macros.get("create_builtins_module") == "sqlmesh.dbt" 140 ): 141 update_dbt_relations( 142 jinja_macros.get("global_objs"), ["refs", "sources", "this"], default_catalog 143 ) 144 145 new_snapshots.append( 146 { 147 "name": normalized_name, 148 "identifier": identifier, 149 "version": version, 150 "snapshot": json.dumps(parsed_snapshot), 151 "kind_name": kind_name, 152 } 153 ) 154 155 if new_snapshots: 156 engine_adapter.delete_from(snapshots_table, "TRUE") 157 158 engine_adapter.insert_append( 159 snapshots_table, 160 pd.DataFrame(new_snapshots), 161 columns_to_types={ 162 "name": exp.DataType.build(index_type), 163 "identifier": exp.DataType.build(index_type), 164 "version": exp.DataType.build(index_type), 165 "snapshot": exp.DataType.build("text"), 166 "kind_name": exp.DataType.build(index_type), 167 }, 168 ) 169 170 new_environments = [] 171 default_dialect = seq_get(list(snapshot_to_dialect.values()), 0) 172 for ( 173 name, 174 snapshots, 175 start_at, 176 end_at, 177 plan_id, 178 previous_plan_id, 179 expiration_ts, 180 finalized_ts, 181 promoted_snapshot_ids, 182 suffix_target, 183 ) in engine_adapter.fetchall( 184 exp.select( 185 "name", 186 "snapshots", 187 "start_at", 188 "end_at", 189 "plan_id", 190 "previous_plan_id", 191 "expiration_ts", 192 "finalized_ts", 193 "promoted_snapshot_ids", 194 "suffix_target", 195 ).from_(environments_table), 196 quote_identifiers=True, 197 ): 198 new_snapshots = [] 199 for snapshot in json.loads(snapshots): 200 snapshot_name = snapshot["name"] 201 snapshot["base_table_name_override"] = snapshot_name 202 dialect = snapshot_to_dialect.get(snapshot_name, default_dialect) 203 node_type = snapshot.get("node_type") 204 normalized_name = ( 205 normalize_model_name(snapshot_name, default_catalog, dialect) 206 if node_type is None or node_type == "model" 207 else snapshot_name 208 ) 209 snapshot["name"] = normalized_name 210 if snapshot["parents"]: 211 snapshot["parents"] = [ 212 { 213 "name": normalize_model_name(parent["name"], default_catalog, dialect), 214 "identifier": parent["identifier"], 215 } 216 for parent in snapshot["parents"] 217 ] 218 new_snapshots.append(snapshot) 219 220 new_environments.append( 221 { 222 "name": name, 223 "snapshots": json.dumps(new_snapshots), 224 "start_at": start_at, 225 "end_at": end_at, 226 "plan_id": plan_id, 227 "previous_plan_id": previous_plan_id, 228 "expiration_ts": expiration_ts, 229 "finalized_ts": finalized_ts, 230 "promoted_snapshot_ids": promoted_snapshot_ids, 231 "suffix_target": suffix_target, 232 } 233 ) 234 235 if new_environments: 236 engine_adapter.delete_from(environments_table, "TRUE") 237 238 engine_adapter.insert_append( 239 environments_table, 240 pd.DataFrame(new_environments), 241 columns_to_types={ 242 "name": exp.DataType.build(index_type), 243 "snapshots": exp.DataType.build("text"), 244 "start_at": exp.DataType.build("text"), 245 "end_at": exp.DataType.build("text"), 246 "plan_id": exp.DataType.build("text"), 247 "previous_plan_id": exp.DataType.build("text"), 248 "expiration_ts": exp.DataType.build("bigint"), 249 "finalized_ts": exp.DataType.build("bigint"), 250 "promoted_snapshot_ids": exp.DataType.build("text"), 251 "suffix_target": exp.DataType.build("text"), 252 }, 253 ) 254 255 # We update environment to not be finalized in order to force them to update their views 256 # in order to make sure the views now have the fully qualified names 257 # We only do this if a default catalog was applied otherwise the current views are fine 258 # We do this post creating the new environments in order to avoid having to find a way to 259 # expression a null timestamp value in pandas that works across all engines 260 if default_catalog: 261 engine_adapter.execute( 262 exp.update(environments_table, {"finalized_ts": None}, where="1=1"), 263 quote_identifiers=True, 264 ) 265 266 new_intervals = [] 267 for ( 268 id, 269 created_ts, 270 name, 271 identifier, 272 version, 273 start_ts, 274 end_ts, 275 is_dev, 276 is_removed, 277 is_compacted, 278 ) in engine_adapter.fetchall( 279 exp.select( 280 "id", 281 "created_ts", 282 "name", 283 "identifier", 284 "version", 285 "start_ts", 286 "end_ts", 287 "is_dev", 288 "is_removed", 289 "is_compacted", 290 ).from_(intervals_table), 291 quote_identifiers=True, 292 ): 293 dialect = snapshot_to_dialect.get(name, default_dialect) 294 normalized_name = normalize_model_name(name, default_catalog, dialect) 295 new_intervals.append( 296 { 297 "id": id, 298 "created_ts": created_ts, 299 "name": normalized_name, 300 "identifier": identifier, 301 "version": version, 302 "start_ts": start_ts, 303 "end_ts": end_ts, 304 "is_dev": is_dev, 305 "is_removed": is_removed, 306 "is_compacted": is_compacted, 307 } 308 ) 309 310 if new_intervals: 311 engine_adapter.delete_from(intervals_table, "TRUE") 312 313 engine_adapter.insert_append( 314 intervals_table, 315 pd.DataFrame(new_intervals), 316 columns_to_types={ 317 "id": exp.DataType.build(index_type), 318 "created_ts": exp.DataType.build("bigint"), 319 "name": exp.DataType.build(index_type), 320 "identifier": exp.DataType.build(index_type), 321 "version": exp.DataType.build(index_type), 322 "start_ts": exp.DataType.build("bigint"), 323 "end_ts": exp.DataType.build("bigint"), 324 "is_dev": exp.DataType.build("boolean"), 325 "is_removed": exp.DataType.build("boolean"), 326 "is_compacted": exp.DataType.build("boolean"), 327 }, 328 ) 329 330 new_seeds = [] 331 for ( 332 name, 333 identifier, 334 content, 335 ) in engine_adapter.fetchall( 336 exp.select( 337 "name", 338 "identifier", 339 "content", 340 ).from_(seeds_table), 341 quote_identifiers=True, 342 ): 343 dialect = snapshot_to_dialect.get(name, default_dialect) 344 normalized_name = normalize_model_name(name, default_catalog, dialect) 345 new_seeds.append( 346 { 347 "name": normalized_name, 348 "identifier": identifier, 349 "content": content, 350 } 351 ) 352 353 if new_seeds: 354 engine_adapter.delete_from(seeds_table, "TRUE") 355 356 engine_adapter.insert_append( 357 seeds_table, 358 pd.DataFrame(new_seeds), 359 columns_to_types={ 360 "name": exp.DataType.build(index_type), 361 "identifier": exp.DataType.build(index_type), 362 "content": exp.DataType.build("text"), 363 }, 364 )
def
set_default_catalog( table: sqlglot.expressions.Table, default_catalog: Union[str, NoneType]) -> sqlglot.expressions.Table:
def
normalize_model_name( table: 'str | exp.Table', default_catalog: Union[str, NoneType], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> str:
28def normalize_model_name( 29 table: str | exp.Table, 30 default_catalog: t.Optional[str], 31 dialect: DialectType = None, 32) -> str: 33 table = exp.to_table(table, dialect=dialect) 34 35 table = set_default_catalog(table, default_catalog) 36 return exp.table_name(normalize_identifiers(table, dialect=dialect), identify=True)
def
normalize_mapping_schema(mapping_schema: Dict, dialect: str) -> Dict:
39def normalize_mapping_schema(mapping_schema: t.Dict, dialect: str) -> t.Dict: 40 # Example input: {'"catalog"': {'schema': {'table': {'column': 'INT'}}}} 41 # Example output: {'"catalog"': {'"schema"': {'"table"': {'column': 'INT'}}}} 42 normalized_mapping_schema = {} 43 for key, value in mapping_schema.items(): 44 if isinstance(value, dict): 45 normalized_mapping_schema[normalize_model_name(key, None, dialect)] = ( 46 normalize_mapping_schema(value, dialect) 47 ) 48 else: 49 normalized_mapping_schema[key] = value 50 return normalized_mapping_schema
def
update_dbt_relations( source: Union[Dict, NoneType], keys: List[str], default_catalog: Union[str, NoneType]) -> None:
53def update_dbt_relations( 54 source: t.Optional[t.Dict], keys: t.List[str], default_catalog: t.Optional[str] 55) -> None: 56 if not default_catalog or not source: 57 return 58 for key in keys: 59 relations = source.get(key) 60 if relations: 61 relations = [relations] if "database" in relations else relations.values() 62 for relation in relations: 63 if not relation["database"]: 64 relation["database"] = default_catalog
def
migrate(state_sync, default_catalog: Union[str, NoneType], **kwargs):
67def migrate(state_sync, default_catalog: t.Optional[str], **kwargs): # type: ignore 68 engine_adapter = state_sync.engine_adapter 69 schema = state_sync.schema 70 snapshots_table = "_snapshots" 71 environments_table = "_environments" 72 intervals_table = "_intervals" 73 seeds_table = "_seeds" 74 75 if schema: 76 snapshots_table = f"{schema}.{snapshots_table}" 77 environments_table = f"{schema}.{environments_table}" 78 intervals_table = f"{schema}.{intervals_table}" 79 seeds_table = f"{schema}.{seeds_table}" 80 81 new_snapshots = [] 82 snapshot_to_dialect = {} 83 index_type = index_text_type(engine_adapter.dialect) 84 85 for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( 86 exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), 87 quote_identifiers=True, 88 ): 89 parsed_snapshot = json.loads(snapshot) 90 # This is here in the case where the user originally had catalog in this model name, and therefore 91 # we would have before created the table with the catalog in the name. New logic removes the catalog, 92 # and therefore we need to make sure the table name is the same as the original table name, so we include 93 # this override 94 parsed_snapshot["base_table_name_override"] = parsed_snapshot["name"] 95 node = parsed_snapshot["node"] 96 dialect = node.get("dialect") 97 normalized_name = ( 98 normalize_model_name(name, default_catalog=default_catalog, dialect=dialect) 99 if node["source_type"] != "audit" 100 else name 101 ) 102 parsed_snapshot["name"] = normalized_name 103 # At the time of migration all nodes had default catalog, so we don't have to check type 104 node["default_catalog"] = default_catalog 105 snapshot_to_dialect[name] = dialect 106 mapping_schema = node.get("mapping_schema", {}) 107 if mapping_schema: 108 normalized_default_catalog = ( 109 normalize_model_name(default_catalog, default_catalog=None, dialect=dialect) 110 if default_catalog 111 else None 112 ) 113 mapping_schema_depth = dict_depth(mapping_schema) 114 if mapping_schema_depth == 3 and normalized_default_catalog: 115 mapping_schema = {normalized_default_catalog: mapping_schema} 116 node["mapping_schema"] = normalize_mapping_schema(mapping_schema, dialect) 117 depends_on = node.get("depends_on", []) 118 if depends_on: 119 node["depends_on"] = [ 120 normalize_model_name(dep, default_catalog, dialect) for dep in depends_on 121 ] 122 if parsed_snapshot["parents"]: 123 parsed_snapshot["parents"] = [ 124 { 125 "name": normalize_model_name(parent["name"], default_catalog, dialect), 126 "identifier": parent["identifier"], 127 } 128 for parent in parsed_snapshot["parents"] 129 ] 130 if parsed_snapshot["indirect_versions"]: 131 parsed_snapshot["indirect_versions"] = { 132 normalize_model_name(name, default_catalog, dialect): snapshot_data_versions 133 for name, snapshot_data_versions in parsed_snapshot["indirect_versions"].items() 134 } 135 # dbt specific migration 136 jinja_macros = node.get("jinja_macros") 137 if ( 138 default_catalog 139 and jinja_macros 140 and jinja_macros.get("create_builtins_module") == "sqlmesh.dbt" 141 ): 142 update_dbt_relations( 143 jinja_macros.get("global_objs"), ["refs", "sources", "this"], default_catalog 144 ) 145 146 new_snapshots.append( 147 { 148 "name": normalized_name, 149 "identifier": identifier, 150 "version": version, 151 "snapshot": json.dumps(parsed_snapshot), 152 "kind_name": kind_name, 153 } 154 ) 155 156 if new_snapshots: 157 engine_adapter.delete_from(snapshots_table, "TRUE") 158 159 engine_adapter.insert_append( 160 snapshots_table, 161 pd.DataFrame(new_snapshots), 162 columns_to_types={ 163 "name": exp.DataType.build(index_type), 164 "identifier": exp.DataType.build(index_type), 165 "version": exp.DataType.build(index_type), 166 "snapshot": exp.DataType.build("text"), 167 "kind_name": exp.DataType.build(index_type), 168 }, 169 ) 170 171 new_environments = [] 172 default_dialect = seq_get(list(snapshot_to_dialect.values()), 0) 173 for ( 174 name, 175 snapshots, 176 start_at, 177 end_at, 178 plan_id, 179 previous_plan_id, 180 expiration_ts, 181 finalized_ts, 182 promoted_snapshot_ids, 183 suffix_target, 184 ) in engine_adapter.fetchall( 185 exp.select( 186 "name", 187 "snapshots", 188 "start_at", 189 "end_at", 190 "plan_id", 191 "previous_plan_id", 192 "expiration_ts", 193 "finalized_ts", 194 "promoted_snapshot_ids", 195 "suffix_target", 196 ).from_(environments_table), 197 quote_identifiers=True, 198 ): 199 new_snapshots = [] 200 for snapshot in json.loads(snapshots): 201 snapshot_name = snapshot["name"] 202 snapshot["base_table_name_override"] = snapshot_name 203 dialect = snapshot_to_dialect.get(snapshot_name, default_dialect) 204 node_type = snapshot.get("node_type") 205 normalized_name = ( 206 normalize_model_name(snapshot_name, default_catalog, dialect) 207 if node_type is None or node_type == "model" 208 else snapshot_name 209 ) 210 snapshot["name"] = normalized_name 211 if snapshot["parents"]: 212 snapshot["parents"] = [ 213 { 214 "name": normalize_model_name(parent["name"], default_catalog, dialect), 215 "identifier": parent["identifier"], 216 } 217 for parent in snapshot["parents"] 218 ] 219 new_snapshots.append(snapshot) 220 221 new_environments.append( 222 { 223 "name": name, 224 "snapshots": json.dumps(new_snapshots), 225 "start_at": start_at, 226 "end_at": end_at, 227 "plan_id": plan_id, 228 "previous_plan_id": previous_plan_id, 229 "expiration_ts": expiration_ts, 230 "finalized_ts": finalized_ts, 231 "promoted_snapshot_ids": promoted_snapshot_ids, 232 "suffix_target": suffix_target, 233 } 234 ) 235 236 if new_environments: 237 engine_adapter.delete_from(environments_table, "TRUE") 238 239 engine_adapter.insert_append( 240 environments_table, 241 pd.DataFrame(new_environments), 242 columns_to_types={ 243 "name": exp.DataType.build(index_type), 244 "snapshots": exp.DataType.build("text"), 245 "start_at": exp.DataType.build("text"), 246 "end_at": exp.DataType.build("text"), 247 "plan_id": exp.DataType.build("text"), 248 "previous_plan_id": exp.DataType.build("text"), 249 "expiration_ts": exp.DataType.build("bigint"), 250 "finalized_ts": exp.DataType.build("bigint"), 251 "promoted_snapshot_ids": exp.DataType.build("text"), 252 "suffix_target": exp.DataType.build("text"), 253 }, 254 ) 255 256 # We update environment to not be finalized in order to force them to update their views 257 # in order to make sure the views now have the fully qualified names 258 # We only do this if a default catalog was applied otherwise the current views are fine 259 # We do this post creating the new environments in order to avoid having to find a way to 260 # expression a null timestamp value in pandas that works across all engines 261 if default_catalog: 262 engine_adapter.execute( 263 exp.update(environments_table, {"finalized_ts": None}, where="1=1"), 264 quote_identifiers=True, 265 ) 266 267 new_intervals = [] 268 for ( 269 id, 270 created_ts, 271 name, 272 identifier, 273 version, 274 start_ts, 275 end_ts, 276 is_dev, 277 is_removed, 278 is_compacted, 279 ) in engine_adapter.fetchall( 280 exp.select( 281 "id", 282 "created_ts", 283 "name", 284 "identifier", 285 "version", 286 "start_ts", 287 "end_ts", 288 "is_dev", 289 "is_removed", 290 "is_compacted", 291 ).from_(intervals_table), 292 quote_identifiers=True, 293 ): 294 dialect = snapshot_to_dialect.get(name, default_dialect) 295 normalized_name = normalize_model_name(name, default_catalog, dialect) 296 new_intervals.append( 297 { 298 "id": id, 299 "created_ts": created_ts, 300 "name": normalized_name, 301 "identifier": identifier, 302 "version": version, 303 "start_ts": start_ts, 304 "end_ts": end_ts, 305 "is_dev": is_dev, 306 "is_removed": is_removed, 307 "is_compacted": is_compacted, 308 } 309 ) 310 311 if new_intervals: 312 engine_adapter.delete_from(intervals_table, "TRUE") 313 314 engine_adapter.insert_append( 315 intervals_table, 316 pd.DataFrame(new_intervals), 317 columns_to_types={ 318 "id": exp.DataType.build(index_type), 319 "created_ts": exp.DataType.build("bigint"), 320 "name": exp.DataType.build(index_type), 321 "identifier": exp.DataType.build(index_type), 322 "version": exp.DataType.build(index_type), 323 "start_ts": exp.DataType.build("bigint"), 324 "end_ts": exp.DataType.build("bigint"), 325 "is_dev": exp.DataType.build("boolean"), 326 "is_removed": exp.DataType.build("boolean"), 327 "is_compacted": exp.DataType.build("boolean"), 328 }, 329 ) 330 331 new_seeds = [] 332 for ( 333 name, 334 identifier, 335 content, 336 ) in engine_adapter.fetchall( 337 exp.select( 338 "name", 339 "identifier", 340 "content", 341 ).from_(seeds_table), 342 quote_identifiers=True, 343 ): 344 dialect = snapshot_to_dialect.get(name, default_dialect) 345 normalized_name = normalize_model_name(name, default_catalog, dialect) 346 new_seeds.append( 347 { 348 "name": normalized_name, 349 "identifier": identifier, 350 "content": content, 351 } 352 ) 353 354 if new_seeds: 355 engine_adapter.delete_from(seeds_table, "TRUE") 356 357 engine_adapter.insert_append( 358 seeds_table, 359 pd.DataFrame(new_seeds), 360 columns_to_types={ 361 "name": exp.DataType.build(index_type), 362 "identifier": exp.DataType.build(index_type), 363 "content": exp.DataType.build("text"), 364 }, 365 )