Edit on GitHub

sqlmesh.core.model.seed

  1from __future__ import annotations
  2
  3import logging
  4import typing as t
  5import zlib
  6from io import StringIO
  7from pathlib import Path
  8
  9from sqlglot import exp
 10from sqlglot.dialects.dialect import UNESCAPED_SEQUENCES
 11from sqlglot.helper import seq_get
 12from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
 13
 14from sqlmesh.core.model.common import parse_bool
 15from sqlmesh.utils.pandas import columns_to_types_from_df
 16from sqlmesh.utils.pydantic import PydanticModel, field_validator
 17
 18if t.TYPE_CHECKING:
 19    import pandas as pd
 20
 21logger = logging.getLogger(__name__)
 22
 23NaHashables = t.List[t.Union[int, str, bool, t.Literal[None]]]
 24NaValues = t.Union[NaHashables, t.Dict[str, NaHashables]]
 25
 26
 27class CsvSettings(PydanticModel):
 28    """Settings for CSV seeds."""
 29
 30    delimiter: t.Optional[str] = None
 31    quotechar: t.Optional[str] = None
 32    doublequote: t.Optional[bool] = None
 33    escapechar: t.Optional[str] = None
 34    skipinitialspace: t.Optional[bool] = None
 35    lineterminator: t.Optional[str] = None
 36    encoding: t.Optional[str] = None
 37    na_values: t.Optional[NaValues] = None
 38    keep_default_na: t.Optional[bool] = None
 39
 40    @field_validator("doublequote", "skipinitialspace", "keep_default_na", mode="before")
 41    @classmethod
 42    def _bool_validator(cls, v: t.Any) -> t.Optional[bool]:
 43        if v is None:
 44            return v
 45        return parse_bool(v)
 46
 47    @field_validator(
 48        "delimiter", "quotechar", "escapechar", "lineterminator", "encoding", mode="before"
 49    )
 50    @classmethod
 51    def _str_validator(cls, v: t.Any) -> t.Optional[str]:
 52        if v is None or not isinstance(v, exp.Expr):
 53            return v
 54
 55        # SQLGlot parses escape sequences like \t as \\t for dialects that don't treat \ as
 56        # an escape character, so we map them back to the corresponding escaped sequence
 57        v = v.this
 58        return UNESCAPED_SEQUENCES.get(v, v)
 59
 60    @field_validator("na_values", mode="before")
 61    @classmethod
 62    def _na_values_validator(cls, v: t.Any) -> t.Optional[NaValues]:
 63        if v is None or not isinstance(v, exp.Expr):
 64            return v
 65
 66        try:
 67            if isinstance(v, exp.Paren) or not isinstance(v, (exp.Tuple, exp.Array)):
 68                v = exp.Tuple(expressions=[v.unnest()])
 69
 70            expressions = v.expressions
 71            if isinstance(seq_get(expressions, 0), (exp.PropertyEQ, exp.EQ)):
 72                return {
 73                    e.left.name: [
 74                        rhs_val.to_py()
 75                        for rhs_val in (
 76                            [e.right.unnest()]
 77                            if isinstance(e.right, exp.Paren)
 78                            else e.right.expressions
 79                        )
 80                    ]
 81                    for e in expressions
 82                }
 83
 84            return [e.to_py() for e in expressions]
 85        except ValueError as e:
 86            logger.warning(f"Failed to coerce na_values '{v}', proceeding with defaults. {str(e)}")
 87
 88        return None
 89
 90
 91class CsvSeedReader:
 92    def __init__(self, content: str, dialect: str, settings: CsvSettings):
 93        self.content = content
 94        self.dialect = dialect
 95        self.settings = settings
 96        self._df: t.Optional[pd.DataFrame] = None
 97
 98    @property
 99    def columns_to_types(self) -> t.Dict[str, exp.DataType]:
100        return columns_to_types_from_df(self._get_df())
101
102    @property
103    def column_hashes(self) -> t.Dict[str, str]:
104        df = self._get_df()
105        return {
106            column_name: str(zlib.crc32(df[column_name].to_json().encode("utf-8")))
107            for column_name in df.columns
108        }
109
110    def read(self, batch_size: t.Optional[int] = None) -> t.Generator[pd.DataFrame, None, None]:
111        df = self._get_df()
112
113        batch_size = batch_size or df.size
114        batch_start = 0
115        while batch_start < df.shape[0]:
116            yield df.iloc[batch_start : batch_start + batch_size, :]
117            batch_start += batch_size
118
119    def _get_df(self) -> pd.DataFrame:
120        import pandas as pd
121
122        if self._df is None:
123            self._df = pd.read_csv(
124                StringIO(self.content),
125                index_col=False,
126                on_bad_lines="error",
127                low_memory=False,
128                **{k: v for k, v in self.settings.dict().items() if v is not None},
129            )
130            self._df = self._df.rename(
131                columns={
132                    col: normalize_identifiers(col, dialect=self.dialect).name
133                    for col in self._df.columns
134                },
135            )
136
137        return self._df
138
139
140class Seed(PydanticModel):
141    """Represents content of a seed.
142
143    Presently only CSV format is supported.
144    """
145
146    content: str
147
148    def reader(self, dialect: str = "", settings: t.Optional[CsvSettings] = None) -> CsvSeedReader:
149        return CsvSeedReader(self.content, dialect, settings or CsvSettings())
150
151
152def create_seed(path: str | Path) -> Seed:
153    with open(Path(path), "r", encoding="utf-8") as fd:
154        return Seed(content=fd.read())
logger = <Logger sqlmesh.core.model.seed (WARNING)>
NaHashables = typing.List[typing.Union[int, str, bool, typing.Literal[None]]]
class CsvSettings(sqlmesh.utils.pydantic.PydanticModel):
28class CsvSettings(PydanticModel):
29    """Settings for CSV seeds."""
30
31    delimiter: t.Optional[str] = None
32    quotechar: t.Optional[str] = None
33    doublequote: t.Optional[bool] = None
34    escapechar: t.Optional[str] = None
35    skipinitialspace: t.Optional[bool] = None
36    lineterminator: t.Optional[str] = None
37    encoding: t.Optional[str] = None
38    na_values: t.Optional[NaValues] = None
39    keep_default_na: t.Optional[bool] = None
40
41    @field_validator("doublequote", "skipinitialspace", "keep_default_na", mode="before")
42    @classmethod
43    def _bool_validator(cls, v: t.Any) -> t.Optional[bool]:
44        if v is None:
45            return v
46        return parse_bool(v)
47
48    @field_validator(
49        "delimiter", "quotechar", "escapechar", "lineterminator", "encoding", mode="before"
50    )
51    @classmethod
52    def _str_validator(cls, v: t.Any) -> t.Optional[str]:
53        if v is None or not isinstance(v, exp.Expr):
54            return v
55
56        # SQLGlot parses escape sequences like \t as \\t for dialects that don't treat \ as
57        # an escape character, so we map them back to the corresponding escaped sequence
58        v = v.this
59        return UNESCAPED_SEQUENCES.get(v, v)
60
61    @field_validator("na_values", mode="before")
62    @classmethod
63    def _na_values_validator(cls, v: t.Any) -> t.Optional[NaValues]:
64        if v is None or not isinstance(v, exp.Expr):
65            return v
66
67        try:
68            if isinstance(v, exp.Paren) or not isinstance(v, (exp.Tuple, exp.Array)):
69                v = exp.Tuple(expressions=[v.unnest()])
70
71            expressions = v.expressions
72            if isinstance(seq_get(expressions, 0), (exp.PropertyEQ, exp.EQ)):
73                return {
74                    e.left.name: [
75                        rhs_val.to_py()
76                        for rhs_val in (
77                            [e.right.unnest()]
78                            if isinstance(e.right, exp.Paren)
79                            else e.right.expressions
80                        )
81                    ]
82                    for e in expressions
83                }
84
85            return [e.to_py() for e in expressions]
86        except ValueError as e:
87            logger.warning(f"Failed to coerce na_values '{v}', proceeding with defaults. {str(e)}")
88
89        return None

Settings for CSV seeds.

delimiter: Optional[str]
quotechar: Optional[str]
doublequote: Optional[bool]
escapechar: Optional[str]
skipinitialspace: Optional[bool]
lineterminator: Optional[str]
encoding: Optional[str]
na_values: Union[List[Union[int, str, bool, Literal[None]]], Dict[str, List[Union[int, str, bool, Literal[None]]]], NoneType]
keep_default_na: Optional[bool]
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
class CsvSeedReader:
 92class CsvSeedReader:
 93    def __init__(self, content: str, dialect: str, settings: CsvSettings):
 94        self.content = content
 95        self.dialect = dialect
 96        self.settings = settings
 97        self._df: t.Optional[pd.DataFrame] = None
 98
 99    @property
100    def columns_to_types(self) -> t.Dict[str, exp.DataType]:
101        return columns_to_types_from_df(self._get_df())
102
103    @property
104    def column_hashes(self) -> t.Dict[str, str]:
105        df = self._get_df()
106        return {
107            column_name: str(zlib.crc32(df[column_name].to_json().encode("utf-8")))
108            for column_name in df.columns
109        }
110
111    def read(self, batch_size: t.Optional[int] = None) -> t.Generator[pd.DataFrame, None, None]:
112        df = self._get_df()
113
114        batch_size = batch_size or df.size
115        batch_start = 0
116        while batch_start < df.shape[0]:
117            yield df.iloc[batch_start : batch_start + batch_size, :]
118            batch_start += batch_size
119
120    def _get_df(self) -> pd.DataFrame:
121        import pandas as pd
122
123        if self._df is None:
124            self._df = pd.read_csv(
125                StringIO(self.content),
126                index_col=False,
127                on_bad_lines="error",
128                low_memory=False,
129                **{k: v for k, v in self.settings.dict().items() if v is not None},
130            )
131            self._df = self._df.rename(
132                columns={
133                    col: normalize_identifiers(col, dialect=self.dialect).name
134                    for col in self._df.columns
135                },
136            )
137
138        return self._df
CsvSeedReader( content: str, dialect: str, settings: CsvSettings)
93    def __init__(self, content: str, dialect: str, settings: CsvSettings):
94        self.content = content
95        self.dialect = dialect
96        self.settings = settings
97        self._df: t.Optional[pd.DataFrame] = None
content
dialect
settings
columns_to_types: Dict[str, sqlglot.expressions.datatypes.DataType]
 99    @property
100    def columns_to_types(self) -> t.Dict[str, exp.DataType]:
101        return columns_to_types_from_df(self._get_df())
column_hashes: Dict[str, str]
103    @property
104    def column_hashes(self) -> t.Dict[str, str]:
105        df = self._get_df()
106        return {
107            column_name: str(zlib.crc32(df[column_name].to_json().encode("utf-8")))
108            for column_name in df.columns
109        }
def read( self, batch_size: Optional[int] = None) -> Generator[pandas.core.frame.DataFrame, NoneType, NoneType]:
111    def read(self, batch_size: t.Optional[int] = None) -> t.Generator[pd.DataFrame, None, None]:
112        df = self._get_df()
113
114        batch_size = batch_size or df.size
115        batch_start = 0
116        while batch_start < df.shape[0]:
117            yield df.iloc[batch_start : batch_start + batch_size, :]
118            batch_start += batch_size
class Seed(sqlmesh.utils.pydantic.PydanticModel):
141class Seed(PydanticModel):
142    """Represents content of a seed.
143
144    Presently only CSV format is supported.
145    """
146
147    content: str
148
149    def reader(self, dialect: str = "", settings: t.Optional[CsvSettings] = None) -> CsvSeedReader:
150        return CsvSeedReader(self.content, dialect, settings or CsvSettings())

Represents content of a seed.

Presently only CSV format is supported.

content: str
def reader( self, dialect: str = '', settings: Optional[CsvSettings] = None) -> CsvSeedReader:
149    def reader(self, dialect: str = "", settings: t.Optional[CsvSettings] = None) -> CsvSeedReader:
150        return CsvSeedReader(self.content, dialect, settings or CsvSettings())
model_config = {'json_encoders': {<class 'sqlglot.expressions.core.Expr'>: <function _expression_encoder>, <class 'sqlglot.expressions.datatypes.DataType'>: <function _expression_encoder>, <class 'sqlglot.expressions.query.Tuple'>: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery]: <function _expression_encoder>, typing.Union[sqlglot.expressions.query.Query, sqlmesh.core.dialect.JinjaQuery, sqlmesh.core.dialect.MacroFunc]: <function _expression_encoder>, <class 'datetime.tzinfo'>: <function PydanticModel.<lambda>>}, 'arbitrary_types_allowed': True, 'extra': 'forbid', 'protected_namespaces': ()}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_fields
model_computed_fields
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
parse_file
from_orm
construct
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
dict
json
copy
fields_set
parse_obj
parse_raw
missing_required_fields
extra_fields
all_fields
all_field_infos
required_fields
def create_seed(path: str | pathlib.Path) -> Seed:
153def create_seed(path: str | Path) -> Seed:
154    with open(Path(path), "r", encoding="utf-8") as fd:
155        return Seed(content=fd.read())