Edit on GitHub

sqlmesh.core.state_sync.cache

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlmesh.core.model import SeedModel
  6from sqlmesh.core.snapshot import (
  7    Snapshot,
  8    SnapshotId,
  9    SnapshotIdLike,
 10    SnapshotIdAndVersionLike,
 11    SnapshotInfoLike,
 12)
 13from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals
 14from sqlmesh.core.state_sync.base import DelegatingStateSync, StateSync
 15from sqlmesh.core.state_sync.common import ExpiredBatchRange
 16from sqlmesh.utils.date import TimeLike, now_timestamp
 17
 18
 19class CachingStateSync(DelegatingStateSync):
 20    """In memory cache for snapshots that implements the state sync api.
 21
 22    Args:
 23        state_sync: The base state sync.
 24        ttl: The number of seconds a snapshot should be cached.
 25    """
 26
 27    def __init__(self, state_sync: StateSync, ttl: int = 120):
 28        super().__init__(state_sync)
 29        # The cache can contain a snapshot or False or None.
 30        # False means that the snapshot does not exist in the state sync but has been requested before
 31        # None means that the snapshot has not been requested.
 32        self.snapshot_cache: t.Dict[
 33            SnapshotId, t.Tuple[t.Optional[Snapshot | t.Literal[False]], int]
 34        ] = {}
 35
 36        self.ttl = ttl
 37
 38    def _from_cache(
 39        self, snapshot_id: SnapshotId, now: int
 40    ) -> t.Optional[Snapshot | t.Literal[False]]:
 41        snapshot: t.Optional[Snapshot | t.Literal[False]] = None
 42        snapshot_expiration = self.snapshot_cache.get(snapshot_id)
 43
 44        if snapshot_expiration and snapshot_expiration[1] >= now:
 45            snapshot = snapshot_expiration[0]
 46
 47        return snapshot
 48
 49    def get_snapshots(
 50        self, snapshot_ids: t.Iterable[SnapshotIdLike]
 51    ) -> t.Dict[SnapshotId, Snapshot]:
 52        existing = {}
 53        missing = set()
 54        now = now_timestamp()
 55        expire_at = now + self.ttl * 1000
 56
 57        for s in snapshot_ids:
 58            snapshot_id = s.snapshot_id
 59            snapshot = self._from_cache(snapshot_id, now)
 60
 61            if snapshot is None:
 62                self.snapshot_cache[snapshot_id] = (False, expire_at)
 63                missing.add(snapshot_id)
 64            elif snapshot:
 65                existing[snapshot_id] = snapshot
 66
 67        if missing:
 68            existing.update(self.state_sync.get_snapshots(missing))
 69
 70        for snapshot_id, snapshot in existing.items():
 71            cached = self._from_cache(snapshot_id, now)
 72            if cached and (not isinstance(cached.node, SeedModel) or cached.node.is_hydrated):
 73                continue
 74            self.snapshot_cache[snapshot_id] = (snapshot, expire_at)
 75
 76        return existing
 77
 78    def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
 79        existing = set()
 80        missing = set()
 81        now = now_timestamp()
 82
 83        for s in snapshot_ids:
 84            snapshot_id = s.snapshot_id
 85            snapshot = self._from_cache(snapshot_id, now)
 86            if snapshot:
 87                existing.add(snapshot_id)
 88            elif snapshot is None:
 89                missing.add(snapshot_id)
 90
 91        if missing:
 92            existing.update(self.state_sync.snapshots_exist(missing))
 93
 94        return existing
 95
 96    def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None:
 97        snapshots = tuple(snapshots)
 98
 99        for snapshot in snapshots:
100            self.snapshot_cache.pop(snapshot.snapshot_id, None)
101
102        self.state_sync.push_snapshots(snapshots)
103
104    def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None:
105        snapshot_ids = tuple(snapshot_ids)
106
107        for s in snapshot_ids:
108            self.snapshot_cache.pop(s.snapshot_id, None)
109        self.state_sync.delete_snapshots(snapshot_ids)
110
111    def delete_expired_snapshots(
112        self,
113        batch_range: ExpiredBatchRange,
114        ignore_ttl: bool = False,
115        current_ts: t.Optional[int] = None,
116    ) -> None:
117        self.snapshot_cache.clear()
118        self.state_sync.delete_expired_snapshots(
119            batch_range=batch_range,
120            ignore_ttl=ignore_ttl,
121            current_ts=current_ts,
122        )
123
124    def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None:
125        for snapshot_intervals in snapshots_intervals:
126            if snapshot_intervals.snapshot_id:
127                self.snapshot_cache.pop(snapshot_intervals.snapshot_id, None)
128            else:
129                # Evict all snapshots that share the same name
130                self.snapshot_cache = {
131                    snapshot_id: value
132                    for snapshot_id, value in self.snapshot_cache.items()
133                    if snapshot_id.name != snapshot_intervals.name
134                }
135        self.state_sync.add_snapshots_intervals(snapshots_intervals)
136
137    def remove_intervals(
138        self,
139        snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]],
140        remove_shared_versions: bool = False,
141    ) -> None:
142        for s, _ in snapshot_intervals:
143            self.snapshot_cache.pop(s.snapshot_id, None)
144        self.state_sync.remove_intervals(snapshot_intervals, remove_shared_versions)
145
146    def unpause_snapshots(
147        self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
148    ) -> None:
149        self.snapshot_cache.clear()
150        self.state_sync.unpause_snapshots(snapshots, unpaused_dt)
151
152    def clear_cache(self) -> None:
153        self.snapshot_cache.clear()
class CachingStateSync(sqlmesh.core.state_sync.base.DelegatingStateSync):
 20class CachingStateSync(DelegatingStateSync):
 21    """In memory cache for snapshots that implements the state sync api.
 22
 23    Args:
 24        state_sync: The base state sync.
 25        ttl: The number of seconds a snapshot should be cached.
 26    """
 27
 28    def __init__(self, state_sync: StateSync, ttl: int = 120):
 29        super().__init__(state_sync)
 30        # The cache can contain a snapshot or False or None.
 31        # False means that the snapshot does not exist in the state sync but has been requested before
 32        # None means that the snapshot has not been requested.
 33        self.snapshot_cache: t.Dict[
 34            SnapshotId, t.Tuple[t.Optional[Snapshot | t.Literal[False]], int]
 35        ] = {}
 36
 37        self.ttl = ttl
 38
 39    def _from_cache(
 40        self, snapshot_id: SnapshotId, now: int
 41    ) -> t.Optional[Snapshot | t.Literal[False]]:
 42        snapshot: t.Optional[Snapshot | t.Literal[False]] = None
 43        snapshot_expiration = self.snapshot_cache.get(snapshot_id)
 44
 45        if snapshot_expiration and snapshot_expiration[1] >= now:
 46            snapshot = snapshot_expiration[0]
 47
 48        return snapshot
 49
 50    def get_snapshots(
 51        self, snapshot_ids: t.Iterable[SnapshotIdLike]
 52    ) -> t.Dict[SnapshotId, Snapshot]:
 53        existing = {}
 54        missing = set()
 55        now = now_timestamp()
 56        expire_at = now + self.ttl * 1000
 57
 58        for s in snapshot_ids:
 59            snapshot_id = s.snapshot_id
 60            snapshot = self._from_cache(snapshot_id, now)
 61
 62            if snapshot is None:
 63                self.snapshot_cache[snapshot_id] = (False, expire_at)
 64                missing.add(snapshot_id)
 65            elif snapshot:
 66                existing[snapshot_id] = snapshot
 67
 68        if missing:
 69            existing.update(self.state_sync.get_snapshots(missing))
 70
 71        for snapshot_id, snapshot in existing.items():
 72            cached = self._from_cache(snapshot_id, now)
 73            if cached and (not isinstance(cached.node, SeedModel) or cached.node.is_hydrated):
 74                continue
 75            self.snapshot_cache[snapshot_id] = (snapshot, expire_at)
 76
 77        return existing
 78
 79    def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
 80        existing = set()
 81        missing = set()
 82        now = now_timestamp()
 83
 84        for s in snapshot_ids:
 85            snapshot_id = s.snapshot_id
 86            snapshot = self._from_cache(snapshot_id, now)
 87            if snapshot:
 88                existing.add(snapshot_id)
 89            elif snapshot is None:
 90                missing.add(snapshot_id)
 91
 92        if missing:
 93            existing.update(self.state_sync.snapshots_exist(missing))
 94
 95        return existing
 96
 97    def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None:
 98        snapshots = tuple(snapshots)
 99
100        for snapshot in snapshots:
101            self.snapshot_cache.pop(snapshot.snapshot_id, None)
102
103        self.state_sync.push_snapshots(snapshots)
104
105    def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None:
106        snapshot_ids = tuple(snapshot_ids)
107
108        for s in snapshot_ids:
109            self.snapshot_cache.pop(s.snapshot_id, None)
110        self.state_sync.delete_snapshots(snapshot_ids)
111
112    def delete_expired_snapshots(
113        self,
114        batch_range: ExpiredBatchRange,
115        ignore_ttl: bool = False,
116        current_ts: t.Optional[int] = None,
117    ) -> None:
118        self.snapshot_cache.clear()
119        self.state_sync.delete_expired_snapshots(
120            batch_range=batch_range,
121            ignore_ttl=ignore_ttl,
122            current_ts=current_ts,
123        )
124
125    def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None:
126        for snapshot_intervals in snapshots_intervals:
127            if snapshot_intervals.snapshot_id:
128                self.snapshot_cache.pop(snapshot_intervals.snapshot_id, None)
129            else:
130                # Evict all snapshots that share the same name
131                self.snapshot_cache = {
132                    snapshot_id: value
133                    for snapshot_id, value in self.snapshot_cache.items()
134                    if snapshot_id.name != snapshot_intervals.name
135                }
136        self.state_sync.add_snapshots_intervals(snapshots_intervals)
137
138    def remove_intervals(
139        self,
140        snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]],
141        remove_shared_versions: bool = False,
142    ) -> None:
143        for s, _ in snapshot_intervals:
144            self.snapshot_cache.pop(s.snapshot_id, None)
145        self.state_sync.remove_intervals(snapshot_intervals, remove_shared_versions)
146
147    def unpause_snapshots(
148        self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
149    ) -> None:
150        self.snapshot_cache.clear()
151        self.state_sync.unpause_snapshots(snapshots, unpaused_dt)
152
153    def clear_cache(self) -> None:
154        self.snapshot_cache.clear()

In memory cache for snapshots that implements the state sync api.

Arguments:
  • state_sync: The base state sync.
  • ttl: The number of seconds a snapshot should be cached.
CachingStateSync(state_sync: sqlmesh.core.state_sync.base.StateSync, ttl: int = 120)
28    def __init__(self, state_sync: StateSync, ttl: int = 120):
29        super().__init__(state_sync)
30        # The cache can contain a snapshot or False or None.
31        # False means that the snapshot does not exist in the state sync but has been requested before
32        # None means that the snapshot has not been requested.
33        self.snapshot_cache: t.Dict[
34            SnapshotId, t.Tuple[t.Optional[Snapshot | t.Literal[False]], int]
35        ] = {}
36
37        self.ttl = ttl
snapshot_cache: Dict[sqlmesh.core.snapshot.definition.SnapshotId, Tuple[Union[sqlmesh.core.snapshot.definition.Snapshot, Literal[False], NoneType], int]]
ttl
50    def get_snapshots(
51        self, snapshot_ids: t.Iterable[SnapshotIdLike]
52    ) -> t.Dict[SnapshotId, Snapshot]:
53        existing = {}
54        missing = set()
55        now = now_timestamp()
56        expire_at = now + self.ttl * 1000
57
58        for s in snapshot_ids:
59            snapshot_id = s.snapshot_id
60            snapshot = self._from_cache(snapshot_id, now)
61
62            if snapshot is None:
63                self.snapshot_cache[snapshot_id] = (False, expire_at)
64                missing.add(snapshot_id)
65            elif snapshot:
66                existing[snapshot_id] = snapshot
67
68        if missing:
69            existing.update(self.state_sync.get_snapshots(missing))
70
71        for snapshot_id, snapshot in existing.items():
72            cached = self._from_cache(snapshot_id, now)
73            if cached and (not isinstance(cached.node, SeedModel) or cached.node.is_hydrated):
74                continue
75            self.snapshot_cache[snapshot_id] = (snapshot, expire_at)
76
77        return existing

Bulk fetch snapshots given the corresponding snapshot ids.

Arguments:
  • snapshot_ids: Iterable of snapshot ids to get.
Returns:

A dictionary of snapshot ids to snapshots for ones that could be found.

79    def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
80        existing = set()
81        missing = set()
82        now = now_timestamp()
83
84        for s in snapshot_ids:
85            snapshot_id = s.snapshot_id
86            snapshot = self._from_cache(snapshot_id, now)
87            if snapshot:
88                existing.add(snapshot_id)
89            elif snapshot is None:
90                missing.add(snapshot_id)
91
92        if missing:
93            existing.update(self.state_sync.snapshots_exist(missing))
94
95        return existing

Checks if multiple snapshots exist in the state sync.

Arguments:
  • snapshot_ids: Iterable of snapshot ids to bulk check.
Returns:

A set of all the existing snapshot ids.

def push_snapshots( self, snapshots: Iterable[sqlmesh.core.snapshot.definition.Snapshot]) -> None:
 97    def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None:
 98        snapshots = tuple(snapshots)
 99
100        for snapshot in snapshots:
101            self.snapshot_cache.pop(snapshot.snapshot_id, None)
102
103        self.state_sync.push_snapshots(snapshots)

Push snapshots into the state sync.

This method only allows for pushing new snapshots. If existing snapshots are found, this method should raise an error.

Raises:
  • SQLMeshError when existing snapshots are pushed.
Arguments:
  • snapshots: A list of snapshots to save in the state sync.
105    def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None:
106        snapshot_ids = tuple(snapshot_ids)
107
108        for s in snapshot_ids:
109            self.snapshot_cache.pop(s.snapshot_id, None)
110        self.state_sync.delete_snapshots(snapshot_ids)

Delete snapshots from the state sync.

Arguments:
  • snapshot_ids: A list of snapshot like objects to delete.
def delete_expired_snapshots( self, batch_range: sqlmesh.core.state_sync.common.ExpiredBatchRange, ignore_ttl: bool = False, current_ts: Optional[int] = None) -> None:
112    def delete_expired_snapshots(
113        self,
114        batch_range: ExpiredBatchRange,
115        ignore_ttl: bool = False,
116        current_ts: t.Optional[int] = None,
117    ) -> None:
118        self.snapshot_cache.clear()
119        self.state_sync.delete_expired_snapshots(
120            batch_range=batch_range,
121            ignore_ttl=ignore_ttl,
122            current_ts=current_ts,
123        )

Removes expired snapshots.

Expired snapshots are snapshots that have exceeded their time-to-live and are no longer in use within an environment.

Arguments:
  • batch_range: The range of snapshots to delete in this batch.
  • ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting all snapshots that are not referenced in any environment
  • current_ts: Timestamp used to evaluate expiration.
def add_snapshots_intervals( self, snapshots_intervals: Sequence[sqlmesh.core.snapshot.definition.SnapshotIntervals]) -> None:
125    def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None:
126        for snapshot_intervals in snapshots_intervals:
127            if snapshot_intervals.snapshot_id:
128                self.snapshot_cache.pop(snapshot_intervals.snapshot_id, None)
129            else:
130                # Evict all snapshots that share the same name
131                self.snapshot_cache = {
132                    snapshot_id: value
133                    for snapshot_id, value in self.snapshot_cache.items()
134                    if snapshot_id.name != snapshot_intervals.name
135                }
136        self.state_sync.add_snapshots_intervals(snapshots_intervals)

Add snapshot intervals to state

Arguments:
  • snapshots_intervals: The intervals to add.
def remove_intervals( self, snapshot_intervals: Sequence[Tuple[Union[sqlmesh.core.snapshot.definition.SnapshotIdAndVersion, sqlmesh.core.snapshot.definition.SnapshotTableInfo, sqlmesh.core.snapshot.definition.Snapshot], Tuple[int, int]]], remove_shared_versions: bool = False) -> None:
138    def remove_intervals(
139        self,
140        snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]],
141        remove_shared_versions: bool = False,
142    ) -> None:
143        for s, _ in snapshot_intervals:
144            self.snapshot_cache.pop(s.snapshot_id, None)
145        self.state_sync.remove_intervals(snapshot_intervals, remove_shared_versions)

Remove an interval from a list of snapshots and sync it to the store.

Because multiple snapshots can be pointing to the same version or physical table, this method can also grab all snapshots tied to the passed in version.

Arguments:
  • snapshot_intervals: The snapshot intervals to remove.
  • remove_shared_versions: Whether to remove intervals for snapshots that share the same version with the target snapshots.
def unpause_snapshots( self, snapshots: Collection[Union[sqlmesh.core.snapshot.definition.SnapshotTableInfo, sqlmesh.core.snapshot.definition.Snapshot]], unpaused_dt: Union[datetime.date, datetime.datetime, str, int, float]) -> None:
147    def unpause_snapshots(
148        self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
149    ) -> None:
150        self.snapshot_cache.clear()
151        self.state_sync.unpause_snapshots(snapshots, unpaused_dt)

Unpauses target snapshots.

Unpaused snapshots are scheduled for evaluation on a recurring basis. Once unpaused a snapshot can't be paused again.

Arguments:
  • snapshots: Target snapshots.
  • unpaused_dt: The datetime object which indicates when target snapshots were unpaused.
def clear_cache(self) -> None:
153    def clear_cache(self) -> None:
154        self.snapshot_cache.clear()