sqlmesh.core.engine_adapter.bigquery
1from __future__ import annotations 2 3import logging 4import typing as t 5from collections import defaultdict 6 7from sqlglot import exp, parse_one 8from sqlglot.transforms import remove_precision_parameterized_types 9 10from sqlmesh.core.dialect import to_schema 11from sqlmesh.core.engine_adapter.base import _get_data_object_cache_key 12from sqlmesh.core.engine_adapter.mixins import ( 13 ClusteredByMixin, 14 GrantsFromInfoSchemaMixin, 15 RowDiffMixin, 16 TableAlterClusterByOperation, 17) 18from sqlmesh.core.engine_adapter.shared import ( 19 CatalogSupport, 20 DataObject, 21 DataObjectType, 22 SourceQuery, 23 set_catalog, 24 InsertOverwriteStrategy, 25) 26from sqlmesh.core.node import IntervalUnit 27from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport 28from sqlmesh.utils import optional_import, get_source_columns_to_types 29from sqlmesh.utils.date import to_datetime 30from sqlmesh.utils.errors import SQLMeshError 31from sqlmesh.utils.pandas import columns_to_types_from_dtypes 32 33if t.TYPE_CHECKING: 34 import pandas as pd 35 from google.api_core.retry import Retry 36 from google.cloud import bigquery 37 from google.cloud.bigquery import StandardSqlDataType 38 from google.cloud.bigquery.client import Client as BigQueryClient 39 from google.cloud.bigquery.job import QueryJob 40 from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult 41 from google.cloud.bigquery.table import Table as BigQueryTable 42 43 from sqlmesh.core._typing import SchemaName, SessionProperties, TableName 44 from sqlmesh.core.engine_adapter._typing import BigframeSession, DCL, DF, GrantsConfig, Query 45 from sqlmesh.core.engine_adapter.base import QueryOrDF 46 47 48logger = logging.getLogger(__name__) 49 50bigframes = optional_import("bigframes") 51bigframes_pd = optional_import("bigframes.pandas") 52 53 54NestedField = t.Tuple[str, str, t.List[str]] 55NestedFieldsDict = t.Dict[str, t.List[NestedField]] 56 57 58@set_catalog() 59class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchemaMixin): 60 """ 61 BigQuery Engine Adapter using the `google-cloud-bigquery` library's DB API. 62 """ 63 64 DIALECT = "bigquery" 65 DEFAULT_BATCH_SIZE = 1000 66 SUPPORTS_TRANSACTIONS = False 67 SUPPORTS_MATERIALIZED_VIEWS = True 68 SUPPORTS_CLONING = True 69 SUPPORTS_GRANTS = True 70 CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("session_user") 71 SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True 72 USE_CATALOG_IN_GRANTS = True 73 GRANT_INFORMATION_SCHEMA_TABLE_NAME = "OBJECT_PRIVILEGES" 74 MAX_TABLE_COMMENT_LENGTH = 1024 75 MAX_COLUMN_COMMENT_LENGTH = 1024 76 SUPPORTS_QUERY_EXECUTION_TRACKING = True 77 SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] 78 INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE 79 80 SCHEMA_DIFFER_KWARGS = { 81 "compatible_types": { 82 exp.DataType.build("INT64", dialect=DIALECT): { 83 exp.DataType.build("NUMERIC", dialect=DIALECT), 84 exp.DataType.build("FLOAT64", dialect=DIALECT), 85 exp.DataType.build("BIGNUMERIC", dialect=DIALECT), 86 }, 87 exp.DataType.build("NUMERIC", dialect=DIALECT): { 88 exp.DataType.build("FLOAT64", dialect=DIALECT), 89 exp.DataType.build("BIGNUMERIC", dialect=DIALECT), 90 }, 91 exp.DataType.build("DATE", dialect=DIALECT): { 92 exp.DataType.build("DATETIME", dialect=DIALECT), 93 }, 94 }, 95 "coerceable_types": { 96 exp.DataType.build("FLOAT64", dialect=DIALECT): { 97 exp.DataType.build("BIGNUMERIC", dialect=DIALECT), 98 }, 99 }, 100 "support_coercing_compatible_types": True, 101 "parameterized_type_defaults": { 102 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 9), (0,)], 103 exp.DataType.build("BIGDECIMAL", dialect=DIALECT).this: [(76.76, 38), (0,)], 104 }, 105 "types_with_unlimited_length": { 106 # parameterized `STRING(n)` can ALTER to unparameterized `STRING` 107 exp.DataType.build("STRING", dialect=DIALECT).this: { 108 exp.DataType.build("STRING", dialect=DIALECT).this, 109 }, 110 # parameterized `BYTES(n)` can ALTER to unparameterized `BYTES` 111 exp.DataType.build("BYTES", dialect=DIALECT).this: { 112 exp.DataType.build("BYTES", dialect=DIALECT).this, 113 }, 114 }, 115 "nested_support": NestedSupport.ALL_BUT_DROP, 116 } 117 118 @property 119 def client(self) -> BigQueryClient: 120 return self.connection._client 121 122 @property 123 def bigframe(self) -> t.Optional[BigframeSession]: 124 if bigframes: 125 options = bigframes.BigQueryOptions( 126 credentials=self.client._credentials, 127 project=self.client.project, 128 location=self.client.location, 129 ) 130 return bigframes.connect(context=options) 131 return None 132 133 @property 134 def _job_params(self) -> t.Dict[str, t.Any]: 135 from sqlmesh.core.config.connection import BigQueryPriority 136 137 params = { 138 "use_legacy_sql": False, 139 "priority": self._extra_config.get( 140 "priority", BigQueryPriority.INTERACTIVE.bigquery_constant 141 ), 142 } 143 if self._extra_config.get("maximum_bytes_billed"): 144 params["maximum_bytes_billed"] = self._extra_config.get("maximum_bytes_billed") 145 if self.correlation_id: 146 # BigQuery label keys must be lowercase 147 key = self.correlation_id.job_type.value.lower() 148 params["labels"] = {key: self.correlation_id.job_id} 149 return params 150 151 @property 152 def catalog_support(self) -> CatalogSupport: 153 return CatalogSupport.FULL_SUPPORT 154 155 def _df_to_source_queries( 156 self, 157 df: DF, 158 target_columns_to_types: t.Dict[str, exp.DataType], 159 batch_size: int, 160 target_table: TableName, 161 source_columns: t.Optional[t.List[str]] = None, 162 ) -> t.List[SourceQuery]: 163 import pandas as pd 164 165 source_columns_to_types = get_source_columns_to_types( 166 target_columns_to_types, source_columns 167 ) 168 169 temp_bq_table = self.__get_temp_bq_table( 170 self._get_temp_table(target_table or "pandas"), source_columns_to_types 171 ) 172 temp_table = exp.table_( 173 temp_bq_table.table_id, 174 db=temp_bq_table.dataset_id, 175 catalog=temp_bq_table.project, 176 ) 177 178 def query_factory() -> Query: 179 ordered_df = df[list(source_columns_to_types)] 180 if bigframes_pd and isinstance(ordered_df, bigframes_pd.DataFrame): 181 ordered_df.to_gbq( 182 f"{temp_bq_table.project}.{temp_bq_table.dataset_id}.{temp_bq_table.table_id}", 183 if_exists="replace", 184 ) 185 elif not self.table_exists(temp_table): 186 # Make mypy happy 187 assert isinstance(ordered_df, pd.DataFrame) 188 self._db_call(self.client.create_table, table=temp_bq_table, exists_ok=False) 189 result = self.__load_pandas_to_table( 190 temp_bq_table, ordered_df, source_columns_to_types, replace=False 191 ) 192 if result.errors: 193 raise SQLMeshError(result.errors) 194 return exp.select( 195 *self._casted_columns(target_columns_to_types, source_columns=source_columns) 196 ).from_(temp_table) 197 198 return [ 199 SourceQuery( 200 query_factory=query_factory, 201 cleanup_func=lambda: self.drop_table(temp_table), 202 ) 203 ] 204 205 def close(self) -> t.Any: 206 # Cancel all pending query jobs across all threads 207 all_query_jobs = self._connection_pool.get_all_attributes("query_job") 208 for query_job in all_query_jobs: 209 if query_job: 210 try: 211 if not self._db_call(query_job.done): 212 self._db_call(query_job.cancel) 213 logger.debug( 214 "Cancelled BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", 215 query_job.project, 216 query_job.location, 217 query_job.job_id, 218 ) 219 except Exception as ex: 220 logger.debug( 221 "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", 222 query_job.project, 223 query_job.location, 224 query_job.job_id, 225 str(ex), 226 ) 227 228 return super().close() 229 230 def _begin_session(self, properties: SessionProperties) -> None: 231 from google.cloud.bigquery import QueryJobConfig 232 233 query_label_property = properties.get("query_label") 234 parsed_query_label: list[tuple[str, str]] = [] 235 if isinstance(query_label_property, (exp.Array, exp.Paren, exp.Tuple)): 236 label_tuples = ( 237 [query_label_property.unnest()] 238 if isinstance(query_label_property, exp.Paren) 239 else query_label_property.expressions 240 ) 241 242 # query_label is a Paren, Array or Tuple of 2-tuples and validated at load time 243 parsed_query_label.extend( 244 (label_tuple.expressions[0].name, label_tuple.expressions[1].name) 245 for label_tuple in label_tuples 246 ) 247 elif query_label_property is not None: 248 raise SQLMeshError( 249 "Invalid value for `session_properties.query_label`. Must be an array or tuple." 250 ) 251 252 if self.correlation_id: 253 parsed_query_label.append( 254 (self.correlation_id.job_type.value.lower(), self.correlation_id.job_id) 255 ) 256 257 if parsed_query_label: 258 query_label_str = ",".join([":".join(label) for label in parsed_query_label]) 259 query = f'SET @@query_label = "{query_label_str}";SELECT 1;' 260 else: 261 query = "SELECT 1;" 262 263 job = self.client.query( 264 query, 265 job_config=QueryJobConfig(create_session=True), 266 ) 267 session_info = job.session_info 268 session_id = session_info.session_id if session_info else None 269 self._session_id = session_id 270 job.result() 271 272 def _end_session(self) -> None: 273 self._session_id = None 274 275 def _is_session_active(self) -> bool: 276 return self._session_id is not None 277 278 def get_current_catalog(self) -> t.Optional[str]: 279 """Returns the catalog name of the current connection.""" 280 return self.client.project 281 282 def set_current_catalog(self, catalog: str) -> None: 283 """Sets the catalog name of the current connection.""" 284 self.client.project = catalog 285 286 def create_schema( 287 self, 288 schema_name: SchemaName, 289 ignore_if_exists: bool = True, 290 warn_on_error: bool = True, 291 properties: t.List[exp.Expression] = [], 292 ) -> None: 293 """Create a schema from a name or qualified table name.""" 294 from google.api_core.exceptions import Conflict 295 296 try: 297 super().create_schema( 298 schema_name, 299 ignore_if_exists=ignore_if_exists, 300 warn_on_error=False, 301 ) 302 except Exception as e: 303 is_already_exists_error = isinstance(e, Conflict) and "Already Exists:" in str(e) 304 if is_already_exists_error and ignore_if_exists: 305 return 306 if not warn_on_error: 307 raise 308 logger.warning("Failed to create schema '%s': %s", schema_name, e) 309 310 def get_bq_schema(self, table_name: TableName) -> t.List[bigquery.SchemaField]: 311 table = exp.to_table(table_name) 312 if len(table.parts) == 3 and "." in table.name: 313 self.execute(exp.select("*").from_(table).limit(0)) 314 query_job = self._query_job 315 assert query_job is not None 316 return query_job._query_results.schema 317 return self._get_table(table).schema 318 319 def columns( 320 self, table_name: TableName, include_pseudo_columns: bool = False 321 ) -> t.Dict[str, exp.DataType]: 322 """Fetches column names and types for the target table.""" 323 324 def dtype_to_sql( 325 dtype: t.Optional[StandardSqlDataType], field: bigquery.SchemaField 326 ) -> str: 327 assert dtype 328 assert field 329 330 kind = dtype.type_kind 331 assert kind 332 333 # Not using the enum value to preserve compatibility with older versions 334 # of the BigQuery library. 335 if kind.name == "ARRAY": 336 return f"ARRAY<{dtype_to_sql(dtype.array_element_type, field)}>" 337 if kind.name == "STRUCT": 338 struct_type = dtype.struct_type 339 assert struct_type 340 fields = ", ".join( 341 f"{struct_field.name} {dtype_to_sql(struct_field.type, nested_field)}" 342 for struct_field, nested_field in zip(struct_type.fields, field.fields) 343 ) 344 return f"STRUCT<{fields}>" 345 if kind.name == "TYPE_KIND_UNSPECIFIED": 346 field_type = field.field_type 347 348 if field_type == "RANGE": 349 # If the field is a RANGE then `range_element_type` should be set to 350 # one of `"DATE"`, `"DATETIME"` or `"TIMESTAMP"`. 351 return f"RANGE<{field.range_element_type.element_type}>" 352 353 return field_type 354 355 return kind.name 356 357 def create_mapping_schema( 358 schema: t.Sequence[bigquery.SchemaField], 359 ) -> t.Dict[str, exp.DataType]: 360 return { 361 field.name: exp.DataType.build( 362 dtype_to_sql(field.to_standard_sql().type, field), dialect=self.dialect 363 ) 364 for field in schema 365 } 366 367 table = exp.to_table(table_name) 368 if len(table.parts) == 3 and "." in table.name: 369 # The client's `get_table` method can't handle paths with >3 identifiers 370 self.execute(exp.select("*").from_(table).limit(0)) 371 query_job = self._query_job 372 assert query_job is not None 373 374 query_results = query_job._query_results 375 columns = create_mapping_schema(query_results.schema) 376 else: 377 bq_table = self._get_table(table) 378 columns = create_mapping_schema(bq_table.schema) 379 380 if include_pseudo_columns: 381 if bq_table.time_partitioning and not bq_table.time_partitioning.field: 382 columns["_PARTITIONTIME"] = exp.DataType.build("TIMESTAMP", dialect="bigquery") 383 if bq_table.time_partitioning.type_ == "DAY": 384 columns["_PARTITIONDATE"] = exp.DataType.build("DATE") 385 if bq_table.table_id.endswith("*"): 386 columns["_TABLE_SUFFIX"] = exp.DataType.build("STRING", dialect="bigquery") 387 if ( 388 bq_table.external_data_configuration is not None 389 and bq_table.external_data_configuration.source_format 390 in ( 391 "CSV", 392 "NEWLINE_DELIMITED_JSON", 393 "AVRO", 394 "PARQUET", 395 "ORC", 396 "DATASTORE_BACKUP", 397 ) 398 ): 399 columns["_FILE_NAME"] = exp.DataType.build("STRING", dialect="bigquery") 400 401 return columns 402 403 def alter_table( 404 self, 405 alter_expressions: t.Union[t.List[exp.Alter], t.List[TableAlterOperation]], 406 ) -> None: 407 """ 408 Performs the alter statements to change the current table into the structure of the target table, 409 and uses the API to add columns to structs, where SQL is not supported. 410 """ 411 if not alter_expressions: 412 return 413 414 cluster_by_operations, alter_statements = [], [] 415 for e in alter_expressions: 416 if isinstance(e, TableAlterClusterByOperation): 417 cluster_by_operations.append(e) 418 elif isinstance(e, TableAlterOperation): 419 alter_statements.append(e.expression) 420 else: 421 alter_statements.append(e) 422 423 for op in cluster_by_operations: 424 self._update_clustering_key(op) 425 426 nested_fields, non_nested_expressions = self._split_alter_expressions(alter_statements) 427 428 if nested_fields: 429 self._update_table_schema_nested_fields(nested_fields, alter_statements[0].this) 430 431 if non_nested_expressions: 432 super().alter_table(non_nested_expressions) 433 434 def fetchone( 435 self, 436 query: t.Union[exp.Expression, str], 437 ignore_unsupported_errors: bool = False, 438 quote_identifiers: bool = False, 439 ) -> t.Optional[t.Tuple]: 440 """ 441 BigQuery's `fetchone` method doesn't call execute and therefore would not benefit from the execute 442 configuration we have in place. Therefore this implementation calls execute instead. 443 """ 444 self.execute( 445 query, 446 ignore_unsupported_errors=ignore_unsupported_errors, 447 quote_identifiers=quote_identifiers, 448 ) 449 try: 450 return next(self._query_data) 451 except StopIteration: 452 return None 453 454 def fetchall( 455 self, 456 query: t.Union[exp.Expression, str], 457 ignore_unsupported_errors: bool = False, 458 quote_identifiers: bool = False, 459 ) -> t.List[t.Tuple]: 460 """ 461 BigQuery's `fetchone` method doesn't call execute and therefore would not benefit from the execute 462 configuration we have in place. Therefore this implementation calls execute instead. 463 """ 464 self.execute( 465 query, 466 ignore_unsupported_errors=ignore_unsupported_errors, 467 quote_identifiers=quote_identifiers, 468 ) 469 return list(self._query_data) 470 471 def _split_alter_expressions( 472 self, 473 alter_expressions: t.List[exp.Alter], 474 ) -> t.Tuple[NestedFieldsDict, t.List[exp.Alter]]: 475 """ 476 Returns a dictionary of the nested fields to add and a list of the non-nested alter expressions. 477 """ 478 nested_fields_to_add: NestedFieldsDict = defaultdict(list) 479 non_nested_expressions = [] 480 481 for alter_expression in alter_expressions: 482 action = alter_expression.args["actions"][0] 483 if ( 484 isinstance(action, exp.ColumnDef) 485 and isinstance(action.this, exp.Dot) 486 and isinstance(action.kind, exp.DataType) 487 ): 488 root_field, *leaf_fields = action.this.this.sql(dialect=self.dialect).split(".") 489 new_field = action.this.expression.sql(dialect=self.dialect) 490 data_type = action.kind.sql(dialect=self.dialect) 491 nested_fields_to_add[root_field].append((new_field, data_type, leaf_fields)) 492 else: 493 non_nested_expressions.append(alter_expression) 494 495 return nested_fields_to_add, non_nested_expressions 496 497 def _build_nested_fields( 498 self, 499 current_fields: t.List[bigquery.SchemaField], 500 fields_to_add: t.List[NestedField], 501 ) -> t.List[bigquery.SchemaField]: 502 """ 503 Recursively builds and updates the schema fields with the new nested fields. 504 """ 505 from google.cloud import bigquery 506 507 new_fields = [] 508 root: t.List[t.Tuple[str, str]] = [] 509 leaves: NestedFieldsDict = defaultdict(list) 510 for new_field, data_type, leaf_fields in fields_to_add: 511 if leaf_fields: 512 leaves[leaf_fields[0]].append((new_field, data_type, leaf_fields[1:])) 513 else: 514 root.append((new_field, data_type)) 515 516 for field in current_fields: 517 # If the new fields are nested, we need to recursively build them 518 if field.name in leaves: 519 subfields = list(field.fields) 520 subfields = self._build_nested_fields(subfields, leaves[field.name]) 521 new_fields.append( 522 bigquery.SchemaField( 523 field.name, "RECORD", mode=field.mode, fields=tuple(subfields) 524 ) 525 ) 526 else: 527 new_fields.append(field) 528 529 # Build and append the new root-level fields 530 new_fields.extend( 531 self.__get_bq_schemafield( 532 new_field[0], exp.DataType.build(new_field[1], dialect=self.dialect) 533 ) 534 for new_field in root 535 ) 536 return new_fields 537 538 def _update_table_schema_nested_fields( 539 self, nested_fields_to_add: NestedFieldsDict, table_name: str 540 ) -> None: 541 """ 542 Updates a BigQuery table schema by adding the new nested fields provided. 543 """ 544 from google.cloud import bigquery 545 546 table = self._get_table(table_name) 547 original_schema = table.schema 548 new_schema = [] 549 for field in original_schema: 550 if field.name in nested_fields_to_add: 551 fields = self._build_nested_fields( 552 list(field.fields), nested_fields_to_add[field.name] 553 ) 554 new_schema.append( 555 bigquery.SchemaField( 556 field.name, 557 "RECORD", 558 mode=field.mode, 559 fields=tuple(fields), 560 ) 561 ) 562 else: 563 new_schema.append(field) 564 565 if new_schema != original_schema: 566 table.schema = new_schema 567 self.client.update_table(table, ["schema"]) 568 569 def __load_pandas_to_table( 570 self, 571 table: bigquery.Table, 572 df: pd.DataFrame, 573 columns_to_types: t.Dict[str, exp.DataType], 574 replace: bool = False, 575 ) -> BigQueryQueryResult: 576 """ 577 Loads a pandas dataframe into a table in BigQuery. Will do an overwrite if replace is True. Note that 578 the replace will replace the entire table, not just the rows that are in the dataframe. 579 """ 580 from google.cloud import bigquery 581 582 job_config = bigquery.job.LoadJobConfig(schema=self.__get_bq_schema(columns_to_types)) 583 if replace: 584 job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE 585 logger.info(f"Loading dataframe to BigQuery. Table Path: {table.path}") 586 # This client call does not support retry so we don't use the `_db_call` method. 587 result = self.__retry( 588 self.__db_load_table_from_dataframe, 589 )(df=df, table=table, job_config=job_config) 590 if result.errors: 591 raise SQLMeshError(result.errors) 592 return result 593 594 def __db_load_table_from_dataframe( 595 self, df: pd.DataFrame, table: bigquery.Table, job_config: bigquery.LoadJobConfig 596 ) -> BigQueryQueryResult: 597 job = self.client.load_table_from_dataframe( 598 dataframe=df, destination=table, job_config=job_config 599 ) 600 return self._db_call(job.result) 601 602 def __get_bq_schemafield(self, name: str, tpe: exp.DataType) -> bigquery.SchemaField: 603 from google.cloud import bigquery 604 605 mode = "NULLABLE" 606 if tpe.is_type(exp.DataType.Type.ARRAY): 607 mode = "REPEATED" 608 tpe = tpe.expressions[0] 609 610 field_type = tpe.sql(dialect=self.dialect) 611 fields = [] 612 if tpe.is_type(*exp.DataType.NESTED_TYPES): 613 field_type = "RECORD" 614 for inner_field in tpe.expressions: 615 if isinstance(inner_field, exp.ColumnDef): 616 inner_name = inner_field.this.sql(dialect=self.dialect) 617 inner_type = inner_field.kind 618 if inner_type is None: 619 raise ValueError( 620 f"cannot convert unknown type to BQ schema field {inner_field}" 621 ) 622 fields.append(self.__get_bq_schemafield(name=inner_name, tpe=inner_type)) 623 else: 624 raise ValueError(f"unexpected nested expression {inner_field}") 625 626 return bigquery.SchemaField( 627 name=name, 628 field_type=field_type, 629 mode=mode, 630 fields=fields, 631 ) 632 633 def __get_bq_schema( 634 self, columns_to_types: t.Dict[str, exp.DataType] 635 ) -> t.List[bigquery.SchemaField]: 636 """ 637 Returns a bigquery schema object from a dictionary of column names to types. 638 """ 639 640 precisionless_col_to_types = { 641 col_name: remove_precision_parameterized_types(col_type) 642 for col_name, col_type in columns_to_types.items() 643 } 644 return [ 645 self.__get_bq_schemafield(name=col_name, tpe=t.cast(exp.DataType, col_type)) 646 for col_name, col_type in precisionless_col_to_types.items() 647 ] 648 649 def __get_temp_bq_table( 650 self, table: exp.Table, columns_to_type: t.Dict[str, exp.DataType] 651 ) -> bigquery.Table: 652 """ 653 Returns a bigquery table object that is temporary and will expire in 3 hours. 654 """ 655 bq_table = self.__get_bq_table(table, columns_to_type) 656 bq_table.expires = to_datetime("in 3 hours") 657 return bq_table 658 659 def __get_bq_table( 660 self, table: TableName, columns_to_type: t.Dict[str, exp.DataType] 661 ) -> bigquery.Table: 662 """ 663 Returns a bigquery table object with a schema defines that matches the columns_to_type dictionary. 664 """ 665 from google.cloud import bigquery 666 667 table_ = exp.to_table(table).copy() 668 669 if not table_.catalog: 670 table_.set("catalog", exp.to_identifier(self.default_catalog)) 671 672 return bigquery.Table( 673 table_ref=self._table_name(table_), 674 schema=self.__get_bq_schema(columns_to_type), 675 ) 676 677 @property 678 def __retry(self) -> Retry: 679 from google.api_core import retry 680 681 return retry.Retry( 682 predicate=_ErrorCounter(self._extra_config["job_retries"]).should_retry, 683 deadline=self._extra_config.get("job_retry_deadline_seconds"), 684 initial=1.0, 685 maximum=3.0, 686 ) 687 688 def insert_overwrite_by_partition( 689 self, 690 table_name: TableName, 691 query_or_df: QueryOrDF, 692 partitioned_by: t.List[exp.Expression], 693 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 694 source_columns: t.Optional[t.List[str]] = None, 695 ) -> None: 696 if len(partitioned_by) != 1: 697 raise SQLMeshError( 698 f"Bigquery only supports partitioning by one column, {len(partitioned_by)} were provided." 699 ) 700 701 partition_exp = partitioned_by[0] 702 partition_column = partition_exp.find(exp.Column) 703 704 granularity = partition_exp.args.get("unit") 705 if granularity: 706 granularity = granularity.name.lower() 707 708 if not partition_column: 709 partition_sql = partition_exp.sql(dialect=self.dialect) 710 raise SQLMeshError( 711 f"The partition expression '{partition_sql}' doesn't contain a column." 712 ) 713 with ( 714 self.session({}), 715 self.temp_table( 716 query_or_df, 717 name=table_name, 718 partitioned_by=partitioned_by, 719 source_columns=source_columns, 720 ) as temp_table_name, 721 ): 722 if target_columns_to_types is None or target_columns_to_types[ 723 partition_column.name 724 ] == exp.DataType.build("unknown"): 725 target_columns_to_types = self.columns(table_name) 726 727 partition_type_sql = target_columns_to_types[partition_column.name].sql( 728 dialect=self.dialect 729 ) 730 731 select_array_agg_partitions = select_partitions_expr( 732 temp_table_name.db, 733 temp_table_name.name, 734 partition_type_sql, 735 granularity=granularity, 736 agg_func="ARRAY_AGG", 737 catalog=temp_table_name.catalog or self.default_catalog, 738 ) 739 740 self.execute( 741 f"DECLARE _sqlmesh_target_partitions_ ARRAY<{partition_type_sql}> DEFAULT ({select_array_agg_partitions});" 742 ) 743 744 where = t.cast(exp.Condition, partition_exp).isin(unnest="_sqlmesh_target_partitions_") 745 746 self._insert_overwrite_by_condition( 747 table_name, 748 [SourceQuery(query_factory=lambda: exp.select("*").from_(temp_table_name))], 749 target_columns_to_types, 750 where=where, 751 ) 752 753 def table_exists(self, table_name: TableName) -> bool: 754 table = exp.to_table(table_name) 755 data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) 756 if data_object_cache_key in self._data_object_cache: 757 logger.debug("Table existence cache hit: %s", data_object_cache_key) 758 return self._data_object_cache[data_object_cache_key] is not None 759 760 try: 761 from google.cloud.exceptions import NotFound 762 except ModuleNotFoundError: 763 from google.api_core.exceptions import NotFound 764 765 try: 766 self._get_table(table_name) 767 return True 768 except NotFound: 769 return False 770 771 def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]: 772 from sqlmesh.utils.date import to_timestamp 773 774 datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list) 775 for table_name in table_names: 776 table = exp.to_table(table_name) 777 datasets_to_tables[table.db].append(table.name) 778 779 results = [] 780 781 for dataset, tables in datasets_to_tables.items(): 782 query = ( 783 f"SELECT TIMESTAMP_MILLIS(last_modified_time) FROM `{dataset}.__TABLES__` WHERE " 784 ) 785 for i, table_name in enumerate(tables): 786 query += f"TABLE_ID = '{table_name}'" 787 if i < len(tables) - 1: 788 query += " OR " 789 results.extend(self.fetchall(query)) 790 791 return [to_timestamp(row[0]) for row in results] 792 793 def _get_table(self, table_name: TableName) -> BigQueryTable: 794 """ 795 Returns a BigQueryTable object for the given table name. 796 797 Raises: `google.cloud.exceptions.NotFound` if the table does not exist. 798 """ 799 return self._db_call(self.client.get_table, table=self._table_name(table_name)) 800 801 def _table_name(self, table_name: TableName) -> str: 802 # the api doesn't support backticks, so we can't call exp.table_name or sql 803 return ".".join(part.name for part in exp.to_table(table_name).parts) 804 805 def _fetch_native_df( 806 self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False 807 ) -> DF: 808 self.execute(query, quote_identifiers=quote_identifiers) 809 query_job = self._query_job 810 assert query_job is not None 811 return query_job.to_dataframe() 812 813 def _create_column_comments( 814 self, 815 table_name: TableName, 816 column_comments: t.Dict[str, str], 817 table_kind: str = "TABLE", 818 materialized_view: bool = False, 819 ) -> None: 820 if not (table_kind == "VIEW" and materialized_view): 821 table = self._get_table(table_name) 822 823 # convert Table object to dict 824 table_def = table.to_api_repr() 825 826 # Set column descriptions, supporting nested fields (e.g. record.field.nested_field) 827 for column, comment in column_comments.items(): 828 fields = table_def["schema"]["fields"] 829 field_names = column.split(".") 830 last_index = len(field_names) - 1 831 832 # Traverse the fields with nested fields down to leaf level 833 for idx, name in enumerate(field_names): 834 if field := next((field for field in fields if field["name"] == name), None): 835 if idx == last_index: 836 field["description"] = self._truncate_comment( 837 comment, self.MAX_COLUMN_COMMENT_LENGTH 838 ) 839 else: 840 fields = field.get("fields") or [] 841 842 # An "etag" is BQ versioning metadata that changes when an object is updated/modified. `update_table` 843 # compares the etags of the table object passed to it and the remote table, erroring if the etags 844 # don't match. We set the local etag to None to avoid this check. 845 table_def["etag"] = None 846 847 # convert dict back to a Table object 848 table = table.from_api_repr(table_def) 849 850 # update table schema 851 logger.info(f"Registering column comments for table {table_name}") 852 self._db_call(self.client.update_table, table=table, fields=["schema"]) 853 854 def _build_description_property_exp( 855 self, 856 description: str, 857 trunc_method: t.Callable, 858 ) -> exp.Property: 859 return exp.Property( 860 this=exp.to_identifier("description", quoted=True), 861 value=exp.Literal.string(trunc_method(description)), 862 ) 863 864 def _build_partitioned_by_exp( 865 self, 866 partitioned_by: t.List[exp.Expression], 867 *, 868 partition_interval_unit: t.Optional[IntervalUnit] = None, 869 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 870 **kwargs: t.Any, 871 ) -> t.Optional[exp.PartitionedByProperty]: 872 if len(partitioned_by) > 1: 873 raise SQLMeshError("BigQuery only supports partitioning by a single column") 874 875 this = partitioned_by[0] 876 if ( 877 isinstance(this, exp.Column) 878 and partition_interval_unit is not None 879 and not partition_interval_unit.is_minute 880 ): 881 column_type: t.Optional[exp.DataType] = (target_columns_to_types or {}).get(this.name) 882 883 if column_type == exp.DataType.build( 884 "date", dialect=self.dialect 885 ) and partition_interval_unit in ( 886 IntervalUnit.MONTH, 887 IntervalUnit.YEAR, 888 ): 889 trunc_func = "DATE_TRUNC" 890 elif column_type == exp.DataType.build("timestamp", dialect=self.dialect): 891 trunc_func = "TIMESTAMP_TRUNC" 892 elif column_type == exp.DataType.build("datetime", dialect=self.dialect): 893 trunc_func = "DATETIME_TRUNC" 894 else: 895 trunc_func = "" 896 897 if trunc_func: 898 this = exp.func( 899 trunc_func, 900 this, 901 exp.var(partition_interval_unit.value.upper()), 902 dialect=self.dialect, 903 ) 904 905 return exp.PartitionedByProperty(this=this) 906 907 def _build_table_properties_exp( 908 self, 909 catalog_name: t.Optional[str] = None, 910 table_format: t.Optional[str] = None, 911 storage_format: t.Optional[str] = None, 912 partitioned_by: t.Optional[t.List[exp.Expression]] = None, 913 partition_interval_unit: t.Optional[IntervalUnit] = None, 914 clustered_by: t.Optional[t.List[exp.Expression]] = None, 915 table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, 916 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 917 table_description: t.Optional[str] = None, 918 table_kind: t.Optional[str] = None, 919 **kwargs: t.Any, 920 ) -> t.Optional[exp.Properties]: 921 properties: t.List[exp.Expression] = [] 922 923 if partitioned_by and ( 924 partitioned_by_prop := self._build_partitioned_by_exp( 925 partitioned_by, 926 partition_interval_unit=partition_interval_unit, 927 target_columns_to_types=target_columns_to_types, 928 ) 929 ): 930 properties.append(partitioned_by_prop) 931 932 if clustered_by and (clustered_by_exp := self._build_clustered_by_exp(clustered_by)): 933 properties.append(clustered_by_exp) 934 935 if table_description: 936 properties.append( 937 self._build_description_property_exp( 938 table_description, self._truncate_table_comment 939 ), 940 ) 941 942 properties.extend(self._table_or_view_properties_to_expressions(table_properties)) 943 944 if properties: 945 return exp.Properties(expressions=properties) 946 return None 947 948 def _build_column_def( 949 self, 950 col_name: str, 951 column_descriptions: t.Optional[t.Dict[str, str]] = None, 952 engine_supports_schema_comments: bool = False, 953 col_type: t.Optional[exp.DATA_TYPE] = None, 954 nested_names: t.List[str] = [], 955 ) -> exp.ColumnDef: 956 # Helper function to build column definitions with column descriptions 957 def _build_struct_with_descriptions( 958 col_type: exp.DataType, 959 nested_names: t.List[str], 960 ) -> exp.DataType: 961 column_expressions = [] 962 for column_def in col_type.expressions: 963 # This is expected to be true, but this check is included as a 964 # precautionary measure in case of an unexpected edge case 965 if isinstance(column_def, exp.ColumnDef): 966 column = self._build_column_def( 967 col_name=column_def.name, 968 column_descriptions=column_descriptions, 969 engine_supports_schema_comments=engine_supports_schema_comments, 970 col_type=column_def.kind, 971 nested_names=nested_names, 972 ) 973 else: 974 column = column_def 975 column_expressions.append(column) 976 return exp.DataType(this=col_type.this, expressions=column_expressions, nested=True) 977 978 # Recursively build column definitions for BigQuery's RECORDs (struct) and REPEATED RECORDs (array of struct) 979 if isinstance(col_type, exp.DataType) and col_type.expressions: 980 expressions = col_type.expressions 981 if col_type.is_type(exp.DataType.Type.STRUCT): 982 col_type = _build_struct_with_descriptions(col_type, nested_names + [col_name]) 983 elif col_type.is_type(exp.DataType.Type.ARRAY) and expressions[0].is_type( 984 exp.DataType.Type.STRUCT 985 ): 986 col_type = exp.DataType( 987 this=exp.DataType.Type.ARRAY, 988 expressions=[ 989 _build_struct_with_descriptions( 990 col_type.expressions[0], nested_names + [col_name] 991 ) 992 ], 993 nested=True, 994 ) 995 996 return exp.ColumnDef( 997 this=exp.to_identifier(col_name), 998 kind=col_type, 999 constraints=( 1000 self._build_col_comment_exp( 1001 ".".join(nested_names + [col_name]), column_descriptions 1002 ) 1003 if engine_supports_schema_comments and self.comments_enabled and column_descriptions 1004 else None 1005 ), 1006 ) 1007 1008 def _build_col_comment_exp( 1009 self, col_name: str, column_descriptions: t.Dict[str, str] 1010 ) -> t.List[exp.ColumnConstraint]: 1011 comment = column_descriptions.get(col_name, None) 1012 if comment: 1013 return [ 1014 exp.ColumnConstraint( 1015 kind=exp.Properties( 1016 expressions=[ 1017 self._build_description_property_exp( 1018 comment, self._truncate_column_comment 1019 ), 1020 ] 1021 ) 1022 ) 1023 ] 1024 return [] 1025 1026 def _build_view_properties_exp( 1027 self, 1028 view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, 1029 table_description: t.Optional[str] = None, 1030 **kwargs: t.Any, 1031 ) -> t.Optional[exp.Properties]: 1032 """Creates a SQLGlot table properties expression for view""" 1033 properties: t.List[exp.Expression] = [] 1034 1035 if table_description: 1036 properties.append( 1037 self._build_description_property_exp( 1038 table_description, self._truncate_table_comment 1039 ), 1040 ) 1041 1042 properties.extend(self._table_or_view_properties_to_expressions(view_properties)) 1043 1044 if properties: 1045 return exp.Properties(expressions=properties) 1046 return None 1047 1048 def _build_create_comment_table_exp( 1049 self, table: exp.Table, table_comment: str, table_kind: str 1050 ) -> exp.Comment | str: 1051 table_sql = table.sql(dialect=self.dialect, identify=True) 1052 1053 truncated_comment = self._truncate_table_comment(table_comment) 1054 comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect) 1055 1056 return f"ALTER {table_kind} {table_sql} SET OPTIONS(description = {comment_sql})" 1057 1058 def _build_create_comment_column_exp( 1059 self, table: exp.Table, column_name: str, column_comment: str, table_kind: str = "TABLE" 1060 ) -> exp.Comment | str: 1061 table_sql = table.sql(dialect=self.dialect, identify=True) 1062 column_sql = exp.column(column_name).sql(dialect=self.dialect, identify=True) 1063 1064 truncated_comment = self._truncate_column_comment(column_comment) 1065 comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect) 1066 1067 return f"ALTER {table_kind} {table_sql} ALTER COLUMN {column_sql} SET OPTIONS(description = {comment_sql})" 1068 1069 def create_state_table( 1070 self, 1071 table_name: str, 1072 target_columns_to_types: t.Dict[str, exp.DataType], 1073 primary_key: t.Optional[t.Tuple[str, ...]] = None, 1074 ) -> None: 1075 self.create_table( 1076 table_name, 1077 target_columns_to_types, 1078 ) 1079 1080 def _db_call(self, func: t.Callable[..., t.Any], *args: t.Any, **kwargs: t.Any) -> t.Any: 1081 return func( 1082 retry=self.__retry, 1083 *args, 1084 **kwargs, 1085 ) 1086 1087 def _execute( 1088 self, 1089 sql: str, 1090 track_rows_processed: bool = False, 1091 **kwargs: t.Any, 1092 ) -> None: 1093 """Execute a sql query.""" 1094 from google.cloud.bigquery import QueryJobConfig 1095 from google.cloud.bigquery.query import ConnectionProperty 1096 1097 # BigQuery's Python DB API implementation does not support retries, so we have to implement them ourselves. 1098 # So we update the cursor's query job and query data with the results of the new query job. This makes sure 1099 # that other cursor based operations execute correctly. 1100 session_id = self._session_id 1101 connection_properties = ( 1102 [ 1103 ConnectionProperty(key="session_id", value=session_id), 1104 ] 1105 if session_id 1106 else [] 1107 ) 1108 1109 job_config = QueryJobConfig(**self._job_params, connection_properties=connection_properties) 1110 self._query_job = self._db_call( 1111 self.client.query, 1112 query=sql, 1113 job_config=job_config, 1114 timeout=self._extra_config.get("job_creation_timeout_seconds"), 1115 ) 1116 query_job = self._query_job 1117 assert query_job is not None 1118 1119 logger.debug( 1120 "BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", 1121 query_job.project, 1122 query_job.location, 1123 query_job.job_id, 1124 ) 1125 1126 results = self._db_call( 1127 query_job.result, 1128 timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore 1129 ) 1130 1131 self._query_data = iter(results) if results.total_rows else iter([]) 1132 query_results = query_job._query_results 1133 self.cursor._set_rowcount(query_results) 1134 self.cursor._set_description(query_results.schema) 1135 1136 if ( 1137 track_rows_processed 1138 and self._query_execution_tracker 1139 and self._query_execution_tracker.is_tracking() 1140 ): 1141 num_rows = None 1142 if query_job.statement_type == "CREATE_TABLE_AS_SELECT": 1143 # since table was just created, number rows in table == number rows processed 1144 query_table = self.client.get_table(query_job.destination) 1145 num_rows = query_table.num_rows 1146 elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]: 1147 num_rows = query_job.num_dml_affected_rows 1148 1149 self._query_execution_tracker.record_execution( 1150 sql, num_rows, query_job.total_bytes_processed 1151 ) 1152 1153 def _get_data_objects( 1154 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 1155 ) -> t.List[DataObject]: 1156 """ 1157 Returns all the data objects that exist in the given schema and optionally catalog. 1158 """ 1159 1160 # The BigQuery Client's list_tables method does not support filtering by table name, so we have to 1161 # resort to using SQL instead. 1162 schema = to_schema(schema_name) 1163 catalog = schema.catalog or self.default_catalog 1164 query = ( 1165 exp.select( 1166 exp.column("table_catalog").as_("catalog"), 1167 exp.column("table_name").as_("name"), 1168 exp.column("table_schema").as_("schema_name"), 1169 exp.case() 1170 .when(exp.column("table_type").eq("BASE TABLE"), exp.Literal.string("TABLE")) 1171 .when(exp.column("table_type").eq("CLONE"), exp.Literal.string("TABLE")) 1172 .when(exp.column("table_type").eq("EXTERNAL"), exp.Literal.string("TABLE")) 1173 .when(exp.column("table_type").eq("SNAPSHOT"), exp.Literal.string("TABLE")) 1174 .when(exp.column("table_type").eq("VIEW"), exp.Literal.string("VIEW")) 1175 .when( 1176 exp.column("table_type").eq("MATERIALIZED VIEW"), 1177 exp.Literal.string("MATERIALIZED_VIEW"), 1178 ) 1179 .else_(exp.column("table_type")) 1180 .as_("type"), 1181 exp.column("clustering_key", "ci").as_("clustering_key"), 1182 ) 1183 .with_( 1184 "clustering_info", 1185 as_=exp.select( 1186 exp.column("table_catalog"), 1187 exp.column("table_schema"), 1188 exp.column("table_name"), 1189 parse_one( 1190 "string_agg(column_name order by clustering_ordinal_position)", 1191 dialect=self.dialect, 1192 ).as_("clustering_key"), 1193 ) 1194 .from_( 1195 exp.to_table( 1196 f"`{catalog}`.`{schema.db}`.INFORMATION_SCHEMA.COLUMNS", 1197 dialect=self.dialect, 1198 ) 1199 ) 1200 .where(exp.column("clustering_ordinal_position").is_(exp.not_(exp.null()))) 1201 .group_by("1", "2", "3"), 1202 ) 1203 .from_( 1204 exp.to_table( 1205 f"`{catalog}`.`{schema.db}`.INFORMATION_SCHEMA.TABLES", dialect=self.dialect 1206 ) 1207 ) 1208 .join( 1209 "clustering_info", 1210 using=["table_catalog", "table_schema", "table_name"], 1211 join_type="left", 1212 join_alias="ci", 1213 ) 1214 ) 1215 if object_names: 1216 query = query.where(exp.column("table_name").isin(*object_names)) 1217 1218 try: 1219 df = self.fetchdf(query, quote_identifiers=True) 1220 except Exception as e: 1221 if "Not found" in str(e): 1222 return [] 1223 raise 1224 1225 if df.empty: 1226 return [] 1227 return [ 1228 DataObject( 1229 catalog=row.catalog, # type: ignore 1230 schema=row.schema_name, # type: ignore 1231 name=row.name, # type: ignore 1232 type=DataObjectType.from_str(row.type), # type: ignore 1233 clustering_key=f"({row.clustering_key})" if row.clustering_key else None, # type: ignore 1234 ) 1235 for row in df.itertuples() 1236 ] 1237 1238 def _update_clustering_key(self, operation: TableAlterClusterByOperation) -> None: 1239 cluster_key_expressions = getattr(operation, "cluster_key_expressions", []) 1240 bq_table = self._get_table(operation.target_table) 1241 1242 rendered_columns = [c.sql(dialect=self.dialect) for c in cluster_key_expressions] 1243 bq_table.clustering_fields = ( 1244 rendered_columns or None 1245 ) # causes a drop of the key if cluster_by is empty or None 1246 1247 self._db_call(self.client.update_table, table=bq_table, fields=["clustering_fields"]) 1248 1249 if cluster_key_expressions: 1250 # BigQuery only applies new clustering going forward, so this rewrites the columns to apply the new clustering to historical data 1251 # ref: https://cloud.google.com/bigquery/docs/creating-clustered-tables#modifying-cluster-spec 1252 self.execute( 1253 exp.update( 1254 operation.target_table, 1255 {c: c for c in cluster_key_expressions}, 1256 where=exp.true(), 1257 ) 1258 ) 1259 1260 def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.Expression: 1261 return exp.func("FORMAT", exp.Literal.string(f"%.{precision}f"), col) 1262 1263 def _normalize_nested_value(self, col: exp.Expression) -> exp.Expression: 1264 return exp.func("TO_JSON_STRING", col, dialect=self.dialect) 1265 1266 @t.overload 1267 def _columns_to_types( 1268 self, 1269 query_or_df: DF, 1270 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 1271 source_columns: t.Optional[t.List[str]] = None, 1272 ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... 1273 1274 @t.overload 1275 def _columns_to_types( 1276 self, 1277 query_or_df: Query, 1278 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 1279 source_columns: t.Optional[t.List[str]] = None, 1280 ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... 1281 1282 def _columns_to_types( 1283 self, 1284 query_or_df: QueryOrDF, 1285 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 1286 source_columns: t.Optional[t.List[str]] = None, 1287 ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: 1288 if ( 1289 not target_columns_to_types 1290 and bigframes 1291 and isinstance(query_or_df, bigframes.dataframe.DataFrame) 1292 ): 1293 # using dry_run=True attempts to prevent the DataFrame from being materialized just to read the column types from it 1294 dtypes = query_or_df.to_pandas(dry_run=True).columnDtypes 1295 target_columns_to_types = columns_to_types_from_dtypes(dtypes.items()) 1296 return target_columns_to_types, list(source_columns or target_columns_to_types) 1297 1298 return super()._columns_to_types( 1299 query_or_df, target_columns_to_types, source_columns=source_columns 1300 ) 1301 1302 def _native_df_to_pandas_df( 1303 self, 1304 query_or_df: QueryOrDF, 1305 ) -> t.Union[Query, pd.DataFrame]: 1306 if bigframes and isinstance(query_or_df, bigframes.dataframe.DataFrame): 1307 return query_or_df.to_pandas() 1308 1309 return super()._native_df_to_pandas_df(query_or_df) 1310 1311 @property 1312 def _query_data(self) -> t.Any: 1313 return self._connection_pool.get_attribute("query_data") 1314 1315 @_query_data.setter 1316 def _query_data(self, value: t.Any) -> None: 1317 self._connection_pool.set_attribute("query_data", value) 1318 1319 @property 1320 def _query_job(self) -> t.Optional[QueryJob]: 1321 return self._connection_pool.get_attribute("query_job") 1322 1323 @_query_job.setter 1324 def _query_job(self, value: t.Any) -> None: 1325 self._connection_pool.set_attribute("query_job", value) 1326 1327 @property 1328 def _session_id(self) -> t.Any: 1329 return self._connection_pool.get_attribute("session_id") 1330 1331 @_session_id.setter 1332 def _session_id(self, value: t.Any) -> None: 1333 self._connection_pool.set_attribute("session_id", value) 1334 1335 def _get_current_schema(self) -> str: 1336 raise NotImplementedError("BigQuery does not support current schema") 1337 1338 def _get_bq_dataset_location(self, project: str, dataset: str) -> str: 1339 return self._db_call(self.client.get_dataset, dataset_ref=f"{project}.{dataset}").location 1340 1341 def _get_grant_expression(self, table: exp.Table) -> exp.Expression: 1342 if not table.db: 1343 raise ValueError( 1344 f"Table {table.sql(dialect=self.dialect)} does not have a schema (dataset)" 1345 ) 1346 project = table.catalog or self.get_current_catalog() 1347 if not project: 1348 raise ValueError( 1349 f"Table {table.sql(dialect=self.dialect)} does not have a catalog (project)" 1350 ) 1351 1352 dataset = table.db 1353 table_name = table.name 1354 location = self._get_bq_dataset_location(project, dataset) 1355 1356 # https://cloud.google.com/bigquery/docs/information-schema-object-privileges 1357 # OBJECT_PRIVILEGES is a project-level INFORMATION_SCHEMA view with regional qualifier 1358 object_privileges_table = exp.to_table( 1359 f"`{project}`.`region-{location}`.INFORMATION_SCHEMA.{self.GRANT_INFORMATION_SCHEMA_TABLE_NAME}", 1360 dialect=self.dialect, 1361 ) 1362 return ( 1363 exp.select("privilege_type", "grantee") 1364 .from_(object_privileges_table) 1365 .where( 1366 exp.and_( 1367 exp.column("object_schema").eq(exp.Literal.string(dataset)), 1368 exp.column("object_name").eq(exp.Literal.string(table_name)), 1369 # Filter out current_user 1370 # BigQuery grantees format: "user:email" or "group:name" 1371 exp.func("split", exp.column("grantee"), exp.Literal.string(":"))[ 1372 exp.func("OFFSET", exp.Literal.number("1")) 1373 ].neq(self.CURRENT_USER_OR_ROLE_EXPRESSION), 1374 ) 1375 ) 1376 ) 1377 1378 @staticmethod 1379 def _grant_object_kind(table_type: DataObjectType) -> str: 1380 if table_type == DataObjectType.VIEW: 1381 return "VIEW" 1382 if table_type == DataObjectType.MATERIALIZED_VIEW: 1383 # We actually need to use "MATERIALIZED VIEW" here even though it's not listed 1384 # as a supported resource_type in the BigQuery DCL doc: 1385 # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-control-language 1386 return "MATERIALIZED VIEW" 1387 return "TABLE" 1388 1389 def _dcl_grants_config_expr( 1390 self, 1391 dcl_cmd: t.Type[DCL], 1392 table: exp.Table, 1393 grants_config: GrantsConfig, 1394 table_type: DataObjectType = DataObjectType.TABLE, 1395 ) -> t.List[exp.Expression]: 1396 expressions: t.List[exp.Expression] = [] 1397 if not grants_config: 1398 return expressions 1399 1400 # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-control-language 1401 1402 def normalize_principal(p: str) -> str: 1403 if ":" not in p: 1404 raise ValueError(f"Principal '{p}' missing a prefix label") 1405 1406 # allUsers and allAuthenticatedUsers special groups that are cas-sensitive and must start with "specialGroup:" 1407 if p.endswith("allUsers") or p.endswith("allAuthenticatedUsers"): 1408 if not p.startswith("specialGroup:"): 1409 raise ValueError( 1410 f"Special group principal '{p}' must start with 'specialGroup:' prefix label" 1411 ) 1412 return p 1413 1414 label, principal = p.split(":", 1) 1415 # always lowercase principals 1416 return f"{label}:{principal.lower()}" 1417 1418 object_kind = self._grant_object_kind(table_type) 1419 for privilege, principals in grants_config.items(): 1420 if not principals: 1421 continue 1422 1423 noramlized_principals = [exp.Literal.string(normalize_principal(p)) for p in principals] 1424 args: t.Dict[str, t.Any] = { 1425 "privileges": [exp.GrantPrivilege(this=exp.to_identifier(privilege, quoted=True))], 1426 "securable": table.copy(), 1427 "principals": noramlized_principals, 1428 } 1429 1430 if object_kind: 1431 args["kind"] = exp.Var(this=object_kind) 1432 1433 expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] 1434 1435 return expressions 1436 1437 1438class _ErrorCounter: 1439 """ 1440 A class that counts errors and determines whether or not to retry based on the number of errors and the error 1441 type. 1442 1443 Reference implementation: https://github.com/dbt-labs/dbt-bigquery/blob/8339a034929b12e027f0a143abf46582f3f6ffbc/dbt/adapters/bigquery/connections.py#L672 1444 1445 TODO: Implement a retry configuration that works across all engines 1446 """ 1447 1448 def __init__(self, num_retries: int) -> None: 1449 self.num_retries = num_retries 1450 self.error_count = 0 1451 1452 @property 1453 def retryable_errors(self) -> t.Tuple[t.Type[Exception], ...]: 1454 try: 1455 from google.cloud.exceptions import ServerError 1456 except ModuleNotFoundError: 1457 from google.api_core.exceptions import ServerError 1458 from requests.exceptions import ConnectionError 1459 1460 return (ServerError, ConnectionError) 1461 1462 def _is_retryable(self, error: BaseException) -> bool: 1463 from google.api_core.exceptions import Forbidden 1464 1465 if isinstance(error, self.retryable_errors): 1466 return True 1467 if isinstance(error, Forbidden) and any( 1468 e["reason"] == "rateLimitExceeded" for e in error.errors 1469 ): 1470 return True 1471 return False 1472 1473 def should_retry(self, error: BaseException) -> bool: 1474 if self.num_retries == 0: 1475 return False 1476 self.error_count += 1 1477 if self._is_retryable(error) and self.error_count <= self.num_retries: 1478 logger.info(f"Retry Num {self.error_count} of {self.num_retries}. Error: {repr(error)}") 1479 return True 1480 return False 1481 1482 1483def select_partitions_expr( 1484 schema: str, 1485 table_name: str, 1486 data_type: t.Union[str, exp.DataType], 1487 granularity: t.Optional[str] = None, 1488 agg_func: str = "MAX", 1489 catalog: t.Optional[str] = None, 1490) -> str: 1491 """Generates a SQL expression that aggregates partition values for a table. 1492 1493 Args: 1494 schema: The schema (BigQuery dataset) of the table. 1495 table_name: The name of the table. 1496 data_type: The data type of the partition column. 1497 granularity: The granularity of the partition. Supported values are: 'day', 'month', 'year' and 'hour'. 1498 agg_func: The aggregation function to use. 1499 catalog: The catalog (BigQuery project ID) of the table. 1500 1501 Returns: 1502 A SELECT statement that aggregates partition values for a table. 1503 """ 1504 partitions_table_name = f"`{schema}`.INFORMATION_SCHEMA.PARTITIONS" 1505 if catalog: 1506 partitions_table_name = f"`{catalog}`.{partitions_table_name}" 1507 1508 if isinstance(data_type, exp.DataType): 1509 data_type = data_type.sql(dialect="bigquery") 1510 data_type = data_type.upper() 1511 1512 parse_fun = f"PARSE_{data_type}" if data_type in ("DATE", "DATETIME", "TIMESTAMP") else None 1513 if parse_fun: 1514 granularity = granularity or "day" 1515 parse_format = GRANULARITY_TO_PARTITION_FORMAT[granularity.lower()] 1516 partition_expr = exp.func( 1517 parse_fun, 1518 exp.Literal.string(parse_format), 1519 exp.column("partition_id"), 1520 dialect="bigquery", 1521 ) 1522 else: 1523 partition_expr = exp.cast(exp.column("partition_id"), "INT64", dialect="bigquery") 1524 1525 return ( 1526 exp.select(exp.func(agg_func, partition_expr)) 1527 .from_(partitions_table_name, dialect="bigquery") 1528 .where( 1529 f"table_name = '{table_name}' AND partition_id IS NOT NULL AND partition_id != '__NULL__'", 1530 copy=False, 1531 ) 1532 .sql(dialect="bigquery") 1533 ) 1534 1535 1536GRANULARITY_TO_PARTITION_FORMAT = { 1537 "day": "%Y%m%d", 1538 "month": "%Y%m", 1539 "year": "%Y", 1540 "hour": "%Y%m%d%H", 1541}
59@set_catalog() 60class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchemaMixin): 61 """ 62 BigQuery Engine Adapter using the `google-cloud-bigquery` library's DB API. 63 """ 64 65 DIALECT = "bigquery" 66 DEFAULT_BATCH_SIZE = 1000 67 SUPPORTS_TRANSACTIONS = False 68 SUPPORTS_MATERIALIZED_VIEWS = True 69 SUPPORTS_CLONING = True 70 SUPPORTS_GRANTS = True 71 CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("session_user") 72 SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True 73 USE_CATALOG_IN_GRANTS = True 74 GRANT_INFORMATION_SCHEMA_TABLE_NAME = "OBJECT_PRIVILEGES" 75 MAX_TABLE_COMMENT_LENGTH = 1024 76 MAX_COLUMN_COMMENT_LENGTH = 1024 77 SUPPORTS_QUERY_EXECUTION_TRACKING = True 78 SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] 79 INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE 80 81 SCHEMA_DIFFER_KWARGS = { 82 "compatible_types": { 83 exp.DataType.build("INT64", dialect=DIALECT): { 84 exp.DataType.build("NUMERIC", dialect=DIALECT), 85 exp.DataType.build("FLOAT64", dialect=DIALECT), 86 exp.DataType.build("BIGNUMERIC", dialect=DIALECT), 87 }, 88 exp.DataType.build("NUMERIC", dialect=DIALECT): { 89 exp.DataType.build("FLOAT64", dialect=DIALECT), 90 exp.DataType.build("BIGNUMERIC", dialect=DIALECT), 91 }, 92 exp.DataType.build("DATE", dialect=DIALECT): { 93 exp.DataType.build("DATETIME", dialect=DIALECT), 94 }, 95 }, 96 "coerceable_types": { 97 exp.DataType.build("FLOAT64", dialect=DIALECT): { 98 exp.DataType.build("BIGNUMERIC", dialect=DIALECT), 99 }, 100 }, 101 "support_coercing_compatible_types": True, 102 "parameterized_type_defaults": { 103 exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 9), (0,)], 104 exp.DataType.build("BIGDECIMAL", dialect=DIALECT).this: [(76.76, 38), (0,)], 105 }, 106 "types_with_unlimited_length": { 107 # parameterized `STRING(n)` can ALTER to unparameterized `STRING` 108 exp.DataType.build("STRING", dialect=DIALECT).this: { 109 exp.DataType.build("STRING", dialect=DIALECT).this, 110 }, 111 # parameterized `BYTES(n)` can ALTER to unparameterized `BYTES` 112 exp.DataType.build("BYTES", dialect=DIALECT).this: { 113 exp.DataType.build("BYTES", dialect=DIALECT).this, 114 }, 115 }, 116 "nested_support": NestedSupport.ALL_BUT_DROP, 117 } 118 119 @property 120 def client(self) -> BigQueryClient: 121 return self.connection._client 122 123 @property 124 def bigframe(self) -> t.Optional[BigframeSession]: 125 if bigframes: 126 options = bigframes.BigQueryOptions( 127 credentials=self.client._credentials, 128 project=self.client.project, 129 location=self.client.location, 130 ) 131 return bigframes.connect(context=options) 132 return None 133 134 @property 135 def _job_params(self) -> t.Dict[str, t.Any]: 136 from sqlmesh.core.config.connection import BigQueryPriority 137 138 params = { 139 "use_legacy_sql": False, 140 "priority": self._extra_config.get( 141 "priority", BigQueryPriority.INTERACTIVE.bigquery_constant 142 ), 143 } 144 if self._extra_config.get("maximum_bytes_billed"): 145 params["maximum_bytes_billed"] = self._extra_config.get("maximum_bytes_billed") 146 if self.correlation_id: 147 # BigQuery label keys must be lowercase 148 key = self.correlation_id.job_type.value.lower() 149 params["labels"] = {key: self.correlation_id.job_id} 150 return params 151 152 @property 153 def catalog_support(self) -> CatalogSupport: 154 return CatalogSupport.FULL_SUPPORT 155 156 def _df_to_source_queries( 157 self, 158 df: DF, 159 target_columns_to_types: t.Dict[str, exp.DataType], 160 batch_size: int, 161 target_table: TableName, 162 source_columns: t.Optional[t.List[str]] = None, 163 ) -> t.List[SourceQuery]: 164 import pandas as pd 165 166 source_columns_to_types = get_source_columns_to_types( 167 target_columns_to_types, source_columns 168 ) 169 170 temp_bq_table = self.__get_temp_bq_table( 171 self._get_temp_table(target_table or "pandas"), source_columns_to_types 172 ) 173 temp_table = exp.table_( 174 temp_bq_table.table_id, 175 db=temp_bq_table.dataset_id, 176 catalog=temp_bq_table.project, 177 ) 178 179 def query_factory() -> Query: 180 ordered_df = df[list(source_columns_to_types)] 181 if bigframes_pd and isinstance(ordered_df, bigframes_pd.DataFrame): 182 ordered_df.to_gbq( 183 f"{temp_bq_table.project}.{temp_bq_table.dataset_id}.{temp_bq_table.table_id}", 184 if_exists="replace", 185 ) 186 elif not self.table_exists(temp_table): 187 # Make mypy happy 188 assert isinstance(ordered_df, pd.DataFrame) 189 self._db_call(self.client.create_table, table=temp_bq_table, exists_ok=False) 190 result = self.__load_pandas_to_table( 191 temp_bq_table, ordered_df, source_columns_to_types, replace=False 192 ) 193 if result.errors: 194 raise SQLMeshError(result.errors) 195 return exp.select( 196 *self._casted_columns(target_columns_to_types, source_columns=source_columns) 197 ).from_(temp_table) 198 199 return [ 200 SourceQuery( 201 query_factory=query_factory, 202 cleanup_func=lambda: self.drop_table(temp_table), 203 ) 204 ] 205 206 def close(self) -> t.Any: 207 # Cancel all pending query jobs across all threads 208 all_query_jobs = self._connection_pool.get_all_attributes("query_job") 209 for query_job in all_query_jobs: 210 if query_job: 211 try: 212 if not self._db_call(query_job.done): 213 self._db_call(query_job.cancel) 214 logger.debug( 215 "Cancelled BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", 216 query_job.project, 217 query_job.location, 218 query_job.job_id, 219 ) 220 except Exception as ex: 221 logger.debug( 222 "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", 223 query_job.project, 224 query_job.location, 225 query_job.job_id, 226 str(ex), 227 ) 228 229 return super().close() 230 231 def _begin_session(self, properties: SessionProperties) -> None: 232 from google.cloud.bigquery import QueryJobConfig 233 234 query_label_property = properties.get("query_label") 235 parsed_query_label: list[tuple[str, str]] = [] 236 if isinstance(query_label_property, (exp.Array, exp.Paren, exp.Tuple)): 237 label_tuples = ( 238 [query_label_property.unnest()] 239 if isinstance(query_label_property, exp.Paren) 240 else query_label_property.expressions 241 ) 242 243 # query_label is a Paren, Array or Tuple of 2-tuples and validated at load time 244 parsed_query_label.extend( 245 (label_tuple.expressions[0].name, label_tuple.expressions[1].name) 246 for label_tuple in label_tuples 247 ) 248 elif query_label_property is not None: 249 raise SQLMeshError( 250 "Invalid value for `session_properties.query_label`. Must be an array or tuple." 251 ) 252 253 if self.correlation_id: 254 parsed_query_label.append( 255 (self.correlation_id.job_type.value.lower(), self.correlation_id.job_id) 256 ) 257 258 if parsed_query_label: 259 query_label_str = ",".join([":".join(label) for label in parsed_query_label]) 260 query = f'SET @@query_label = "{query_label_str}";SELECT 1;' 261 else: 262 query = "SELECT 1;" 263 264 job = self.client.query( 265 query, 266 job_config=QueryJobConfig(create_session=True), 267 ) 268 session_info = job.session_info 269 session_id = session_info.session_id if session_info else None 270 self._session_id = session_id 271 job.result() 272 273 def _end_session(self) -> None: 274 self._session_id = None 275 276 def _is_session_active(self) -> bool: 277 return self._session_id is not None 278 279 def get_current_catalog(self) -> t.Optional[str]: 280 """Returns the catalog name of the current connection.""" 281 return self.client.project 282 283 def set_current_catalog(self, catalog: str) -> None: 284 """Sets the catalog name of the current connection.""" 285 self.client.project = catalog 286 287 def create_schema( 288 self, 289 schema_name: SchemaName, 290 ignore_if_exists: bool = True, 291 warn_on_error: bool = True, 292 properties: t.List[exp.Expression] = [], 293 ) -> None: 294 """Create a schema from a name or qualified table name.""" 295 from google.api_core.exceptions import Conflict 296 297 try: 298 super().create_schema( 299 schema_name, 300 ignore_if_exists=ignore_if_exists, 301 warn_on_error=False, 302 ) 303 except Exception as e: 304 is_already_exists_error = isinstance(e, Conflict) and "Already Exists:" in str(e) 305 if is_already_exists_error and ignore_if_exists: 306 return 307 if not warn_on_error: 308 raise 309 logger.warning("Failed to create schema '%s': %s", schema_name, e) 310 311 def get_bq_schema(self, table_name: TableName) -> t.List[bigquery.SchemaField]: 312 table = exp.to_table(table_name) 313 if len(table.parts) == 3 and "." in table.name: 314 self.execute(exp.select("*").from_(table).limit(0)) 315 query_job = self._query_job 316 assert query_job is not None 317 return query_job._query_results.schema 318 return self._get_table(table).schema 319 320 def columns( 321 self, table_name: TableName, include_pseudo_columns: bool = False 322 ) -> t.Dict[str, exp.DataType]: 323 """Fetches column names and types for the target table.""" 324 325 def dtype_to_sql( 326 dtype: t.Optional[StandardSqlDataType], field: bigquery.SchemaField 327 ) -> str: 328 assert dtype 329 assert field 330 331 kind = dtype.type_kind 332 assert kind 333 334 # Not using the enum value to preserve compatibility with older versions 335 # of the BigQuery library. 336 if kind.name == "ARRAY": 337 return f"ARRAY<{dtype_to_sql(dtype.array_element_type, field)}>" 338 if kind.name == "STRUCT": 339 struct_type = dtype.struct_type 340 assert struct_type 341 fields = ", ".join( 342 f"{struct_field.name} {dtype_to_sql(struct_field.type, nested_field)}" 343 for struct_field, nested_field in zip(struct_type.fields, field.fields) 344 ) 345 return f"STRUCT<{fields}>" 346 if kind.name == "TYPE_KIND_UNSPECIFIED": 347 field_type = field.field_type 348 349 if field_type == "RANGE": 350 # If the field is a RANGE then `range_element_type` should be set to 351 # one of `"DATE"`, `"DATETIME"` or `"TIMESTAMP"`. 352 return f"RANGE<{field.range_element_type.element_type}>" 353 354 return field_type 355 356 return kind.name 357 358 def create_mapping_schema( 359 schema: t.Sequence[bigquery.SchemaField], 360 ) -> t.Dict[str, exp.DataType]: 361 return { 362 field.name: exp.DataType.build( 363 dtype_to_sql(field.to_standard_sql().type, field), dialect=self.dialect 364 ) 365 for field in schema 366 } 367 368 table = exp.to_table(table_name) 369 if len(table.parts) == 3 and "." in table.name: 370 # The client's `get_table` method can't handle paths with >3 identifiers 371 self.execute(exp.select("*").from_(table).limit(0)) 372 query_job = self._query_job 373 assert query_job is not None 374 375 query_results = query_job._query_results 376 columns = create_mapping_schema(query_results.schema) 377 else: 378 bq_table = self._get_table(table) 379 columns = create_mapping_schema(bq_table.schema) 380 381 if include_pseudo_columns: 382 if bq_table.time_partitioning and not bq_table.time_partitioning.field: 383 columns["_PARTITIONTIME"] = exp.DataType.build("TIMESTAMP", dialect="bigquery") 384 if bq_table.time_partitioning.type_ == "DAY": 385 columns["_PARTITIONDATE"] = exp.DataType.build("DATE") 386 if bq_table.table_id.endswith("*"): 387 columns["_TABLE_SUFFIX"] = exp.DataType.build("STRING", dialect="bigquery") 388 if ( 389 bq_table.external_data_configuration is not None 390 and bq_table.external_data_configuration.source_format 391 in ( 392 "CSV", 393 "NEWLINE_DELIMITED_JSON", 394 "AVRO", 395 "PARQUET", 396 "ORC", 397 "DATASTORE_BACKUP", 398 ) 399 ): 400 columns["_FILE_NAME"] = exp.DataType.build("STRING", dialect="bigquery") 401 402 return columns 403 404 def alter_table( 405 self, 406 alter_expressions: t.Union[t.List[exp.Alter], t.List[TableAlterOperation]], 407 ) -> None: 408 """ 409 Performs the alter statements to change the current table into the structure of the target table, 410 and uses the API to add columns to structs, where SQL is not supported. 411 """ 412 if not alter_expressions: 413 return 414 415 cluster_by_operations, alter_statements = [], [] 416 for e in alter_expressions: 417 if isinstance(e, TableAlterClusterByOperation): 418 cluster_by_operations.append(e) 419 elif isinstance(e, TableAlterOperation): 420 alter_statements.append(e.expression) 421 else: 422 alter_statements.append(e) 423 424 for op in cluster_by_operations: 425 self._update_clustering_key(op) 426 427 nested_fields, non_nested_expressions = self._split_alter_expressions(alter_statements) 428 429 if nested_fields: 430 self._update_table_schema_nested_fields(nested_fields, alter_statements[0].this) 431 432 if non_nested_expressions: 433 super().alter_table(non_nested_expressions) 434 435 def fetchone( 436 self, 437 query: t.Union[exp.Expression, str], 438 ignore_unsupported_errors: bool = False, 439 quote_identifiers: bool = False, 440 ) -> t.Optional[t.Tuple]: 441 """ 442 BigQuery's `fetchone` method doesn't call execute and therefore would not benefit from the execute 443 configuration we have in place. Therefore this implementation calls execute instead. 444 """ 445 self.execute( 446 query, 447 ignore_unsupported_errors=ignore_unsupported_errors, 448 quote_identifiers=quote_identifiers, 449 ) 450 try: 451 return next(self._query_data) 452 except StopIteration: 453 return None 454 455 def fetchall( 456 self, 457 query: t.Union[exp.Expression, str], 458 ignore_unsupported_errors: bool = False, 459 quote_identifiers: bool = False, 460 ) -> t.List[t.Tuple]: 461 """ 462 BigQuery's `fetchone` method doesn't call execute and therefore would not benefit from the execute 463 configuration we have in place. Therefore this implementation calls execute instead. 464 """ 465 self.execute( 466 query, 467 ignore_unsupported_errors=ignore_unsupported_errors, 468 quote_identifiers=quote_identifiers, 469 ) 470 return list(self._query_data) 471 472 def _split_alter_expressions( 473 self, 474 alter_expressions: t.List[exp.Alter], 475 ) -> t.Tuple[NestedFieldsDict, t.List[exp.Alter]]: 476 """ 477 Returns a dictionary of the nested fields to add and a list of the non-nested alter expressions. 478 """ 479 nested_fields_to_add: NestedFieldsDict = defaultdict(list) 480 non_nested_expressions = [] 481 482 for alter_expression in alter_expressions: 483 action = alter_expression.args["actions"][0] 484 if ( 485 isinstance(action, exp.ColumnDef) 486 and isinstance(action.this, exp.Dot) 487 and isinstance(action.kind, exp.DataType) 488 ): 489 root_field, *leaf_fields = action.this.this.sql(dialect=self.dialect).split(".") 490 new_field = action.this.expression.sql(dialect=self.dialect) 491 data_type = action.kind.sql(dialect=self.dialect) 492 nested_fields_to_add[root_field].append((new_field, data_type, leaf_fields)) 493 else: 494 non_nested_expressions.append(alter_expression) 495 496 return nested_fields_to_add, non_nested_expressions 497 498 def _build_nested_fields( 499 self, 500 current_fields: t.List[bigquery.SchemaField], 501 fields_to_add: t.List[NestedField], 502 ) -> t.List[bigquery.SchemaField]: 503 """ 504 Recursively builds and updates the schema fields with the new nested fields. 505 """ 506 from google.cloud import bigquery 507 508 new_fields = [] 509 root: t.List[t.Tuple[str, str]] = [] 510 leaves: NestedFieldsDict = defaultdict(list) 511 for new_field, data_type, leaf_fields in fields_to_add: 512 if leaf_fields: 513 leaves[leaf_fields[0]].append((new_field, data_type, leaf_fields[1:])) 514 else: 515 root.append((new_field, data_type)) 516 517 for field in current_fields: 518 # If the new fields are nested, we need to recursively build them 519 if field.name in leaves: 520 subfields = list(field.fields) 521 subfields = self._build_nested_fields(subfields, leaves[field.name]) 522 new_fields.append( 523 bigquery.SchemaField( 524 field.name, "RECORD", mode=field.mode, fields=tuple(subfields) 525 ) 526 ) 527 else: 528 new_fields.append(field) 529 530 # Build and append the new root-level fields 531 new_fields.extend( 532 self.__get_bq_schemafield( 533 new_field[0], exp.DataType.build(new_field[1], dialect=self.dialect) 534 ) 535 for new_field in root 536 ) 537 return new_fields 538 539 def _update_table_schema_nested_fields( 540 self, nested_fields_to_add: NestedFieldsDict, table_name: str 541 ) -> None: 542 """ 543 Updates a BigQuery table schema by adding the new nested fields provided. 544 """ 545 from google.cloud import bigquery 546 547 table = self._get_table(table_name) 548 original_schema = table.schema 549 new_schema = [] 550 for field in original_schema: 551 if field.name in nested_fields_to_add: 552 fields = self._build_nested_fields( 553 list(field.fields), nested_fields_to_add[field.name] 554 ) 555 new_schema.append( 556 bigquery.SchemaField( 557 field.name, 558 "RECORD", 559 mode=field.mode, 560 fields=tuple(fields), 561 ) 562 ) 563 else: 564 new_schema.append(field) 565 566 if new_schema != original_schema: 567 table.schema = new_schema 568 self.client.update_table(table, ["schema"]) 569 570 def __load_pandas_to_table( 571 self, 572 table: bigquery.Table, 573 df: pd.DataFrame, 574 columns_to_types: t.Dict[str, exp.DataType], 575 replace: bool = False, 576 ) -> BigQueryQueryResult: 577 """ 578 Loads a pandas dataframe into a table in BigQuery. Will do an overwrite if replace is True. Note that 579 the replace will replace the entire table, not just the rows that are in the dataframe. 580 """ 581 from google.cloud import bigquery 582 583 job_config = bigquery.job.LoadJobConfig(schema=self.__get_bq_schema(columns_to_types)) 584 if replace: 585 job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE 586 logger.info(f"Loading dataframe to BigQuery. Table Path: {table.path}") 587 # This client call does not support retry so we don't use the `_db_call` method. 588 result = self.__retry( 589 self.__db_load_table_from_dataframe, 590 )(df=df, table=table, job_config=job_config) 591 if result.errors: 592 raise SQLMeshError(result.errors) 593 return result 594 595 def __db_load_table_from_dataframe( 596 self, df: pd.DataFrame, table: bigquery.Table, job_config: bigquery.LoadJobConfig 597 ) -> BigQueryQueryResult: 598 job = self.client.load_table_from_dataframe( 599 dataframe=df, destination=table, job_config=job_config 600 ) 601 return self._db_call(job.result) 602 603 def __get_bq_schemafield(self, name: str, tpe: exp.DataType) -> bigquery.SchemaField: 604 from google.cloud import bigquery 605 606 mode = "NULLABLE" 607 if tpe.is_type(exp.DataType.Type.ARRAY): 608 mode = "REPEATED" 609 tpe = tpe.expressions[0] 610 611 field_type = tpe.sql(dialect=self.dialect) 612 fields = [] 613 if tpe.is_type(*exp.DataType.NESTED_TYPES): 614 field_type = "RECORD" 615 for inner_field in tpe.expressions: 616 if isinstance(inner_field, exp.ColumnDef): 617 inner_name = inner_field.this.sql(dialect=self.dialect) 618 inner_type = inner_field.kind 619 if inner_type is None: 620 raise ValueError( 621 f"cannot convert unknown type to BQ schema field {inner_field}" 622 ) 623 fields.append(self.__get_bq_schemafield(name=inner_name, tpe=inner_type)) 624 else: 625 raise ValueError(f"unexpected nested expression {inner_field}") 626 627 return bigquery.SchemaField( 628 name=name, 629 field_type=field_type, 630 mode=mode, 631 fields=fields, 632 ) 633 634 def __get_bq_schema( 635 self, columns_to_types: t.Dict[str, exp.DataType] 636 ) -> t.List[bigquery.SchemaField]: 637 """ 638 Returns a bigquery schema object from a dictionary of column names to types. 639 """ 640 641 precisionless_col_to_types = { 642 col_name: remove_precision_parameterized_types(col_type) 643 for col_name, col_type in columns_to_types.items() 644 } 645 return [ 646 self.__get_bq_schemafield(name=col_name, tpe=t.cast(exp.DataType, col_type)) 647 for col_name, col_type in precisionless_col_to_types.items() 648 ] 649 650 def __get_temp_bq_table( 651 self, table: exp.Table, columns_to_type: t.Dict[str, exp.DataType] 652 ) -> bigquery.Table: 653 """ 654 Returns a bigquery table object that is temporary and will expire in 3 hours. 655 """ 656 bq_table = self.__get_bq_table(table, columns_to_type) 657 bq_table.expires = to_datetime("in 3 hours") 658 return bq_table 659 660 def __get_bq_table( 661 self, table: TableName, columns_to_type: t.Dict[str, exp.DataType] 662 ) -> bigquery.Table: 663 """ 664 Returns a bigquery table object with a schema defines that matches the columns_to_type dictionary. 665 """ 666 from google.cloud import bigquery 667 668 table_ = exp.to_table(table).copy() 669 670 if not table_.catalog: 671 table_.set("catalog", exp.to_identifier(self.default_catalog)) 672 673 return bigquery.Table( 674 table_ref=self._table_name(table_), 675 schema=self.__get_bq_schema(columns_to_type), 676 ) 677 678 @property 679 def __retry(self) -> Retry: 680 from google.api_core import retry 681 682 return retry.Retry( 683 predicate=_ErrorCounter(self._extra_config["job_retries"]).should_retry, 684 deadline=self._extra_config.get("job_retry_deadline_seconds"), 685 initial=1.0, 686 maximum=3.0, 687 ) 688 689 def insert_overwrite_by_partition( 690 self, 691 table_name: TableName, 692 query_or_df: QueryOrDF, 693 partitioned_by: t.List[exp.Expression], 694 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 695 source_columns: t.Optional[t.List[str]] = None, 696 ) -> None: 697 if len(partitioned_by) != 1: 698 raise SQLMeshError( 699 f"Bigquery only supports partitioning by one column, {len(partitioned_by)} were provided." 700 ) 701 702 partition_exp = partitioned_by[0] 703 partition_column = partition_exp.find(exp.Column) 704 705 granularity = partition_exp.args.get("unit") 706 if granularity: 707 granularity = granularity.name.lower() 708 709 if not partition_column: 710 partition_sql = partition_exp.sql(dialect=self.dialect) 711 raise SQLMeshError( 712 f"The partition expression '{partition_sql}' doesn't contain a column." 713 ) 714 with ( 715 self.session({}), 716 self.temp_table( 717 query_or_df, 718 name=table_name, 719 partitioned_by=partitioned_by, 720 source_columns=source_columns, 721 ) as temp_table_name, 722 ): 723 if target_columns_to_types is None or target_columns_to_types[ 724 partition_column.name 725 ] == exp.DataType.build("unknown"): 726 target_columns_to_types = self.columns(table_name) 727 728 partition_type_sql = target_columns_to_types[partition_column.name].sql( 729 dialect=self.dialect 730 ) 731 732 select_array_agg_partitions = select_partitions_expr( 733 temp_table_name.db, 734 temp_table_name.name, 735 partition_type_sql, 736 granularity=granularity, 737 agg_func="ARRAY_AGG", 738 catalog=temp_table_name.catalog or self.default_catalog, 739 ) 740 741 self.execute( 742 f"DECLARE _sqlmesh_target_partitions_ ARRAY<{partition_type_sql}> DEFAULT ({select_array_agg_partitions});" 743 ) 744 745 where = t.cast(exp.Condition, partition_exp).isin(unnest="_sqlmesh_target_partitions_") 746 747 self._insert_overwrite_by_condition( 748 table_name, 749 [SourceQuery(query_factory=lambda: exp.select("*").from_(temp_table_name))], 750 target_columns_to_types, 751 where=where, 752 ) 753 754 def table_exists(self, table_name: TableName) -> bool: 755 table = exp.to_table(table_name) 756 data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) 757 if data_object_cache_key in self._data_object_cache: 758 logger.debug("Table existence cache hit: %s", data_object_cache_key) 759 return self._data_object_cache[data_object_cache_key] is not None 760 761 try: 762 from google.cloud.exceptions import NotFound 763 except ModuleNotFoundError: 764 from google.api_core.exceptions import NotFound 765 766 try: 767 self._get_table(table_name) 768 return True 769 except NotFound: 770 return False 771 772 def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]: 773 from sqlmesh.utils.date import to_timestamp 774 775 datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list) 776 for table_name in table_names: 777 table = exp.to_table(table_name) 778 datasets_to_tables[table.db].append(table.name) 779 780 results = [] 781 782 for dataset, tables in datasets_to_tables.items(): 783 query = ( 784 f"SELECT TIMESTAMP_MILLIS(last_modified_time) FROM `{dataset}.__TABLES__` WHERE " 785 ) 786 for i, table_name in enumerate(tables): 787 query += f"TABLE_ID = '{table_name}'" 788 if i < len(tables) - 1: 789 query += " OR " 790 results.extend(self.fetchall(query)) 791 792 return [to_timestamp(row[0]) for row in results] 793 794 def _get_table(self, table_name: TableName) -> BigQueryTable: 795 """ 796 Returns a BigQueryTable object for the given table name. 797 798 Raises: `google.cloud.exceptions.NotFound` if the table does not exist. 799 """ 800 return self._db_call(self.client.get_table, table=self._table_name(table_name)) 801 802 def _table_name(self, table_name: TableName) -> str: 803 # the api doesn't support backticks, so we can't call exp.table_name or sql 804 return ".".join(part.name for part in exp.to_table(table_name).parts) 805 806 def _fetch_native_df( 807 self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False 808 ) -> DF: 809 self.execute(query, quote_identifiers=quote_identifiers) 810 query_job = self._query_job 811 assert query_job is not None 812 return query_job.to_dataframe() 813 814 def _create_column_comments( 815 self, 816 table_name: TableName, 817 column_comments: t.Dict[str, str], 818 table_kind: str = "TABLE", 819 materialized_view: bool = False, 820 ) -> None: 821 if not (table_kind == "VIEW" and materialized_view): 822 table = self._get_table(table_name) 823 824 # convert Table object to dict 825 table_def = table.to_api_repr() 826 827 # Set column descriptions, supporting nested fields (e.g. record.field.nested_field) 828 for column, comment in column_comments.items(): 829 fields = table_def["schema"]["fields"] 830 field_names = column.split(".") 831 last_index = len(field_names) - 1 832 833 # Traverse the fields with nested fields down to leaf level 834 for idx, name in enumerate(field_names): 835 if field := next((field for field in fields if field["name"] == name), None): 836 if idx == last_index: 837 field["description"] = self._truncate_comment( 838 comment, self.MAX_COLUMN_COMMENT_LENGTH 839 ) 840 else: 841 fields = field.get("fields") or [] 842 843 # An "etag" is BQ versioning metadata that changes when an object is updated/modified. `update_table` 844 # compares the etags of the table object passed to it and the remote table, erroring if the etags 845 # don't match. We set the local etag to None to avoid this check. 846 table_def["etag"] = None 847 848 # convert dict back to a Table object 849 table = table.from_api_repr(table_def) 850 851 # update table schema 852 logger.info(f"Registering column comments for table {table_name}") 853 self._db_call(self.client.update_table, table=table, fields=["schema"]) 854 855 def _build_description_property_exp( 856 self, 857 description: str, 858 trunc_method: t.Callable, 859 ) -> exp.Property: 860 return exp.Property( 861 this=exp.to_identifier("description", quoted=True), 862 value=exp.Literal.string(trunc_method(description)), 863 ) 864 865 def _build_partitioned_by_exp( 866 self, 867 partitioned_by: t.List[exp.Expression], 868 *, 869 partition_interval_unit: t.Optional[IntervalUnit] = None, 870 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 871 **kwargs: t.Any, 872 ) -> t.Optional[exp.PartitionedByProperty]: 873 if len(partitioned_by) > 1: 874 raise SQLMeshError("BigQuery only supports partitioning by a single column") 875 876 this = partitioned_by[0] 877 if ( 878 isinstance(this, exp.Column) 879 and partition_interval_unit is not None 880 and not partition_interval_unit.is_minute 881 ): 882 column_type: t.Optional[exp.DataType] = (target_columns_to_types or {}).get(this.name) 883 884 if column_type == exp.DataType.build( 885 "date", dialect=self.dialect 886 ) and partition_interval_unit in ( 887 IntervalUnit.MONTH, 888 IntervalUnit.YEAR, 889 ): 890 trunc_func = "DATE_TRUNC" 891 elif column_type == exp.DataType.build("timestamp", dialect=self.dialect): 892 trunc_func = "TIMESTAMP_TRUNC" 893 elif column_type == exp.DataType.build("datetime", dialect=self.dialect): 894 trunc_func = "DATETIME_TRUNC" 895 else: 896 trunc_func = "" 897 898 if trunc_func: 899 this = exp.func( 900 trunc_func, 901 this, 902 exp.var(partition_interval_unit.value.upper()), 903 dialect=self.dialect, 904 ) 905 906 return exp.PartitionedByProperty(this=this) 907 908 def _build_table_properties_exp( 909 self, 910 catalog_name: t.Optional[str] = None, 911 table_format: t.Optional[str] = None, 912 storage_format: t.Optional[str] = None, 913 partitioned_by: t.Optional[t.List[exp.Expression]] = None, 914 partition_interval_unit: t.Optional[IntervalUnit] = None, 915 clustered_by: t.Optional[t.List[exp.Expression]] = None, 916 table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, 917 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 918 table_description: t.Optional[str] = None, 919 table_kind: t.Optional[str] = None, 920 **kwargs: t.Any, 921 ) -> t.Optional[exp.Properties]: 922 properties: t.List[exp.Expression] = [] 923 924 if partitioned_by and ( 925 partitioned_by_prop := self._build_partitioned_by_exp( 926 partitioned_by, 927 partition_interval_unit=partition_interval_unit, 928 target_columns_to_types=target_columns_to_types, 929 ) 930 ): 931 properties.append(partitioned_by_prop) 932 933 if clustered_by and (clustered_by_exp := self._build_clustered_by_exp(clustered_by)): 934 properties.append(clustered_by_exp) 935 936 if table_description: 937 properties.append( 938 self._build_description_property_exp( 939 table_description, self._truncate_table_comment 940 ), 941 ) 942 943 properties.extend(self._table_or_view_properties_to_expressions(table_properties)) 944 945 if properties: 946 return exp.Properties(expressions=properties) 947 return None 948 949 def _build_column_def( 950 self, 951 col_name: str, 952 column_descriptions: t.Optional[t.Dict[str, str]] = None, 953 engine_supports_schema_comments: bool = False, 954 col_type: t.Optional[exp.DATA_TYPE] = None, 955 nested_names: t.List[str] = [], 956 ) -> exp.ColumnDef: 957 # Helper function to build column definitions with column descriptions 958 def _build_struct_with_descriptions( 959 col_type: exp.DataType, 960 nested_names: t.List[str], 961 ) -> exp.DataType: 962 column_expressions = [] 963 for column_def in col_type.expressions: 964 # This is expected to be true, but this check is included as a 965 # precautionary measure in case of an unexpected edge case 966 if isinstance(column_def, exp.ColumnDef): 967 column = self._build_column_def( 968 col_name=column_def.name, 969 column_descriptions=column_descriptions, 970 engine_supports_schema_comments=engine_supports_schema_comments, 971 col_type=column_def.kind, 972 nested_names=nested_names, 973 ) 974 else: 975 column = column_def 976 column_expressions.append(column) 977 return exp.DataType(this=col_type.this, expressions=column_expressions, nested=True) 978 979 # Recursively build column definitions for BigQuery's RECORDs (struct) and REPEATED RECORDs (array of struct) 980 if isinstance(col_type, exp.DataType) and col_type.expressions: 981 expressions = col_type.expressions 982 if col_type.is_type(exp.DataType.Type.STRUCT): 983 col_type = _build_struct_with_descriptions(col_type, nested_names + [col_name]) 984 elif col_type.is_type(exp.DataType.Type.ARRAY) and expressions[0].is_type( 985 exp.DataType.Type.STRUCT 986 ): 987 col_type = exp.DataType( 988 this=exp.DataType.Type.ARRAY, 989 expressions=[ 990 _build_struct_with_descriptions( 991 col_type.expressions[0], nested_names + [col_name] 992 ) 993 ], 994 nested=True, 995 ) 996 997 return exp.ColumnDef( 998 this=exp.to_identifier(col_name), 999 kind=col_type, 1000 constraints=( 1001 self._build_col_comment_exp( 1002 ".".join(nested_names + [col_name]), column_descriptions 1003 ) 1004 if engine_supports_schema_comments and self.comments_enabled and column_descriptions 1005 else None 1006 ), 1007 ) 1008 1009 def _build_col_comment_exp( 1010 self, col_name: str, column_descriptions: t.Dict[str, str] 1011 ) -> t.List[exp.ColumnConstraint]: 1012 comment = column_descriptions.get(col_name, None) 1013 if comment: 1014 return [ 1015 exp.ColumnConstraint( 1016 kind=exp.Properties( 1017 expressions=[ 1018 self._build_description_property_exp( 1019 comment, self._truncate_column_comment 1020 ), 1021 ] 1022 ) 1023 ) 1024 ] 1025 return [] 1026 1027 def _build_view_properties_exp( 1028 self, 1029 view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, 1030 table_description: t.Optional[str] = None, 1031 **kwargs: t.Any, 1032 ) -> t.Optional[exp.Properties]: 1033 """Creates a SQLGlot table properties expression for view""" 1034 properties: t.List[exp.Expression] = [] 1035 1036 if table_description: 1037 properties.append( 1038 self._build_description_property_exp( 1039 table_description, self._truncate_table_comment 1040 ), 1041 ) 1042 1043 properties.extend(self._table_or_view_properties_to_expressions(view_properties)) 1044 1045 if properties: 1046 return exp.Properties(expressions=properties) 1047 return None 1048 1049 def _build_create_comment_table_exp( 1050 self, table: exp.Table, table_comment: str, table_kind: str 1051 ) -> exp.Comment | str: 1052 table_sql = table.sql(dialect=self.dialect, identify=True) 1053 1054 truncated_comment = self._truncate_table_comment(table_comment) 1055 comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect) 1056 1057 return f"ALTER {table_kind} {table_sql} SET OPTIONS(description = {comment_sql})" 1058 1059 def _build_create_comment_column_exp( 1060 self, table: exp.Table, column_name: str, column_comment: str, table_kind: str = "TABLE" 1061 ) -> exp.Comment | str: 1062 table_sql = table.sql(dialect=self.dialect, identify=True) 1063 column_sql = exp.column(column_name).sql(dialect=self.dialect, identify=True) 1064 1065 truncated_comment = self._truncate_column_comment(column_comment) 1066 comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect) 1067 1068 return f"ALTER {table_kind} {table_sql} ALTER COLUMN {column_sql} SET OPTIONS(description = {comment_sql})" 1069 1070 def create_state_table( 1071 self, 1072 table_name: str, 1073 target_columns_to_types: t.Dict[str, exp.DataType], 1074 primary_key: t.Optional[t.Tuple[str, ...]] = None, 1075 ) -> None: 1076 self.create_table( 1077 table_name, 1078 target_columns_to_types, 1079 ) 1080 1081 def _db_call(self, func: t.Callable[..., t.Any], *args: t.Any, **kwargs: t.Any) -> t.Any: 1082 return func( 1083 retry=self.__retry, 1084 *args, 1085 **kwargs, 1086 ) 1087 1088 def _execute( 1089 self, 1090 sql: str, 1091 track_rows_processed: bool = False, 1092 **kwargs: t.Any, 1093 ) -> None: 1094 """Execute a sql query.""" 1095 from google.cloud.bigquery import QueryJobConfig 1096 from google.cloud.bigquery.query import ConnectionProperty 1097 1098 # BigQuery's Python DB API implementation does not support retries, so we have to implement them ourselves. 1099 # So we update the cursor's query job and query data with the results of the new query job. This makes sure 1100 # that other cursor based operations execute correctly. 1101 session_id = self._session_id 1102 connection_properties = ( 1103 [ 1104 ConnectionProperty(key="session_id", value=session_id), 1105 ] 1106 if session_id 1107 else [] 1108 ) 1109 1110 job_config = QueryJobConfig(**self._job_params, connection_properties=connection_properties) 1111 self._query_job = self._db_call( 1112 self.client.query, 1113 query=sql, 1114 job_config=job_config, 1115 timeout=self._extra_config.get("job_creation_timeout_seconds"), 1116 ) 1117 query_job = self._query_job 1118 assert query_job is not None 1119 1120 logger.debug( 1121 "BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", 1122 query_job.project, 1123 query_job.location, 1124 query_job.job_id, 1125 ) 1126 1127 results = self._db_call( 1128 query_job.result, 1129 timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore 1130 ) 1131 1132 self._query_data = iter(results) if results.total_rows else iter([]) 1133 query_results = query_job._query_results 1134 self.cursor._set_rowcount(query_results) 1135 self.cursor._set_description(query_results.schema) 1136 1137 if ( 1138 track_rows_processed 1139 and self._query_execution_tracker 1140 and self._query_execution_tracker.is_tracking() 1141 ): 1142 num_rows = None 1143 if query_job.statement_type == "CREATE_TABLE_AS_SELECT": 1144 # since table was just created, number rows in table == number rows processed 1145 query_table = self.client.get_table(query_job.destination) 1146 num_rows = query_table.num_rows 1147 elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]: 1148 num_rows = query_job.num_dml_affected_rows 1149 1150 self._query_execution_tracker.record_execution( 1151 sql, num_rows, query_job.total_bytes_processed 1152 ) 1153 1154 def _get_data_objects( 1155 self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None 1156 ) -> t.List[DataObject]: 1157 """ 1158 Returns all the data objects that exist in the given schema and optionally catalog. 1159 """ 1160 1161 # The BigQuery Client's list_tables method does not support filtering by table name, so we have to 1162 # resort to using SQL instead. 1163 schema = to_schema(schema_name) 1164 catalog = schema.catalog or self.default_catalog 1165 query = ( 1166 exp.select( 1167 exp.column("table_catalog").as_("catalog"), 1168 exp.column("table_name").as_("name"), 1169 exp.column("table_schema").as_("schema_name"), 1170 exp.case() 1171 .when(exp.column("table_type").eq("BASE TABLE"), exp.Literal.string("TABLE")) 1172 .when(exp.column("table_type").eq("CLONE"), exp.Literal.string("TABLE")) 1173 .when(exp.column("table_type").eq("EXTERNAL"), exp.Literal.string("TABLE")) 1174 .when(exp.column("table_type").eq("SNAPSHOT"), exp.Literal.string("TABLE")) 1175 .when(exp.column("table_type").eq("VIEW"), exp.Literal.string("VIEW")) 1176 .when( 1177 exp.column("table_type").eq("MATERIALIZED VIEW"), 1178 exp.Literal.string("MATERIALIZED_VIEW"), 1179 ) 1180 .else_(exp.column("table_type")) 1181 .as_("type"), 1182 exp.column("clustering_key", "ci").as_("clustering_key"), 1183 ) 1184 .with_( 1185 "clustering_info", 1186 as_=exp.select( 1187 exp.column("table_catalog"), 1188 exp.column("table_schema"), 1189 exp.column("table_name"), 1190 parse_one( 1191 "string_agg(column_name order by clustering_ordinal_position)", 1192 dialect=self.dialect, 1193 ).as_("clustering_key"), 1194 ) 1195 .from_( 1196 exp.to_table( 1197 f"`{catalog}`.`{schema.db}`.INFORMATION_SCHEMA.COLUMNS", 1198 dialect=self.dialect, 1199 ) 1200 ) 1201 .where(exp.column("clustering_ordinal_position").is_(exp.not_(exp.null()))) 1202 .group_by("1", "2", "3"), 1203 ) 1204 .from_( 1205 exp.to_table( 1206 f"`{catalog}`.`{schema.db}`.INFORMATION_SCHEMA.TABLES", dialect=self.dialect 1207 ) 1208 ) 1209 .join( 1210 "clustering_info", 1211 using=["table_catalog", "table_schema", "table_name"], 1212 join_type="left", 1213 join_alias="ci", 1214 ) 1215 ) 1216 if object_names: 1217 query = query.where(exp.column("table_name").isin(*object_names)) 1218 1219 try: 1220 df = self.fetchdf(query, quote_identifiers=True) 1221 except Exception as e: 1222 if "Not found" in str(e): 1223 return [] 1224 raise 1225 1226 if df.empty: 1227 return [] 1228 return [ 1229 DataObject( 1230 catalog=row.catalog, # type: ignore 1231 schema=row.schema_name, # type: ignore 1232 name=row.name, # type: ignore 1233 type=DataObjectType.from_str(row.type), # type: ignore 1234 clustering_key=f"({row.clustering_key})" if row.clustering_key else None, # type: ignore 1235 ) 1236 for row in df.itertuples() 1237 ] 1238 1239 def _update_clustering_key(self, operation: TableAlterClusterByOperation) -> None: 1240 cluster_key_expressions = getattr(operation, "cluster_key_expressions", []) 1241 bq_table = self._get_table(operation.target_table) 1242 1243 rendered_columns = [c.sql(dialect=self.dialect) for c in cluster_key_expressions] 1244 bq_table.clustering_fields = ( 1245 rendered_columns or None 1246 ) # causes a drop of the key if cluster_by is empty or None 1247 1248 self._db_call(self.client.update_table, table=bq_table, fields=["clustering_fields"]) 1249 1250 if cluster_key_expressions: 1251 # BigQuery only applies new clustering going forward, so this rewrites the columns to apply the new clustering to historical data 1252 # ref: https://cloud.google.com/bigquery/docs/creating-clustered-tables#modifying-cluster-spec 1253 self.execute( 1254 exp.update( 1255 operation.target_table, 1256 {c: c for c in cluster_key_expressions}, 1257 where=exp.true(), 1258 ) 1259 ) 1260 1261 def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.Expression: 1262 return exp.func("FORMAT", exp.Literal.string(f"%.{precision}f"), col) 1263 1264 def _normalize_nested_value(self, col: exp.Expression) -> exp.Expression: 1265 return exp.func("TO_JSON_STRING", col, dialect=self.dialect) 1266 1267 @t.overload 1268 def _columns_to_types( 1269 self, 1270 query_or_df: DF, 1271 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 1272 source_columns: t.Optional[t.List[str]] = None, 1273 ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... 1274 1275 @t.overload 1276 def _columns_to_types( 1277 self, 1278 query_or_df: Query, 1279 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 1280 source_columns: t.Optional[t.List[str]] = None, 1281 ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... 1282 1283 def _columns_to_types( 1284 self, 1285 query_or_df: QueryOrDF, 1286 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 1287 source_columns: t.Optional[t.List[str]] = None, 1288 ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: 1289 if ( 1290 not target_columns_to_types 1291 and bigframes 1292 and isinstance(query_or_df, bigframes.dataframe.DataFrame) 1293 ): 1294 # using dry_run=True attempts to prevent the DataFrame from being materialized just to read the column types from it 1295 dtypes = query_or_df.to_pandas(dry_run=True).columnDtypes 1296 target_columns_to_types = columns_to_types_from_dtypes(dtypes.items()) 1297 return target_columns_to_types, list(source_columns or target_columns_to_types) 1298 1299 return super()._columns_to_types( 1300 query_or_df, target_columns_to_types, source_columns=source_columns 1301 ) 1302 1303 def _native_df_to_pandas_df( 1304 self, 1305 query_or_df: QueryOrDF, 1306 ) -> t.Union[Query, pd.DataFrame]: 1307 if bigframes and isinstance(query_or_df, bigframes.dataframe.DataFrame): 1308 return query_or_df.to_pandas() 1309 1310 return super()._native_df_to_pandas_df(query_or_df) 1311 1312 @property 1313 def _query_data(self) -> t.Any: 1314 return self._connection_pool.get_attribute("query_data") 1315 1316 @_query_data.setter 1317 def _query_data(self, value: t.Any) -> None: 1318 self._connection_pool.set_attribute("query_data", value) 1319 1320 @property 1321 def _query_job(self) -> t.Optional[QueryJob]: 1322 return self._connection_pool.get_attribute("query_job") 1323 1324 @_query_job.setter 1325 def _query_job(self, value: t.Any) -> None: 1326 self._connection_pool.set_attribute("query_job", value) 1327 1328 @property 1329 def _session_id(self) -> t.Any: 1330 return self._connection_pool.get_attribute("session_id") 1331 1332 @_session_id.setter 1333 def _session_id(self, value: t.Any) -> None: 1334 self._connection_pool.set_attribute("session_id", value) 1335 1336 def _get_current_schema(self) -> str: 1337 raise NotImplementedError("BigQuery does not support current schema") 1338 1339 def _get_bq_dataset_location(self, project: str, dataset: str) -> str: 1340 return self._db_call(self.client.get_dataset, dataset_ref=f"{project}.{dataset}").location 1341 1342 def _get_grant_expression(self, table: exp.Table) -> exp.Expression: 1343 if not table.db: 1344 raise ValueError( 1345 f"Table {table.sql(dialect=self.dialect)} does not have a schema (dataset)" 1346 ) 1347 project = table.catalog or self.get_current_catalog() 1348 if not project: 1349 raise ValueError( 1350 f"Table {table.sql(dialect=self.dialect)} does not have a catalog (project)" 1351 ) 1352 1353 dataset = table.db 1354 table_name = table.name 1355 location = self._get_bq_dataset_location(project, dataset) 1356 1357 # https://cloud.google.com/bigquery/docs/information-schema-object-privileges 1358 # OBJECT_PRIVILEGES is a project-level INFORMATION_SCHEMA view with regional qualifier 1359 object_privileges_table = exp.to_table( 1360 f"`{project}`.`region-{location}`.INFORMATION_SCHEMA.{self.GRANT_INFORMATION_SCHEMA_TABLE_NAME}", 1361 dialect=self.dialect, 1362 ) 1363 return ( 1364 exp.select("privilege_type", "grantee") 1365 .from_(object_privileges_table) 1366 .where( 1367 exp.and_( 1368 exp.column("object_schema").eq(exp.Literal.string(dataset)), 1369 exp.column("object_name").eq(exp.Literal.string(table_name)), 1370 # Filter out current_user 1371 # BigQuery grantees format: "user:email" or "group:name" 1372 exp.func("split", exp.column("grantee"), exp.Literal.string(":"))[ 1373 exp.func("OFFSET", exp.Literal.number("1")) 1374 ].neq(self.CURRENT_USER_OR_ROLE_EXPRESSION), 1375 ) 1376 ) 1377 ) 1378 1379 @staticmethod 1380 def _grant_object_kind(table_type: DataObjectType) -> str: 1381 if table_type == DataObjectType.VIEW: 1382 return "VIEW" 1383 if table_type == DataObjectType.MATERIALIZED_VIEW: 1384 # We actually need to use "MATERIALIZED VIEW" here even though it's not listed 1385 # as a supported resource_type in the BigQuery DCL doc: 1386 # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-control-language 1387 return "MATERIALIZED VIEW" 1388 return "TABLE" 1389 1390 def _dcl_grants_config_expr( 1391 self, 1392 dcl_cmd: t.Type[DCL], 1393 table: exp.Table, 1394 grants_config: GrantsConfig, 1395 table_type: DataObjectType = DataObjectType.TABLE, 1396 ) -> t.List[exp.Expression]: 1397 expressions: t.List[exp.Expression] = [] 1398 if not grants_config: 1399 return expressions 1400 1401 # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-control-language 1402 1403 def normalize_principal(p: str) -> str: 1404 if ":" not in p: 1405 raise ValueError(f"Principal '{p}' missing a prefix label") 1406 1407 # allUsers and allAuthenticatedUsers special groups that are cas-sensitive and must start with "specialGroup:" 1408 if p.endswith("allUsers") or p.endswith("allAuthenticatedUsers"): 1409 if not p.startswith("specialGroup:"): 1410 raise ValueError( 1411 f"Special group principal '{p}' must start with 'specialGroup:' prefix label" 1412 ) 1413 return p 1414 1415 label, principal = p.split(":", 1) 1416 # always lowercase principals 1417 return f"{label}:{principal.lower()}" 1418 1419 object_kind = self._grant_object_kind(table_type) 1420 for privilege, principals in grants_config.items(): 1421 if not principals: 1422 continue 1423 1424 noramlized_principals = [exp.Literal.string(normalize_principal(p)) for p in principals] 1425 args: t.Dict[str, t.Any] = { 1426 "privileges": [exp.GrantPrivilege(this=exp.to_identifier(privilege, quoted=True))], 1427 "securable": table.copy(), 1428 "principals": noramlized_principals, 1429 } 1430 1431 if object_kind: 1432 args["kind"] = exp.Var(this=object_kind) 1433 1434 expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] 1435 1436 return expressions
BigQuery Engine Adapter using the google-cloud-bigquery library's DB API.
123 @property 124 def bigframe(self) -> t.Optional[BigframeSession]: 125 if bigframes: 126 options = bigframes.BigQueryOptions( 127 credentials=self.client._credentials, 128 project=self.client.project, 129 location=self.client.location, 130 ) 131 return bigframes.connect(context=options) 132 return None
206 def close(self) -> t.Any: 207 # Cancel all pending query jobs across all threads 208 all_query_jobs = self._connection_pool.get_all_attributes("query_job") 209 for query_job in all_query_jobs: 210 if query_job: 211 try: 212 if not self._db_call(query_job.done): 213 self._db_call(query_job.cancel) 214 logger.debug( 215 "Cancelled BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", 216 query_job.project, 217 query_job.location, 218 query_job.job_id, 219 ) 220 except Exception as ex: 221 logger.debug( 222 "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", 223 query_job.project, 224 query_job.location, 225 query_job.job_id, 226 str(ex), 227 ) 228 229 return super().close()
Closes all open connections and releases all allocated resources.
279 def get_current_catalog(self) -> t.Optional[str]: 280 """Returns the catalog name of the current connection.""" 281 return self.client.project
Returns the catalog name of the current connection.
283 def set_current_catalog(self, catalog: str) -> None: 284 """Sets the catalog name of the current connection.""" 285 self.client.project = catalog
Sets the catalog name of the current connection.
287 def create_schema( 288 self, 289 schema_name: SchemaName, 290 ignore_if_exists: bool = True, 291 warn_on_error: bool = True, 292 properties: t.List[exp.Expression] = [], 293 ) -> None: 294 """Create a schema from a name or qualified table name.""" 295 from google.api_core.exceptions import Conflict 296 297 try: 298 super().create_schema( 299 schema_name, 300 ignore_if_exists=ignore_if_exists, 301 warn_on_error=False, 302 ) 303 except Exception as e: 304 is_already_exists_error = isinstance(e, Conflict) and "Already Exists:" in str(e) 305 if is_already_exists_error and ignore_if_exists: 306 return 307 if not warn_on_error: 308 raise 309 logger.warning("Failed to create schema '%s': %s", schema_name, e)
Create a schema from a name or qualified table name.
311 def get_bq_schema(self, table_name: TableName) -> t.List[bigquery.SchemaField]: 312 table = exp.to_table(table_name) 313 if len(table.parts) == 3 and "." in table.name: 314 self.execute(exp.select("*").from_(table).limit(0)) 315 query_job = self._query_job 316 assert query_job is not None 317 return query_job._query_results.schema 318 return self._get_table(table).schema
320 def columns( 321 self, table_name: TableName, include_pseudo_columns: bool = False 322 ) -> t.Dict[str, exp.DataType]: 323 """Fetches column names and types for the target table.""" 324 325 def dtype_to_sql( 326 dtype: t.Optional[StandardSqlDataType], field: bigquery.SchemaField 327 ) -> str: 328 assert dtype 329 assert field 330 331 kind = dtype.type_kind 332 assert kind 333 334 # Not using the enum value to preserve compatibility with older versions 335 # of the BigQuery library. 336 if kind.name == "ARRAY": 337 return f"ARRAY<{dtype_to_sql(dtype.array_element_type, field)}>" 338 if kind.name == "STRUCT": 339 struct_type = dtype.struct_type 340 assert struct_type 341 fields = ", ".join( 342 f"{struct_field.name} {dtype_to_sql(struct_field.type, nested_field)}" 343 for struct_field, nested_field in zip(struct_type.fields, field.fields) 344 ) 345 return f"STRUCT<{fields}>" 346 if kind.name == "TYPE_KIND_UNSPECIFIED": 347 field_type = field.field_type 348 349 if field_type == "RANGE": 350 # If the field is a RANGE then `range_element_type` should be set to 351 # one of `"DATE"`, `"DATETIME"` or `"TIMESTAMP"`. 352 return f"RANGE<{field.range_element_type.element_type}>" 353 354 return field_type 355 356 return kind.name 357 358 def create_mapping_schema( 359 schema: t.Sequence[bigquery.SchemaField], 360 ) -> t.Dict[str, exp.DataType]: 361 return { 362 field.name: exp.DataType.build( 363 dtype_to_sql(field.to_standard_sql().type, field), dialect=self.dialect 364 ) 365 for field in schema 366 } 367 368 table = exp.to_table(table_name) 369 if len(table.parts) == 3 and "." in table.name: 370 # The client's `get_table` method can't handle paths with >3 identifiers 371 self.execute(exp.select("*").from_(table).limit(0)) 372 query_job = self._query_job 373 assert query_job is not None 374 375 query_results = query_job._query_results 376 columns = create_mapping_schema(query_results.schema) 377 else: 378 bq_table = self._get_table(table) 379 columns = create_mapping_schema(bq_table.schema) 380 381 if include_pseudo_columns: 382 if bq_table.time_partitioning and not bq_table.time_partitioning.field: 383 columns["_PARTITIONTIME"] = exp.DataType.build("TIMESTAMP", dialect="bigquery") 384 if bq_table.time_partitioning.type_ == "DAY": 385 columns["_PARTITIONDATE"] = exp.DataType.build("DATE") 386 if bq_table.table_id.endswith("*"): 387 columns["_TABLE_SUFFIX"] = exp.DataType.build("STRING", dialect="bigquery") 388 if ( 389 bq_table.external_data_configuration is not None 390 and bq_table.external_data_configuration.source_format 391 in ( 392 "CSV", 393 "NEWLINE_DELIMITED_JSON", 394 "AVRO", 395 "PARQUET", 396 "ORC", 397 "DATASTORE_BACKUP", 398 ) 399 ): 400 columns["_FILE_NAME"] = exp.DataType.build("STRING", dialect="bigquery") 401 402 return columns
Fetches column names and types for the target table.
404 def alter_table( 405 self, 406 alter_expressions: t.Union[t.List[exp.Alter], t.List[TableAlterOperation]], 407 ) -> None: 408 """ 409 Performs the alter statements to change the current table into the structure of the target table, 410 and uses the API to add columns to structs, where SQL is not supported. 411 """ 412 if not alter_expressions: 413 return 414 415 cluster_by_operations, alter_statements = [], [] 416 for e in alter_expressions: 417 if isinstance(e, TableAlterClusterByOperation): 418 cluster_by_operations.append(e) 419 elif isinstance(e, TableAlterOperation): 420 alter_statements.append(e.expression) 421 else: 422 alter_statements.append(e) 423 424 for op in cluster_by_operations: 425 self._update_clustering_key(op) 426 427 nested_fields, non_nested_expressions = self._split_alter_expressions(alter_statements) 428 429 if nested_fields: 430 self._update_table_schema_nested_fields(nested_fields, alter_statements[0].this) 431 432 if non_nested_expressions: 433 super().alter_table(non_nested_expressions)
Performs the alter statements to change the current table into the structure of the target table, and uses the API to add columns to structs, where SQL is not supported.
435 def fetchone( 436 self, 437 query: t.Union[exp.Expression, str], 438 ignore_unsupported_errors: bool = False, 439 quote_identifiers: bool = False, 440 ) -> t.Optional[t.Tuple]: 441 """ 442 BigQuery's `fetchone` method doesn't call execute and therefore would not benefit from the execute 443 configuration we have in place. Therefore this implementation calls execute instead. 444 """ 445 self.execute( 446 query, 447 ignore_unsupported_errors=ignore_unsupported_errors, 448 quote_identifiers=quote_identifiers, 449 ) 450 try: 451 return next(self._query_data) 452 except StopIteration: 453 return None
BigQuery's fetchone method doesn't call execute and therefore would not benefit from the execute
configuration we have in place. Therefore this implementation calls execute instead.
455 def fetchall( 456 self, 457 query: t.Union[exp.Expression, str], 458 ignore_unsupported_errors: bool = False, 459 quote_identifiers: bool = False, 460 ) -> t.List[t.Tuple]: 461 """ 462 BigQuery's `fetchone` method doesn't call execute and therefore would not benefit from the execute 463 configuration we have in place. Therefore this implementation calls execute instead. 464 """ 465 self.execute( 466 query, 467 ignore_unsupported_errors=ignore_unsupported_errors, 468 quote_identifiers=quote_identifiers, 469 ) 470 return list(self._query_data)
BigQuery's fetchone method doesn't call execute and therefore would not benefit from the execute
configuration we have in place. Therefore this implementation calls execute instead.
689 def insert_overwrite_by_partition( 690 self, 691 table_name: TableName, 692 query_or_df: QueryOrDF, 693 partitioned_by: t.List[exp.Expression], 694 target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, 695 source_columns: t.Optional[t.List[str]] = None, 696 ) -> None: 697 if len(partitioned_by) != 1: 698 raise SQLMeshError( 699 f"Bigquery only supports partitioning by one column, {len(partitioned_by)} were provided." 700 ) 701 702 partition_exp = partitioned_by[0] 703 partition_column = partition_exp.find(exp.Column) 704 705 granularity = partition_exp.args.get("unit") 706 if granularity: 707 granularity = granularity.name.lower() 708 709 if not partition_column: 710 partition_sql = partition_exp.sql(dialect=self.dialect) 711 raise SQLMeshError( 712 f"The partition expression '{partition_sql}' doesn't contain a column." 713 ) 714 with ( 715 self.session({}), 716 self.temp_table( 717 query_or_df, 718 name=table_name, 719 partitioned_by=partitioned_by, 720 source_columns=source_columns, 721 ) as temp_table_name, 722 ): 723 if target_columns_to_types is None or target_columns_to_types[ 724 partition_column.name 725 ] == exp.DataType.build("unknown"): 726 target_columns_to_types = self.columns(table_name) 727 728 partition_type_sql = target_columns_to_types[partition_column.name].sql( 729 dialect=self.dialect 730 ) 731 732 select_array_agg_partitions = select_partitions_expr( 733 temp_table_name.db, 734 temp_table_name.name, 735 partition_type_sql, 736 granularity=granularity, 737 agg_func="ARRAY_AGG", 738 catalog=temp_table_name.catalog or self.default_catalog, 739 ) 740 741 self.execute( 742 f"DECLARE _sqlmesh_target_partitions_ ARRAY<{partition_type_sql}> DEFAULT ({select_array_agg_partitions});" 743 ) 744 745 where = t.cast(exp.Condition, partition_exp).isin(unnest="_sqlmesh_target_partitions_") 746 747 self._insert_overwrite_by_condition( 748 table_name, 749 [SourceQuery(query_factory=lambda: exp.select("*").from_(temp_table_name))], 750 target_columns_to_types, 751 where=where, 752 )
754 def table_exists(self, table_name: TableName) -> bool: 755 table = exp.to_table(table_name) 756 data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) 757 if data_object_cache_key in self._data_object_cache: 758 logger.debug("Table existence cache hit: %s", data_object_cache_key) 759 return self._data_object_cache[data_object_cache_key] is not None 760 761 try: 762 from google.cloud.exceptions import NotFound 763 except ModuleNotFoundError: 764 from google.api_core.exceptions import NotFound 765 766 try: 767 self._get_table(table_name) 768 return True 769 except NotFound: 770 return False
772 def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]: 773 from sqlmesh.utils.date import to_timestamp 774 775 datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list) 776 for table_name in table_names: 777 table = exp.to_table(table_name) 778 datasets_to_tables[table.db].append(table.name) 779 780 results = [] 781 782 for dataset, tables in datasets_to_tables.items(): 783 query = ( 784 f"SELECT TIMESTAMP_MILLIS(last_modified_time) FROM `{dataset}.__TABLES__` WHERE " 785 ) 786 for i, table_name in enumerate(tables): 787 query += f"TABLE_ID = '{table_name}'" 788 if i < len(tables) - 1: 789 query += " OR " 790 results.extend(self.fetchall(query)) 791 792 return [to_timestamp(row[0]) for row in results]
1070 def create_state_table( 1071 self, 1072 table_name: str, 1073 target_columns_to_types: t.Dict[str, exp.DataType], 1074 primary_key: t.Optional[t.Tuple[str, ...]] = None, 1075 ) -> None: 1076 self.create_table( 1077 table_name, 1078 target_columns_to_types, 1079 )
Create a table to store SQLMesh internal state.
Arguments:
- table_name: The name of the table to create. Can be fully qualified or just table name.
- target_columns_to_types: A mapping between the column name and its data type.
- primary_key: Determines the table primary key.
362 def get_alter_operations( 363 self, 364 current_table_name: TableName, 365 target_table_name: TableName, 366 *, 367 ignore_destructive: bool = False, 368 ignore_additive: bool = False, 369 ) -> t.List[TableAlterOperation]: 370 operations = super().get_alter_operations( 371 current_table_name, 372 target_table_name, 373 ignore_destructive=ignore_destructive, 374 ignore_additive=ignore_additive, 375 ) 376 377 # check for a change in clustering 378 current_table = exp.to_table(current_table_name) 379 target_table = exp.to_table(target_table_name) 380 381 current_table_schema = schema_(current_table.db, catalog=current_table.catalog) 382 target_table_schema = schema_(target_table.db, catalog=target_table.catalog) 383 384 current_table_info = seq_get( 385 self.get_data_objects(current_table_schema, {current_table.name}), 0 386 ) 387 target_table_info = seq_get( 388 self.get_data_objects(target_table_schema, {target_table.name}), 0 389 ) 390 391 if current_table_info and target_table_info: 392 if target_table_info.is_clustered: 393 if target_table_info.clustering_key and ( 394 current_table_info.clustering_key != target_table_info.clustering_key 395 ): 396 operations.append( 397 TableAlterChangeClusterKeyOperation( 398 target_table=current_table, 399 clustering_key=target_table_info.clustering_key, 400 dialect=self.dialect, 401 ) 402 ) 403 elif current_table_info.is_clustered: 404 operations.append(TableAlterDropClusterKeyOperation(target_table=current_table)) 405 406 return operations
Determines the alter statements needed to change the current table into the structure of the target table.
Inherited Members
- sqlmesh.core.engine_adapter.base.EngineAdapter
- EngineAdapter
- DATA_OBJECT_FILTER_BATCH_SIZE
- SUPPORTS_INDEXES
- COMMENT_CREATION_TABLE
- COMMENT_CREATION_VIEW
- SUPPORTS_MATERIALIZED_VIEW_SCHEMA
- SUPPORTS_VIEW_SCHEMA
- SUPPORTS_MANAGED_MODELS
- SUPPORTS_CREATE_DROP_CATALOG
- SUPPORTS_TUPLE_IN
- HAS_VIEW_BINDING
- SUPPORTS_REPLACE_TABLE
- DEFAULT_CATALOG_TYPE
- QUOTE_IDENTIFIERS_IN_VIEWS
- MAX_IDENTIFIER_LENGTH
- ATTACH_CORRELATION_ID
- SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS
- dialect
- correlation_id
- with_settings
- cursor
- connection
- spark
- snowpark
- comments_enabled
- schema_differ
- default_catalog
- engine_run_mode
- recycle
- get_catalog_type
- get_catalog_type_from_table
- current_catalog_type
- replace_query
- create_index
- create_table
- create_managed_table
- ctas
- create_table_like
- clone_table
- drop_data_object
- drop_table
- drop_managed_table
- create_view
- drop_schema
- drop_view
- create_catalog
- drop_catalog
- delete_from
- insert_append
- insert_overwrite_by_time_partition
- update_table
- scd_type_2_by_time
- scd_type_2_by_column
- merge
- rename_table
- get_data_object
- get_data_objects
- fetchdf
- fetch_pyspark_df
- wap_enabled
- wap_supported
- wap_table_name
- wap_prepare
- wap_publish
- sync_grants_config
- transaction
- session
- execute
- temp_table
- drop_data_object_on_type_mismatch
- ensure_nulls_for_unmatched_after_join
- use_server_nulls_for_unmatched_after_join
- ping
1484def select_partitions_expr( 1485 schema: str, 1486 table_name: str, 1487 data_type: t.Union[str, exp.DataType], 1488 granularity: t.Optional[str] = None, 1489 agg_func: str = "MAX", 1490 catalog: t.Optional[str] = None, 1491) -> str: 1492 """Generates a SQL expression that aggregates partition values for a table. 1493 1494 Args: 1495 schema: The schema (BigQuery dataset) of the table. 1496 table_name: The name of the table. 1497 data_type: The data type of the partition column. 1498 granularity: The granularity of the partition. Supported values are: 'day', 'month', 'year' and 'hour'. 1499 agg_func: The aggregation function to use. 1500 catalog: The catalog (BigQuery project ID) of the table. 1501 1502 Returns: 1503 A SELECT statement that aggregates partition values for a table. 1504 """ 1505 partitions_table_name = f"`{schema}`.INFORMATION_SCHEMA.PARTITIONS" 1506 if catalog: 1507 partitions_table_name = f"`{catalog}`.{partitions_table_name}" 1508 1509 if isinstance(data_type, exp.DataType): 1510 data_type = data_type.sql(dialect="bigquery") 1511 data_type = data_type.upper() 1512 1513 parse_fun = f"PARSE_{data_type}" if data_type in ("DATE", "DATETIME", "TIMESTAMP") else None 1514 if parse_fun: 1515 granularity = granularity or "day" 1516 parse_format = GRANULARITY_TO_PARTITION_FORMAT[granularity.lower()] 1517 partition_expr = exp.func( 1518 parse_fun, 1519 exp.Literal.string(parse_format), 1520 exp.column("partition_id"), 1521 dialect="bigquery", 1522 ) 1523 else: 1524 partition_expr = exp.cast(exp.column("partition_id"), "INT64", dialect="bigquery") 1525 1526 return ( 1527 exp.select(exp.func(agg_func, partition_expr)) 1528 .from_(partitions_table_name, dialect="bigquery") 1529 .where( 1530 f"table_name = '{table_name}' AND partition_id IS NOT NULL AND partition_id != '__NULL__'", 1531 copy=False, 1532 ) 1533 .sql(dialect="bigquery") 1534 )
Generates a SQL expression that aggregates partition values for a table.
Arguments:
- schema: The schema (BigQuery dataset) of the table.
- table_name: The name of the table.
- data_type: The data type of the partition column.
- granularity: The granularity of the partition. Supported values are: 'day', 'month', 'year' and 'hour'.
- agg_func: The aggregation function to use.
- catalog: The catalog (BigQuery project ID) of the table.
Returns:
A SELECT statement that aggregates partition values for a table.