Edit on GitHub

sqlmesh.integrations.dlt

  1import typing as t
  2import click
  3from datetime import datetime, timedelta, timezone
  4from pydantic import ValidationError
  5from sqlglot import exp, parse_one
  6from sqlmesh.core.config.connection import parse_connection_config
  7from sqlmesh.core.context import Context
  8from sqlmesh.utils.date import yesterday_ds
  9
 10
 11def generate_dlt_models_and_settings(
 12    pipeline_name: str,
 13    dialect: str,
 14    tables: t.Optional[t.List[str]] = None,
 15    dlt_path: t.Optional[str] = None,
 16) -> t.Tuple[t.Set[t.Tuple[str, str]], t.Optional[str], str]:
 17    """
 18    This function attaches to a DLT pipeline and retrieves the connection configs and
 19    SQLMesh models based on the tables present in the pipeline's default schema.
 20
 21    Args:
 22        pipeline_name: The name of the DLT pipeline to attach to.
 23        dialect: The SQL dialect to use for generating SQLMesh models.
 24        tables: A list of table names to include.
 25        dlt_path: The path to the directory containing the DLT pipelines.
 26
 27    Returns:
 28        A tuple containing a set of the SQLMesh model definitions, the connection config and the start date.
 29    """
 30
 31    import dlt
 32    from dlt.common.schema.utils import has_table_seen_data, is_complete_column
 33    from dlt.pipeline.exceptions import CannotRestorePipelineException
 34
 35    try:
 36        pipeline = dlt.attach(pipeline_name=pipeline_name, pipelines_dir=dlt_path or "")
 37    except CannotRestorePipelineException:
 38        raise click.ClickException(f"Could not attach to pipeline {pipeline_name}")
 39
 40    schema = pipeline.default_schema
 41    dataset = pipeline.dataset_name
 42
 43    # Get the start date from the load_ids
 44    storage_ids = list(pipeline._get_load_storage().list_loaded_packages())
 45    start_date = get_start_date(storage_ids)
 46
 47    # Get the connection credentials
 48    db_type = pipeline.destination.to_name(pipeline.destination)
 49    if db_type == "filesystem":
 50        connection_config = None
 51    else:
 52        if dlt.__version__ >= "1.10.0":
 53            client = pipeline.destination_client()
 54        else:
 55            client = pipeline._sql_job_client(schema)  # type: ignore
 56        config = client.config
 57        credentials = config.credentials
 58        configs = {
 59            key: value
 60            for key in dir(credentials)
 61            if not key.startswith("_")
 62            and not callable(value := getattr(credentials, key))
 63            and value is not None
 64        }
 65        connection_config = format_config(configs, db_type)
 66
 67    dlt_tables = {
 68        name: table
 69        for name, table in schema.tables.items()
 70        if (
 71            (has_table_seen_data(table) and not name.startswith(schema._dlt_tables_prefix))
 72            or name == schema.loads_table_name
 73        )
 74        and (name in tables if tables else True)
 75    }
 76
 77    sqlmesh_models = set()
 78    for table_name, table in dlt_tables.items():
 79        dlt_columns = {}
 80        primary_key = []
 81
 82        # is_complete_column returns true if column contains a name and a data type
 83        for col in filter(is_complete_column, table["columns"].values()):
 84            dlt_columns[col["name"]] = exp.DataType.build(str(col["data_type"]), dialect=dialect)
 85            if col.get("primary_key"):
 86                primary_key.append(str(col["name"]))
 87
 88        load_id = next(
 89            (col for col in ["_dlt_load_id", "load_id"] if col in dlt_columns),
 90            None,
 91        )
 92        load_key = "c." + load_id if load_id else ""
 93        parent_table = None
 94
 95        # Handling for nested tables: https://dlthub.com/docs/general-usage/destination-tables#nested-tables
 96        if not load_id:
 97            if (
 98                "_dlt_parent_id" in dlt_columns
 99                and (parent_table := table["parent"])
100                and parent_table in dlt_tables
101            ):
102                load_key = "p._dlt_load_id"
103                parent_table = dataset + "." + parent_table
104            else:
105                break
106
107        column_types = [
108            exp.cast(exp.column(column, table="c"), data_type, dialect=dialect)
109            .as_(column)
110            .sql(dialect=dialect)
111            for column, data_type in dlt_columns.items()
112            if isinstance(column, str)
113        ]
114        select_columns = (
115            ",\n".join(f"  {column_name}" for column_name in column_types) if column_types else ""
116        )
117
118        grain = f"\n  grain ({', '.join(primary_key)})," if primary_key else ""
119        incremental_model_name = f"{dataset}_sqlmesh.incremental_{table_name}"
120        incremental_model_sql = generate_incremental_model(
121            incremental_model_name,
122            select_columns,
123            grain,
124            dataset + "." + table_name,
125            dialect,
126            load_key,
127            parent_table,
128        )
129        sqlmesh_models.add((incremental_model_name, incremental_model_sql))
130
131    return sqlmesh_models, connection_config, start_date
132
133
134def generate_dlt_models(
135    context: Context,
136    pipeline_name: str,
137    tables: t.List[str],
138    force: bool,
139    dlt_path: t.Optional[str] = None,
140) -> t.List[str]:
141    from sqlmesh.cli.project_init import _create_object_files
142
143    sqlmesh_models, _, _ = generate_dlt_models_and_settings(
144        pipeline_name=pipeline_name,
145        dialect=context.config.dialect or "",
146        tables=tables if tables else None,
147        dlt_path=dlt_path,
148    )
149
150    if not tables and not force:
151        existing_models = [m.name for m in context.models.values()]
152        sqlmesh_models = {model for model in sqlmesh_models if model[0] not in existing_models}
153
154    if sqlmesh_models:
155        _create_object_files(
156            context.path / "models",
157            {model[0].split(".")[-1]: model[1] for model in sqlmesh_models},
158            "sql",
159        )
160        return [model[0] for model in sqlmesh_models]
161    return []
162
163
164def generate_incremental_model(
165    model_name: str,
166    select_columns: str,
167    grain: str,
168    from_table: str,
169    dialect: str,
170    load_id: str,
171    parent_table: t.Optional[str] = None,
172) -> str:
173    """Generate the SQL definition for an incremental model."""
174
175    time_column = parse_one(f"to_timestamp(CAST({load_id} AS DOUBLE))").sql(dialect=dialect)
176
177    from_clause = f"{from_table} as c"
178    if parent_table:
179        from_clause += f"""\nJOIN
180  {parent_table} as p
181ON
182  c._dlt_parent_id = p._dlt_id"""
183
184    return f"""MODEL (
185  name {model_name},
186  kind INCREMENTAL_BY_TIME_RANGE (
187    time_column _dlt_load_time,
188  ),{grain}
189);
190
191SELECT
192{select_columns},
193  {time_column} as _dlt_load_time
194FROM
195  {from_clause}
196WHERE
197  {time_column} BETWEEN @start_ds AND @end_ds
198"""
199
200
201def format_config(configs: t.Dict[str, str], db_type: str) -> str:
202    """Generate a string for the gateway connection config."""
203    config = {
204        "type": db_type,
205    }
206
207    for key, value in configs.items():
208        if key == "password":
209            config[key] = f'"{value}"'
210        elif key == "username":
211            config["user"] = value
212        else:
213            config[key] = value
214
215    # Validate the connection config fields
216    invalid_fields = []
217    try:
218        parse_connection_config(config)
219    except ValidationError as e:
220        for error in e.errors():
221            invalid_fields.append(error.get("loc", [])[0])
222
223    return "\n".join(
224        [f"      {key}: {value}" for key, value in config.items() if key not in invalid_fields]
225    )
226
227
228def get_start_date(load_ids: t.List[str]) -> str:
229    """Convert the earliest load_id to UTC timestamp, subtract a day and format as 'YYYY-MM-DD'."""
230
231    timestamps = [datetime.fromtimestamp(float(id), tz=timezone.utc) for id in load_ids]
232    if timestamps:
233        start_timestamp = min(timestamps) - timedelta(days=1)
234        return start_timestamp.strftime("%Y-%m-%d")
235    return yesterday_ds()
def generate_dlt_models_and_settings( pipeline_name: str, dialect: str, tables: Optional[List[str]] = None, dlt_path: Optional[str] = None) -> Tuple[Set[Tuple[str, str]], Optional[str], str]:
 12def generate_dlt_models_and_settings(
 13    pipeline_name: str,
 14    dialect: str,
 15    tables: t.Optional[t.List[str]] = None,
 16    dlt_path: t.Optional[str] = None,
 17) -> t.Tuple[t.Set[t.Tuple[str, str]], t.Optional[str], str]:
 18    """
 19    This function attaches to a DLT pipeline and retrieves the connection configs and
 20    SQLMesh models based on the tables present in the pipeline's default schema.
 21
 22    Args:
 23        pipeline_name: The name of the DLT pipeline to attach to.
 24        dialect: The SQL dialect to use for generating SQLMesh models.
 25        tables: A list of table names to include.
 26        dlt_path: The path to the directory containing the DLT pipelines.
 27
 28    Returns:
 29        A tuple containing a set of the SQLMesh model definitions, the connection config and the start date.
 30    """
 31
 32    import dlt
 33    from dlt.common.schema.utils import has_table_seen_data, is_complete_column
 34    from dlt.pipeline.exceptions import CannotRestorePipelineException
 35
 36    try:
 37        pipeline = dlt.attach(pipeline_name=pipeline_name, pipelines_dir=dlt_path or "")
 38    except CannotRestorePipelineException:
 39        raise click.ClickException(f"Could not attach to pipeline {pipeline_name}")
 40
 41    schema = pipeline.default_schema
 42    dataset = pipeline.dataset_name
 43
 44    # Get the start date from the load_ids
 45    storage_ids = list(pipeline._get_load_storage().list_loaded_packages())
 46    start_date = get_start_date(storage_ids)
 47
 48    # Get the connection credentials
 49    db_type = pipeline.destination.to_name(pipeline.destination)
 50    if db_type == "filesystem":
 51        connection_config = None
 52    else:
 53        if dlt.__version__ >= "1.10.0":
 54            client = pipeline.destination_client()
 55        else:
 56            client = pipeline._sql_job_client(schema)  # type: ignore
 57        config = client.config
 58        credentials = config.credentials
 59        configs = {
 60            key: value
 61            for key in dir(credentials)
 62            if not key.startswith("_")
 63            and not callable(value := getattr(credentials, key))
 64            and value is not None
 65        }
 66        connection_config = format_config(configs, db_type)
 67
 68    dlt_tables = {
 69        name: table
 70        for name, table in schema.tables.items()
 71        if (
 72            (has_table_seen_data(table) and not name.startswith(schema._dlt_tables_prefix))
 73            or name == schema.loads_table_name
 74        )
 75        and (name in tables if tables else True)
 76    }
 77
 78    sqlmesh_models = set()
 79    for table_name, table in dlt_tables.items():
 80        dlt_columns = {}
 81        primary_key = []
 82
 83        # is_complete_column returns true if column contains a name and a data type
 84        for col in filter(is_complete_column, table["columns"].values()):
 85            dlt_columns[col["name"]] = exp.DataType.build(str(col["data_type"]), dialect=dialect)
 86            if col.get("primary_key"):
 87                primary_key.append(str(col["name"]))
 88
 89        load_id = next(
 90            (col for col in ["_dlt_load_id", "load_id"] if col in dlt_columns),
 91            None,
 92        )
 93        load_key = "c." + load_id if load_id else ""
 94        parent_table = None
 95
 96        # Handling for nested tables: https://dlthub.com/docs/general-usage/destination-tables#nested-tables
 97        if not load_id:
 98            if (
 99                "_dlt_parent_id" in dlt_columns
100                and (parent_table := table["parent"])
101                and parent_table in dlt_tables
102            ):
103                load_key = "p._dlt_load_id"
104                parent_table = dataset + "." + parent_table
105            else:
106                break
107
108        column_types = [
109            exp.cast(exp.column(column, table="c"), data_type, dialect=dialect)
110            .as_(column)
111            .sql(dialect=dialect)
112            for column, data_type in dlt_columns.items()
113            if isinstance(column, str)
114        ]
115        select_columns = (
116            ",\n".join(f"  {column_name}" for column_name in column_types) if column_types else ""
117        )
118
119        grain = f"\n  grain ({', '.join(primary_key)})," if primary_key else ""
120        incremental_model_name = f"{dataset}_sqlmesh.incremental_{table_name}"
121        incremental_model_sql = generate_incremental_model(
122            incremental_model_name,
123            select_columns,
124            grain,
125            dataset + "." + table_name,
126            dialect,
127            load_key,
128            parent_table,
129        )
130        sqlmesh_models.add((incremental_model_name, incremental_model_sql))
131
132    return sqlmesh_models, connection_config, start_date

This function attaches to a DLT pipeline and retrieves the connection configs and SQLMesh models based on the tables present in the pipeline's default schema.

Arguments:
  • pipeline_name: The name of the DLT pipeline to attach to.
  • dialect: The SQL dialect to use for generating SQLMesh models.
  • tables: A list of table names to include.
  • dlt_path: The path to the directory containing the DLT pipelines.
Returns:

A tuple containing a set of the SQLMesh model definitions, the connection config and the start date.

def generate_dlt_models( context: sqlmesh.core.context.Context, pipeline_name: str, tables: List[str], force: bool, dlt_path: Optional[str] = None) -> List[str]:
135def generate_dlt_models(
136    context: Context,
137    pipeline_name: str,
138    tables: t.List[str],
139    force: bool,
140    dlt_path: t.Optional[str] = None,
141) -> t.List[str]:
142    from sqlmesh.cli.project_init import _create_object_files
143
144    sqlmesh_models, _, _ = generate_dlt_models_and_settings(
145        pipeline_name=pipeline_name,
146        dialect=context.config.dialect or "",
147        tables=tables if tables else None,
148        dlt_path=dlt_path,
149    )
150
151    if not tables and not force:
152        existing_models = [m.name for m in context.models.values()]
153        sqlmesh_models = {model for model in sqlmesh_models if model[0] not in existing_models}
154
155    if sqlmesh_models:
156        _create_object_files(
157            context.path / "models",
158            {model[0].split(".")[-1]: model[1] for model in sqlmesh_models},
159            "sql",
160        )
161        return [model[0] for model in sqlmesh_models]
162    return []
def generate_incremental_model( model_name: str, select_columns: str, grain: str, from_table: str, dialect: str, load_id: str, parent_table: Optional[str] = None) -> str:
165def generate_incremental_model(
166    model_name: str,
167    select_columns: str,
168    grain: str,
169    from_table: str,
170    dialect: str,
171    load_id: str,
172    parent_table: t.Optional[str] = None,
173) -> str:
174    """Generate the SQL definition for an incremental model."""
175
176    time_column = parse_one(f"to_timestamp(CAST({load_id} AS DOUBLE))").sql(dialect=dialect)
177
178    from_clause = f"{from_table} as c"
179    if parent_table:
180        from_clause += f"""\nJOIN
181  {parent_table} as p
182ON
183  c._dlt_parent_id = p._dlt_id"""
184
185    return f"""MODEL (
186  name {model_name},
187  kind INCREMENTAL_BY_TIME_RANGE (
188    time_column _dlt_load_time,
189  ),{grain}
190);
191
192SELECT
193{select_columns},
194  {time_column} as _dlt_load_time
195FROM
196  {from_clause}
197WHERE
198  {time_column} BETWEEN @start_ds AND @end_ds
199"""

Generate the SQL definition for an incremental model.

def format_config(configs: Dict[str, str], db_type: str) -> str:
202def format_config(configs: t.Dict[str, str], db_type: str) -> str:
203    """Generate a string for the gateway connection config."""
204    config = {
205        "type": db_type,
206    }
207
208    for key, value in configs.items():
209        if key == "password":
210            config[key] = f'"{value}"'
211        elif key == "username":
212            config["user"] = value
213        else:
214            config[key] = value
215
216    # Validate the connection config fields
217    invalid_fields = []
218    try:
219        parse_connection_config(config)
220    except ValidationError as e:
221        for error in e.errors():
222            invalid_fields.append(error.get("loc", [])[0])
223
224    return "\n".join(
225        [f"      {key}: {value}" for key, value in config.items() if key not in invalid_fields]
226    )

Generate a string for the gateway connection config.

def get_start_date(load_ids: List[str]) -> str:
229def get_start_date(load_ids: t.List[str]) -> str:
230    """Convert the earliest load_id to UTC timestamp, subtract a day and format as 'YYYY-MM-DD'."""
231
232    timestamps = [datetime.fromtimestamp(float(id), tz=timezone.utc) for id in load_ids]
233    if timestamps:
234        start_timestamp = min(timestamps) - timedelta(days=1)
235        return start_timestamp.strftime("%Y-%m-%d")
236    return yesterday_ds()

Convert the earliest load_id to UTC timestamp, subtract a day and format as 'YYYY-MM-DD'.