sqlmesh.core.linter.definition
1from __future__ import annotations 2 3import operator as op 4import typing as t 5from collections.abc import Iterator, Iterable, Set, Mapping, Callable 6from functools import reduce 7 8from sqlmesh.core.config.linter import LinterConfig 9from sqlmesh.core.console import LinterConsole, get_console 10from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix 11from sqlmesh.core.model import Model 12from sqlmesh.utils.errors import raise_config_error 13 14if t.TYPE_CHECKING: 15 from sqlmesh.core.context import GenericContext 16 17 18def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet: 19 if "all" in rule_names: 20 return all_rules 21 22 rules = set() 23 for rule_name in rule_names: 24 if rule_name not in all_rules: 25 raise_config_error(f"Rule {rule_name} could not be found") 26 27 rules.add(all_rules[rule_name]) 28 29 return RuleSet(rules) 30 31 32class Linter: 33 def __init__( 34 self, enabled: bool, all_rules: RuleSet, rules: RuleSet, warn_rules: RuleSet 35 ) -> None: 36 self.enabled = enabled 37 self.all_rules = all_rules 38 self.rules = rules 39 self.warn_rules = warn_rules 40 41 if overlapping := rules.intersection(warn_rules): 42 overlapping_rules = ", ".join(rule for rule in overlapping) 43 raise_config_error( 44 f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]" 45 ) 46 47 @classmethod 48 def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter: 49 ignored_rules = select_rules(all_rules, config.ignored_rules) 50 included_rules = all_rules.difference(ignored_rules) 51 52 rules = select_rules(included_rules, config.rules) 53 warn_rules = select_rules(included_rules, config.warn_rules) 54 55 return Linter(config.enabled, all_rules, rules, warn_rules) 56 57 def lint_model( 58 self, model: Model, context: GenericContext, console: LinterConsole = get_console() 59 ) -> t.Tuple[bool, t.List[AnnotatedRuleViolation]]: 60 if not self.enabled: 61 return False, [] 62 63 ignored_rules = select_rules(self.all_rules, model.ignored_rules) 64 65 rules = self.rules.difference(ignored_rules) 66 warn_rules = self.warn_rules.difference(ignored_rules) 67 68 error_violations = rules.check_model(model, context) 69 warn_violations = warn_rules.check_model(model, context) 70 71 all_violations: t.List[AnnotatedRuleViolation] = [ 72 AnnotatedRuleViolation( 73 rule=violation.rule, 74 violation_msg=violation.violation_msg, 75 model=model, 76 violation_type="error", 77 violation_range=violation.violation_range, 78 fixes=violation.fixes, 79 ) 80 for violation in error_violations 81 ] + [ 82 AnnotatedRuleViolation( 83 rule=violation.rule, 84 violation_msg=violation.violation_msg, 85 model=model, 86 violation_type="warning", 87 violation_range=violation.violation_range, 88 fixes=violation.fixes, 89 ) 90 for violation in warn_violations 91 ] 92 93 if warn_violations: 94 console.show_linter_violations(warn_violations, model) 95 if error_violations: 96 console.show_linter_violations(error_violations, model, is_error=True) 97 return True, all_violations 98 99 return False, all_violations 100 101 102class RuleSet(Mapping[str, type[Rule]]): 103 def __init__(self, rules: Iterable[type[Rule]] = ()) -> None: 104 self._underlying = {rule.name: rule for rule in rules} 105 106 def check_model(self, model: Model, context: GenericContext) -> t.List[RuleViolation]: 107 violations = [] 108 109 for rule in self._underlying.values(): 110 violation = rule(context).check_model(model) 111 if isinstance(violation, RuleViolation): 112 violation = [violation] 113 if violation: 114 violations.extend(violation) 115 116 return violations 117 118 def __iter__(self) -> Iterator[str]: 119 return iter(self._underlying) 120 121 def __len__(self) -> int: 122 return len(self._underlying) 123 124 def __getitem__(self, rule: str | type[Rule]) -> type[Rule]: 125 key = rule if isinstance(rule, str) else rule.name 126 return self._underlying[key] 127 128 def __op( 129 self, 130 op: Callable[[Set[type[Rule]], Set[type[Rule]]], Set[type[Rule]]], 131 other: RuleSet, 132 /, 133 ) -> RuleSet: 134 rules = set() 135 for rule in op(set(self.values()), set(other.values())): 136 rules.add(other[rule] if rule in other else self[rule]) 137 138 return RuleSet(rules) 139 140 def union(self, *others: RuleSet) -> RuleSet: 141 return reduce(lambda lhs, rhs: lhs.__op(op.or_, rhs), (self, *others)) 142 143 def intersection(self, *others: RuleSet) -> RuleSet: 144 return reduce(lambda lhs, rhs: lhs.__op(op.and_, rhs), (self, *others)) 145 146 def difference(self, *others: RuleSet) -> RuleSet: 147 return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others)) 148 149 150class AnnotatedRuleViolation(RuleViolation): 151 def __init__( 152 self, 153 rule: Rule, 154 violation_msg: str, 155 model: Model, 156 violation_type: t.Literal["error", "warning"], 157 violation_range: t.Optional[Range] = None, 158 fixes: t.Optional[t.List[Fix]] = None, 159 ) -> None: 160 super().__init__(rule, violation_msg, violation_range, fixes) 161 self.model = model 162 self.violation_type = violation_type
19def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet: 20 if "all" in rule_names: 21 return all_rules 22 23 rules = set() 24 for rule_name in rule_names: 25 if rule_name not in all_rules: 26 raise_config_error(f"Rule {rule_name} could not be found") 27 28 rules.add(all_rules[rule_name]) 29 30 return RuleSet(rules)
class
Linter:
33class Linter: 34 def __init__( 35 self, enabled: bool, all_rules: RuleSet, rules: RuleSet, warn_rules: RuleSet 36 ) -> None: 37 self.enabled = enabled 38 self.all_rules = all_rules 39 self.rules = rules 40 self.warn_rules = warn_rules 41 42 if overlapping := rules.intersection(warn_rules): 43 overlapping_rules = ", ".join(rule for rule in overlapping) 44 raise_config_error( 45 f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]" 46 ) 47 48 @classmethod 49 def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter: 50 ignored_rules = select_rules(all_rules, config.ignored_rules) 51 included_rules = all_rules.difference(ignored_rules) 52 53 rules = select_rules(included_rules, config.rules) 54 warn_rules = select_rules(included_rules, config.warn_rules) 55 56 return Linter(config.enabled, all_rules, rules, warn_rules) 57 58 def lint_model( 59 self, model: Model, context: GenericContext, console: LinterConsole = get_console() 60 ) -> t.Tuple[bool, t.List[AnnotatedRuleViolation]]: 61 if not self.enabled: 62 return False, [] 63 64 ignored_rules = select_rules(self.all_rules, model.ignored_rules) 65 66 rules = self.rules.difference(ignored_rules) 67 warn_rules = self.warn_rules.difference(ignored_rules) 68 69 error_violations = rules.check_model(model, context) 70 warn_violations = warn_rules.check_model(model, context) 71 72 all_violations: t.List[AnnotatedRuleViolation] = [ 73 AnnotatedRuleViolation( 74 rule=violation.rule, 75 violation_msg=violation.violation_msg, 76 model=model, 77 violation_type="error", 78 violation_range=violation.violation_range, 79 fixes=violation.fixes, 80 ) 81 for violation in error_violations 82 ] + [ 83 AnnotatedRuleViolation( 84 rule=violation.rule, 85 violation_msg=violation.violation_msg, 86 model=model, 87 violation_type="warning", 88 violation_range=violation.violation_range, 89 fixes=violation.fixes, 90 ) 91 for violation in warn_violations 92 ] 93 94 if warn_violations: 95 console.show_linter_violations(warn_violations, model) 96 if error_violations: 97 console.show_linter_violations(error_violations, model, is_error=True) 98 return True, all_violations 99 100 return False, all_violations
34 def __init__( 35 self, enabled: bool, all_rules: RuleSet, rules: RuleSet, warn_rules: RuleSet 36 ) -> None: 37 self.enabled = enabled 38 self.all_rules = all_rules 39 self.rules = rules 40 self.warn_rules = warn_rules 41 42 if overlapping := rules.intersection(warn_rules): 43 overlapping_rules = ", ".join(rule for rule in overlapping) 44 raise_config_error( 45 f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]" 46 )
@classmethod
def
from_rules( cls, all_rules: RuleSet, config: sqlmesh.core.config.linter.LinterConfig) -> Linter:
48 @classmethod 49 def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter: 50 ignored_rules = select_rules(all_rules, config.ignored_rules) 51 included_rules = all_rules.difference(ignored_rules) 52 53 rules = select_rules(included_rules, config.rules) 54 warn_rules = select_rules(included_rules, config.warn_rules) 55 56 return Linter(config.enabled, all_rules, rules, warn_rules)
def
lint_model( self, model: Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel], context: sqlmesh.core.context.GenericContext, console: sqlmesh.core.console.LinterConsole = <sqlmesh.core.console.NoopConsole object>) -> Tuple[bool, List[AnnotatedRuleViolation]]:
58 def lint_model( 59 self, model: Model, context: GenericContext, console: LinterConsole = get_console() 60 ) -> t.Tuple[bool, t.List[AnnotatedRuleViolation]]: 61 if not self.enabled: 62 return False, [] 63 64 ignored_rules = select_rules(self.all_rules, model.ignored_rules) 65 66 rules = self.rules.difference(ignored_rules) 67 warn_rules = self.warn_rules.difference(ignored_rules) 68 69 error_violations = rules.check_model(model, context) 70 warn_violations = warn_rules.check_model(model, context) 71 72 all_violations: t.List[AnnotatedRuleViolation] = [ 73 AnnotatedRuleViolation( 74 rule=violation.rule, 75 violation_msg=violation.violation_msg, 76 model=model, 77 violation_type="error", 78 violation_range=violation.violation_range, 79 fixes=violation.fixes, 80 ) 81 for violation in error_violations 82 ] + [ 83 AnnotatedRuleViolation( 84 rule=violation.rule, 85 violation_msg=violation.violation_msg, 86 model=model, 87 violation_type="warning", 88 violation_range=violation.violation_range, 89 fixes=violation.fixes, 90 ) 91 for violation in warn_violations 92 ] 93 94 if warn_violations: 95 console.show_linter_violations(warn_violations, model) 96 if error_violations: 97 console.show_linter_violations(error_violations, model, is_error=True) 98 return True, all_violations 99 100 return False, all_violations
class
RuleSet(collections.abc.Mapping[str, type[sqlmesh.core.linter.rule.Rule]]):
103class RuleSet(Mapping[str, type[Rule]]): 104 def __init__(self, rules: Iterable[type[Rule]] = ()) -> None: 105 self._underlying = {rule.name: rule for rule in rules} 106 107 def check_model(self, model: Model, context: GenericContext) -> t.List[RuleViolation]: 108 violations = [] 109 110 for rule in self._underlying.values(): 111 violation = rule(context).check_model(model) 112 if isinstance(violation, RuleViolation): 113 violation = [violation] 114 if violation: 115 violations.extend(violation) 116 117 return violations 118 119 def __iter__(self) -> Iterator[str]: 120 return iter(self._underlying) 121 122 def __len__(self) -> int: 123 return len(self._underlying) 124 125 def __getitem__(self, rule: str | type[Rule]) -> type[Rule]: 126 key = rule if isinstance(rule, str) else rule.name 127 return self._underlying[key] 128 129 def __op( 130 self, 131 op: Callable[[Set[type[Rule]], Set[type[Rule]]], Set[type[Rule]]], 132 other: RuleSet, 133 /, 134 ) -> RuleSet: 135 rules = set() 136 for rule in op(set(self.values()), set(other.values())): 137 rules.add(other[rule] if rule in other else self[rule]) 138 139 return RuleSet(rules) 140 141 def union(self, *others: RuleSet) -> RuleSet: 142 return reduce(lambda lhs, rhs: lhs.__op(op.or_, rhs), (self, *others)) 143 144 def intersection(self, *others: RuleSet) -> RuleSet: 145 return reduce(lambda lhs, rhs: lhs.__op(op.and_, rhs), (self, *others)) 146 147 def difference(self, *others: RuleSet) -> RuleSet: 148 return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others))
A Mapping is a generic container for associating key/value pairs.
This class provides concrete generic implementations of all methods except for __getitem__, __iter__, and __len__.
RuleSet( rules: collections.abc.Iterable[type[sqlmesh.core.linter.rule.Rule]] = ())
def
check_model( self, model: Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel], context: sqlmesh.core.context.GenericContext) -> List[sqlmesh.core.linter.rule.RuleViolation]:
107 def check_model(self, model: Model, context: GenericContext) -> t.List[RuleViolation]: 108 violations = [] 109 110 for rule in self._underlying.values(): 111 violation = rule(context).check_model(model) 112 if isinstance(violation, RuleViolation): 113 violation = [violation] 114 if violation: 115 violations.extend(violation) 116 117 return violations
Inherited Members
- collections.abc.Mapping
- get
- keys
- items
- values
151class AnnotatedRuleViolation(RuleViolation): 152 def __init__( 153 self, 154 rule: Rule, 155 violation_msg: str, 156 model: Model, 157 violation_type: t.Literal["error", "warning"], 158 violation_range: t.Optional[Range] = None, 159 fixes: t.Optional[t.List[Fix]] = None, 160 ) -> None: 161 super().__init__(rule, violation_msg, violation_range, fixes) 162 self.model = model 163 self.violation_type = violation_type
AnnotatedRuleViolation( rule: sqlmesh.core.linter.rule.Rule, violation_msg: str, model: Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel], violation_type: Literal['error', 'warning'], violation_range: Optional[sqlmesh.core.linter.rule.Range] = None, fixes: Optional[List[sqlmesh.core.linter.rule.Fix]] = None)
152 def __init__( 153 self, 154 rule: Rule, 155 violation_msg: str, 156 model: Model, 157 violation_type: t.Literal["error", "warning"], 158 violation_range: t.Optional[Range] = None, 159 fixes: t.Optional[t.List[Fix]] = None, 160 ) -> None: 161 super().__init__(rule, violation_msg, violation_range, fixes) 162 self.model = model 163 self.violation_type = violation_type