sqlmesh.core.snapshot.execution_tracker
1from __future__ import annotations 2 3import typing as t 4from contextlib import contextmanager 5from threading import local 6from dataclasses import dataclass, field 7from sqlmesh.core.snapshot import SnapshotIdBatch 8 9 10@dataclass 11class QueryExecutionStats: 12 snapshot_id_batch: SnapshotIdBatch 13 total_rows_processed: t.Optional[int] = None 14 total_bytes_processed: t.Optional[int] = None 15 16 17@dataclass 18class QueryExecutionContext: 19 """ 20 Container for tracking rows processed or other execution information during snapshot evaluation. 21 22 It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation. 23 24 Attributes: 25 snapshot_id_batch: Identifier linking this context to a specific snapshot evaluation 26 stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation 27 """ 28 29 snapshot_id_batch: SnapshotIdBatch 30 stats: QueryExecutionStats = field(init=False) 31 32 def __post_init__(self) -> None: 33 self.stats = QueryExecutionStats(snapshot_id_batch=self.snapshot_id_batch) 34 35 def add_execution( 36 self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] 37 ) -> None: 38 if row_count is not None and row_count >= 0: 39 if self.stats.total_rows_processed is None: 40 self.stats.total_rows_processed = row_count 41 else: 42 self.stats.total_rows_processed += row_count 43 44 # conditional on row_count because we should only count bytes corresponding to 45 # DML actions whose rows were captured 46 if bytes_processed is not None: 47 if self.stats.total_bytes_processed is None: 48 self.stats.total_bytes_processed = bytes_processed 49 else: 50 self.stats.total_bytes_processed += bytes_processed 51 52 def get_execution_stats(self) -> QueryExecutionStats: 53 return self.stats 54 55 56class QueryExecutionTracker: 57 """Thread-local context manager for snapshot execution statistics, such as rows processed.""" 58 59 def __init__(self) -> None: 60 self._thread_local = local() 61 self._contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {} 62 63 def get_execution_context( 64 self, snapshot_id_batch: SnapshotIdBatch 65 ) -> t.Optional[QueryExecutionContext]: 66 return self._contexts.get(snapshot_id_batch) 67 68 def is_tracking(self) -> bool: 69 return getattr(self._thread_local, "context", None) is not None 70 71 @contextmanager 72 def track_execution( 73 self, snapshot_id_batch: SnapshotIdBatch 74 ) -> t.Iterator[t.Optional[QueryExecutionContext]]: 75 """Context manager for tracking snapshot execution statistics such as row counts and bytes processed.""" 76 context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch) 77 self._thread_local.context = context 78 self._contexts[snapshot_id_batch] = context 79 80 try: 81 yield context 82 finally: 83 self._thread_local.context = None 84 85 def record_execution( 86 self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] 87 ) -> None: 88 context = getattr(self._thread_local, "context", None) 89 if context is not None: 90 context.add_execution(sql, row_count, bytes_processed) 91 92 def get_execution_stats( 93 self, snapshot_id_batch: SnapshotIdBatch 94 ) -> t.Optional[QueryExecutionStats]: 95 context = self._contexts.get(snapshot_id_batch) 96 self._contexts.pop(snapshot_id_batch, None) 97 return context.get_execution_stats() if context else None
@dataclass
class
QueryExecutionStats:
11@dataclass 12class QueryExecutionStats: 13 snapshot_id_batch: SnapshotIdBatch 14 total_rows_processed: t.Optional[int] = None 15 total_bytes_processed: t.Optional[int] = None
QueryExecutionStats( snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch, total_rows_processed: Optional[int] = None, total_bytes_processed: Optional[int] = None)
snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch
@dataclass
class
QueryExecutionContext:
18@dataclass 19class QueryExecutionContext: 20 """ 21 Container for tracking rows processed or other execution information during snapshot evaluation. 22 23 It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation. 24 25 Attributes: 26 snapshot_id_batch: Identifier linking this context to a specific snapshot evaluation 27 stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation 28 """ 29 30 snapshot_id_batch: SnapshotIdBatch 31 stats: QueryExecutionStats = field(init=False) 32 33 def __post_init__(self) -> None: 34 self.stats = QueryExecutionStats(snapshot_id_batch=self.snapshot_id_batch) 35 36 def add_execution( 37 self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] 38 ) -> None: 39 if row_count is not None and row_count >= 0: 40 if self.stats.total_rows_processed is None: 41 self.stats.total_rows_processed = row_count 42 else: 43 self.stats.total_rows_processed += row_count 44 45 # conditional on row_count because we should only count bytes corresponding to 46 # DML actions whose rows were captured 47 if bytes_processed is not None: 48 if self.stats.total_bytes_processed is None: 49 self.stats.total_bytes_processed = bytes_processed 50 else: 51 self.stats.total_bytes_processed += bytes_processed 52 53 def get_execution_stats(self) -> QueryExecutionStats: 54 return self.stats
Container for tracking rows processed or other execution information during snapshot evaluation.
It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation.
Attributes:
- snapshot_id_batch: Identifier linking this context to a specific snapshot evaluation
- stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation
QueryExecutionContext(snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch)
snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch
stats: QueryExecutionStats
def
add_execution( self, sql: str, row_count: Optional[int], bytes_processed: Optional[int]) -> None:
36 def add_execution( 37 self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] 38 ) -> None: 39 if row_count is not None and row_count >= 0: 40 if self.stats.total_rows_processed is None: 41 self.stats.total_rows_processed = row_count 42 else: 43 self.stats.total_rows_processed += row_count 44 45 # conditional on row_count because we should only count bytes corresponding to 46 # DML actions whose rows were captured 47 if bytes_processed is not None: 48 if self.stats.total_bytes_processed is None: 49 self.stats.total_bytes_processed = bytes_processed 50 else: 51 self.stats.total_bytes_processed += bytes_processed
class
QueryExecutionTracker:
57class QueryExecutionTracker: 58 """Thread-local context manager for snapshot execution statistics, such as rows processed.""" 59 60 def __init__(self) -> None: 61 self._thread_local = local() 62 self._contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {} 63 64 def get_execution_context( 65 self, snapshot_id_batch: SnapshotIdBatch 66 ) -> t.Optional[QueryExecutionContext]: 67 return self._contexts.get(snapshot_id_batch) 68 69 def is_tracking(self) -> bool: 70 return getattr(self._thread_local, "context", None) is not None 71 72 @contextmanager 73 def track_execution( 74 self, snapshot_id_batch: SnapshotIdBatch 75 ) -> t.Iterator[t.Optional[QueryExecutionContext]]: 76 """Context manager for tracking snapshot execution statistics such as row counts and bytes processed.""" 77 context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch) 78 self._thread_local.context = context 79 self._contexts[snapshot_id_batch] = context 80 81 try: 82 yield context 83 finally: 84 self._thread_local.context = None 85 86 def record_execution( 87 self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] 88 ) -> None: 89 context = getattr(self._thread_local, "context", None) 90 if context is not None: 91 context.add_execution(sql, row_count, bytes_processed) 92 93 def get_execution_stats( 94 self, snapshot_id_batch: SnapshotIdBatch 95 ) -> t.Optional[QueryExecutionStats]: 96 context = self._contexts.get(snapshot_id_batch) 97 self._contexts.pop(snapshot_id_batch, None) 98 return context.get_execution_stats() if context else None
Thread-local context manager for snapshot execution statistics, such as rows processed.
def
get_execution_context( self, snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch) -> Optional[QueryExecutionContext]:
@contextmanager
def
track_execution( self, snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch) -> Iterator[Optional[QueryExecutionContext]]:
72 @contextmanager 73 def track_execution( 74 self, snapshot_id_batch: SnapshotIdBatch 75 ) -> t.Iterator[t.Optional[QueryExecutionContext]]: 76 """Context manager for tracking snapshot execution statistics such as row counts and bytes processed.""" 77 context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch) 78 self._thread_local.context = context 79 self._contexts[snapshot_id_batch] = context 80 81 try: 82 yield context 83 finally: 84 self._thread_local.context = None
Context manager for tracking snapshot execution statistics such as row counts and bytes processed.
def
record_execution( self, sql: str, row_count: Optional[int], bytes_processed: Optional[int]) -> None:
def
get_execution_stats( self, snapshot_id_batch: sqlmesh.core.snapshot.definition.SnapshotIdBatch) -> Optional[QueryExecutionStats]: