sqlmesh.utils.pydantic
1from __future__ import annotations 2 3import json 4import typing as t 5from datetime import tzinfo 6 7import pydantic 8from pydantic import ValidationInfo as ValidationInfo 9from pydantic.fields import FieldInfo 10from sqlglot import exp, parse_one 11from sqlglot.helper import ensure_list 12from sqlglot.optimizer.normalize_identifiers import normalize_identifiers 13from sqlglot.optimizer.qualify_columns import quote_identifiers 14 15from sqlmesh.core import dialect as d 16from sqlmesh.utils import str_to_bool 17 18if t.TYPE_CHECKING: 19 from sqlglot._typing import E 20 21 Model = t.TypeVar("Model", bound="PydanticModel") 22 23 24T = t.TypeVar("T") 25DEFAULT_ARGS = {"exclude_none": True, "by_alias": True} 26PRIVATE_FIELDS = "__pydantic_private__" 27PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION = [int(p) for p in pydantic.__version__.split(".")][ 28 :2 29] 30 31 32def field_validator(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: 33 return pydantic.field_validator(*args, **kwargs) 34 35 36def model_validator(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: 37 return pydantic.model_validator(*args, **kwargs) 38 39 40def field_serializer(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: 41 return pydantic.field_serializer(*args, **kwargs) 42 43 44def get_dialect(values: t.Any) -> str: 45 """Extracts dialect from a dict or pydantic obj, defaulting to the globally set dialect. 46 47 Python models allow users to instantiate pydantic models by hand. This is problematic 48 because the validators kick in with the SQLGLot dialect. To instantiate Pydantic Models used 49 in python models using the project default dialect, we set a class variable on the model 50 registry and use that here. 51 """ 52 53 from sqlmesh.core.model import model 54 55 dialect = (values if isinstance(values, dict) else values.data).get("dialect") 56 return model._dialect if dialect is None else dialect # type: ignore 57 58 59def _expression_encoder(e: exp.Expr) -> str: 60 return e.meta.get("sql") or e.sql(dialect=e.meta.get("dialect")) 61 62 63AuditQueryTypes = t.Union[exp.Query, d.JinjaQuery] 64ModelQueryTypes = t.Union[exp.Query, d.JinjaQuery, d.MacroFunc] 65 66 67class PydanticModel(pydantic.BaseModel): 68 model_config = pydantic.ConfigDict( 69 # Even though Pydantic v2 kept support for json_encoders, the functionality has been 70 # crippled badly. Here we need to enumerate all different ways of how sqlglot expressions 71 # show up in pydantic models. 72 json_encoders={ 73 exp.Expr: _expression_encoder, 74 exp.DataType: _expression_encoder, 75 exp.Tuple: _expression_encoder, 76 AuditQueryTypes: _expression_encoder, # type: ignore 77 ModelQueryTypes: _expression_encoder, # type: ignore 78 tzinfo: lambda tz: tz.key, 79 }, 80 arbitrary_types_allowed=True, 81 extra="forbid", 82 protected_namespaces=(), 83 ) 84 85 _hash_func_mapping: t.ClassVar[t.Dict[t.Type[t.Any], t.Callable[[t.Any], int]]] = {} 86 87 def dict(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: 88 kwargs = {**DEFAULT_ARGS, **kwargs} 89 return super().model_dump(**kwargs) # type: ignore 90 91 def json( 92 self, 93 **kwargs: t.Any, 94 ) -> str: 95 kwargs = {**DEFAULT_ARGS, **kwargs} 96 # Pydantic v2 doesn't support arbitrary arguments for json.dump(). 97 if kwargs.pop("sort_keys", False): 98 return json.dumps(super().model_dump(mode="json", **kwargs), sort_keys=True) 99 100 return super().model_dump_json(**kwargs) 101 102 def copy(self: "Model", **kwargs: t.Any) -> "Model": 103 return super().model_copy(**kwargs) 104 105 @property 106 def fields_set(self: "Model") -> t.Set[str]: 107 return self.__pydantic_fields_set__ 108 109 @classmethod 110 def parse_obj(cls: t.Type["Model"], obj: t.Any) -> "Model": 111 return super().model_validate(obj) 112 113 @classmethod 114 def parse_raw(cls: t.Type["Model"], b: t.Union[str, bytes], **kwargs: t.Any) -> "Model": 115 return super().model_validate_json(b, **kwargs) 116 117 @classmethod 118 def missing_required_fields( 119 cls: t.Type["PydanticModel"], provided_fields: t.Set[str] 120 ) -> t.Set[str]: 121 return cls.required_fields() - provided_fields 122 123 @classmethod 124 def extra_fields(cls: t.Type["PydanticModel"], provided_fields: t.Set[str]) -> t.Set[str]: 125 return provided_fields - cls.all_fields() 126 127 @classmethod 128 def all_fields(cls: t.Type["PydanticModel"]) -> t.Set[str]: 129 return cls._fields() 130 131 @classmethod 132 def all_field_infos(cls: t.Type["PydanticModel"]) -> t.Dict[str, FieldInfo]: 133 return cls.model_fields 134 135 @classmethod 136 def required_fields(cls: t.Type["PydanticModel"]) -> t.Set[str]: 137 return cls._fields(lambda field: field.is_required()) 138 139 @classmethod 140 def _fields( 141 cls: t.Type["PydanticModel"], 142 predicate: t.Callable[[t.Any], bool] = lambda _: True, 143 ) -> t.Set[str]: 144 return { 145 field_info.alias if field_info.alias else field_name 146 for field_name, field_info in cls.all_field_infos().items() 147 if predicate(field_info) 148 } 149 150 def __eq__(self, other: t.Any) -> bool: 151 if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) < (2, 6): 152 if isinstance(other, pydantic.BaseModel): 153 return self.dict() == other.dict() 154 return self.dict() == other 155 return super().__eq__(other) 156 157 def __hash__(self) -> int: 158 if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) < (2, 6): 159 obj = {k: v for k, v in self.__dict__.items() if k in self.all_field_infos()} 160 return hash(self.__class__) + hash(tuple(obj.values())) 161 162 from pydantic._internal._model_construction import make_hash_func # type: ignore 163 164 if self.__class__ not in PydanticModel._hash_func_mapping: 165 PydanticModel._hash_func_mapping[self.__class__] = make_hash_func(self.__class__) 166 167 return PydanticModel._hash_func_mapping[self.__class__](self) 168 169 def __str__(self) -> str: 170 args = [] 171 172 for k, info in self.all_field_infos().items(): 173 v = getattr(self, k) 174 175 if type(v) != type(info.default) or v != info.default: 176 args.append(f"{k}: {v}") 177 178 return f"{self.__class__.__name__}<{', '.join(args)}>" 179 180 def __repr__(self) -> str: 181 return str(self) 182 183 184def validate_list_of_strings(v: t.Any) -> t.List[str]: 185 if isinstance(v, exp.Identifier): 186 return [v.name] 187 if isinstance(v, (exp.Tuple, exp.Array)): 188 return [e.name for e in v.expressions] 189 return [i.name if isinstance(i, exp.Identifier) else str(i) for i in v] 190 191 192def validate_string(v: t.Any) -> str: 193 if isinstance(v, exp.Expr): 194 return v.name 195 return str(v) 196 197 198def validate_expression(expression: E, dialect: str) -> E: 199 # this normalizes and quotes identifiers in the given expression according the specified dialect 200 # it also sets expression.meta["dialect"] so that when we serialize for state, the expression is serialized in the correct dialect 201 return _get_field(expression, {"dialect": dialect}) # type: ignore 202 203 204def bool_validator(v: t.Any) -> bool: 205 if isinstance(v, exp.Boolean): 206 return v.this 207 if isinstance(v, exp.Expr): 208 return str_to_bool(v.name) 209 return str_to_bool(str(v or "")) 210 211 212def positive_int_validator(v: t.Any) -> int: 213 if isinstance(v, exp.Expr) and v.is_int: 214 v = int(v.name) 215 if not isinstance(v, int): 216 raise ValueError(f"Invalid num {v}. Value must be an integer value") 217 if v <= 0: 218 raise ValueError(f"Invalid num {v}. Value must be a positive integer") 219 return v 220 221 222def validation_error_message(error: pydantic.ValidationError, base: str) -> str: 223 errors = "\n ".join(_formatted_validation_errors(error)) 224 return f"{base}\n {errors}" 225 226 227def _formatted_validation_errors(error: pydantic.ValidationError) -> t.List[str]: 228 result = [] 229 for e in error.errors(): 230 msg = e["msg"] 231 loc: t.Optional[t.Tuple] = e.get("loc") 232 loc_str = ".".join(loc) if loc else None 233 result.append(f"Invalid field '{loc_str}':\n {msg}" if loc_str else msg) 234 return result 235 236 237def _get_field( 238 v: t.Any, 239 values: t.Any, 240) -> exp.Expr: 241 dialect = get_dialect(values) 242 243 if isinstance(v, exp.Expr): 244 expression = v 245 else: 246 expression = parse_one(v, dialect=dialect) 247 248 expression = exp.column(expression) if isinstance(expression, exp.Identifier) else expression 249 expression = quote_identifiers( 250 normalize_identifiers(expression, dialect=dialect), dialect=dialect 251 ) 252 expression.meta["dialect"] = dialect 253 254 return expression 255 256 257def _get_fields( 258 v: t.Any, 259 values: t.Any, 260) -> t.List[exp.Expr]: 261 dialect = get_dialect(values) 262 263 if isinstance(v, (exp.Tuple, exp.Array)): 264 expressions: t.List[exp.Expr] = v.expressions 265 elif isinstance(v, exp.Expr): 266 expressions = [v] 267 else: 268 expressions = [ 269 parse_one(entry, dialect=dialect) if isinstance(entry, str) else entry # type: ignore[misc] 270 for entry in ensure_list(v) 271 ] 272 273 results = [] 274 275 for expr in expressions: 276 results.append(_get_field(expr, values)) 277 278 return results 279 280 281def list_of_fields_validator(v: t.Any, values: t.Any) -> t.List[exp.Expr]: 282 return _get_fields(v, values) 283 284 285def column_validator(v: t.Any, values: t.Any) -> exp.Column: 286 expression = _get_field(v, values) 287 if not isinstance(expression, exp.Column): 288 raise ValueError(f"Invalid column {expression}. Value must be a column") 289 return expression 290 291 292def list_of_fields_or_star_validator( 293 v: t.Any, values: t.Any 294) -> t.Union[exp.Star, t.List[exp.Expr]]: 295 expressions = _get_fields(v, values) 296 if len(expressions) == 1 and isinstance(expressions[0], exp.Star): 297 return t.cast(exp.Star, expressions[0]) 298 return t.cast(t.List[exp.Expr], expressions) 299 300 301def cron_validator(v: t.Any) -> str: 302 if isinstance(v, exp.Expr): 303 v = v.name 304 305 from croniter import CroniterBadCronError, croniter 306 307 if not isinstance(v, str): 308 raise ValueError(f"Invalid cron expression '{v}'. Value must be a string.") 309 310 try: 311 croniter(v) 312 except CroniterBadCronError: 313 raise ValueError(f"Invalid cron expression '{v}'") 314 return v 315 316 317def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]: 318 concrete_types = set() 319 unpacked = t.get_origin(typehint) 320 if unpacked is None: 321 if type(typehint) == type(type): 322 return {typehint} 323 elif unpacked is t.Union: 324 for item in t.get_args(typehint): 325 if str(item).startswith("typing."): 326 concrete_types |= get_concrete_types_from_typehint(item) 327 else: 328 concrete_types.add(item) 329 else: 330 concrete_types.add(unpacked) 331 332 return concrete_types 333 334 335if t.TYPE_CHECKING: 336 SQLGlotListOfStrings = t.List[str] 337 SQLGlotString = str 338 SQLGlotBool = bool 339 SQLGlotPositiveInt = int 340 SQLGlotColumn = exp.Column 341 SQLGlotListOfFields = t.List[exp.Expr] 342 SQLGlotListOfFieldsOrStar = t.Union[SQLGlotListOfFields, exp.Star] 343 SQLGlotCron = str 344else: 345 from pydantic.functional_validators import BeforeValidator 346 347 SQLGlotListOfStrings = t.Annotated[t.List[str], BeforeValidator(validate_list_of_strings)] 348 SQLGlotString = t.Annotated[str, BeforeValidator(validate_string)] 349 SQLGlotBool = t.Annotated[bool, BeforeValidator(bool_validator)] 350 SQLGlotPositiveInt = t.Annotated[int, BeforeValidator(positive_int_validator)] 351 SQLGlotColumn = t.Annotated[exp.Expr, BeforeValidator(column_validator)] 352 SQLGlotListOfFields = t.Annotated[t.List[exp.Expr], BeforeValidator(list_of_fields_validator)] 353 SQLGlotListOfFieldsOrStar = t.Annotated[ 354 t.Union[SQLGlotListOfFields, exp.Star], BeforeValidator(list_of_fields_or_star_validator) 355 ] 356 SQLGlotCron = t.Annotated[str, BeforeValidator(cron_validator)]
45def get_dialect(values: t.Any) -> str: 46 """Extracts dialect from a dict or pydantic obj, defaulting to the globally set dialect. 47 48 Python models allow users to instantiate pydantic models by hand. This is problematic 49 because the validators kick in with the SQLGLot dialect. To instantiate Pydantic Models used 50 in python models using the project default dialect, we set a class variable on the model 51 registry and use that here. 52 """ 53 54 from sqlmesh.core.model import model 55 56 dialect = (values if isinstance(values, dict) else values.data).get("dialect") 57 return model._dialect if dialect is None else dialect # type: ignore
Extracts dialect from a dict or pydantic obj, defaulting to the globally set dialect.
Python models allow users to instantiate pydantic models by hand. This is problematic because the validators kick in with the SQLGLot dialect. To instantiate Pydantic Models used in python models using the project default dialect, we set a class variable on the model registry and use that here.
68class PydanticModel(pydantic.BaseModel): 69 model_config = pydantic.ConfigDict( 70 # Even though Pydantic v2 kept support for json_encoders, the functionality has been 71 # crippled badly. Here we need to enumerate all different ways of how sqlglot expressions 72 # show up in pydantic models. 73 json_encoders={ 74 exp.Expr: _expression_encoder, 75 exp.DataType: _expression_encoder, 76 exp.Tuple: _expression_encoder, 77 AuditQueryTypes: _expression_encoder, # type: ignore 78 ModelQueryTypes: _expression_encoder, # type: ignore 79 tzinfo: lambda tz: tz.key, 80 }, 81 arbitrary_types_allowed=True, 82 extra="forbid", 83 protected_namespaces=(), 84 ) 85 86 _hash_func_mapping: t.ClassVar[t.Dict[t.Type[t.Any], t.Callable[[t.Any], int]]] = {} 87 88 def dict(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: 89 kwargs = {**DEFAULT_ARGS, **kwargs} 90 return super().model_dump(**kwargs) # type: ignore 91 92 def json( 93 self, 94 **kwargs: t.Any, 95 ) -> str: 96 kwargs = {**DEFAULT_ARGS, **kwargs} 97 # Pydantic v2 doesn't support arbitrary arguments for json.dump(). 98 if kwargs.pop("sort_keys", False): 99 return json.dumps(super().model_dump(mode="json", **kwargs), sort_keys=True) 100 101 return super().model_dump_json(**kwargs) 102 103 def copy(self: "Model", **kwargs: t.Any) -> "Model": 104 return super().model_copy(**kwargs) 105 106 @property 107 def fields_set(self: "Model") -> t.Set[str]: 108 return self.__pydantic_fields_set__ 109 110 @classmethod 111 def parse_obj(cls: t.Type["Model"], obj: t.Any) -> "Model": 112 return super().model_validate(obj) 113 114 @classmethod 115 def parse_raw(cls: t.Type["Model"], b: t.Union[str, bytes], **kwargs: t.Any) -> "Model": 116 return super().model_validate_json(b, **kwargs) 117 118 @classmethod 119 def missing_required_fields( 120 cls: t.Type["PydanticModel"], provided_fields: t.Set[str] 121 ) -> t.Set[str]: 122 return cls.required_fields() - provided_fields 123 124 @classmethod 125 def extra_fields(cls: t.Type["PydanticModel"], provided_fields: t.Set[str]) -> t.Set[str]: 126 return provided_fields - cls.all_fields() 127 128 @classmethod 129 def all_fields(cls: t.Type["PydanticModel"]) -> t.Set[str]: 130 return cls._fields() 131 132 @classmethod 133 def all_field_infos(cls: t.Type["PydanticModel"]) -> t.Dict[str, FieldInfo]: 134 return cls.model_fields 135 136 @classmethod 137 def required_fields(cls: t.Type["PydanticModel"]) -> t.Set[str]: 138 return cls._fields(lambda field: field.is_required()) 139 140 @classmethod 141 def _fields( 142 cls: t.Type["PydanticModel"], 143 predicate: t.Callable[[t.Any], bool] = lambda _: True, 144 ) -> t.Set[str]: 145 return { 146 field_info.alias if field_info.alias else field_name 147 for field_name, field_info in cls.all_field_infos().items() 148 if predicate(field_info) 149 } 150 151 def __eq__(self, other: t.Any) -> bool: 152 if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) < (2, 6): 153 if isinstance(other, pydantic.BaseModel): 154 return self.dict() == other.dict() 155 return self.dict() == other 156 return super().__eq__(other) 157 158 def __hash__(self) -> int: 159 if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) < (2, 6): 160 obj = {k: v for k, v in self.__dict__.items() if k in self.all_field_infos()} 161 return hash(self.__class__) + hash(tuple(obj.values())) 162 163 from pydantic._internal._model_construction import make_hash_func # type: ignore 164 165 if self.__class__ not in PydanticModel._hash_func_mapping: 166 PydanticModel._hash_func_mapping[self.__class__] = make_hash_func(self.__class__) 167 168 return PydanticModel._hash_func_mapping[self.__class__](self) 169 170 def __str__(self) -> str: 171 args = [] 172 173 for k, info in self.all_field_infos().items(): 174 v = getattr(self, k) 175 176 if type(v) != type(info.default) or v != info.default: 177 args.append(f"{k}: {v}") 178 179 return f"{self.__class__.__name__}<{', '.join(args)}>" 180 181 def __repr__(self) -> str: 182 return str(self)
!!! abstract "Usage Documentation" Models
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__[Signature][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-coreSchemaSerializerused to dump instances of the model. - __pydantic_validator__: The
pydantic-coreSchemaValidatorused to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
92 def json( 93 self, 94 **kwargs: t.Any, 95 ) -> str: 96 kwargs = {**DEFAULT_ARGS, **kwargs} 97 # Pydantic v2 doesn't support arbitrary arguments for json.dump(). 98 if kwargs.pop("sort_keys", False): 99 return json.dumps(super().model_dump(mode="json", **kwargs), sort_keys=True) 100 101 return super().model_dump_json(**kwargs)
Returns a copy of the model.
!!! warning "Deprecated"
This method is now deprecated; use model_copy instead.
If you need include or exclude, use:
python {test="skip" lint="skip"}
data = self.model_dump(include=include, exclude=exclude, round_trip=True)
data = {**data, **(update or {})}
copied = self.model_validate(data)
Arguments:
- include: Optional set or mapping specifying which fields to include in the copied model.
- exclude: Optional set or mapping specifying which fields to exclude in the copied model.
- update: Optional dictionary of field-value pairs to override field values in the copied model.
- deep: If True, the values of fields that are Pydantic models will be deep-copied.
Returns:
A copy of the model with included, excluded and updated fields as specified.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- model_fields
- model_computed_fields
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_post_init
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- parse_file
- from_orm
- construct
- schema
- schema_json
- validate
- update_forward_refs
199def validate_expression(expression: E, dialect: str) -> E: 200 # this normalizes and quotes identifiers in the given expression according the specified dialect 201 # it also sets expression.meta["dialect"] so that when we serialize for state, the expression is serialized in the correct dialect 202 return _get_field(expression, {"dialect": dialect}) # type: ignore
213def positive_int_validator(v: t.Any) -> int: 214 if isinstance(v, exp.Expr) and v.is_int: 215 v = int(v.name) 216 if not isinstance(v, int): 217 raise ValueError(f"Invalid num {v}. Value must be an integer value") 218 if v <= 0: 219 raise ValueError(f"Invalid num {v}. Value must be a positive integer") 220 return v
293def list_of_fields_or_star_validator( 294 v: t.Any, values: t.Any 295) -> t.Union[exp.Star, t.List[exp.Expr]]: 296 expressions = _get_fields(v, values) 297 if len(expressions) == 1 and isinstance(expressions[0], exp.Star): 298 return t.cast(exp.Star, expressions[0]) 299 return t.cast(t.List[exp.Expr], expressions)
302def cron_validator(v: t.Any) -> str: 303 if isinstance(v, exp.Expr): 304 v = v.name 305 306 from croniter import CroniterBadCronError, croniter 307 308 if not isinstance(v, str): 309 raise ValueError(f"Invalid cron expression '{v}'. Value must be a string.") 310 311 try: 312 croniter(v) 313 except CroniterBadCronError: 314 raise ValueError(f"Invalid cron expression '{v}'") 315 return v
318def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]: 319 concrete_types = set() 320 unpacked = t.get_origin(typehint) 321 if unpacked is None: 322 if type(typehint) == type(type): 323 return {typehint} 324 elif unpacked is t.Union: 325 for item in t.get_args(typehint): 326 if str(item).startswith("typing."): 327 concrete_types |= get_concrete_types_from_typehint(item) 328 else: 329 concrete_types.add(item) 330 else: 331 concrete_types.add(unpacked) 332 333 return concrete_types