Edit on GitHub

sqlmesh.core.snapshot.execution_tracker

 1from __future__ import annotations
 2
 3import typing as t
 4from contextlib import contextmanager
 5from threading import local
 6from dataclasses import dataclass, field
 7from sqlmesh.core.snapshot import SnapshotIdBatch
 8
 9
10@dataclass
11class QueryExecutionStats:
12    snapshot_id_batch: SnapshotIdBatch
13    total_rows_processed: t.Optional[int] = None
14    total_bytes_processed: t.Optional[int] = None
15
16
17@dataclass
18class QueryExecutionContext:
19    """
20    Container for tracking rows processed or other execution information during snapshot evaluation.
21
22    It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation.
23
24    Attributes:
25        snapshot_id_batch: Identifier linking this context to a specific snapshot evaluation
26        stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation
27    """
28
29    snapshot_id_batch: SnapshotIdBatch
30    stats: QueryExecutionStats = field(init=False)
31
32    def __post_init__(self) -> None:
33        self.stats = QueryExecutionStats(snapshot_id_batch=self.snapshot_id_batch)
34
35    def add_execution(
36        self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
37    ) -> None:
38        if row_count is not None and row_count >= 0:
39            if self.stats.total_rows_processed is None:
40                self.stats.total_rows_processed = row_count
41            else:
42                self.stats.total_rows_processed += row_count
43
44            # conditional on row_count because we should only count bytes corresponding to
45            # DML actions whose rows were captured
46            if bytes_processed is not None:
47                if self.stats.total_bytes_processed is None:
48                    self.stats.total_bytes_processed = bytes_processed
49                else:
50                    self.stats.total_bytes_processed += bytes_processed
51
52    def get_execution_stats(self) -> QueryExecutionStats:
53        return self.stats
54
55
56class QueryExecutionTracker:
57    """Thread-local context manager for snapshot execution statistics, such as rows processed."""
58
59    def __init__(self) -> None:
60        self._thread_local = local()
61        self._contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {}
62
63    def get_execution_context(
64        self, snapshot_id_batch: SnapshotIdBatch
65    ) -> t.Optional[QueryExecutionContext]:
66        return self._contexts.get(snapshot_id_batch)
67
68    def is_tracking(self) -> bool:
69        return getattr(self._thread_local, "context", None) is not None
70
71    @contextmanager
72    def track_execution(
73        self, snapshot_id_batch: SnapshotIdBatch
74    ) -> t.Iterator[t.Optional[QueryExecutionContext]]:
75        """Context manager for tracking snapshot execution statistics such as row counts and bytes processed."""
76        context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch)
77        self._thread_local.context = context
78        self._contexts[snapshot_id_batch] = context
79
80        try:
81            yield context
82        finally:
83            self._thread_local.context = None
84
85    def record_execution(
86        self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
87    ) -> None:
88        context = getattr(self._thread_local, "context", None)
89        if context is not None:
90            context.add_execution(sql, row_count, bytes_processed)
91
92    def get_execution_stats(
93        self, snapshot_id_batch: SnapshotIdBatch
94    ) -> t.Optional[QueryExecutionStats]:
95        context = self._contexts.get(snapshot_id_batch)
96        self._contexts.pop(snapshot_id_batch, None)
97        return context.get_execution_stats() if context else None
@dataclass
class QueryExecutionStats:
11@dataclass
12class QueryExecutionStats:
13    snapshot_id_batch: SnapshotIdBatch
14    total_rows_processed: t.Optional[int] = None
15    total_bytes_processed: t.Optional[int] = None
QueryExecutionStats( snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch, total_rows_processed: Optional[int] = None, total_bytes_processed: Optional[int] = None)
total_rows_processed: Optional[int] = None
total_bytes_processed: Optional[int] = None
@dataclass
class QueryExecutionContext:
18@dataclass
19class QueryExecutionContext:
20    """
21    Container for tracking rows processed or other execution information during snapshot evaluation.
22
23    It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation.
24
25    Attributes:
26        snapshot_id_batch: Identifier linking this context to a specific snapshot evaluation
27        stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation
28    """
29
30    snapshot_id_batch: SnapshotIdBatch
31    stats: QueryExecutionStats = field(init=False)
32
33    def __post_init__(self) -> None:
34        self.stats = QueryExecutionStats(snapshot_id_batch=self.snapshot_id_batch)
35
36    def add_execution(
37        self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
38    ) -> None:
39        if row_count is not None and row_count >= 0:
40            if self.stats.total_rows_processed is None:
41                self.stats.total_rows_processed = row_count
42            else:
43                self.stats.total_rows_processed += row_count
44
45            # conditional on row_count because we should only count bytes corresponding to
46            # DML actions whose rows were captured
47            if bytes_processed is not None:
48                if self.stats.total_bytes_processed is None:
49                    self.stats.total_bytes_processed = bytes_processed
50                else:
51                    self.stats.total_bytes_processed += bytes_processed
52
53    def get_execution_stats(self) -> QueryExecutionStats:
54        return self.stats

Container for tracking rows processed or other execution information during snapshot evaluation.

It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation.

Attributes:
  • snapshot_id_batch: Identifier linking this context to a specific snapshot evaluation
  • stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation
QueryExecutionContext(snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch)
def add_execution( self, sql: str, row_count: Optional[int], bytes_processed: Optional[int]) -> None:
36    def add_execution(
37        self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
38    ) -> None:
39        if row_count is not None and row_count >= 0:
40            if self.stats.total_rows_processed is None:
41                self.stats.total_rows_processed = row_count
42            else:
43                self.stats.total_rows_processed += row_count
44
45            # conditional on row_count because we should only count bytes corresponding to
46            # DML actions whose rows were captured
47            if bytes_processed is not None:
48                if self.stats.total_bytes_processed is None:
49                    self.stats.total_bytes_processed = bytes_processed
50                else:
51                    self.stats.total_bytes_processed += bytes_processed
def get_execution_stats(self) -> QueryExecutionStats:
53    def get_execution_stats(self) -> QueryExecutionStats:
54        return self.stats
class QueryExecutionTracker:
57class QueryExecutionTracker:
58    """Thread-local context manager for snapshot execution statistics, such as rows processed."""
59
60    def __init__(self) -> None:
61        self._thread_local = local()
62        self._contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {}
63
64    def get_execution_context(
65        self, snapshot_id_batch: SnapshotIdBatch
66    ) -> t.Optional[QueryExecutionContext]:
67        return self._contexts.get(snapshot_id_batch)
68
69    def is_tracking(self) -> bool:
70        return getattr(self._thread_local, "context", None) is not None
71
72    @contextmanager
73    def track_execution(
74        self, snapshot_id_batch: SnapshotIdBatch
75    ) -> t.Iterator[t.Optional[QueryExecutionContext]]:
76        """Context manager for tracking snapshot execution statistics such as row counts and bytes processed."""
77        context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch)
78        self._thread_local.context = context
79        self._contexts[snapshot_id_batch] = context
80
81        try:
82            yield context
83        finally:
84            self._thread_local.context = None
85
86    def record_execution(
87        self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
88    ) -> None:
89        context = getattr(self._thread_local, "context", None)
90        if context is not None:
91            context.add_execution(sql, row_count, bytes_processed)
92
93    def get_execution_stats(
94        self, snapshot_id_batch: SnapshotIdBatch
95    ) -> t.Optional[QueryExecutionStats]:
96        context = self._contexts.get(snapshot_id_batch)
97        self._contexts.pop(snapshot_id_batch, None)
98        return context.get_execution_stats() if context else None

Thread-local context manager for snapshot execution statistics, such as rows processed.

def get_execution_context( self, snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch) -> Optional[QueryExecutionContext]:
64    def get_execution_context(
65        self, snapshot_id_batch: SnapshotIdBatch
66    ) -> t.Optional[QueryExecutionContext]:
67        return self._contexts.get(snapshot_id_batch)
def is_tracking(self) -> bool:
69    def is_tracking(self) -> bool:
70        return getattr(self._thread_local, "context", None) is not None
@contextmanager
def track_execution( self, snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch) -> Iterator[Optional[QueryExecutionContext]]:
72    @contextmanager
73    def track_execution(
74        self, snapshot_id_batch: SnapshotIdBatch
75    ) -> t.Iterator[t.Optional[QueryExecutionContext]]:
76        """Context manager for tracking snapshot execution statistics such as row counts and bytes processed."""
77        context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch)
78        self._thread_local.context = context
79        self._contexts[snapshot_id_batch] = context
80
81        try:
82            yield context
83        finally:
84            self._thread_local.context = None

Context manager for tracking snapshot execution statistics such as row counts and bytes processed.

def record_execution( self, sql: str, row_count: Optional[int], bytes_processed: Optional[int]) -> None:
86    def record_execution(
87        self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
88    ) -> None:
89        context = getattr(self._thread_local, "context", None)
90        if context is not None:
91            context.add_execution(sql, row_count, bytes_processed)
def get_execution_stats( self, snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch) -> Optional[QueryExecutionStats]:
93    def get_execution_stats(
94        self, snapshot_id_batch: SnapshotIdBatch
95    ) -> t.Optional[QueryExecutionStats]:
96        context = self._contexts.get(snapshot_id_batch)
97        self._contexts.pop(snapshot_id_batch, None)
98        return context.get_execution_stats() if context else None