Edit on GitHub

sqlmesh.schedulers.airflow.mwaa_client

 1from __future__ import annotations
 2
 3import base64
 4import json
 5import logging
 6import typing as t
 7from urllib.parse import urljoin
 8
 9from requests import Session
10
11from sqlmesh.core.console import Console
12from sqlmesh.schedulers.airflow.client import BaseAirflowClient, raise_for_status
13from sqlmesh.utils.date import now_timestamp
14from sqlmesh.utils.errors import NotFoundError
15
16logger = logging.getLogger(__name__)
17
18
19TOKEN_TTL_MS = 30 * 1000
20
21
22class MWAAClient(BaseAirflowClient):
23    def __init__(self, environment: str, console: t.Optional[Console] = None):
24        airflow_url, auth_token = url_and_auth_token_for_environment(environment)
25        super().__init__(airflow_url, console)
26
27        self._environment = environment
28        self._last_token_refresh_ts = now_timestamp()
29        self.__session: Session = _create_session(auth_token)
30
31    def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
32        dag_runs = self._list_dag_runs(dag_id)
33        if dag_runs:
34            return dag_runs[-1]["run_id"]
35        return None
36
37    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
38        dag_runs = self._list_dag_runs(dag_id) or []
39        for dag_run in dag_runs:
40            if dag_run["run_id"] == dag_run_id:
41                return dag_run["state"].lower()
42        raise NotFoundError(f"DAG run '{dag_run_id}' was not found for DAG '{dag_id}'")
43
44    def get_variable(self, key: str) -> t.Optional[str]:
45        stdout, stderr = self._post(f"variables get {key}")
46        if "does not exist" in stderr:
47            return None
48        return stdout
49
50    def _list_dag_runs(self, dag_id: str) -> t.Optional[t.List[t.Dict[str, t.Any]]]:
51        stdout, stderr = self._post(f"dags list-runs -o json -d {dag_id}")
52        if stdout:
53            return json.loads(stdout)
54        return None
55
56    def _post(self, data: str) -> t.Tuple[str, str]:
57        response = self._session.post(urljoin(self._airflow_url, "aws_mwaa/cli"), data=data)
58        raise_for_status(response)
59        response_body = response.json()
60
61        cli_stdout = base64.b64decode(response_body["stdout"]).decode("utf8").strip()
62        cli_stderr = base64.b64decode(response_body["stderr"]).decode("utf8").strip()
63        return cli_stdout, cli_stderr
64
65    @property
66    def _session(self) -> Session:
67        current_ts = now_timestamp()
68        if current_ts - self._last_token_refresh_ts > TOKEN_TTL_MS:
69            _, auth_token = url_and_auth_token_for_environment(self._environment)
70            self.__session = _create_session(auth_token)
71            self._last_token_refresh_ts = current_ts
72        return self.__session
73
74
75def _create_session(auth_token: str) -> Session:
76    session = Session()
77    session.headers.update({"Authorization": f"Bearer {auth_token}", "Content-Type": "text/plain"})
78    return session
79
80
81def url_and_auth_token_for_environment(environment_name: str) -> t.Tuple[str, str]:
82    import boto3
83
84    logger.info("Fetching the MWAA CLI token")
85
86    client = boto3.client("mwaa")
87    cli_token = client.create_cli_token(Name=environment_name)
88
89    url = f"https://{cli_token['WebServerHostname']}/"
90    auth_token = cli_token["CliToken"]
91    return url, auth_token
23class MWAAClient(BaseAirflowClient):
24    def __init__(self, environment: str, console: t.Optional[Console] = None):
25        airflow_url, auth_token = url_and_auth_token_for_environment(environment)
26        super().__init__(airflow_url, console)
27
28        self._environment = environment
29        self._last_token_refresh_ts = now_timestamp()
30        self.__session: Session = _create_session(auth_token)
31
32    def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
33        dag_runs = self._list_dag_runs(dag_id)
34        if dag_runs:
35            return dag_runs[-1]["run_id"]
36        return None
37
38    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
39        dag_runs = self._list_dag_runs(dag_id) or []
40        for dag_run in dag_runs:
41            if dag_run["run_id"] == dag_run_id:
42                return dag_run["state"].lower()
43        raise NotFoundError(f"DAG run '{dag_run_id}' was not found for DAG '{dag_id}'")
44
45    def get_variable(self, key: str) -> t.Optional[str]:
46        stdout, stderr = self._post(f"variables get {key}")
47        if "does not exist" in stderr:
48            return None
49        return stdout
50
51    def _list_dag_runs(self, dag_id: str) -> t.Optional[t.List[t.Dict[str, t.Any]]]:
52        stdout, stderr = self._post(f"dags list-runs -o json -d {dag_id}")
53        if stdout:
54            return json.loads(stdout)
55        return None
56
57    def _post(self, data: str) -> t.Tuple[str, str]:
58        response = self._session.post(urljoin(self._airflow_url, "aws_mwaa/cli"), data=data)
59        raise_for_status(response)
60        response_body = response.json()
61
62        cli_stdout = base64.b64decode(response_body["stdout"]).decode("utf8").strip()
63        cli_stderr = base64.b64decode(response_body["stderr"]).decode("utf8").strip()
64        return cli_stdout, cli_stderr
65
66    @property
67    def _session(self) -> Session:
68        current_ts = now_timestamp()
69        if current_ts - self._last_token_refresh_ts > TOKEN_TTL_MS:
70            _, auth_token = url_and_auth_token_for_environment(self._environment)
71            self.__session = _create_session(auth_token)
72            self._last_token_refresh_ts = current_ts
73        return self.__session

Helper class that provides a standard way to create an ABC using inheritance.

MWAAClient( environment: str, console: Union[sqlmesh.core.console.Console, NoneType] = None)
24    def __init__(self, environment: str, console: t.Optional[Console] = None):
25        airflow_url, auth_token = url_and_auth_token_for_environment(environment)
26        super().__init__(airflow_url, console)
27
28        self._environment = environment
29        self._last_token_refresh_ts = now_timestamp()
30        self.__session: Session = _create_session(auth_token)
def get_first_dag_run_id(self, dag_id: str) -> Union[str, NoneType]:
32    def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
33        dag_runs = self._list_dag_runs(dag_id)
34        if dag_runs:
35            return dag_runs[-1]["run_id"]
36        return None

Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.

Arguments:
  • dag_id: The DAG ID.
Returns:

The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.

def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
38    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
39        dag_runs = self._list_dag_runs(dag_id) or []
40        for dag_run in dag_runs:
41            if dag_run["run_id"] == dag_run_id:
42                return dag_run["state"].lower()
43        raise NotFoundError(f"DAG run '{dag_run_id}' was not found for DAG '{dag_id}'")

Returns the state of the given DAG Run.

Arguments:
  • dag_id: The DAG ID.
  • dag_run_id: The DAG Run ID.
Returns:

The state of the given DAG Run.

def get_variable(self, key: str) -> Union[str, NoneType]:
45    def get_variable(self, key: str) -> t.Optional[str]:
46        stdout, stderr = self._post(f"variables get {key}")
47        if "does not exist" in stderr:
48            return None
49        return stdout

Returns the value of an Airflow variable with the given key.

Arguments:
  • key: The variable key.
Returns:

The variable value or None if no variable with the given key exists.

def url_and_auth_token_for_environment(environment_name: str) -> Tuple[str, str]:
82def url_and_auth_token_for_environment(environment_name: str) -> t.Tuple[str, str]:
83    import boto3
84
85    logger.info("Fetching the MWAA CLI token")
86
87    client = boto3.client("mwaa")
88    cli_token = client.create_cli_token(Name=environment_name)
89
90    url = f"https://{cli_token['WebServerHostname']}/"
91    auth_token = cli_token["CliToken"]
92    return url, auth_token