Skip to content

Commit c1f4d11

Browse files
committed
Add test for Python 3.10+ unions
1 parent 0e91bf9 commit c1f4d11

File tree

4 files changed

+68
-9
lines changed

4 files changed

+68
-9
lines changed

python/coglet/adt.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
def _is_union(tpe: type) -> bool:
1414
if typing.get_origin(tpe) is Union:
1515
return True
16-
if sys.version_info[0] > 3 or (
17-
sys.version_info[0] == 3 and sys.version_info[1] >= 10
18-
):
16+
if sys.version_info >= (3, 10):
1917
from types import UnionType
2018

2119
if typing.get_origin(tpe) is UnionType:
@@ -144,7 +142,7 @@ def from_type(tpe: type):
144142
assert len(t_args) == 2 and type(None) in t_args, (
145143
f'unsupported union type {tpe}'
146144
)
147-
elem_t = t_args[0] if t_args[1] is type(None) else t_args[0]
145+
elem_t = t_args[0] if t_args[1] is type(None) else t_args[1]
148146
# Fail fast to avoid the cryptic "unsupported Cog type" error later with elem_t
149147
nested_t = typing.get_origin(elem_t)
150148
assert nested_t is None, f'Optional cannot have nested type {nested_t}'

python/tests/runners/unions.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from typing import Optional
2+
3+
from cog import BasePredictor
4+
from tests.util import check_python_version
5+
6+
check_python_version(min_version=(3, 10))
7+
8+
FIXTURE = [
9+
(
10+
{
11+
'os1': 'foo0',
12+
'os2': 'bar0',
13+
'os3': 'baz0',
14+
},
15+
'foo0-bar0-baz0',
16+
),
17+
]
18+
19+
20+
class Predictor(BasePredictor):
21+
test_inputs = {
22+
'os1': 'foo',
23+
'os2': 'bar',
24+
'os3': 'baz',
25+
}
26+
setup_done = False
27+
28+
def setup(self) -> None:
29+
self.setup_done = True
30+
31+
def predict(
32+
self,
33+
os1: Optional[str],
34+
os2: str | None,
35+
os3: None | str,
36+
) -> str:
37+
return f'{os1}-{os2}-{os3}'

python/tests/test_test_inputs.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from coglet import inspector, runner, scope
8+
from tests.util import PythonVersionError
89

910

1011
def get_predictors() -> List[str]:
@@ -21,9 +22,12 @@ async def test_test_inputs(predictor):
2122
if predictor.startswith('function_'):
2223
entrypoint = 'predict'
2324

24-
p = inspector.create_predictor(module_name, entrypoint)
25-
r = runner.Runner(p)
25+
try:
26+
p = inspector.create_predictor(module_name, entrypoint)
27+
r = runner.Runner(p)
2628

27-
# Some predictors calls current_scope() and requires ctx_pid
28-
scope.ctx_pid.set(predictor)
29-
assert await r.test()
29+
# Some predictors calls current_scope() and requires ctx_pid
30+
scope.ctx_pid.set(predictor)
31+
assert await r.test()
32+
except PythonVersionError as e:
33+
pytest.skip(reason=str(e))

python/tests/util.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import sys
2+
from typing import Optional, Tuple
3+
4+
5+
def check_python_version(
6+
min_version: Optional[Tuple[int, int]] = None,
7+
max_version: Optional[Tuple[int, int]] = None,
8+
) -> None:
9+
if min_version is not None and sys.version_info < min_version:
10+
raise PythonVersionError(
11+
f'Python version must be >= {min_version[0]}.{min_version[1]}'
12+
)
13+
if max_version is not None and sys.version_info > max_version:
14+
raise PythonVersionError(
15+
f'Python version must be <= {max_version[0]}.{max_version[1]}'
16+
)
17+
18+
19+
class PythonVersionError(Exception):
20+
pass

0 commit comments

Comments
 (0)