Skip to content
This repository was archived by the owner on Feb 19, 2023. It is now read-only.

Commit 18d81ef

Browse files
committed
added checks for return type annotations and variable assignment annotations
1 parent 9e1998f commit 18d81ef

File tree

2 files changed

+137
-3
lines changed

2 files changed

+137
-3
lines changed

pandas_dev_flaker/_plugins_tree/disallow_argument_types.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,57 @@
66
MSG = "PDF026 found union between Series and AnyArrayLike in type hint"
77

88

9+
# for function arguments annotations
910
@register(ast.FunctionDef)
10-
def visit_arg(
11+
def visit_FunctionDef(
1112
state: State,
1213
node: ast.FunctionDef,
1314
parent: ast.AST,
1415
) -> Iterator[Tuple[int, int, str]]:
1516
arguments = node.args.args
1617
for arg in arguments:
17-
if isinstance(arg.annotation, ast.BinOp) and parse_annotation(
18+
if isinstance(
19+
arg.annotation,
20+
ast.BinOp,
21+
) and _contains_series_and_arraylike(
1822
arg.annotation,
1923
):
2024
yield arg.lineno, arg.col_offset, MSG
2125

2226

23-
def parse_annotation(node: ast.BinOp) -> bool:
27+
# for return arguments annotations
28+
@register(ast.FunctionDef)
29+
def visit_args(
30+
state: State,
31+
node: ast.FunctionDef,
32+
parent: ast.AST,
33+
) -> Iterator[Tuple[int, int, str]]:
34+
return_arguments = node.returns # BinOp, Name, Subscript
35+
if (
36+
isinstance(
37+
return_arguments,
38+
ast.BinOp,
39+
)
40+
and _contains_series_and_arraylike(return_arguments)
41+
):
42+
yield node.lineno, node.col_offset, MSG
43+
44+
45+
# for annotations defined outside function args & return args
46+
@register(ast.AnnAssign)
47+
def visit_AnnAssign(
48+
state: State,
49+
node: ast.AnnAssign,
50+
parent: ast.AST,
51+
) -> Iterator[Tuple[int, int, str]]:
52+
annotations = node.annotation
53+
if isinstance(annotations, ast.BinOp) and _contains_series_and_arraylike(
54+
annotations,
55+
):
56+
yield node.lineno, node.col_offset, MSG
57+
58+
59+
def _contains_series_and_arraylike(node: ast.BinOp) -> bool:
2460
series, any_array_like = "Series", "AnyArrayLike"
2561
is_series, is_array_like = False, False
2662

tests/disallow_argument_types_test.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def results(s):
4949
"def f(a: Callable[..., T] | DataFrame | list[int]): pass",
5050
id="Function annotation containing Subscript type",
5151
),
52+
pytest.param(
53+
"def f(a: DataFrame | list[int]) -> int | str: pass",
54+
id="Function return annotation containing Subscript type",
55+
),
5256
),
5357
)
5458
def test_noop(source):
@@ -96,3 +100,97 @@ def test_noop(source):
96100
def test_violation(source, expected):
97101
(result,) = results(source)
98102
assert result == expected
103+
104+
105+
@pytest.mark.parametrize(
106+
"source",
107+
(
108+
pytest.param(
109+
"def f(foo) -> int | str | bool: pass",
110+
id="Function with multiple return type annotations",
111+
),
112+
pytest.param(
113+
"def foo(bar: list[int]): pass",
114+
id="Function with no return type",
115+
),
116+
pytest.param(
117+
"def foo(self, bar: int) -> int: pass",
118+
id="Function with one return type annotation",
119+
),
120+
),
121+
)
122+
def test_noop2(source):
123+
assert not results(source)
124+
125+
126+
@pytest.mark.parametrize(
127+
"source, expected",
128+
(
129+
pytest.param(
130+
"def bar(foo, other: tuple[Callable[..., T]] | "
131+
"Series | list[int]) -> Series | AnyArrayLike | "
132+
"DataFrame: pass",
133+
"1:0: PDF026 found union between Series and "
134+
"AnyArrayLike in "
135+
"type hint",
136+
id="found union between Series and AnyArrayLike "
137+
"in return annotations",
138+
),
139+
pytest.param(
140+
"def bar(foo: int, other: tuple[Callable[..., T]] | "
141+
"Series | list[int]) -> Series | AnyArrayLike: pass",
142+
"1:0: PDF026 found union between Series and "
143+
"AnyArrayLike in "
144+
"type hint",
145+
id="found union between Series and AnyArrayLike "
146+
"in return annotations",
147+
),
148+
),
149+
)
150+
def test_violation2(source, expected):
151+
(result,) = results(source)
152+
assert result == expected
153+
154+
155+
@pytest.mark.parametrize(
156+
"source",
157+
(
158+
pytest.param(
159+
"foo: str = 'string variable'",
160+
id="Assignment with one annotation",
161+
),
162+
pytest.param(
163+
"self.bar: DataFrame | Timezone = [1, 2, 3]",
164+
id="Assignment with multiple annotations",
165+
),
166+
pytest.param("cls.foo = 3", id="Assignment with no annotation"),
167+
),
168+
)
169+
def test_noop3(source):
170+
assert not results(source)
171+
172+
173+
@pytest.mark.parametrize(
174+
"source, expected",
175+
(
176+
pytest.param(
177+
"self.foo: AnyArrayLike | Timezone | Series = 2",
178+
"1:0: PDF026 found union between Series and "
179+
"AnyArrayLike in "
180+
"type hint",
181+
id="found union between Series and AnyArrayLike "
182+
"in variable assignment",
183+
),
184+
pytest.param(
185+
"cls.foo: AnyArrayLike | Series = 2",
186+
"1:0: PDF026 found union between Series and "
187+
"AnyArrayLike in "
188+
"type hint",
189+
id="found union between Series and AnyArrayLike "
190+
"in variable assignment",
191+
),
192+
),
193+
)
194+
def test_violation3(source, expected):
195+
(result,) = results(source)
196+
assert result == expected

0 commit comments

Comments
 (0)