1
1
from __future__ import annotations
2
2
3
3
import functools
4
+ import operator
4
5
import typing as t
6
+ from abc import abstractmethod
5
7
from dataclasses import dataclass , field , replace
6
8
7
- from packaging .markers import Marker as _Marker
9
+ from packaging .markers import default_environment
10
+ from packaging .specifiers import InvalidSpecifier , Specifier
11
+ from packaging .version import InvalidVersion
8
12
9
13
from dep_logic .markers .any import AnyMarker
10
- from dep_logic .markers .base import BaseMarker
14
+ from dep_logic .markers .base import BaseMarker , EvaluationContext
11
15
from dep_logic .markers .empty import EmptyMarker
12
16
from dep_logic .specifiers import BaseSpecifier
13
17
from dep_logic .specifiers .base import VersionSpecifier
14
18
from dep_logic .specifiers .generic import GenericSpecifier
15
- from dep_logic .utils import DATACLASS_ARGS , OrderedSet , get_reflect_op
19
+ from dep_logic .utils import DATACLASS_ARGS , OrderedSet , get_reflect_op , normalize_name
16
20
17
21
if t .TYPE_CHECKING :
18
22
from dep_logic .markers .multi import MultiMarker
19
23
from dep_logic .markers .union import MarkerUnion
20
24
21
25
PYTHON_VERSION_MARKERS = {"python_version" , "python_full_version" }
26
+ MARKERS_ALLOWING_SET = {"extras" , "dependency_groups" }
27
+ Operator = t .Callable [[str , t .Union [str , t .Set [str ]]], bool ]
28
+ _operators : dict [str , Operator ] = {
29
+ "in" : lambda lhs , rhs : lhs in rhs ,
30
+ "not in" : lambda lhs , rhs : lhs not in rhs ,
31
+ "<" : operator .lt ,
32
+ "<=" : operator .le ,
33
+ "==" : operator .eq ,
34
+ "!=" : operator .ne ,
35
+ ">=" : operator .ge ,
36
+ ">" : operator .gt ,
37
+ }
38
+
39
+
40
+ class UndefinedComparison (ValueError ):
41
+ pass
22
42
23
43
24
44
class SingleMarker (BaseMarker ):
@@ -44,16 +64,25 @@ def only(self, *marker_names: str) -> BaseMarker:
44
64
45
65
return self
46
66
47
- def evaluate (self , environment : dict [str , str ] | None = None ) -> bool :
48
- pkg_marker = _Marker (str (self ))
49
- if self .name != "extra" or not environment or "extra" not in environment :
50
- return pkg_marker .evaluate (environment )
51
- extras = [extra ] if isinstance (extra := environment ["extra" ], str ) else extra
52
- assert isinstance (self , MarkerExpression )
53
- is_negated = self .op in ("not in" , "!=" )
54
- if is_negated :
55
- return all (pkg_marker .evaluate ({"extra" : extra }) for extra in extras )
56
- return any (pkg_marker .evaluate ({"extra" : extra }) for extra in extras )
67
+ def evaluate (
68
+ self ,
69
+ environment : dict [str , str | set [str ]] | None = None ,
70
+ context : EvaluationContext = "metadata" ,
71
+ ) -> bool :
72
+ current_environment = t .cast ("dict[str, str|set[str]]" , default_environment ())
73
+ if context == "metadata" :
74
+ current_environment ["extra" ] = ""
75
+ elif context == "lock_file" :
76
+ current_environment .update (extras = set (), dependency_groups = set ())
77
+ if environment :
78
+ current_environment .update (environment )
79
+ if "extra" in current_environment and current_environment ["extra" ] is None :
80
+ current_environment ["extra" ] = ""
81
+ return self ._evaluate (current_environment )
82
+
83
+ @abstractmethod
84
+ def _evaluate (self , environment : dict [str , str | set [str ]]) -> bool :
85
+ raise NotImplementedError
57
86
58
87
59
88
@dataclass (unsafe_hash = True , ** DATACLASS_ARGS )
@@ -141,6 +170,46 @@ def __or__(self, other: t.Any) -> BaseMarker:
141
170
142
171
return MarkerUnion (self , other )
143
172
173
+ def _evaluate (self , environment : dict [str , str | set [str ]]) -> bool :
174
+ if self .name == "extra" :
175
+ # Support batch comparison for "extra" markers
176
+ extra = environment ["extra" ]
177
+ if isinstance (extra , str ):
178
+ extra = {extra }
179
+ assert self .op in ("==" , "!=" )
180
+ value = normalize_name (self .value )
181
+ extra = {normalize_name (v ) for v in extra }
182
+ return value in extra if self .op == "==" else value not in extra
183
+
184
+ target = environment [self .name ]
185
+ if self .reversed :
186
+ lhs , rhs = self .value , target
187
+ oper = _operators .get (get_reflect_op (self .op ))
188
+ else :
189
+ lhs , rhs = target , self .value
190
+ assert isinstance (lhs , str )
191
+ oper = _operators .get (self .op )
192
+ if self .name in MARKERS_ALLOWING_SET :
193
+ lhs = normalize_name (lhs )
194
+ if isinstance (rhs , set ):
195
+ rhs = {normalize_name (v ) for v in rhs }
196
+ else :
197
+ rhs = normalize_name (rhs )
198
+ if isinstance (rhs , str ):
199
+ try :
200
+ spec = Specifier (f"{ self .op } { rhs } " )
201
+ except InvalidSpecifier :
202
+ pass
203
+ else :
204
+ try :
205
+ return spec .contains (lhs )
206
+ except InvalidVersion :
207
+ pass
208
+
209
+ if oper is None :
210
+ raise UndefinedComparison (f"Undefined comparison { self } " )
211
+ return oper (lhs , rhs )
212
+
144
213
145
214
@dataclass (frozen = True , unsafe_hash = True , ** DATACLASS_ARGS )
146
215
class EqualityMarkerUnion (SingleMarker ):
@@ -210,6 +279,9 @@ def __or__(self, other: t.Any) -> BaseMarker:
210
279
__rand__ = __and__
211
280
__ror__ = __or__
212
281
282
+ def _evaluate (self , environment : dict [str , str | set [str ]]) -> bool :
283
+ return environment [self .name ] in self .values
284
+
213
285
214
286
@dataclass (frozen = True , unsafe_hash = True , ** DATACLASS_ARGS )
215
287
class InequalityMultiMarker (SingleMarker ):
@@ -283,6 +355,9 @@ def __or__(self, other: t.Any) -> BaseMarker:
283
355
__rand__ = __and__
284
356
__ror__ = __or__
285
357
358
+ def _evaluate (self , environment : dict [str , str | set [str ]]) -> bool :
359
+ return environment [self .name ] not in self .values
360
+
286
361
287
362
@functools .lru_cache (maxsize = None )
288
363
def _merge_single_markers (
@@ -375,5 +450,5 @@ def _normalize_python_version_specifier(marker: MarkerExpression) -> BaseSpecifi
375
450
splitted [- 1 ] = str (int (splitted [- 1 ]) + 1 )
376
451
op = "<"
377
452
378
- spec = parse_version_specifier (f' { op } { "." .join (splitted )} ' )
453
+ spec = parse_version_specifier (f" { op } { '.' .join (splitted )} " )
379
454
return spec
0 commit comments