sqlmesh.schedulers.airflow.client
1import abc 2import json 3import time 4import typing as t 5import uuid 6from urllib.parse import urlencode, urljoin 7 8import requests 9 10from sqlmesh.core.console import Console 11from sqlmesh.core.environment import Environment 12from sqlmesh.core.notification_target import NotificationTarget 13from sqlmesh.core.snapshot import Snapshot, SnapshotId 14from sqlmesh.core.snapshot.definition import Interval 15from sqlmesh.core.state_sync import Versions 16from sqlmesh.core.user import User 17from sqlmesh.schedulers.airflow import common 18from sqlmesh.utils import unique 19from sqlmesh.utils.date import TimeLike 20from sqlmesh.utils.errors import ( 21 ApiServerError, 22 NotFoundError, 23 SQLMeshError, 24 raise_for_status, 25) 26from sqlmesh.utils.pydantic import PydanticModel 27 28DAG_RUN_PATH_TEMPLATE = "api/v1/dags/{}/dagRuns" 29 30 31PLANS_PATH = f"{common.SQLMESH_API_BASE_PATH}/plans" 32ENVIRONMENTS_PATH = f"{common.SQLMESH_API_BASE_PATH}/environments" 33SNAPSHOTS_PATH = f"{common.SQLMESH_API_BASE_PATH}/snapshots" 34SEEDS_PATH = f"{common.SQLMESH_API_BASE_PATH}/seeds" 35INTERVALS_PATH = f"{common.SQLMESH_API_BASE_PATH}/intervals" 36MODELS_PATH = f"{common.SQLMESH_API_BASE_PATH}/models" 37VERSIONS_PATH = f"{common.SQLMESH_API_BASE_PATH}/versions" 38 39 40class BaseAirflowClient(abc.ABC): 41 def __init__(self, airflow_url: str, console: t.Optional[Console]): 42 self._airflow_url = airflow_url 43 self._console = console 44 45 @property 46 def default_catalog(self) -> str: 47 default_catalog = self.get_variable(common.DEFAULT_CATALOG_VARIABLE_NAME) 48 if not default_catalog: 49 raise SQLMeshError( 50 "Must define `default_catalog` when creating `SQLMeshAirflow` object. See docs for more info: https://sqlmesh.readthedocs.io/en/stable/integrations/airflow/#airflow-cluster-configuration" 51 ) 52 return default_catalog 53 54 def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None: 55 if not self._console: 56 return 57 58 tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id) 59 # TODO: Figure out generalized solution for links 60 self._console.log_status_update( 61 f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]" 62 ) 63 64 def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str: 65 url_params = urlencode( 66 dict( 67 dag_id=dag_id, 68 run_id=dag_run_id, 69 ) 70 ) 71 return urljoin(self._airflow_url, f"dagrun_details?{url_params}") 72 73 def wait_for_dag_run_completion( 74 self, dag_id: str, dag_run_id: str, poll_interval_secs: int 75 ) -> bool: 76 """Blocks until the given DAG Run completes. 77 78 Args: 79 dag_id: The DAG ID. 80 dag_run_id: The DAG Run ID. 81 poll_interval_secs: The number of seconds to wait between polling for the DAG Run state. 82 83 Returns: 84 True if the DAG Run completed successfully, False otherwise. 85 """ 86 loading_id = self._console_loading_start() 87 88 while True: 89 state = self.get_dag_run_state(dag_id, dag_run_id) 90 if state in ("failed", "success"): 91 if self._console and loading_id: 92 self._console.loading_stop(loading_id) 93 return state == "success" 94 95 time.sleep(poll_interval_secs) 96 97 def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str: 98 """Blocks until the first DAG Run for the given DAG ID is created. 99 100 Args: 101 dag_id: The DAG ID. 102 poll_interval_secs: The number of seconds to wait between polling for the DAG Run. 103 max_retries: The maximum number of retries. 104 105 Returns: 106 The ID of the first DAG Run for the given DAG ID. 107 """ 108 109 loading_id = self._console_loading_start() 110 111 attempt_num = 1 112 113 try: 114 while True: 115 try: 116 first_dag_run_id = self.get_first_dag_run_id(dag_id) 117 if first_dag_run_id is None: 118 raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'") 119 return first_dag_run_id 120 except ApiServerError: 121 raise 122 except SQLMeshError: 123 if attempt_num > max_retries: 124 raise 125 126 attempt_num += 1 127 time.sleep(poll_interval_secs) 128 finally: 129 if self._console and loading_id: 130 self._console.loading_stop(loading_id) 131 132 @abc.abstractmethod 133 def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: 134 """Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist. 135 136 Args: 137 dag_id: The DAG ID. 138 139 Returns: 140 The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist. 141 """ 142 143 @abc.abstractmethod 144 def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: 145 """Returns the state of the given DAG Run. 146 147 Args: 148 dag_id: The DAG ID. 149 dag_run_id: The DAG Run ID. 150 151 Returns: 152 The state of the given DAG Run. 153 """ 154 155 @abc.abstractmethod 156 def get_variable(self, key: str) -> t.Optional[str]: 157 """Returns the value of an Airflow variable with the given key. 158 159 Args: 160 key: The variable key. 161 162 Returns: 163 The variable value or None if no variable with the given key exists. 164 """ 165 166 def _console_loading_start(self) -> t.Optional[uuid.UUID]: 167 if self._console: 168 return self._console.loading_start() 169 return None 170 171 172class AirflowClient(BaseAirflowClient): 173 def __init__( 174 self, 175 session: requests.Session, 176 airflow_url: str, 177 console: t.Optional[Console] = None, 178 snapshot_ids_batch_size: t.Optional[int] = None, 179 ): 180 super().__init__(airflow_url, console) 181 self._session = session 182 self._snapshot_ids_batch_size = snapshot_ids_batch_size 183 184 def apply_plan( 185 self, 186 new_snapshots: t.Iterable[Snapshot], 187 environment: Environment, 188 request_id: str, 189 no_gaps: bool = False, 190 skip_backfill: bool = False, 191 restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, 192 notification_targets: t.Optional[t.List[NotificationTarget]] = None, 193 backfill_concurrent_tasks: int = 1, 194 ddl_concurrent_tasks: int = 1, 195 users: t.Optional[t.List[User]] = None, 196 is_dev: bool = False, 197 forward_only: bool = False, 198 models_to_backfill: t.Optional[t.Set[str]] = None, 199 end_bounded: bool = False, 200 ensure_finalized_snapshots: bool = False, 201 directly_modified_snapshots: t.Optional[t.List[SnapshotId]] = None, 202 indirectly_modified_snapshots: t.Optional[t.Dict[str, t.List[SnapshotId]]] = None, 203 removed_snapshots: t.Optional[t.List[SnapshotId]] = None, 204 execution_time: t.Optional[TimeLike] = None, 205 ) -> None: 206 request = common.PlanApplicationRequest( 207 new_snapshots=list(new_snapshots), 208 environment=environment, 209 no_gaps=no_gaps, 210 skip_backfill=skip_backfill, 211 request_id=request_id, 212 restatements={s.name: i for s, i in (restatements or {}).items()}, 213 notification_targets=notification_targets or [], 214 backfill_concurrent_tasks=backfill_concurrent_tasks, 215 ddl_concurrent_tasks=ddl_concurrent_tasks, 216 users=users or [], 217 is_dev=is_dev, 218 forward_only=forward_only, 219 models_to_backfill=models_to_backfill, 220 end_bounded=end_bounded, 221 ensure_finalized_snapshots=ensure_finalized_snapshots, 222 directly_modified_snapshots=directly_modified_snapshots or [], 223 indirectly_modified_snapshots=indirectly_modified_snapshots or {}, 224 removed_snapshots=removed_snapshots or [], 225 execution_time=execution_time, 226 ) 227 228 response = self._session.post( 229 urljoin(self._airflow_url, PLANS_PATH), 230 data=request.json(), 231 ) 232 raise_for_status(response) 233 234 def get_snapshots( 235 self, snapshot_ids: t.Optional[t.List[SnapshotId]], hydrate_seeds: bool = False 236 ) -> t.List[Snapshot]: 237 flags = ["hydrate_seeds"] if hydrate_seeds else [] 238 239 output = [] 240 241 if snapshot_ids is not None: 242 for ids_batch in _list_to_json( 243 unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size 244 ): 245 output.extend( 246 common.SnapshotsResponse.parse_obj( 247 self._get(SNAPSHOTS_PATH, *flags, ids=ids_batch) 248 ).snapshots 249 ) 250 return output 251 252 return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, *flags)).snapshots 253 254 def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]: 255 output = set() 256 for ids_batch in _list_to_json( 257 unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size 258 ): 259 output |= set( 260 common.SnapshotIdsResponse.parse_obj( 261 self._get(SNAPSHOTS_PATH, "check_existence", ids=ids_batch) 262 ).snapshot_ids 263 ) 264 265 return output 266 267 def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: 268 flags = ["exclude_external"] if exclude_external else [] 269 return set( 270 common.ExistingModelsResponse.parse_obj( 271 self._get(MODELS_PATH, *flags, names=",".join(names)) 272 ).names 273 ) 274 275 def get_environment(self, environment: str) -> t.Optional[Environment]: 276 try: 277 response = self._get(f"{ENVIRONMENTS_PATH}/{environment}") 278 return Environment.parse_obj(response) 279 except NotFoundError: 280 return None 281 282 def get_environments(self) -> t.List[Environment]: 283 response = self._get(ENVIRONMENTS_PATH) 284 return common.EnvironmentsResponse.parse_obj(response).environments 285 286 def max_interval_end_for_environment( 287 self, environment: str, ensure_finalized_snapshots: bool 288 ) -> t.Optional[int]: 289 flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else [] 290 response = self._get(f"{ENVIRONMENTS_PATH}/{environment}/max_interval_end", *flags) 291 return common.IntervalEndResponse.parse_obj(response).max_interval_end 292 293 def greatest_common_interval_end( 294 self, environment: str, models: t.Collection[str], ensure_finalized_snapshots: bool 295 ) -> t.Optional[int]: 296 flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else [] 297 response = self._get( 298 f"{ENVIRONMENTS_PATH}/{environment}/greatest_common_interval_end", 299 *flags, 300 models=_json_query_param(list(models)), 301 ) 302 return common.IntervalEndResponse.parse_obj(response).max_interval_end 303 304 def invalidate_environment(self, environment: str) -> None: 305 response = self._session.delete( 306 urljoin(self._airflow_url, f"{ENVIRONMENTS_PATH}/{environment}") 307 ) 308 raise_for_status(response) 309 310 def get_versions(self) -> Versions: 311 return Versions.parse_obj(self._get(VERSIONS_PATH)) 312 313 def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: 314 url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}" 315 return self._get(url)["state"].lower() 316 317 def get_janitor_dag(self) -> t.Dict[str, t.Any]: 318 return self._get_dag(common.JANITOR_DAG_ID) 319 320 def get_snapshot_dag(self, name: str, version: str) -> t.Dict[str, t.Any]: 321 return self._get_dag(common.dag_id_for_name_version(name, version)) 322 323 def get_all_dags(self) -> t.Dict[str, t.Any]: 324 return self._get("api/v1/dags") 325 326 def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: 327 dag_runs_response = self._get(f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}", limit="1") 328 dag_runs = dag_runs_response["dag_runs"] 329 if not dag_runs: 330 return None 331 return dag_runs[0]["dag_run_id"] 332 333 def get_variable(self, key: str) -> t.Optional[str]: 334 try: 335 variables_response = self._get(f"api/v1/variables/{key}") 336 return variables_response["value"] 337 except NotFoundError: 338 return None 339 340 def close(self) -> None: 341 self._session.close() 342 343 def _get_dag(self, dag_id: str) -> t.Dict[str, t.Any]: 344 return self._get(f"api/v1/dags/{dag_id}") 345 346 def _get(self, path: str, *flags: str, **params: str) -> t.Dict[str, t.Any]: 347 all_params = [*flags, *([urlencode(params)] if params else [])] 348 query_string = "&".join(all_params) 349 if query_string: 350 path = f"{path}?{query_string}" 351 response = self._session.get(urljoin(self._airflow_url, path)) 352 raise_for_status(response) 353 return response.json() 354 355 356T = t.TypeVar("T", bound=PydanticModel) 357 358 359def _list_to_json(models: t.Collection[T], batch_size: t.Optional[int] = None) -> t.List[str]: 360 serialized = [m.dict() for m in models] 361 if batch_size is not None: 362 batches = [serialized[i : i + batch_size] for i in range(0, len(serialized), batch_size)] 363 else: 364 batches = [serialized] 365 return [_json_query_param(batch) for batch in batches] 366 367 368def _json_query_param(value: t.Any) -> str: 369 return json.dumps(value, separators=(",", ":"))
41class BaseAirflowClient(abc.ABC): 42 def __init__(self, airflow_url: str, console: t.Optional[Console]): 43 self._airflow_url = airflow_url 44 self._console = console 45 46 @property 47 def default_catalog(self) -> str: 48 default_catalog = self.get_variable(common.DEFAULT_CATALOG_VARIABLE_NAME) 49 if not default_catalog: 50 raise SQLMeshError( 51 "Must define `default_catalog` when creating `SQLMeshAirflow` object. See docs for more info: https://sqlmesh.readthedocs.io/en/stable/integrations/airflow/#airflow-cluster-configuration" 52 ) 53 return default_catalog 54 55 def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None: 56 if not self._console: 57 return 58 59 tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id) 60 # TODO: Figure out generalized solution for links 61 self._console.log_status_update( 62 f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]" 63 ) 64 65 def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str: 66 url_params = urlencode( 67 dict( 68 dag_id=dag_id, 69 run_id=dag_run_id, 70 ) 71 ) 72 return urljoin(self._airflow_url, f"dagrun_details?{url_params}") 73 74 def wait_for_dag_run_completion( 75 self, dag_id: str, dag_run_id: str, poll_interval_secs: int 76 ) -> bool: 77 """Blocks until the given DAG Run completes. 78 79 Args: 80 dag_id: The DAG ID. 81 dag_run_id: The DAG Run ID. 82 poll_interval_secs: The number of seconds to wait between polling for the DAG Run state. 83 84 Returns: 85 True if the DAG Run completed successfully, False otherwise. 86 """ 87 loading_id = self._console_loading_start() 88 89 while True: 90 state = self.get_dag_run_state(dag_id, dag_run_id) 91 if state in ("failed", "success"): 92 if self._console and loading_id: 93 self._console.loading_stop(loading_id) 94 return state == "success" 95 96 time.sleep(poll_interval_secs) 97 98 def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str: 99 """Blocks until the first DAG Run for the given DAG ID is created. 100 101 Args: 102 dag_id: The DAG ID. 103 poll_interval_secs: The number of seconds to wait between polling for the DAG Run. 104 max_retries: The maximum number of retries. 105 106 Returns: 107 The ID of the first DAG Run for the given DAG ID. 108 """ 109 110 loading_id = self._console_loading_start() 111 112 attempt_num = 1 113 114 try: 115 while True: 116 try: 117 first_dag_run_id = self.get_first_dag_run_id(dag_id) 118 if first_dag_run_id is None: 119 raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'") 120 return first_dag_run_id 121 except ApiServerError: 122 raise 123 except SQLMeshError: 124 if attempt_num > max_retries: 125 raise 126 127 attempt_num += 1 128 time.sleep(poll_interval_secs) 129 finally: 130 if self._console and loading_id: 131 self._console.loading_stop(loading_id) 132 133 @abc.abstractmethod 134 def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: 135 """Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist. 136 137 Args: 138 dag_id: The DAG ID. 139 140 Returns: 141 The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist. 142 """ 143 144 @abc.abstractmethod 145 def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: 146 """Returns the state of the given DAG Run. 147 148 Args: 149 dag_id: The DAG ID. 150 dag_run_id: The DAG Run ID. 151 152 Returns: 153 The state of the given DAG Run. 154 """ 155 156 @abc.abstractmethod 157 def get_variable(self, key: str) -> t.Optional[str]: 158 """Returns the value of an Airflow variable with the given key. 159 160 Args: 161 key: The variable key. 162 163 Returns: 164 The variable value or None if no variable with the given key exists. 165 """ 166 167 def _console_loading_start(self) -> t.Optional[uuid.UUID]: 168 if self._console: 169 return self._console.loading_start() 170 return None
Helper class that provides a standard way to create an ABC using inheritance.
55 def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None: 56 if not self._console: 57 return 58 59 tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id) 60 # TODO: Figure out generalized solution for links 61 self._console.log_status_update( 62 f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]" 63 )
74 def wait_for_dag_run_completion( 75 self, dag_id: str, dag_run_id: str, poll_interval_secs: int 76 ) -> bool: 77 """Blocks until the given DAG Run completes. 78 79 Args: 80 dag_id: The DAG ID. 81 dag_run_id: The DAG Run ID. 82 poll_interval_secs: The number of seconds to wait between polling for the DAG Run state. 83 84 Returns: 85 True if the DAG Run completed successfully, False otherwise. 86 """ 87 loading_id = self._console_loading_start() 88 89 while True: 90 state = self.get_dag_run_state(dag_id, dag_run_id) 91 if state in ("failed", "success"): 92 if self._console and loading_id: 93 self._console.loading_stop(loading_id) 94 return state == "success" 95 96 time.sleep(poll_interval_secs)
Blocks until the given DAG Run completes.
Arguments:
- dag_id: The DAG ID.
- dag_run_id: The DAG Run ID.
- poll_interval_secs: The number of seconds to wait between polling for the DAG Run state.
Returns:
True if the DAG Run completed successfully, False otherwise.
98 def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str: 99 """Blocks until the first DAG Run for the given DAG ID is created. 100 101 Args: 102 dag_id: The DAG ID. 103 poll_interval_secs: The number of seconds to wait between polling for the DAG Run. 104 max_retries: The maximum number of retries. 105 106 Returns: 107 The ID of the first DAG Run for the given DAG ID. 108 """ 109 110 loading_id = self._console_loading_start() 111 112 attempt_num = 1 113 114 try: 115 while True: 116 try: 117 first_dag_run_id = self.get_first_dag_run_id(dag_id) 118 if first_dag_run_id is None: 119 raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'") 120 return first_dag_run_id 121 except ApiServerError: 122 raise 123 except SQLMeshError: 124 if attempt_num > max_retries: 125 raise 126 127 attempt_num += 1 128 time.sleep(poll_interval_secs) 129 finally: 130 if self._console and loading_id: 131 self._console.loading_stop(loading_id)
Blocks until the first DAG Run for the given DAG ID is created.
Arguments:
- dag_id: The DAG ID.
- poll_interval_secs: The number of seconds to wait between polling for the DAG Run.
- max_retries: The maximum number of retries.
Returns:
The ID of the first DAG Run for the given DAG ID.
133 @abc.abstractmethod 134 def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: 135 """Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist. 136 137 Args: 138 dag_id: The DAG ID. 139 140 Returns: 141 The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist. 142 """
Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
Arguments:
- dag_id: The DAG ID.
Returns:
The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
144 @abc.abstractmethod 145 def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: 146 """Returns the state of the given DAG Run. 147 148 Args: 149 dag_id: The DAG ID. 150 dag_run_id: The DAG Run ID. 151 152 Returns: 153 The state of the given DAG Run. 154 """
Returns the state of the given DAG Run.
Arguments:
- dag_id: The DAG ID.
- dag_run_id: The DAG Run ID.
Returns:
The state of the given DAG Run.
156 @abc.abstractmethod 157 def get_variable(self, key: str) -> t.Optional[str]: 158 """Returns the value of an Airflow variable with the given key. 159 160 Args: 161 key: The variable key. 162 163 Returns: 164 The variable value or None if no variable with the given key exists. 165 """
Returns the value of an Airflow variable with the given key.
Arguments:
- key: The variable key.
Returns:
The variable value or None if no variable with the given key exists.
173class AirflowClient(BaseAirflowClient): 174 def __init__( 175 self, 176 session: requests.Session, 177 airflow_url: str, 178 console: t.Optional[Console] = None, 179 snapshot_ids_batch_size: t.Optional[int] = None, 180 ): 181 super().__init__(airflow_url, console) 182 self._session = session 183 self._snapshot_ids_batch_size = snapshot_ids_batch_size 184 185 def apply_plan( 186 self, 187 new_snapshots: t.Iterable[Snapshot], 188 environment: Environment, 189 request_id: str, 190 no_gaps: bool = False, 191 skip_backfill: bool = False, 192 restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, 193 notification_targets: t.Optional[t.List[NotificationTarget]] = None, 194 backfill_concurrent_tasks: int = 1, 195 ddl_concurrent_tasks: int = 1, 196 users: t.Optional[t.List[User]] = None, 197 is_dev: bool = False, 198 forward_only: bool = False, 199 models_to_backfill: t.Optional[t.Set[str]] = None, 200 end_bounded: bool = False, 201 ensure_finalized_snapshots: bool = False, 202 directly_modified_snapshots: t.Optional[t.List[SnapshotId]] = None, 203 indirectly_modified_snapshots: t.Optional[t.Dict[str, t.List[SnapshotId]]] = None, 204 removed_snapshots: t.Optional[t.List[SnapshotId]] = None, 205 execution_time: t.Optional[TimeLike] = None, 206 ) -> None: 207 request = common.PlanApplicationRequest( 208 new_snapshots=list(new_snapshots), 209 environment=environment, 210 no_gaps=no_gaps, 211 skip_backfill=skip_backfill, 212 request_id=request_id, 213 restatements={s.name: i for s, i in (restatements or {}).items()}, 214 notification_targets=notification_targets or [], 215 backfill_concurrent_tasks=backfill_concurrent_tasks, 216 ddl_concurrent_tasks=ddl_concurrent_tasks, 217 users=users or [], 218 is_dev=is_dev, 219 forward_only=forward_only, 220 models_to_backfill=models_to_backfill, 221 end_bounded=end_bounded, 222 ensure_finalized_snapshots=ensure_finalized_snapshots, 223 directly_modified_snapshots=directly_modified_snapshots or [], 224 indirectly_modified_snapshots=indirectly_modified_snapshots or {}, 225 removed_snapshots=removed_snapshots or [], 226 execution_time=execution_time, 227 ) 228 229 response = self._session.post( 230 urljoin(self._airflow_url, PLANS_PATH), 231 data=request.json(), 232 ) 233 raise_for_status(response) 234 235 def get_snapshots( 236 self, snapshot_ids: t.Optional[t.List[SnapshotId]], hydrate_seeds: bool = False 237 ) -> t.List[Snapshot]: 238 flags = ["hydrate_seeds"] if hydrate_seeds else [] 239 240 output = [] 241 242 if snapshot_ids is not None: 243 for ids_batch in _list_to_json( 244 unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size 245 ): 246 output.extend( 247 common.SnapshotsResponse.parse_obj( 248 self._get(SNAPSHOTS_PATH, *flags, ids=ids_batch) 249 ).snapshots 250 ) 251 return output 252 253 return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, *flags)).snapshots 254 255 def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]: 256 output = set() 257 for ids_batch in _list_to_json( 258 unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size 259 ): 260 output |= set( 261 common.SnapshotIdsResponse.parse_obj( 262 self._get(SNAPSHOTS_PATH, "check_existence", ids=ids_batch) 263 ).snapshot_ids 264 ) 265 266 return output 267 268 def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: 269 flags = ["exclude_external"] if exclude_external else [] 270 return set( 271 common.ExistingModelsResponse.parse_obj( 272 self._get(MODELS_PATH, *flags, names=",".join(names)) 273 ).names 274 ) 275 276 def get_environment(self, environment: str) -> t.Optional[Environment]: 277 try: 278 response = self._get(f"{ENVIRONMENTS_PATH}/{environment}") 279 return Environment.parse_obj(response) 280 except NotFoundError: 281 return None 282 283 def get_environments(self) -> t.List[Environment]: 284 response = self._get(ENVIRONMENTS_PATH) 285 return common.EnvironmentsResponse.parse_obj(response).environments 286 287 def max_interval_end_for_environment( 288 self, environment: str, ensure_finalized_snapshots: bool 289 ) -> t.Optional[int]: 290 flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else [] 291 response = self._get(f"{ENVIRONMENTS_PATH}/{environment}/max_interval_end", *flags) 292 return common.IntervalEndResponse.parse_obj(response).max_interval_end 293 294 def greatest_common_interval_end( 295 self, environment: str, models: t.Collection[str], ensure_finalized_snapshots: bool 296 ) -> t.Optional[int]: 297 flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else [] 298 response = self._get( 299 f"{ENVIRONMENTS_PATH}/{environment}/greatest_common_interval_end", 300 *flags, 301 models=_json_query_param(list(models)), 302 ) 303 return common.IntervalEndResponse.parse_obj(response).max_interval_end 304 305 def invalidate_environment(self, environment: str) -> None: 306 response = self._session.delete( 307 urljoin(self._airflow_url, f"{ENVIRONMENTS_PATH}/{environment}") 308 ) 309 raise_for_status(response) 310 311 def get_versions(self) -> Versions: 312 return Versions.parse_obj(self._get(VERSIONS_PATH)) 313 314 def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: 315 url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}" 316 return self._get(url)["state"].lower() 317 318 def get_janitor_dag(self) -> t.Dict[str, t.Any]: 319 return self._get_dag(common.JANITOR_DAG_ID) 320 321 def get_snapshot_dag(self, name: str, version: str) -> t.Dict[str, t.Any]: 322 return self._get_dag(common.dag_id_for_name_version(name, version)) 323 324 def get_all_dags(self) -> t.Dict[str, t.Any]: 325 return self._get("api/v1/dags") 326 327 def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: 328 dag_runs_response = self._get(f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}", limit="1") 329 dag_runs = dag_runs_response["dag_runs"] 330 if not dag_runs: 331 return None 332 return dag_runs[0]["dag_run_id"] 333 334 def get_variable(self, key: str) -> t.Optional[str]: 335 try: 336 variables_response = self._get(f"api/v1/variables/{key}") 337 return variables_response["value"] 338 except NotFoundError: 339 return None 340 341 def close(self) -> None: 342 self._session.close() 343 344 def _get_dag(self, dag_id: str) -> t.Dict[str, t.Any]: 345 return self._get(f"api/v1/dags/{dag_id}") 346 347 def _get(self, path: str, *flags: str, **params: str) -> t.Dict[str, t.Any]: 348 all_params = [*flags, *([urlencode(params)] if params else [])] 349 query_string = "&".join(all_params) 350 if query_string: 351 path = f"{path}?{query_string}" 352 response = self._session.get(urljoin(self._airflow_url, path)) 353 raise_for_status(response) 354 return response.json()
Helper class that provides a standard way to create an ABC using inheritance.
174 def __init__( 175 self, 176 session: requests.Session, 177 airflow_url: str, 178 console: t.Optional[Console] = None, 179 snapshot_ids_batch_size: t.Optional[int] = None, 180 ): 181 super().__init__(airflow_url, console) 182 self._session = session 183 self._snapshot_ids_batch_size = snapshot_ids_batch_size
185 def apply_plan( 186 self, 187 new_snapshots: t.Iterable[Snapshot], 188 environment: Environment, 189 request_id: str, 190 no_gaps: bool = False, 191 skip_backfill: bool = False, 192 restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, 193 notification_targets: t.Optional[t.List[NotificationTarget]] = None, 194 backfill_concurrent_tasks: int = 1, 195 ddl_concurrent_tasks: int = 1, 196 users: t.Optional[t.List[User]] = None, 197 is_dev: bool = False, 198 forward_only: bool = False, 199 models_to_backfill: t.Optional[t.Set[str]] = None, 200 end_bounded: bool = False, 201 ensure_finalized_snapshots: bool = False, 202 directly_modified_snapshots: t.Optional[t.List[SnapshotId]] = None, 203 indirectly_modified_snapshots: t.Optional[t.Dict[str, t.List[SnapshotId]]] = None, 204 removed_snapshots: t.Optional[t.List[SnapshotId]] = None, 205 execution_time: t.Optional[TimeLike] = None, 206 ) -> None: 207 request = common.PlanApplicationRequest( 208 new_snapshots=list(new_snapshots), 209 environment=environment, 210 no_gaps=no_gaps, 211 skip_backfill=skip_backfill, 212 request_id=request_id, 213 restatements={s.name: i for s, i in (restatements or {}).items()}, 214 notification_targets=notification_targets or [], 215 backfill_concurrent_tasks=backfill_concurrent_tasks, 216 ddl_concurrent_tasks=ddl_concurrent_tasks, 217 users=users or [], 218 is_dev=is_dev, 219 forward_only=forward_only, 220 models_to_backfill=models_to_backfill, 221 end_bounded=end_bounded, 222 ensure_finalized_snapshots=ensure_finalized_snapshots, 223 directly_modified_snapshots=directly_modified_snapshots or [], 224 indirectly_modified_snapshots=indirectly_modified_snapshots or {}, 225 removed_snapshots=removed_snapshots or [], 226 execution_time=execution_time, 227 ) 228 229 response = self._session.post( 230 urljoin(self._airflow_url, PLANS_PATH), 231 data=request.json(), 232 ) 233 raise_for_status(response)
235 def get_snapshots( 236 self, snapshot_ids: t.Optional[t.List[SnapshotId]], hydrate_seeds: bool = False 237 ) -> t.List[Snapshot]: 238 flags = ["hydrate_seeds"] if hydrate_seeds else [] 239 240 output = [] 241 242 if snapshot_ids is not None: 243 for ids_batch in _list_to_json( 244 unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size 245 ): 246 output.extend( 247 common.SnapshotsResponse.parse_obj( 248 self._get(SNAPSHOTS_PATH, *flags, ids=ids_batch) 249 ).snapshots 250 ) 251 return output 252 253 return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, *flags)).snapshots
255 def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]: 256 output = set() 257 for ids_batch in _list_to_json( 258 unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size 259 ): 260 output |= set( 261 common.SnapshotIdsResponse.parse_obj( 262 self._get(SNAPSHOTS_PATH, "check_existence", ids=ids_batch) 263 ).snapshot_ids 264 ) 265 266 return output
287 def max_interval_end_for_environment( 288 self, environment: str, ensure_finalized_snapshots: bool 289 ) -> t.Optional[int]: 290 flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else [] 291 response = self._get(f"{ENVIRONMENTS_PATH}/{environment}/max_interval_end", *flags) 292 return common.IntervalEndResponse.parse_obj(response).max_interval_end
294 def greatest_common_interval_end( 295 self, environment: str, models: t.Collection[str], ensure_finalized_snapshots: bool 296 ) -> t.Optional[int]: 297 flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else [] 298 response = self._get( 299 f"{ENVIRONMENTS_PATH}/{environment}/greatest_common_interval_end", 300 *flags, 301 models=_json_query_param(list(models)), 302 ) 303 return common.IntervalEndResponse.parse_obj(response).max_interval_end
314 def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: 315 url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}" 316 return self._get(url)["state"].lower()
Returns the state of the given DAG Run.
Arguments:
- dag_id: The DAG ID.
- dag_run_id: The DAG Run ID.
Returns:
The state of the given DAG Run.
327 def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: 328 dag_runs_response = self._get(f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}", limit="1") 329 dag_runs = dag_runs_response["dag_runs"] 330 if not dag_runs: 331 return None 332 return dag_runs[0]["dag_run_id"]
Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
Arguments:
- dag_id: The DAG ID.
Returns:
The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
334 def get_variable(self, key: str) -> t.Optional[str]: 335 try: 336 variables_response = self._get(f"api/v1/variables/{key}") 337 return variables_response["value"] 338 except NotFoundError: 339 return None
Returns the value of an Airflow variable with the given key.
Arguments:
- key: The variable key.
Returns:
The variable value or None if no variable with the given key exists.