sqlmesh.schedulers.airflow.mwaa_client
1from __future__ import annotations 2 3import base64 4import json 5import logging 6import typing as t 7from urllib.parse import urljoin 8 9from requests import Session 10 11from sqlmesh.core.console import Console 12from sqlmesh.schedulers.airflow.client import BaseAirflowClient, raise_for_status 13from sqlmesh.utils.date import now_timestamp 14from sqlmesh.utils.errors import NotFoundError 15 16logger = logging.getLogger(__name__) 17 18 19TOKEN_TTL_MS = 30 * 1000 20 21 22class MWAAClient(BaseAirflowClient): 23 def __init__(self, environment: str, console: t.Optional[Console] = None): 24 airflow_url, auth_token = url_and_auth_token_for_environment(environment) 25 super().__init__(airflow_url, console) 26 27 self._environment = environment 28 self._last_token_refresh_ts = now_timestamp() 29 self.__session: Session = _create_session(auth_token) 30 31 def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: 32 dag_runs = self._list_dag_runs(dag_id) 33 if dag_runs: 34 return dag_runs[-1]["run_id"] 35 return None 36 37 def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: 38 dag_runs = self._list_dag_runs(dag_id) or [] 39 for dag_run in dag_runs: 40 if dag_run["run_id"] == dag_run_id: 41 return dag_run["state"].lower() 42 raise NotFoundError(f"DAG run '{dag_run_id}' was not found for DAG '{dag_id}'") 43 44 def get_variable(self, key: str) -> t.Optional[str]: 45 stdout, stderr = self._post(f"variables get {key}") 46 if "does not exist" in stderr: 47 return None 48 return stdout 49 50 def _list_dag_runs(self, dag_id: str) -> t.Optional[t.List[t.Dict[str, t.Any]]]: 51 stdout, stderr = self._post(f"dags list-runs -o json -d {dag_id}") 52 if stdout: 53 return json.loads(stdout) 54 return None 55 56 def _post(self, data: str) -> t.Tuple[str, str]: 57 response = self._session.post(urljoin(self._airflow_url, "aws_mwaa/cli"), data=data) 58 raise_for_status(response) 59 response_body = response.json() 60 61 cli_stdout = base64.b64decode(response_body["stdout"]).decode("utf8").strip() 62 cli_stderr = base64.b64decode(response_body["stderr"]).decode("utf8").strip() 63 return cli_stdout, cli_stderr 64 65 @property 66 def _session(self) -> Session: 67 current_ts = now_timestamp() 68 if current_ts - self._last_token_refresh_ts > TOKEN_TTL_MS: 69 _, auth_token = url_and_auth_token_for_environment(self._environment) 70 self.__session = _create_session(auth_token) 71 self._last_token_refresh_ts = current_ts 72 return self.__session 73 74 75def _create_session(auth_token: str) -> Session: 76 session = Session() 77 session.headers.update({"Authorization": f"Bearer {auth_token}", "Content-Type": "text/plain"}) 78 return session 79 80 81def url_and_auth_token_for_environment(environment_name: str) -> t.Tuple[str, str]: 82 import boto3 83 84 logger.info("Fetching the MWAA CLI token") 85 86 client = boto3.client("mwaa") 87 cli_token = client.create_cli_token(Name=environment_name) 88 89 url = f"https://{cli_token['WebServerHostname']}/" 90 auth_token = cli_token["CliToken"] 91 return url, auth_token
23class MWAAClient(BaseAirflowClient): 24 def __init__(self, environment: str, console: t.Optional[Console] = None): 25 airflow_url, auth_token = url_and_auth_token_for_environment(environment) 26 super().__init__(airflow_url, console) 27 28 self._environment = environment 29 self._last_token_refresh_ts = now_timestamp() 30 self.__session: Session = _create_session(auth_token) 31 32 def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: 33 dag_runs = self._list_dag_runs(dag_id) 34 if dag_runs: 35 return dag_runs[-1]["run_id"] 36 return None 37 38 def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: 39 dag_runs = self._list_dag_runs(dag_id) or [] 40 for dag_run in dag_runs: 41 if dag_run["run_id"] == dag_run_id: 42 return dag_run["state"].lower() 43 raise NotFoundError(f"DAG run '{dag_run_id}' was not found for DAG '{dag_id}'") 44 45 def get_variable(self, key: str) -> t.Optional[str]: 46 stdout, stderr = self._post(f"variables get {key}") 47 if "does not exist" in stderr: 48 return None 49 return stdout 50 51 def _list_dag_runs(self, dag_id: str) -> t.Optional[t.List[t.Dict[str, t.Any]]]: 52 stdout, stderr = self._post(f"dags list-runs -o json -d {dag_id}") 53 if stdout: 54 return json.loads(stdout) 55 return None 56 57 def _post(self, data: str) -> t.Tuple[str, str]: 58 response = self._session.post(urljoin(self._airflow_url, "aws_mwaa/cli"), data=data) 59 raise_for_status(response) 60 response_body = response.json() 61 62 cli_stdout = base64.b64decode(response_body["stdout"]).decode("utf8").strip() 63 cli_stderr = base64.b64decode(response_body["stderr"]).decode("utf8").strip() 64 return cli_stdout, cli_stderr 65 66 @property 67 def _session(self) -> Session: 68 current_ts = now_timestamp() 69 if current_ts - self._last_token_refresh_ts > TOKEN_TTL_MS: 70 _, auth_token = url_and_auth_token_for_environment(self._environment) 71 self.__session = _create_session(auth_token) 72 self._last_token_refresh_ts = current_ts 73 return self.__session
Helper class that provides a standard way to create an ABC using inheritance.
MWAAClient( environment: str, console: Union[sqlmesh.core.console.Console, NoneType] = None)
24 def __init__(self, environment: str, console: t.Optional[Console] = None): 25 airflow_url, auth_token = url_and_auth_token_for_environment(environment) 26 super().__init__(airflow_url, console) 27 28 self._environment = environment 29 self._last_token_refresh_ts = now_timestamp() 30 self.__session: Session = _create_session(auth_token)
def
get_first_dag_run_id(self, dag_id: str) -> Union[str, NoneType]:
32 def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: 33 dag_runs = self._list_dag_runs(dag_id) 34 if dag_runs: 35 return dag_runs[-1]["run_id"] 36 return None
Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
Arguments:
- dag_id: The DAG ID.
Returns:
The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist.
def
get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
38 def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: 39 dag_runs = self._list_dag_runs(dag_id) or [] 40 for dag_run in dag_runs: 41 if dag_run["run_id"] == dag_run_id: 42 return dag_run["state"].lower() 43 raise NotFoundError(f"DAG run '{dag_run_id}' was not found for DAG '{dag_id}'")
Returns the state of the given DAG Run.
Arguments:
- dag_id: The DAG ID.
- dag_run_id: The DAG Run ID.
Returns:
The state of the given DAG Run.
def
get_variable(self, key: str) -> Union[str, NoneType]:
45 def get_variable(self, key: str) -> t.Optional[str]: 46 stdout, stderr = self._post(f"variables get {key}") 47 if "does not exist" in stderr: 48 return None 49 return stdout
Returns the value of an Airflow variable with the given key.
Arguments:
- key: The variable key.
Returns:
The variable value or None if no variable with the given key exists.
def
url_and_auth_token_for_environment(environment_name: str) -> Tuple[str, str]:
82def url_and_auth_token_for_environment(environment_name: str) -> t.Tuple[str, str]: 83 import boto3 84 85 logger.info("Fetching the MWAA CLI token") 86 87 client = boto3.client("mwaa") 88 cli_token = client.create_cli_token(Name=environment_name) 89 90 url = f"https://{cli_token['WebServerHostname']}/" 91 auth_token = cli_token["CliToken"] 92 return url, auth_token