Edit on GitHub

sqlmesh.core.linter.definition

  1from __future__ import annotations
  2
  3import operator as op
  4import typing as t
  5from collections.abc import Iterator, Iterable, Set, Mapping, Callable
  6from functools import reduce
  7
  8from sqlmesh.core.config.linter import LinterConfig
  9from sqlmesh.core.console import LinterConsole, get_console
 10from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix
 11from sqlmesh.core.model import Model
 12from sqlmesh.utils.errors import raise_config_error
 13
 14if t.TYPE_CHECKING:
 15    from sqlmesh.core.context import GenericContext
 16
 17
 18def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet:
 19    if "all" in rule_names:
 20        return all_rules
 21
 22    rules = set()
 23    for rule_name in rule_names:
 24        if rule_name not in all_rules:
 25            raise_config_error(f"Rule {rule_name} could not be found")
 26
 27        rules.add(all_rules[rule_name])
 28
 29    return RuleSet(rules)
 30
 31
 32class Linter:
 33    def __init__(
 34        self, enabled: bool, all_rules: RuleSet, rules: RuleSet, warn_rules: RuleSet
 35    ) -> None:
 36        self.enabled = enabled
 37        self.all_rules = all_rules
 38        self.rules = rules
 39        self.warn_rules = warn_rules
 40
 41        if overlapping := rules.intersection(warn_rules):
 42            overlapping_rules = ", ".join(rule for rule in overlapping)
 43            raise_config_error(
 44                f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]"
 45            )
 46
 47    @classmethod
 48    def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
 49        ignored_rules = select_rules(all_rules, config.ignored_rules)
 50        included_rules = all_rules.difference(ignored_rules)
 51
 52        rules = select_rules(included_rules, config.rules)
 53        warn_rules = select_rules(included_rules, config.warn_rules)
 54
 55        return Linter(config.enabled, all_rules, rules, warn_rules)
 56
 57    def lint_model(
 58        self, model: Model, context: GenericContext, console: LinterConsole = get_console()
 59    ) -> t.Tuple[bool, t.List[AnnotatedRuleViolation]]:
 60        if not self.enabled:
 61            return False, []
 62
 63        ignored_rules = select_rules(self.all_rules, model.ignored_rules)
 64
 65        rules = self.rules.difference(ignored_rules)
 66        warn_rules = self.warn_rules.difference(ignored_rules)
 67
 68        error_violations = rules.check_model(model, context)
 69        warn_violations = warn_rules.check_model(model, context)
 70
 71        all_violations: t.List[AnnotatedRuleViolation] = [
 72            AnnotatedRuleViolation(
 73                rule=violation.rule,
 74                violation_msg=violation.violation_msg,
 75                model=model,
 76                violation_type="error",
 77                violation_range=violation.violation_range,
 78                fixes=violation.fixes,
 79            )
 80            for violation in error_violations
 81        ] + [
 82            AnnotatedRuleViolation(
 83                rule=violation.rule,
 84                violation_msg=violation.violation_msg,
 85                model=model,
 86                violation_type="warning",
 87                violation_range=violation.violation_range,
 88                fixes=violation.fixes,
 89            )
 90            for violation in warn_violations
 91        ]
 92
 93        if warn_violations:
 94            console.show_linter_violations(warn_violations, model)
 95        if error_violations:
 96            console.show_linter_violations(error_violations, model, is_error=True)
 97            return True, all_violations
 98
 99        return False, all_violations
100
101
102class RuleSet(Mapping[str, type[Rule]]):
103    def __init__(self, rules: Iterable[type[Rule]] = ()) -> None:
104        self._underlying = {rule.name: rule for rule in rules}
105
106    def check_model(self, model: Model, context: GenericContext) -> t.List[RuleViolation]:
107        violations = []
108
109        for rule in self._underlying.values():
110            violation = rule(context).check_model(model)
111            if isinstance(violation, RuleViolation):
112                violation = [violation]
113            if violation:
114                violations.extend(violation)
115
116        return violations
117
118    def __iter__(self) -> Iterator[str]:
119        return iter(self._underlying)
120
121    def __len__(self) -> int:
122        return len(self._underlying)
123
124    def __getitem__(self, rule: str | type[Rule]) -> type[Rule]:
125        key = rule if isinstance(rule, str) else rule.name
126        return self._underlying[key]
127
128    def __op(
129        self,
130        op: Callable[[Set[type[Rule]], Set[type[Rule]]], Set[type[Rule]]],
131        other: RuleSet,
132        /,
133    ) -> RuleSet:
134        rules = set()
135        for rule in op(set(self.values()), set(other.values())):
136            rules.add(other[rule] if rule in other else self[rule])
137
138        return RuleSet(rules)
139
140    def union(self, *others: RuleSet) -> RuleSet:
141        return reduce(lambda lhs, rhs: lhs.__op(op.or_, rhs), (self, *others))
142
143    def intersection(self, *others: RuleSet) -> RuleSet:
144        return reduce(lambda lhs, rhs: lhs.__op(op.and_, rhs), (self, *others))
145
146    def difference(self, *others: RuleSet) -> RuleSet:
147        return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others))
148
149
150class AnnotatedRuleViolation(RuleViolation):
151    def __init__(
152        self,
153        rule: Rule,
154        violation_msg: str,
155        model: Model,
156        violation_type: t.Literal["error", "warning"],
157        violation_range: t.Optional[Range] = None,
158        fixes: t.Optional[t.List[Fix]] = None,
159    ) -> None:
160        super().__init__(rule, violation_msg, violation_range, fixes)
161        self.model = model
162        self.violation_type = violation_type
def select_rules( all_rules: RuleSet, rule_names: Set[str]) -> RuleSet:
19def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet:
20    if "all" in rule_names:
21        return all_rules
22
23    rules = set()
24    for rule_name in rule_names:
25        if rule_name not in all_rules:
26            raise_config_error(f"Rule {rule_name} could not be found")
27
28        rules.add(all_rules[rule_name])
29
30    return RuleSet(rules)
class Linter:
 33class Linter:
 34    def __init__(
 35        self, enabled: bool, all_rules: RuleSet, rules: RuleSet, warn_rules: RuleSet
 36    ) -> None:
 37        self.enabled = enabled
 38        self.all_rules = all_rules
 39        self.rules = rules
 40        self.warn_rules = warn_rules
 41
 42        if overlapping := rules.intersection(warn_rules):
 43            overlapping_rules = ", ".join(rule for rule in overlapping)
 44            raise_config_error(
 45                f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]"
 46            )
 47
 48    @classmethod
 49    def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
 50        ignored_rules = select_rules(all_rules, config.ignored_rules)
 51        included_rules = all_rules.difference(ignored_rules)
 52
 53        rules = select_rules(included_rules, config.rules)
 54        warn_rules = select_rules(included_rules, config.warn_rules)
 55
 56        return Linter(config.enabled, all_rules, rules, warn_rules)
 57
 58    def lint_model(
 59        self, model: Model, context: GenericContext, console: LinterConsole = get_console()
 60    ) -> t.Tuple[bool, t.List[AnnotatedRuleViolation]]:
 61        if not self.enabled:
 62            return False, []
 63
 64        ignored_rules = select_rules(self.all_rules, model.ignored_rules)
 65
 66        rules = self.rules.difference(ignored_rules)
 67        warn_rules = self.warn_rules.difference(ignored_rules)
 68
 69        error_violations = rules.check_model(model, context)
 70        warn_violations = warn_rules.check_model(model, context)
 71
 72        all_violations: t.List[AnnotatedRuleViolation] = [
 73            AnnotatedRuleViolation(
 74                rule=violation.rule,
 75                violation_msg=violation.violation_msg,
 76                model=model,
 77                violation_type="error",
 78                violation_range=violation.violation_range,
 79                fixes=violation.fixes,
 80            )
 81            for violation in error_violations
 82        ] + [
 83            AnnotatedRuleViolation(
 84                rule=violation.rule,
 85                violation_msg=violation.violation_msg,
 86                model=model,
 87                violation_type="warning",
 88                violation_range=violation.violation_range,
 89                fixes=violation.fixes,
 90            )
 91            for violation in warn_violations
 92        ]
 93
 94        if warn_violations:
 95            console.show_linter_violations(warn_violations, model)
 96        if error_violations:
 97            console.show_linter_violations(error_violations, model, is_error=True)
 98            return True, all_violations
 99
100        return False, all_violations
Linter( enabled: bool, all_rules: RuleSet, rules: RuleSet, warn_rules: RuleSet)
34    def __init__(
35        self, enabled: bool, all_rules: RuleSet, rules: RuleSet, warn_rules: RuleSet
36    ) -> None:
37        self.enabled = enabled
38        self.all_rules = all_rules
39        self.rules = rules
40        self.warn_rules = warn_rules
41
42        if overlapping := rules.intersection(warn_rules):
43            overlapping_rules = ", ".join(rule for rule in overlapping)
44            raise_config_error(
45                f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]"
46            )
enabled
all_rules
rules
warn_rules
@classmethod
def from_rules( cls, all_rules: RuleSet, config: sqlmesh.core.config.linter.LinterConfig) -> Linter:
48    @classmethod
49    def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
50        ignored_rules = select_rules(all_rules, config.ignored_rules)
51        included_rules = all_rules.difference(ignored_rules)
52
53        rules = select_rules(included_rules, config.rules)
54        warn_rules = select_rules(included_rules, config.warn_rules)
55
56        return Linter(config.enabled, all_rules, rules, warn_rules)
 58    def lint_model(
 59        self, model: Model, context: GenericContext, console: LinterConsole = get_console()
 60    ) -> t.Tuple[bool, t.List[AnnotatedRuleViolation]]:
 61        if not self.enabled:
 62            return False, []
 63
 64        ignored_rules = select_rules(self.all_rules, model.ignored_rules)
 65
 66        rules = self.rules.difference(ignored_rules)
 67        warn_rules = self.warn_rules.difference(ignored_rules)
 68
 69        error_violations = rules.check_model(model, context)
 70        warn_violations = warn_rules.check_model(model, context)
 71
 72        all_violations: t.List[AnnotatedRuleViolation] = [
 73            AnnotatedRuleViolation(
 74                rule=violation.rule,
 75                violation_msg=violation.violation_msg,
 76                model=model,
 77                violation_type="error",
 78                violation_range=violation.violation_range,
 79                fixes=violation.fixes,
 80            )
 81            for violation in error_violations
 82        ] + [
 83            AnnotatedRuleViolation(
 84                rule=violation.rule,
 85                violation_msg=violation.violation_msg,
 86                model=model,
 87                violation_type="warning",
 88                violation_range=violation.violation_range,
 89                fixes=violation.fixes,
 90            )
 91            for violation in warn_violations
 92        ]
 93
 94        if warn_violations:
 95            console.show_linter_violations(warn_violations, model)
 96        if error_violations:
 97            console.show_linter_violations(error_violations, model, is_error=True)
 98            return True, all_violations
 99
100        return False, all_violations
class RuleSet(collections.abc.Mapping[str, type[sqlmesh.core.linter.rule.Rule]]):
103class RuleSet(Mapping[str, type[Rule]]):
104    def __init__(self, rules: Iterable[type[Rule]] = ()) -> None:
105        self._underlying = {rule.name: rule for rule in rules}
106
107    def check_model(self, model: Model, context: GenericContext) -> t.List[RuleViolation]:
108        violations = []
109
110        for rule in self._underlying.values():
111            violation = rule(context).check_model(model)
112            if isinstance(violation, RuleViolation):
113                violation = [violation]
114            if violation:
115                violations.extend(violation)
116
117        return violations
118
119    def __iter__(self) -> Iterator[str]:
120        return iter(self._underlying)
121
122    def __len__(self) -> int:
123        return len(self._underlying)
124
125    def __getitem__(self, rule: str | type[Rule]) -> type[Rule]:
126        key = rule if isinstance(rule, str) else rule.name
127        return self._underlying[key]
128
129    def __op(
130        self,
131        op: Callable[[Set[type[Rule]], Set[type[Rule]]], Set[type[Rule]]],
132        other: RuleSet,
133        /,
134    ) -> RuleSet:
135        rules = set()
136        for rule in op(set(self.values()), set(other.values())):
137            rules.add(other[rule] if rule in other else self[rule])
138
139        return RuleSet(rules)
140
141    def union(self, *others: RuleSet) -> RuleSet:
142        return reduce(lambda lhs, rhs: lhs.__op(op.or_, rhs), (self, *others))
143
144    def intersection(self, *others: RuleSet) -> RuleSet:
145        return reduce(lambda lhs, rhs: lhs.__op(op.and_, rhs), (self, *others))
146
147    def difference(self, *others: RuleSet) -> RuleSet:
148        return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others))

A Mapping is a generic container for associating key/value pairs.

This class provides concrete generic implementations of all methods except for __getitem__, __iter__, and __len__.

RuleSet( rules: collections.abc.Iterable[type[sqlmesh.core.linter.rule.Rule]] = ())
104    def __init__(self, rules: Iterable[type[Rule]] = ()) -> None:
105        self._underlying = {rule.name: rule for rule in rules}
107    def check_model(self, model: Model, context: GenericContext) -> t.List[RuleViolation]:
108        violations = []
109
110        for rule in self._underlying.values():
111            violation = rule(context).check_model(model)
112            if isinstance(violation, RuleViolation):
113                violation = [violation]
114            if violation:
115                violations.extend(violation)
116
117        return violations
def union( self, *others: RuleSet) -> RuleSet:
141    def union(self, *others: RuleSet) -> RuleSet:
142        return reduce(lambda lhs, rhs: lhs.__op(op.or_, rhs), (self, *others))
def intersection( self, *others: RuleSet) -> RuleSet:
144    def intersection(self, *others: RuleSet) -> RuleSet:
145        return reduce(lambda lhs, rhs: lhs.__op(op.and_, rhs), (self, *others))
def difference( self, *others: RuleSet) -> RuleSet:
147    def difference(self, *others: RuleSet) -> RuleSet:
148        return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others))
Inherited Members
collections.abc.Mapping
get
keys
items
values
class AnnotatedRuleViolation(sqlmesh.core.linter.rule.RuleViolation):
151class AnnotatedRuleViolation(RuleViolation):
152    def __init__(
153        self,
154        rule: Rule,
155        violation_msg: str,
156        model: Model,
157        violation_type: t.Literal["error", "warning"],
158        violation_range: t.Optional[Range] = None,
159        fixes: t.Optional[t.List[Fix]] = None,
160    ) -> None:
161        super().__init__(rule, violation_msg, violation_range, fixes)
162        self.model = model
163        self.violation_type = violation_type
AnnotatedRuleViolation( rule: sqlmesh.core.linter.rule.Rule, violation_msg: str, model: Union[sqlmesh.core.model.definition.SqlModel, sqlmesh.core.model.definition.SeedModel, sqlmesh.core.model.definition.PythonModel, sqlmesh.core.model.definition.ExternalModel], violation_type: Literal['error', 'warning'], violation_range: Optional[sqlmesh.core.linter.rule.Range] = None, fixes: Optional[List[sqlmesh.core.linter.rule.Fix]] = None)
152    def __init__(
153        self,
154        rule: Rule,
155        violation_msg: str,
156        model: Model,
157        violation_type: t.Literal["error", "warning"],
158        violation_range: t.Optional[Range] = None,
159        fixes: t.Optional[t.List[Fix]] = None,
160    ) -> None:
161        super().__init__(rule, violation_msg, violation_range, fixes)
162        self.model = model
163        self.violation_type = violation_type
model
violation_type