Edit on GitHub

sqlmesh.schedulers.airflow.client

  1import abc
  2import json
  3import time
  4import typing as t
  5import uuid
  6from urllib.parse import urlencode, urljoin
  7
  8import requests
  9
 10from sqlmesh.core.console import Console
 11from sqlmesh.core.environment import Environment
 12from sqlmesh.core.notification_target import NotificationTarget
 13from sqlmesh.core.snapshot import Snapshot, SnapshotId
 14from sqlmesh.core.snapshot.definition import Interval
 15from sqlmesh.core.state_sync import Versions
 16from sqlmesh.core.user import User
 17from sqlmesh.schedulers.airflow import common
 18from sqlmesh.utils import unique
 19from sqlmesh.utils.date import TimeLike
 20from sqlmesh.utils.errors import (
 21    ApiServerError,
 22    NotFoundError,
 23    SQLMeshError,
 24    raise_for_status,
 25)
 26from sqlmesh.utils.pydantic import PydanticModel
 27
 28DAG_RUN_PATH_TEMPLATE = "api/v1/dags/{}/dagRuns"
 29
 30
 31PLANS_PATH = f"{common.SQLMESH_API_BASE_PATH}/plans"
 32ENVIRONMENTS_PATH = f"{common.SQLMESH_API_BASE_PATH}/environments"
 33SNAPSHOTS_PATH = f"{common.SQLMESH_API_BASE_PATH}/snapshots"
 34SEEDS_PATH = f"{common.SQLMESH_API_BASE_PATH}/seeds"
 35INTERVALS_PATH = f"{common.SQLMESH_API_BASE_PATH}/intervals"
 36MODELS_PATH = f"{common.SQLMESH_API_BASE_PATH}/models"
 37VERSIONS_PATH = f"{common.SQLMESH_API_BASE_PATH}/versions"
 38
 39
 40class BaseAirflowClient(abc.ABC):
 41    def __init__(self, airflow_url: str, console: t.Optional[Console]):
 42        self._airflow_url = airflow_url
 43        self._console = console
 44
 45    @property
 46    def default_catalog(self) -> str:
 47        default_catalog = self.get_variable(common.DEFAULT_CATALOG_VARIABLE_NAME)
 48        if not default_catalog:
 49            raise SQLMeshError(
 50                "Must define `default_catalog` when creating `SQLMeshAirflow` object. See docs for more info: https://sqlmesh.readthedocs.io/en/stable/integrations/airflow/#airflow-cluster-configuration"
 51            )
 52        return default_catalog
 53
 54    def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None:
 55        if not self._console:
 56            return
 57
 58        tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id)
 59        # TODO: Figure out generalized solution for links
 60        self._console.log_status_update(
 61            f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]"
 62        )
 63
 64    def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str:
 65        url_params = urlencode(
 66            dict(
 67                dag_id=dag_id,
 68                run_id=dag_run_id,
 69            )
 70        )
 71        return urljoin(self._airflow_url, f"dagrun_details?{url_params}")
 72
 73    def wait_for_dag_run_completion(
 74        self, dag_id: str, dag_run_id: str, poll_interval_secs: int
 75    ) -> bool:
 76        """Blocks until the given DAG Run completes.
 77
 78        Args:
 79            dag_id: The DAG ID.
 80            dag_run_id: The DAG Run ID.
 81            poll_interval_secs: The number of seconds to wait between polling for the DAG Run state.
 82
 83        Returns:
 84            True if the DAG Run completed successfully, False otherwise.
 85        """
 86        loading_id = self._console_loading_start()
 87
 88        while True:
 89            state = self.get_dag_run_state(dag_id, dag_run_id)
 90            if state in ("failed", "success"):
 91                if self._console and loading_id:
 92                    self._console.loading_stop(loading_id)
 93                return state == "success"
 94
 95            time.sleep(poll_interval_secs)
 96
 97    def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str:
 98        """Blocks until the first DAG Run for the given DAG ID is created.
 99
100        Args:
101            dag_id: The DAG ID.
102            poll_interval_secs: The number of seconds to wait between polling for the DAG Run.
103            max_retries: The maximum number of retries.
104
105        Returns:
106            The ID of the first DAG Run for the given DAG ID.
107        """
108
109        loading_id = self._console_loading_start()
110
111        attempt_num = 1
112
113        try:
114            while True:
115                try:
116                    first_dag_run_id = self.get_first_dag_run_id(dag_id)
117                    if first_dag_run_id is None:
118                        raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'")
119                    return first_dag_run_id
120                except ApiServerError:
121                    raise
122                except SQLMeshError:
123                    if attempt_num > max_retries:
124                        raise
125
126                attempt_num += 1
127                time.sleep(poll_interval_secs)
128        finally:
129            if self._console and loading_id:
130                self._console.loading_stop(loading_id)
131
132    @abc.abstractmethod
133    def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
134        """Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
135
136        Args:
137            dag_id: The DAG ID.
138
139        Returns:
140            The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
141        """
142
143    @abc.abstractmethod
144    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
145        """Returns the state of the given DAG Run.
146
147        Args:
148            dag_id: The DAG ID.
149            dag_run_id: The DAG Run ID.
150
151        Returns:
152            The state of the given DAG Run.
153        """
154
155    @abc.abstractmethod
156    def get_variable(self, key: str) -> t.Optional[str]:
157        """Returns the value of an Airflow variable with the given key.
158
159        Args:
160            key: The variable key.
161
162        Returns:
163            The variable value or None if no variable with the given key exists.
164        """
165
166    def _console_loading_start(self) -> t.Optional[uuid.UUID]:
167        if self._console:
168            return self._console.loading_start()
169        return None
170
171
172class AirflowClient(BaseAirflowClient):
173    def __init__(
174        self,
175        session: requests.Session,
176        airflow_url: str,
177        console: t.Optional[Console] = None,
178        snapshot_ids_batch_size: t.Optional[int] = None,
179    ):
180        super().__init__(airflow_url, console)
181        self._session = session
182        self._snapshot_ids_batch_size = snapshot_ids_batch_size
183
184    def apply_plan(
185        self,
186        new_snapshots: t.Iterable[Snapshot],
187        environment: Environment,
188        request_id: str,
189        no_gaps: bool = False,
190        skip_backfill: bool = False,
191        restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None,
192        notification_targets: t.Optional[t.List[NotificationTarget]] = None,
193        backfill_concurrent_tasks: int = 1,
194        ddl_concurrent_tasks: int = 1,
195        users: t.Optional[t.List[User]] = None,
196        is_dev: bool = False,
197        forward_only: bool = False,
198        models_to_backfill: t.Optional[t.Set[str]] = None,
199        end_bounded: bool = False,
200        ensure_finalized_snapshots: bool = False,
201        directly_modified_snapshots: t.Optional[t.List[SnapshotId]] = None,
202        indirectly_modified_snapshots: t.Optional[t.Dict[str, t.List[SnapshotId]]] = None,
203        removed_snapshots: t.Optional[t.List[SnapshotId]] = None,
204        execution_time: t.Optional[TimeLike] = None,
205    ) -> None:
206        request = common.PlanApplicationRequest(
207            new_snapshots=list(new_snapshots),
208            environment=environment,
209            no_gaps=no_gaps,
210            skip_backfill=skip_backfill,
211            request_id=request_id,
212            restatements={s.name: i for s, i in (restatements or {}).items()},
213            notification_targets=notification_targets or [],
214            backfill_concurrent_tasks=backfill_concurrent_tasks,
215            ddl_concurrent_tasks=ddl_concurrent_tasks,
216            users=users or [],
217            is_dev=is_dev,
218            forward_only=forward_only,
219            models_to_backfill=models_to_backfill,
220            end_bounded=end_bounded,
221            ensure_finalized_snapshots=ensure_finalized_snapshots,
222            directly_modified_snapshots=directly_modified_snapshots or [],
223            indirectly_modified_snapshots=indirectly_modified_snapshots or {},
224            removed_snapshots=removed_snapshots or [],
225            execution_time=execution_time,
226        )
227
228        response = self._session.post(
229            urljoin(self._airflow_url, PLANS_PATH),
230            data=request.json(),
231        )
232        raise_for_status(response)
233
234    def get_snapshots(
235        self, snapshot_ids: t.Optional[t.List[SnapshotId]], hydrate_seeds: bool = False
236    ) -> t.List[Snapshot]:
237        flags = ["hydrate_seeds"] if hydrate_seeds else []
238
239        output = []
240
241        if snapshot_ids is not None:
242            for ids_batch in _list_to_json(
243                unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size
244            ):
245                output.extend(
246                    common.SnapshotsResponse.parse_obj(
247                        self._get(SNAPSHOTS_PATH, *flags, ids=ids_batch)
248                    ).snapshots
249                )
250            return output
251
252        return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, *flags)).snapshots
253
254    def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]:
255        output = set()
256        for ids_batch in _list_to_json(
257            unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size
258        ):
259            output |= set(
260                common.SnapshotIdsResponse.parse_obj(
261                    self._get(SNAPSHOTS_PATH, "check_existence", ids=ids_batch)
262                ).snapshot_ids
263            )
264
265        return output
266
267    def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]:
268        flags = ["exclude_external"] if exclude_external else []
269        return set(
270            common.ExistingModelsResponse.parse_obj(
271                self._get(MODELS_PATH, *flags, names=",".join(names))
272            ).names
273        )
274
275    def get_environment(self, environment: str) -> t.Optional[Environment]:
276        try:
277            response = self._get(f"{ENVIRONMENTS_PATH}/{environment}")
278            return Environment.parse_obj(response)
279        except NotFoundError:
280            return None
281
282    def get_environments(self) -> t.List[Environment]:
283        response = self._get(ENVIRONMENTS_PATH)
284        return common.EnvironmentsResponse.parse_obj(response).environments
285
286    def max_interval_end_for_environment(
287        self, environment: str, ensure_finalized_snapshots: bool
288    ) -> t.Optional[int]:
289        flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else []
290        response = self._get(f"{ENVIRONMENTS_PATH}/{environment}/max_interval_end", *flags)
291        return common.IntervalEndResponse.parse_obj(response).max_interval_end
292
293    def greatest_common_interval_end(
294        self, environment: str, models: t.Collection[str], ensure_finalized_snapshots: bool
295    ) -> t.Optional[int]:
296        flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else []
297        response = self._get(
298            f"{ENVIRONMENTS_PATH}/{environment}/greatest_common_interval_end",
299            *flags,
300            models=_json_query_param(list(models)),
301        )
302        return common.IntervalEndResponse.parse_obj(response).max_interval_end
303
304    def invalidate_environment(self, environment: str) -> None:
305        response = self._session.delete(
306            urljoin(self._airflow_url, f"{ENVIRONMENTS_PATH}/{environment}")
307        )
308        raise_for_status(response)
309
310    def get_versions(self) -> Versions:
311        return Versions.parse_obj(self._get(VERSIONS_PATH))
312
313    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
314        url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}"
315        return self._get(url)["state"].lower()
316
317    def get_janitor_dag(self) -> t.Dict[str, t.Any]:
318        return self._get_dag(common.JANITOR_DAG_ID)
319
320    def get_snapshot_dag(self, name: str, version: str) -> t.Dict[str, t.Any]:
321        return self._get_dag(common.dag_id_for_name_version(name, version))
322
323    def get_all_dags(self) -> t.Dict[str, t.Any]:
324        return self._get("api/v1/dags")
325
326    def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
327        dag_runs_response = self._get(f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}", limit="1")
328        dag_runs = dag_runs_response["dag_runs"]
329        if not dag_runs:
330            return None
331        return dag_runs[0]["dag_run_id"]
332
333    def get_variable(self, key: str) -> t.Optional[str]:
334        try:
335            variables_response = self._get(f"api/v1/variables/{key}")
336            return variables_response["value"]
337        except NotFoundError:
338            return None
339
340    def close(self) -> None:
341        self._session.close()
342
343    def _get_dag(self, dag_id: str) -> t.Dict[str, t.Any]:
344        return self._get(f"api/v1/dags/{dag_id}")
345
346    def _get(self, path: str, *flags: str, **params: str) -> t.Dict[str, t.Any]:
347        all_params = [*flags, *([urlencode(params)] if params else [])]
348        query_string = "&".join(all_params)
349        if query_string:
350            path = f"{path}?{query_string}"
351        response = self._session.get(urljoin(self._airflow_url, path))
352        raise_for_status(response)
353        return response.json()
354
355
356T = t.TypeVar("T", bound=PydanticModel)
357
358
359def _list_to_json(models: t.Collection[T], batch_size: t.Optional[int] = None) -> t.List[str]:
360    serialized = [m.dict() for m in models]
361    if batch_size is not None:
362        batches = [serialized[i : i + batch_size] for i in range(0, len(serialized), batch_size)]
363    else:
364        batches = [serialized]
365    return [_json_query_param(batch) for batch in batches]
366
367
368def _json_query_param(value: t.Any) -> str:
369    return json.dumps(value, separators=(",", ":"))
class BaseAirflowClient(abc.ABC):
 41class BaseAirflowClient(abc.ABC):
 42    def __init__(self, airflow_url: str, console: t.Optional[Console]):
 43        self._airflow_url = airflow_url
 44        self._console = console
 45
 46    @property
 47    def default_catalog(self) -> str:
 48        default_catalog = self.get_variable(common.DEFAULT_CATALOG_VARIABLE_NAME)
 49        if not default_catalog:
 50            raise SQLMeshError(
 51                "Must define `default_catalog` when creating `SQLMeshAirflow` object. See docs for more info: https://sqlmesh.readthedocs.io/en/stable/integrations/airflow/#airflow-cluster-configuration"
 52            )
 53        return default_catalog
 54
 55    def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None:
 56        if not self._console:
 57            return
 58
 59        tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id)
 60        # TODO: Figure out generalized solution for links
 61        self._console.log_status_update(
 62            f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]"
 63        )
 64
 65    def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str:
 66        url_params = urlencode(
 67            dict(
 68                dag_id=dag_id,
 69                run_id=dag_run_id,
 70            )
 71        )
 72        return urljoin(self._airflow_url, f"dagrun_details?{url_params}")
 73
 74    def wait_for_dag_run_completion(
 75        self, dag_id: str, dag_run_id: str, poll_interval_secs: int
 76    ) -> bool:
 77        """Blocks until the given DAG Run completes.
 78
 79        Args:
 80            dag_id: The DAG ID.
 81            dag_run_id: The DAG Run ID.
 82            poll_interval_secs: The number of seconds to wait between polling for the DAG Run state.
 83
 84        Returns:
 85            True if the DAG Run completed successfully, False otherwise.
 86        """
 87        loading_id = self._console_loading_start()
 88
 89        while True:
 90            state = self.get_dag_run_state(dag_id, dag_run_id)
 91            if state in ("failed", "success"):
 92                if self._console and loading_id:
 93                    self._console.loading_stop(loading_id)
 94                return state == "success"
 95
 96            time.sleep(poll_interval_secs)
 97
 98    def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str:
 99        """Blocks until the first DAG Run for the given DAG ID is created.
100
101        Args:
102            dag_id: The DAG ID.
103            poll_interval_secs: The number of seconds to wait between polling for the DAG Run.
104            max_retries: The maximum number of retries.
105
106        Returns:
107            The ID of the first DAG Run for the given DAG ID.
108        """
109
110        loading_id = self._console_loading_start()
111
112        attempt_num = 1
113
114        try:
115            while True:
116                try:
117                    first_dag_run_id = self.get_first_dag_run_id(dag_id)
118                    if first_dag_run_id is None:
119                        raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'")
120                    return first_dag_run_id
121                except ApiServerError:
122                    raise
123                except SQLMeshError:
124                    if attempt_num > max_retries:
125                        raise
126
127                attempt_num += 1
128                time.sleep(poll_interval_secs)
129        finally:
130            if self._console and loading_id:
131                self._console.loading_stop(loading_id)
132
133    @abc.abstractmethod
134    def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
135        """Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
136
137        Args:
138            dag_id: The DAG ID.
139
140        Returns:
141            The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
142        """
143
144    @abc.abstractmethod
145    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
146        """Returns the state of the given DAG Run.
147
148        Args:
149            dag_id: The DAG ID.
150            dag_run_id: The DAG Run ID.
151
152        Returns:
153            The state of the given DAG Run.
154        """
155
156    @abc.abstractmethod
157    def get_variable(self, key: str) -> t.Optional[str]:
158        """Returns the value of an Airflow variable with the given key.
159
160        Args:
161            key: The variable key.
162
163        Returns:
164            The variable value or None if no variable with the given key exists.
165        """
166
167    def _console_loading_start(self) -> t.Optional[uuid.UUID]:
168        if self._console:
169            return self._console.loading_start()
170        return None

Helper class that provides a standard way to create an ABC using inheritance.

def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None:
55    def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None:
56        if not self._console:
57            return
58
59        tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id)
60        # TODO: Figure out generalized solution for links
61        self._console.log_status_update(
62            f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]"
63        )
def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str:
65    def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str:
66        url_params = urlencode(
67            dict(
68                dag_id=dag_id,
69                run_id=dag_run_id,
70            )
71        )
72        return urljoin(self._airflow_url, f"dagrun_details?{url_params}")
def wait_for_dag_run_completion(self, dag_id: str, dag_run_id: str, poll_interval_secs: int) -> bool:
74    def wait_for_dag_run_completion(
75        self, dag_id: str, dag_run_id: str, poll_interval_secs: int
76    ) -> bool:
77        """Blocks until the given DAG Run completes.
78
79        Args:
80            dag_id: The DAG ID.
81            dag_run_id: The DAG Run ID.
82            poll_interval_secs: The number of seconds to wait between polling for the DAG Run state.
83
84        Returns:
85            True if the DAG Run completed successfully, False otherwise.
86        """
87        loading_id = self._console_loading_start()
88
89        while True:
90            state = self.get_dag_run_state(dag_id, dag_run_id)
91            if state in ("failed", "success"):
92                if self._console and loading_id:
93                    self._console.loading_stop(loading_id)
94                return state == "success"
95
96            time.sleep(poll_interval_secs)

Blocks until the given DAG Run completes.

Arguments:
  • dag_id: The DAG ID.
  • dag_run_id: The DAG Run ID.
  • poll_interval_secs: The number of seconds to wait between polling for the DAG Run state.
Returns:

True if the DAG Run completed successfully, False otherwise.

def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str:
 98    def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str:
 99        """Blocks until the first DAG Run for the given DAG ID is created.
100
101        Args:
102            dag_id: The DAG ID.
103            poll_interval_secs: The number of seconds to wait between polling for the DAG Run.
104            max_retries: The maximum number of retries.
105
106        Returns:
107            The ID of the first DAG Run for the given DAG ID.
108        """
109
110        loading_id = self._console_loading_start()
111
112        attempt_num = 1
113
114        try:
115            while True:
116                try:
117                    first_dag_run_id = self.get_first_dag_run_id(dag_id)
118                    if first_dag_run_id is None:
119                        raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'")
120                    return first_dag_run_id
121                except ApiServerError:
122                    raise
123                except SQLMeshError:
124                    if attempt_num > max_retries:
125                        raise
126
127                attempt_num += 1
128                time.sleep(poll_interval_secs)
129        finally:
130            if self._console and loading_id:
131                self._console.loading_stop(loading_id)

Blocks until the first DAG Run for the given DAG ID is created.

Arguments:
  • dag_id: The DAG ID.
  • poll_interval_secs: The number of seconds to wait between polling for the DAG Run.
  • max_retries: The maximum number of retries.
Returns:

The ID of the first DAG Run for the given DAG ID.

@abc.abstractmethod
def get_first_dag_run_id(self, dag_id: str) -> Union[str, NoneType]:
133    @abc.abstractmethod
134    def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
135        """Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
136
137        Args:
138            dag_id: The DAG ID.
139
140        Returns:
141            The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
142        """

Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.

Arguments:
  • dag_id: The DAG ID.
Returns:

The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.

@abc.abstractmethod
def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
144    @abc.abstractmethod
145    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
146        """Returns the state of the given DAG Run.
147
148        Args:
149            dag_id: The DAG ID.
150            dag_run_id: The DAG Run ID.
151
152        Returns:
153            The state of the given DAG Run.
154        """

Returns the state of the given DAG Run.

Arguments:
  • dag_id: The DAG ID.
  • dag_run_id: The DAG Run ID.
Returns:

The state of the given DAG Run.

@abc.abstractmethod
def get_variable(self, key: str) -> Union[str, NoneType]:
156    @abc.abstractmethod
157    def get_variable(self, key: str) -> t.Optional[str]:
158        """Returns the value of an Airflow variable with the given key.
159
160        Args:
161            key: The variable key.
162
163        Returns:
164            The variable value or None if no variable with the given key exists.
165        """

Returns the value of an Airflow variable with the given key.

Arguments:
  • key: The variable key.
Returns:

The variable value or None if no variable with the given key exists.

class AirflowClient(BaseAirflowClient):
173class AirflowClient(BaseAirflowClient):
174    def __init__(
175        self,
176        session: requests.Session,
177        airflow_url: str,
178        console: t.Optional[Console] = None,
179        snapshot_ids_batch_size: t.Optional[int] = None,
180    ):
181        super().__init__(airflow_url, console)
182        self._session = session
183        self._snapshot_ids_batch_size = snapshot_ids_batch_size
184
185    def apply_plan(
186        self,
187        new_snapshots: t.Iterable[Snapshot],
188        environment: Environment,
189        request_id: str,
190        no_gaps: bool = False,
191        skip_backfill: bool = False,
192        restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None,
193        notification_targets: t.Optional[t.List[NotificationTarget]] = None,
194        backfill_concurrent_tasks: int = 1,
195        ddl_concurrent_tasks: int = 1,
196        users: t.Optional[t.List[User]] = None,
197        is_dev: bool = False,
198        forward_only: bool = False,
199        models_to_backfill: t.Optional[t.Set[str]] = None,
200        end_bounded: bool = False,
201        ensure_finalized_snapshots: bool = False,
202        directly_modified_snapshots: t.Optional[t.List[SnapshotId]] = None,
203        indirectly_modified_snapshots: t.Optional[t.Dict[str, t.List[SnapshotId]]] = None,
204        removed_snapshots: t.Optional[t.List[SnapshotId]] = None,
205        execution_time: t.Optional[TimeLike] = None,
206    ) -> None:
207        request = common.PlanApplicationRequest(
208            new_snapshots=list(new_snapshots),
209            environment=environment,
210            no_gaps=no_gaps,
211            skip_backfill=skip_backfill,
212            request_id=request_id,
213            restatements={s.name: i for s, i in (restatements or {}).items()},
214            notification_targets=notification_targets or [],
215            backfill_concurrent_tasks=backfill_concurrent_tasks,
216            ddl_concurrent_tasks=ddl_concurrent_tasks,
217            users=users or [],
218            is_dev=is_dev,
219            forward_only=forward_only,
220            models_to_backfill=models_to_backfill,
221            end_bounded=end_bounded,
222            ensure_finalized_snapshots=ensure_finalized_snapshots,
223            directly_modified_snapshots=directly_modified_snapshots or [],
224            indirectly_modified_snapshots=indirectly_modified_snapshots or {},
225            removed_snapshots=removed_snapshots or [],
226            execution_time=execution_time,
227        )
228
229        response = self._session.post(
230            urljoin(self._airflow_url, PLANS_PATH),
231            data=request.json(),
232        )
233        raise_for_status(response)
234
235    def get_snapshots(
236        self, snapshot_ids: t.Optional[t.List[SnapshotId]], hydrate_seeds: bool = False
237    ) -> t.List[Snapshot]:
238        flags = ["hydrate_seeds"] if hydrate_seeds else []
239
240        output = []
241
242        if snapshot_ids is not None:
243            for ids_batch in _list_to_json(
244                unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size
245            ):
246                output.extend(
247                    common.SnapshotsResponse.parse_obj(
248                        self._get(SNAPSHOTS_PATH, *flags, ids=ids_batch)
249                    ).snapshots
250                )
251            return output
252
253        return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, *flags)).snapshots
254
255    def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]:
256        output = set()
257        for ids_batch in _list_to_json(
258            unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size
259        ):
260            output |= set(
261                common.SnapshotIdsResponse.parse_obj(
262                    self._get(SNAPSHOTS_PATH, "check_existence", ids=ids_batch)
263                ).snapshot_ids
264            )
265
266        return output
267
268    def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]:
269        flags = ["exclude_external"] if exclude_external else []
270        return set(
271            common.ExistingModelsResponse.parse_obj(
272                self._get(MODELS_PATH, *flags, names=",".join(names))
273            ).names
274        )
275
276    def get_environment(self, environment: str) -> t.Optional[Environment]:
277        try:
278            response = self._get(f"{ENVIRONMENTS_PATH}/{environment}")
279            return Environment.parse_obj(response)
280        except NotFoundError:
281            return None
282
283    def get_environments(self) -> t.List[Environment]:
284        response = self._get(ENVIRONMENTS_PATH)
285        return common.EnvironmentsResponse.parse_obj(response).environments
286
287    def max_interval_end_for_environment(
288        self, environment: str, ensure_finalized_snapshots: bool
289    ) -> t.Optional[int]:
290        flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else []
291        response = self._get(f"{ENVIRONMENTS_PATH}/{environment}/max_interval_end", *flags)
292        return common.IntervalEndResponse.parse_obj(response).max_interval_end
293
294    def greatest_common_interval_end(
295        self, environment: str, models: t.Collection[str], ensure_finalized_snapshots: bool
296    ) -> t.Optional[int]:
297        flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else []
298        response = self._get(
299            f"{ENVIRONMENTS_PATH}/{environment}/greatest_common_interval_end",
300            *flags,
301            models=_json_query_param(list(models)),
302        )
303        return common.IntervalEndResponse.parse_obj(response).max_interval_end
304
305    def invalidate_environment(self, environment: str) -> None:
306        response = self._session.delete(
307            urljoin(self._airflow_url, f"{ENVIRONMENTS_PATH}/{environment}")
308        )
309        raise_for_status(response)
310
311    def get_versions(self) -> Versions:
312        return Versions.parse_obj(self._get(VERSIONS_PATH))
313
314    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
315        url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}"
316        return self._get(url)["state"].lower()
317
318    def get_janitor_dag(self) -> t.Dict[str, t.Any]:
319        return self._get_dag(common.JANITOR_DAG_ID)
320
321    def get_snapshot_dag(self, name: str, version: str) -> t.Dict[str, t.Any]:
322        return self._get_dag(common.dag_id_for_name_version(name, version))
323
324    def get_all_dags(self) -> t.Dict[str, t.Any]:
325        return self._get("api/v1/dags")
326
327    def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
328        dag_runs_response = self._get(f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}", limit="1")
329        dag_runs = dag_runs_response["dag_runs"]
330        if not dag_runs:
331            return None
332        return dag_runs[0]["dag_run_id"]
333
334    def get_variable(self, key: str) -> t.Optional[str]:
335        try:
336            variables_response = self._get(f"api/v1/variables/{key}")
337            return variables_response["value"]
338        except NotFoundError:
339            return None
340
341    def close(self) -> None:
342        self._session.close()
343
344    def _get_dag(self, dag_id: str) -> t.Dict[str, t.Any]:
345        return self._get(f"api/v1/dags/{dag_id}")
346
347    def _get(self, path: str, *flags: str, **params: str) -> t.Dict[str, t.Any]:
348        all_params = [*flags, *([urlencode(params)] if params else [])]
349        query_string = "&".join(all_params)
350        if query_string:
351            path = f"{path}?{query_string}"
352        response = self._session.get(urljoin(self._airflow_url, path))
353        raise_for_status(response)
354        return response.json()

Helper class that provides a standard way to create an ABC using inheritance.

AirflowClient( session: requests.sessions.Session, airflow_url: str, console: Union[sqlmesh.core.console.Console, NoneType] = None, snapshot_ids_batch_size: Union[int, NoneType] = None)
174    def __init__(
175        self,
176        session: requests.Session,
177        airflow_url: str,
178        console: t.Optional[Console] = None,
179        snapshot_ids_batch_size: t.Optional[int] = None,
180    ):
181        super().__init__(airflow_url, console)
182        self._session = session
183        self._snapshot_ids_batch_size = snapshot_ids_batch_size
def apply_plan( self, new_snapshots: Iterable[sqlmesh.core.snapshot.definition.Snapshot], environment: sqlmesh.core.environment.Environment, request_id: str, no_gaps: bool = False, skip_backfill: bool = False, restatements: Union[Dict[sqlmesh.core.snapshot.definition.SnapshotId, Tuple[int, int]], NoneType] = None, notification_targets: Union[List[typing_extensions.Annotated[Union[sqlmesh.core.notification_target.BasicSMTPNotificationTarget, sqlmesh.core.notification_target.ConsoleNotificationTarget, sqlmesh.core.notification_target.SlackApiNotificationTarget, sqlmesh.core.notification_target.SlackWebhookNotificationTarget], FieldInfo(annotation=NoneType, required=True, discriminator='type_')]], NoneType] = None, backfill_concurrent_tasks: int = 1, ddl_concurrent_tasks: int = 1, users: Union[List[sqlmesh.core.user.User], NoneType] = None, is_dev: bool = False, forward_only: bool = False, models_to_backfill: Union[Set[str], NoneType] = None, end_bounded: bool = False, ensure_finalized_snapshots: bool = False, directly_modified_snapshots: Union[List[sqlmesh.core.snapshot.definition.SnapshotId], NoneType] = None, indirectly_modified_snapshots: Union[Dict[str, List[sqlmesh.core.snapshot.definition.SnapshotId]], NoneType] = None, removed_snapshots: Union[List[sqlmesh.core.snapshot.definition.SnapshotId], NoneType] = None, execution_time: Union[datetime.date, datetime.datetime, str, int, float, NoneType] = None) -> None:
185    def apply_plan(
186        self,
187        new_snapshots: t.Iterable[Snapshot],
188        environment: Environment,
189        request_id: str,
190        no_gaps: bool = False,
191        skip_backfill: bool = False,
192        restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None,
193        notification_targets: t.Optional[t.List[NotificationTarget]] = None,
194        backfill_concurrent_tasks: int = 1,
195        ddl_concurrent_tasks: int = 1,
196        users: t.Optional[t.List[User]] = None,
197        is_dev: bool = False,
198        forward_only: bool = False,
199        models_to_backfill: t.Optional[t.Set[str]] = None,
200        end_bounded: bool = False,
201        ensure_finalized_snapshots: bool = False,
202        directly_modified_snapshots: t.Optional[t.List[SnapshotId]] = None,
203        indirectly_modified_snapshots: t.Optional[t.Dict[str, t.List[SnapshotId]]] = None,
204        removed_snapshots: t.Optional[t.List[SnapshotId]] = None,
205        execution_time: t.Optional[TimeLike] = None,
206    ) -> None:
207        request = common.PlanApplicationRequest(
208            new_snapshots=list(new_snapshots),
209            environment=environment,
210            no_gaps=no_gaps,
211            skip_backfill=skip_backfill,
212            request_id=request_id,
213            restatements={s.name: i for s, i in (restatements or {}).items()},
214            notification_targets=notification_targets or [],
215            backfill_concurrent_tasks=backfill_concurrent_tasks,
216            ddl_concurrent_tasks=ddl_concurrent_tasks,
217            users=users or [],
218            is_dev=is_dev,
219            forward_only=forward_only,
220            models_to_backfill=models_to_backfill,
221            end_bounded=end_bounded,
222            ensure_finalized_snapshots=ensure_finalized_snapshots,
223            directly_modified_snapshots=directly_modified_snapshots or [],
224            indirectly_modified_snapshots=indirectly_modified_snapshots or {},
225            removed_snapshots=removed_snapshots or [],
226            execution_time=execution_time,
227        )
228
229        response = self._session.post(
230            urljoin(self._airflow_url, PLANS_PATH),
231            data=request.json(),
232        )
233        raise_for_status(response)
def get_snapshots( self, snapshot_ids: Union[List[sqlmesh.core.snapshot.definition.SnapshotId], NoneType], hydrate_seeds: bool = False) -> List[sqlmesh.core.snapshot.definition.Snapshot]:
235    def get_snapshots(
236        self, snapshot_ids: t.Optional[t.List[SnapshotId]], hydrate_seeds: bool = False
237    ) -> t.List[Snapshot]:
238        flags = ["hydrate_seeds"] if hydrate_seeds else []
239
240        output = []
241
242        if snapshot_ids is not None:
243            for ids_batch in _list_to_json(
244                unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size
245            ):
246                output.extend(
247                    common.SnapshotsResponse.parse_obj(
248                        self._get(SNAPSHOTS_PATH, *flags, ids=ids_batch)
249                    ).snapshots
250                )
251            return output
252
253        return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, *flags)).snapshots
def snapshots_exist( self, snapshot_ids: List[sqlmesh.core.snapshot.definition.SnapshotId]) -> Set[sqlmesh.core.snapshot.definition.SnapshotId]:
255    def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]:
256        output = set()
257        for ids_batch in _list_to_json(
258            unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size
259        ):
260            output |= set(
261                common.SnapshotIdsResponse.parse_obj(
262                    self._get(SNAPSHOTS_PATH, "check_existence", ids=ids_batch)
263                ).snapshot_ids
264            )
265
266        return output
def nodes_exist(self, names: Iterable[str], exclude_external: bool = False) -> Set[str]:
268    def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]:
269        flags = ["exclude_external"] if exclude_external else []
270        return set(
271            common.ExistingModelsResponse.parse_obj(
272                self._get(MODELS_PATH, *flags, names=",".join(names))
273            ).names
274        )
def get_environment( self, environment: str) -> Union[sqlmesh.core.environment.Environment, NoneType]:
276    def get_environment(self, environment: str) -> t.Optional[Environment]:
277        try:
278            response = self._get(f"{ENVIRONMENTS_PATH}/{environment}")
279            return Environment.parse_obj(response)
280        except NotFoundError:
281            return None
def get_environments(self) -> List[sqlmesh.core.environment.Environment]:
283    def get_environments(self) -> t.List[Environment]:
284        response = self._get(ENVIRONMENTS_PATH)
285        return common.EnvironmentsResponse.parse_obj(response).environments
def max_interval_end_for_environment( self, environment: str, ensure_finalized_snapshots: bool) -> Union[int, NoneType]:
287    def max_interval_end_for_environment(
288        self, environment: str, ensure_finalized_snapshots: bool
289    ) -> t.Optional[int]:
290        flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else []
291        response = self._get(f"{ENVIRONMENTS_PATH}/{environment}/max_interval_end", *flags)
292        return common.IntervalEndResponse.parse_obj(response).max_interval_end
def greatest_common_interval_end( self, environment: str, models: Collection[str], ensure_finalized_snapshots: bool) -> Union[int, NoneType]:
294    def greatest_common_interval_end(
295        self, environment: str, models: t.Collection[str], ensure_finalized_snapshots: bool
296    ) -> t.Optional[int]:
297        flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else []
298        response = self._get(
299            f"{ENVIRONMENTS_PATH}/{environment}/greatest_common_interval_end",
300            *flags,
301            models=_json_query_param(list(models)),
302        )
303        return common.IntervalEndResponse.parse_obj(response).max_interval_end
def invalidate_environment(self, environment: str) -> None:
305    def invalidate_environment(self, environment: str) -> None:
306        response = self._session.delete(
307            urljoin(self._airflow_url, f"{ENVIRONMENTS_PATH}/{environment}")
308        )
309        raise_for_status(response)
def get_versions(self) -> sqlmesh.core.state_sync.base.Versions:
311    def get_versions(self) -> Versions:
312        return Versions.parse_obj(self._get(VERSIONS_PATH))
def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
314    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
315        url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}"
316        return self._get(url)["state"].lower()

Returns the state of the given DAG Run.

Arguments:
  • dag_id: The DAG ID.
  • dag_run_id: The DAG Run ID.
Returns:

The state of the given DAG Run.

def get_janitor_dag(self) -> Dict[str, Any]:
318    def get_janitor_dag(self) -> t.Dict[str, t.Any]:
319        return self._get_dag(common.JANITOR_DAG_ID)
def get_snapshot_dag(self, name: str, version: str) -> Dict[str, Any]:
321    def get_snapshot_dag(self, name: str, version: str) -> t.Dict[str, t.Any]:
322        return self._get_dag(common.dag_id_for_name_version(name, version))
def get_all_dags(self) -> Dict[str, Any]:
324    def get_all_dags(self) -> t.Dict[str, t.Any]:
325        return self._get("api/v1/dags")
def get_first_dag_run_id(self, dag_id: str) -> Union[str, NoneType]:
327    def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
328        dag_runs_response = self._get(f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}", limit="1")
329        dag_runs = dag_runs_response["dag_runs"]
330        if not dag_runs:
331            return None
332        return dag_runs[0]["dag_run_id"]

Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.

Arguments:
  • dag_id: The DAG ID.
Returns:

The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.

def get_variable(self, key: str) -> Union[str, NoneType]:
334    def get_variable(self, key: str) -> t.Optional[str]:
335        try:
336            variables_response = self._get(f"api/v1/variables/{key}")
337            return variables_response["value"]
338        except NotFoundError:
339            return None

Returns the value of an Airflow variable with the given key.

Arguments:
  • key: The variable key.
Returns:

The variable value or None if no variable with the given key exists.

def close(self) -> None:
341    def close(self) -> None:
342        self._session.close()