Skip to content

Commit daebcee

Browse files
authored
Ensure protected names (except for classes) are not imported, mainly to avoid importing _S and other TypeVars (#83)
1 parent 6062e85 commit daebcee

File tree

9 files changed

+132
-62
lines changed

9 files changed

+132
-62
lines changed

src/spellbind/float_values.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from typing_extensions import TYPE_CHECKING
1010

1111
from spellbind.bool_values import BoolValue
12-
from spellbind.functions import _clamp_float, _multiply_all_floats
12+
from spellbind.numbers import multiply_all_floats, clamp_float
1313
from spellbind.values import Value, SimpleVariable, OneToOneValue, DerivedValueBase, Constant, \
14-
NotConstantError, ThreeToOneValue, _create_value_getter, get_constant_of_generic_like
14+
NotConstantError, ThreeToOneValue, create_value_getter, get_constant_of_generic_like
1515

1616
if TYPE_CHECKING:
1717
from spellbind.int_values import IntValue, IntLike # pragma: no cover
@@ -41,10 +41,10 @@ def __rsub__(self, other: int | float) -> FloatValue:
4141
return FloatValue.derive_from_two(operator.sub, other, self)
4242

4343
def __mul__(self, other: FloatLike) -> FloatValue:
44-
return FloatValue.derive_from_many(_multiply_all_floats, self, other, is_associative=True)
44+
return FloatValue.derive_from_many(multiply_all_floats, self, other, is_associative=True)
4545

4646
def __rmul__(self, other: int | float) -> FloatValue:
47-
return FloatValue.derive_from_many(_multiply_all_floats, other, self, is_associative=True)
47+
return FloatValue.derive_from_many(multiply_all_floats, other, self, is_associative=True)
4848

4949
def __truediv__(self, other: FloatLike) -> FloatValue:
5050
return FloatValue.derive_from_two(operator.truediv, self, other)
@@ -110,7 +110,7 @@ def __pos__(self) -> Self:
110110
return self
111111

112112
def clamp(self, min_value: FloatLike, max_value: FloatLike) -> FloatValue:
113-
return FloatValue.derive_from_three_floats(_clamp_float, self, min_value, max_value)
113+
return FloatValue.derive_from_three_floats(clamp_float, self, min_value, max_value)
114114

115115
def decompose_float_operands(self, operator_: Callable[..., float]) -> Sequence[FloatLike]:
116116
return (self,)
@@ -204,7 +204,7 @@ def sum_floats(*values: FloatLike) -> FloatValue:
204204

205205

206206
def multiply_floats(*values: FloatLike) -> FloatValue:
207-
return FloatValue.derive_from_many(_multiply_all_floats, *values, is_associative=True)
207+
return FloatValue.derive_from_many(multiply_all_floats, *values, is_associative=True)
208208

209209

210210
class OneToFloatValue(Generic[_S], OneToOneValue[_S, float], FloatValue):
@@ -327,7 +327,7 @@ def __init__(self, transformer: Callable[[float, int], _S],
327327
self._of_first = first
328328
self._of_second = second
329329
self._first_getter = _create_float_getter(first)
330-
self._second_getter = _create_value_getter(second)
330+
self._second_getter = create_value_getter(second)
331331
super().__init__(*[v for v in (first, second) if isinstance(v, Value)])
332332

333333
@override

src/spellbind/functions.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import inspect
22
from inspect import Parameter
3-
from typing import Callable, Iterable, Any
3+
from typing import Callable, Any
44

55

6-
def _is_positional_parameter(param: Parameter) -> bool:
6+
def is_positional_parameter(param: Parameter) -> bool:
77
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
88

99

@@ -14,16 +14,16 @@ def has_var_args(function: Callable[..., Any]) -> bool:
1414

1515
def count_positional_parameters(function: Callable[..., Any]) -> int:
1616
parameters = inspect.signature(function).parameters
17-
return sum(1 for parameter in parameters.values() if _is_positional_parameter(parameter))
17+
return sum(1 for parameter in parameters.values() if is_positional_parameter(parameter))
1818

1919

20-
def _is_required_positional_parameter(param: Parameter) -> bool:
21-
return param.default == param.empty and _is_positional_parameter(param)
20+
def is_required_positional_parameter(param: Parameter) -> bool:
21+
return param.default == param.empty and is_positional_parameter(param)
2222

2323

2424
def count_non_default_parameters(function: Callable[..., Any]) -> int:
2525
parameters = inspect.signature(function).parameters
26-
return sum(1 for param in parameters.values() if _is_required_positional_parameter(param))
26+
return sum(1 for param in parameters.values() if is_required_positional_parameter(param))
2727

2828

2929
def assert_parameter_max_count(callable_: Callable[..., Any], max_count: int) -> None:
@@ -36,33 +36,3 @@ def assert_parameter_max_count(callable_: Callable[..., Any], max_count: int) ->
3636
callable_name = str(callable_) # pragma: no cover
3737
raise ValueError(f"Callable {callable_name} has too many non-default parameters: "
3838
f"{count_non_default_parameters(callable_)} > {max_count}")
39-
40-
41-
def _multiply_all_ints(vals: Iterable[int]) -> int:
42-
result = 1
43-
for val in vals:
44-
result *= val
45-
return result
46-
47-
48-
def _multiply_all_floats(vals: Iterable[float]) -> float:
49-
result = 1.
50-
for val in vals:
51-
result *= val
52-
return result
53-
54-
55-
def _clamp_int(value: int, min_value: int, max_value: int) -> int:
56-
if value < min_value:
57-
return min_value
58-
elif value > max_value:
59-
return max_value
60-
return value
61-
62-
63-
def _clamp_float(value: float, min_value: float, max_value: float) -> float:
64-
if value < min_value:
65-
return min_value
66-
elif value > max_value:
67-
return max_value
68-
return value

src/spellbind/int_collections.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33
import operator
44
from abc import ABC, abstractmethod
55
from functools import cached_property
6-
from typing import Iterable, Callable, Any
6+
from typing import Iterable, Callable, Any, TypeVar
77

88
from typing_extensions import TypeIs, override
99

1010
from spellbind.int_values import IntValue, IntConstant
1111
from spellbind.observable_collections import ObservableCollection, ReducedValue, CombinedValue, ValueCollection
12-
from spellbind.observable_sequences import ObservableList, _S, TypedValueList, ValueSequence, UnboxedValueSequence, \
12+
from spellbind.observable_sequences import ObservableList, TypedValueList, ValueSequence, UnboxedValueSequence, \
1313
ObservableSequence
1414
from spellbind.values import Value
1515

1616

17+
_S = TypeVar("_S")
18+
19+
1720
class ObservableIntCollection(ObservableCollection[int], ABC):
1821
@property
1922
def summed(self) -> IntValue:

src/spellbind/int_values.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from spellbind.bool_values import BoolValue
1010
from spellbind.float_values import FloatValue, \
1111
CompareNumbersValues
12-
from spellbind.functions import _clamp_int, _multiply_all_ints, _multiply_all_floats
12+
from spellbind.numbers import multiply_all_ints, multiply_all_floats, clamp_int
1313
from spellbind.values import Value, SimpleVariable, TwoToOneValue, OneToOneValue, Constant, \
1414
ThreeToOneValue, NotConstantError, ManyToSameValue, get_constant_of_generic_like
1515

@@ -75,8 +75,8 @@ def __mul__(self, other: float | FloatValue) -> FloatValue: ...
7575

7676
def __mul__(self, other: FloatLike) -> IntValue | FloatValue:
7777
if isinstance(other, (float, FloatValue)):
78-
return FloatValue.derive_from_many(_multiply_all_floats, self, other, is_associative=True)
79-
return IntValue.derive_from_many(_multiply_all_ints, self, other, is_associative=True)
78+
return FloatValue.derive_from_many(multiply_all_floats, self, other, is_associative=True)
79+
return IntValue.derive_from_many(multiply_all_ints, self, other, is_associative=True)
8080

8181
@overload
8282
def __rmul__(self, other: int) -> IntValue: ...
@@ -86,8 +86,8 @@ def __rmul__(self, other: float) -> FloatValue: ...
8686

8787
def __rmul__(self, other: int | float) -> IntValue | FloatValue:
8888
if isinstance(other, float):
89-
return FloatValue.derive_from_many(_multiply_all_floats, other, self, is_associative=True)
90-
return IntValue.derive_from_many(_multiply_all_ints, other, self, is_associative=True)
89+
return FloatValue.derive_from_many(multiply_all_floats, other, self, is_associative=True)
90+
return IntValue.derive_from_many(multiply_all_ints, other, self, is_associative=True)
9191

9292
def __truediv__(self, other: FloatLike) -> FloatValue:
9393
return FloatValue.derive_from_two(operator.truediv, self, other)
@@ -135,7 +135,7 @@ def __pos__(self) -> Self:
135135
return self
136136

137137
def clamp(self, min_value: IntLike, max_value: IntLike) -> IntValue:
138-
return IntValue.derive_from_three(_clamp_int, self, min_value, max_value)
138+
return IntValue.derive_from_three(clamp_int, self, min_value, max_value)
139139

140140
@classmethod
141141
def derive_from_one(cls, operator_: Callable[[_S], int], value: _S | Value[_S]) -> IntValue:

src/spellbind/numbers.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Iterable
2+
3+
4+
def multiply_all_ints(vals: Iterable[int]) -> int:
5+
result = 1
6+
for val in vals:
7+
result *= val
8+
return result
9+
10+
11+
def multiply_all_floats(vals: Iterable[float]) -> float:
12+
result = 1.
13+
for val in vals:
14+
result *= val
15+
return result
16+
17+
18+
def clamp_int(value: int, min_value: int, max_value: int) -> int:
19+
if value < min_value:
20+
return min_value
21+
elif value > max_value:
22+
return max_value
23+
return value
24+
25+
26+
def clamp_float(value: float, min_value: float, max_value: float) -> float:
27+
if value < min_value:
28+
return min_value
29+
elif value > max_value:
30+
return max_value
31+
return value

src/spellbind/str_collections.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from abc import ABC
2-
from typing import Iterable, Callable, Any
2+
from typing import Iterable, Callable, Any, TypeVar
33

44
from typing_extensions import TypeIs
55

66
from spellbind.int_values import IntValue
77
from spellbind.observable_collections import ObservableCollection, ReducedValue, CombinedValue
8-
from spellbind.observable_sequences import ObservableList, _S, TypedValueList
8+
from spellbind.observable_sequences import ObservableList, TypedValueList
99
from spellbind.str_values import StrValue, StrConstant
1010
from spellbind.values import Value
1111

1212

13+
_S = TypeVar("_S")
14+
15+
1316
class ObservableStrCollection(ObservableCollection[str], ABC):
1417
@property
1518
def concatenated(self) -> StrValue:

src/spellbind/values.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
_W = TypeVar("_W")
2828

2929

30-
def _create_value_getter(value: Value[_S] | _S) -> Callable[[], _S]:
30+
def create_value_getter(value: Value[_S] | _S) -> Callable[[], _S]:
3131
if isinstance(value, Value):
3232
return lambda: value.value
3333
else:
@@ -385,7 +385,7 @@ class OneToOneValue(DerivedValueBase[_T], Generic[_S, _T]):
385385
_getter: Callable[[], _S]
386386

387387
def __init__(self, transformer: Callable[[_S], _T], of: Value[_S]) -> None:
388-
self._getter = _create_value_getter(of)
388+
self._getter = create_value_getter(of)
389389
self._of = of
390390
self._transformer = transformer
391391
super().__init__(*[v for v in (of,) if isinstance(v, Value)])
@@ -398,7 +398,7 @@ def _calculate_value(self) -> _T:
398398
class ManyToOneValue(DerivedValueBase[_T], Generic[_S, _T]):
399399
def __init__(self, transformer: Callable[[Iterable[_S]], _T], *values: _S | Value[_S]):
400400
self._input_values = tuple(values)
401-
self._value_getters = [_create_value_getter(v) for v in self._input_values]
401+
self._value_getters = [create_value_getter(v) for v in self._input_values]
402402
self._transformer = transformer
403403
super().__init__(*[v for v in self._input_values if isinstance(v, Value)])
404404

@@ -422,8 +422,8 @@ def __init__(self, transformer: Callable[[_S, _T], _U],
422422
self._transformer = transformer
423423
self._of_first = first
424424
self._of_second = second
425-
self._first_getter = _create_value_getter(first)
426-
self._second_getter = _create_value_getter(second)
425+
self._first_getter = create_value_getter(first)
426+
self._second_getter = create_value_getter(second)
427427
super().__init__(*[v for v in (first, second) if isinstance(v, Value)])
428428

429429
@override
@@ -438,9 +438,9 @@ def __init__(self, transformer: Callable[[_S, _T, _U], _V],
438438
self._of_first = first
439439
self._of_second = second
440440
self._of_third = third
441-
self._first_getter = _create_value_getter(first)
442-
self._second_getter = _create_value_getter(second)
443-
self._third_getter = _create_value_getter(third)
441+
self._first_getter = create_value_getter(first)
442+
self._second_getter = create_value_getter(second)
443+
self._third_getter = create_value_getter(third)
444444
super().__init__(*[v for v in (first, second, third) if isinstance(v, Value)])
445445

446446
@override

tests/conftest.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import ast
12
from contextlib import contextmanager
2-
from typing import Any, Sequence, Callable
3+
from pathlib import Path
4+
from typing import Any, Sequence, Callable, Generator, Tuple
35
from typing import Iterable, overload, Collection
46
from unittest.mock import Mock
57

@@ -13,6 +15,54 @@
1315
_S = TypeVar("_S")
1416

1517

18+
PROJECT_ROOT_PATH = Path(__file__).parent.parent.resolve()
19+
SOURCE_PATH = PROJECT_ROOT_PATH / "src"
20+
21+
22+
def iter_python_files(source_path: Path) -> Generator[Path, None, None]:
23+
yield from source_path.rglob("*.py")
24+
25+
26+
def is_class_definition(module_path: Path, object_name: str) -> bool:
27+
text = module_path.read_text(encoding="utf-8")
28+
node = ast.parse(text, filename=str(module_path))
29+
for item in node.body:
30+
if hasattr(item, "name") and getattr(item, "name") == object_name:
31+
if isinstance(item, ast.ClassDef):
32+
return True
33+
else:
34+
return False
35+
return False
36+
37+
38+
def resolve_module_path(base_path: Path, module: str) -> Path:
39+
unfinished_module_path = base_path / Path(*module.split("."))
40+
init_path = unfinished_module_path / "__init__.py"
41+
if init_path.exists():
42+
return init_path
43+
file_path = unfinished_module_path.with_suffix(".py")
44+
return file_path
45+
46+
47+
def is_class_import(alias: ast.alias, import_: ast.ImportFrom, source_root: Path = SOURCE_PATH) -> bool:
48+
module = import_.module
49+
if module is None:
50+
return False
51+
module_path = resolve_module_path(source_root, module)
52+
if module_path is None:
53+
return False
54+
return is_class_definition(module_path, alias.name)
55+
56+
57+
def iter_imported_aliases(file_path: Path) -> Generator[Tuple[ast.alias, ast.ImportFrom], None, None]:
58+
text = file_path.read_text(encoding="utf-8")
59+
node = ast.parse(text, filename=str(file_path))
60+
for statement in ast.walk(node):
61+
if isinstance(statement, ast.ImportFrom):
62+
for alias_ in statement.names:
63+
yield alias_, statement
64+
65+
1666
class Call:
1767
def __init__(self, *args, **kwargs):
1868
self.args = args

tests/test_imports.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
3+
from conftest import iter_imported_aliases, SOURCE_PATH, iter_python_files, is_class_import
4+
5+
6+
def test_no_protected_imports_except_for_classes():
7+
lines = []
8+
for file in iter_python_files(SOURCE_PATH):
9+
for alias_, statement in iter_imported_aliases(file):
10+
if alias_.name.startswith("_") and not is_class_import(alias_, statement):
11+
lines.append(f"{file}:{statement.lineno}: imports protected name '{alias_.name}'")
12+
if len(lines) > 0:
13+
pytest.fail(f"Found {len(lines)} protected imports\n" + "\n".join(lines))

0 commit comments

Comments
 (0)