sqlmesh.integrations.dlt
1import typing as t 2import click 3from datetime import datetime, timedelta, timezone 4from pydantic import ValidationError 5from sqlglot import exp, parse_one 6from sqlmesh.core.config.connection import parse_connection_config 7from sqlmesh.core.context import Context 8from sqlmesh.utils.date import yesterday_ds 9 10 11def generate_dlt_models_and_settings( 12 pipeline_name: str, 13 dialect: str, 14 tables: t.Optional[t.List[str]] = None, 15 dlt_path: t.Optional[str] = None, 16) -> t.Tuple[t.Set[t.Tuple[str, str]], t.Optional[str], str]: 17 """ 18 This function attaches to a DLT pipeline and retrieves the connection configs and 19 SQLMesh models based on the tables present in the pipeline's default schema. 20 21 Args: 22 pipeline_name: The name of the DLT pipeline to attach to. 23 dialect: The SQL dialect to use for generating SQLMesh models. 24 tables: A list of table names to include. 25 dlt_path: The path to the directory containing the DLT pipelines. 26 27 Returns: 28 A tuple containing a set of the SQLMesh model definitions, the connection config and the start date. 29 """ 30 31 import dlt 32 from dlt.common.schema.utils import has_table_seen_data, is_complete_column 33 from dlt.pipeline.exceptions import CannotRestorePipelineException 34 35 try: 36 pipeline = dlt.attach(pipeline_name=pipeline_name, pipelines_dir=dlt_path or "") 37 except CannotRestorePipelineException: 38 raise click.ClickException(f"Could not attach to pipeline {pipeline_name}") 39 40 schema = pipeline.default_schema 41 dataset = pipeline.dataset_name 42 43 # Get the start date from the load_ids 44 storage_ids = list(pipeline._get_load_storage().list_loaded_packages()) 45 start_date = get_start_date(storage_ids) 46 47 # Get the connection credentials 48 db_type = pipeline.destination.to_name(pipeline.destination) 49 if db_type == "filesystem": 50 connection_config = None 51 else: 52 if dlt.__version__ >= "1.10.0": 53 client = pipeline.destination_client() 54 else: 55 client = pipeline._sql_job_client(schema) # type: ignore 56 config = client.config 57 credentials = config.credentials 58 configs = { 59 key: value 60 for key in dir(credentials) 61 if not key.startswith("_") 62 and not callable(value := getattr(credentials, key)) 63 and value is not None 64 } 65 connection_config = format_config(configs, db_type) 66 67 dlt_tables = { 68 name: table 69 for name, table in schema.tables.items() 70 if ( 71 (has_table_seen_data(table) and not name.startswith(schema._dlt_tables_prefix)) 72 or name == schema.loads_table_name 73 ) 74 and (name in tables if tables else True) 75 } 76 77 sqlmesh_models = set() 78 for table_name, table in dlt_tables.items(): 79 dlt_columns = {} 80 primary_key = [] 81 82 # is_complete_column returns true if column contains a name and a data type 83 for col in filter(is_complete_column, table["columns"].values()): 84 dlt_columns[col["name"]] = exp.DataType.build(str(col["data_type"]), dialect=dialect) 85 if col.get("primary_key"): 86 primary_key.append(str(col["name"])) 87 88 load_id = next( 89 (col for col in ["_dlt_load_id", "load_id"] if col in dlt_columns), 90 None, 91 ) 92 load_key = "c." + load_id if load_id else "" 93 parent_table = None 94 95 # Handling for nested tables: https://dlthub.com/docs/general-usage/destination-tables#nested-tables 96 if not load_id: 97 if ( 98 "_dlt_parent_id" in dlt_columns 99 and (parent_table := table["parent"]) 100 and parent_table in dlt_tables 101 ): 102 load_key = "p._dlt_load_id" 103 parent_table = dataset + "." + parent_table 104 else: 105 break 106 107 column_types = [ 108 exp.cast(exp.column(column, table="c"), data_type, dialect=dialect) 109 .as_(column) 110 .sql(dialect=dialect) 111 for column, data_type in dlt_columns.items() 112 if isinstance(column, str) 113 ] 114 select_columns = ( 115 ",\n".join(f" {column_name}" for column_name in column_types) if column_types else "" 116 ) 117 118 grain = f"\n grain ({', '.join(primary_key)})," if primary_key else "" 119 incremental_model_name = f"{dataset}_sqlmesh.incremental_{table_name}" 120 incremental_model_sql = generate_incremental_model( 121 incremental_model_name, 122 select_columns, 123 grain, 124 dataset + "." + table_name, 125 dialect, 126 load_key, 127 parent_table, 128 ) 129 sqlmesh_models.add((incremental_model_name, incremental_model_sql)) 130 131 return sqlmesh_models, connection_config, start_date 132 133 134def generate_dlt_models( 135 context: Context, 136 pipeline_name: str, 137 tables: t.List[str], 138 force: bool, 139 dlt_path: t.Optional[str] = None, 140) -> t.List[str]: 141 from sqlmesh.cli.project_init import _create_object_files 142 143 sqlmesh_models, _, _ = generate_dlt_models_and_settings( 144 pipeline_name=pipeline_name, 145 dialect=context.config.dialect or "", 146 tables=tables if tables else None, 147 dlt_path=dlt_path, 148 ) 149 150 if not tables and not force: 151 existing_models = [m.name for m in context.models.values()] 152 sqlmesh_models = {model for model in sqlmesh_models if model[0] not in existing_models} 153 154 if sqlmesh_models: 155 _create_object_files( 156 context.path / "models", 157 {model[0].split(".")[-1]: model[1] for model in sqlmesh_models}, 158 "sql", 159 ) 160 return [model[0] for model in sqlmesh_models] 161 return [] 162 163 164def generate_incremental_model( 165 model_name: str, 166 select_columns: str, 167 grain: str, 168 from_table: str, 169 dialect: str, 170 load_id: str, 171 parent_table: t.Optional[str] = None, 172) -> str: 173 """Generate the SQL definition for an incremental model.""" 174 175 time_column = parse_one(f"to_timestamp(CAST({load_id} AS DOUBLE))").sql(dialect=dialect) 176 177 from_clause = f"{from_table} as c" 178 if parent_table: 179 from_clause += f"""\nJOIN 180 {parent_table} as p 181ON 182 c._dlt_parent_id = p._dlt_id""" 183 184 return f"""MODEL ( 185 name {model_name}, 186 kind INCREMENTAL_BY_TIME_RANGE ( 187 time_column _dlt_load_time, 188 ),{grain} 189); 190 191SELECT 192{select_columns}, 193 {time_column} as _dlt_load_time 194FROM 195 {from_clause} 196WHERE 197 {time_column} BETWEEN @start_ds AND @end_ds 198""" 199 200 201def format_config(configs: t.Dict[str, str], db_type: str) -> str: 202 """Generate a string for the gateway connection config.""" 203 config = { 204 "type": db_type, 205 } 206 207 for key, value in configs.items(): 208 if key == "password": 209 config[key] = f'"{value}"' 210 elif key == "username": 211 config["user"] = value 212 else: 213 config[key] = value 214 215 # Validate the connection config fields 216 invalid_fields = [] 217 try: 218 parse_connection_config(config) 219 except ValidationError as e: 220 for error in e.errors(): 221 invalid_fields.append(error.get("loc", [])[0]) 222 223 return "\n".join( 224 [f" {key}: {value}" for key, value in config.items() if key not in invalid_fields] 225 ) 226 227 228def get_start_date(load_ids: t.List[str]) -> str: 229 """Convert the earliest load_id to UTC timestamp, subtract a day and format as 'YYYY-MM-DD'.""" 230 231 timestamps = [datetime.fromtimestamp(float(id), tz=timezone.utc) for id in load_ids] 232 if timestamps: 233 start_timestamp = min(timestamps) - timedelta(days=1) 234 return start_timestamp.strftime("%Y-%m-%d") 235 return yesterday_ds()
def
generate_dlt_models_and_settings( pipeline_name: str, dialect: str, tables: Optional[List[str]] = None, dlt_path: Optional[str] = None) -> Tuple[Set[Tuple[str, str]], Optional[str], str]:
12def generate_dlt_models_and_settings( 13 pipeline_name: str, 14 dialect: str, 15 tables: t.Optional[t.List[str]] = None, 16 dlt_path: t.Optional[str] = None, 17) -> t.Tuple[t.Set[t.Tuple[str, str]], t.Optional[str], str]: 18 """ 19 This function attaches to a DLT pipeline and retrieves the connection configs and 20 SQLMesh models based on the tables present in the pipeline's default schema. 21 22 Args: 23 pipeline_name: The name of the DLT pipeline to attach to. 24 dialect: The SQL dialect to use for generating SQLMesh models. 25 tables: A list of table names to include. 26 dlt_path: The path to the directory containing the DLT pipelines. 27 28 Returns: 29 A tuple containing a set of the SQLMesh model definitions, the connection config and the start date. 30 """ 31 32 import dlt 33 from dlt.common.schema.utils import has_table_seen_data, is_complete_column 34 from dlt.pipeline.exceptions import CannotRestorePipelineException 35 36 try: 37 pipeline = dlt.attach(pipeline_name=pipeline_name, pipelines_dir=dlt_path or "") 38 except CannotRestorePipelineException: 39 raise click.ClickException(f"Could not attach to pipeline {pipeline_name}") 40 41 schema = pipeline.default_schema 42 dataset = pipeline.dataset_name 43 44 # Get the start date from the load_ids 45 storage_ids = list(pipeline._get_load_storage().list_loaded_packages()) 46 start_date = get_start_date(storage_ids) 47 48 # Get the connection credentials 49 db_type = pipeline.destination.to_name(pipeline.destination) 50 if db_type == "filesystem": 51 connection_config = None 52 else: 53 if dlt.__version__ >= "1.10.0": 54 client = pipeline.destination_client() 55 else: 56 client = pipeline._sql_job_client(schema) # type: ignore 57 config = client.config 58 credentials = config.credentials 59 configs = { 60 key: value 61 for key in dir(credentials) 62 if not key.startswith("_") 63 and not callable(value := getattr(credentials, key)) 64 and value is not None 65 } 66 connection_config = format_config(configs, db_type) 67 68 dlt_tables = { 69 name: table 70 for name, table in schema.tables.items() 71 if ( 72 (has_table_seen_data(table) and not name.startswith(schema._dlt_tables_prefix)) 73 or name == schema.loads_table_name 74 ) 75 and (name in tables if tables else True) 76 } 77 78 sqlmesh_models = set() 79 for table_name, table in dlt_tables.items(): 80 dlt_columns = {} 81 primary_key = [] 82 83 # is_complete_column returns true if column contains a name and a data type 84 for col in filter(is_complete_column, table["columns"].values()): 85 dlt_columns[col["name"]] = exp.DataType.build(str(col["data_type"]), dialect=dialect) 86 if col.get("primary_key"): 87 primary_key.append(str(col["name"])) 88 89 load_id = next( 90 (col for col in ["_dlt_load_id", "load_id"] if col in dlt_columns), 91 None, 92 ) 93 load_key = "c." + load_id if load_id else "" 94 parent_table = None 95 96 # Handling for nested tables: https://dlthub.com/docs/general-usage/destination-tables#nested-tables 97 if not load_id: 98 if ( 99 "_dlt_parent_id" in dlt_columns 100 and (parent_table := table["parent"]) 101 and parent_table in dlt_tables 102 ): 103 load_key = "p._dlt_load_id" 104 parent_table = dataset + "." + parent_table 105 else: 106 break 107 108 column_types = [ 109 exp.cast(exp.column(column, table="c"), data_type, dialect=dialect) 110 .as_(column) 111 .sql(dialect=dialect) 112 for column, data_type in dlt_columns.items() 113 if isinstance(column, str) 114 ] 115 select_columns = ( 116 ",\n".join(f" {column_name}" for column_name in column_types) if column_types else "" 117 ) 118 119 grain = f"\n grain ({', '.join(primary_key)})," if primary_key else "" 120 incremental_model_name = f"{dataset}_sqlmesh.incremental_{table_name}" 121 incremental_model_sql = generate_incremental_model( 122 incremental_model_name, 123 select_columns, 124 grain, 125 dataset + "." + table_name, 126 dialect, 127 load_key, 128 parent_table, 129 ) 130 sqlmesh_models.add((incremental_model_name, incremental_model_sql)) 131 132 return sqlmesh_models, connection_config, start_date
This function attaches to a DLT pipeline and retrieves the connection configs and SQLMesh models based on the tables present in the pipeline's default schema.
Arguments:
- pipeline_name: The name of the DLT pipeline to attach to.
- dialect: The SQL dialect to use for generating SQLMesh models.
- tables: A list of table names to include.
- dlt_path: The path to the directory containing the DLT pipelines.
Returns:
A tuple containing a set of the SQLMesh model definitions, the connection config and the start date.
def
generate_dlt_models( context: sqlmesh.core.context.Context, pipeline_name: str, tables: List[str], force: bool, dlt_path: Optional[str] = None) -> List[str]:
135def generate_dlt_models( 136 context: Context, 137 pipeline_name: str, 138 tables: t.List[str], 139 force: bool, 140 dlt_path: t.Optional[str] = None, 141) -> t.List[str]: 142 from sqlmesh.cli.project_init import _create_object_files 143 144 sqlmesh_models, _, _ = generate_dlt_models_and_settings( 145 pipeline_name=pipeline_name, 146 dialect=context.config.dialect or "", 147 tables=tables if tables else None, 148 dlt_path=dlt_path, 149 ) 150 151 if not tables and not force: 152 existing_models = [m.name for m in context.models.values()] 153 sqlmesh_models = {model for model in sqlmesh_models if model[0] not in existing_models} 154 155 if sqlmesh_models: 156 _create_object_files( 157 context.path / "models", 158 {model[0].split(".")[-1]: model[1] for model in sqlmesh_models}, 159 "sql", 160 ) 161 return [model[0] for model in sqlmesh_models] 162 return []
def
generate_incremental_model( model_name: str, select_columns: str, grain: str, from_table: str, dialect: str, load_id: str, parent_table: Optional[str] = None) -> str:
165def generate_incremental_model( 166 model_name: str, 167 select_columns: str, 168 grain: str, 169 from_table: str, 170 dialect: str, 171 load_id: str, 172 parent_table: t.Optional[str] = None, 173) -> str: 174 """Generate the SQL definition for an incremental model.""" 175 176 time_column = parse_one(f"to_timestamp(CAST({load_id} AS DOUBLE))").sql(dialect=dialect) 177 178 from_clause = f"{from_table} as c" 179 if parent_table: 180 from_clause += f"""\nJOIN 181 {parent_table} as p 182ON 183 c._dlt_parent_id = p._dlt_id""" 184 185 return f"""MODEL ( 186 name {model_name}, 187 kind INCREMENTAL_BY_TIME_RANGE ( 188 time_column _dlt_load_time, 189 ),{grain} 190); 191 192SELECT 193{select_columns}, 194 {time_column} as _dlt_load_time 195FROM 196 {from_clause} 197WHERE 198 {time_column} BETWEEN @start_ds AND @end_ds 199"""
Generate the SQL definition for an incremental model.
def
format_config(configs: Dict[str, str], db_type: str) -> str:
202def format_config(configs: t.Dict[str, str], db_type: str) -> str: 203 """Generate a string for the gateway connection config.""" 204 config = { 205 "type": db_type, 206 } 207 208 for key, value in configs.items(): 209 if key == "password": 210 config[key] = f'"{value}"' 211 elif key == "username": 212 config["user"] = value 213 else: 214 config[key] = value 215 216 # Validate the connection config fields 217 invalid_fields = [] 218 try: 219 parse_connection_config(config) 220 except ValidationError as e: 221 for error in e.errors(): 222 invalid_fields.append(error.get("loc", [])[0]) 223 224 return "\n".join( 225 [f" {key}: {value}" for key, value in config.items() if key not in invalid_fields] 226 )
Generate a string for the gateway connection config.
def
get_start_date(load_ids: List[str]) -> str:
229def get_start_date(load_ids: t.List[str]) -> str: 230 """Convert the earliest load_id to UTC timestamp, subtract a day and format as 'YYYY-MM-DD'.""" 231 232 timestamps = [datetime.fromtimestamp(float(id), tz=timezone.utc) for id in load_ids] 233 if timestamps: 234 start_timestamp = min(timestamps) - timedelta(days=1) 235 return start_timestamp.strftime("%Y-%m-%d") 236 return yesterday_ds()
Convert the earliest load_id to UTC timestamp, subtract a day and format as 'YYYY-MM-DD'.