Skip to content

Commit fd72b03

Browse files
authored
Allow Union and List input types (#1311)
* Allow Union and List input types Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Update TypeError messages to note support for Union and List types Signed-off-by: Mattt Zmuda <mattt@replicate.com> --------- Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 6123056 commit fd72b03

File tree

4 files changed

+67
-15
lines changed

4 files changed

+67
-15
lines changed

python/cog/predictor.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,19 @@ def get_predict(predictor: Any) -> Callable:
214214
return predictor.predict
215215
return predictor
216216

217+
def validate_input_type(type: Type, name: str) -> None:
218+
if type is inspect.Signature.empty:
219+
raise TypeError(
220+
f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types."
221+
)
222+
elif type not in ALLOWED_INPUT_TYPES:
223+
if hasattr(type, "__origin__") and (type.__origin__ is Union or type.__origin__ is list):
224+
for t in get_args(type):
225+
validate_input_type(t, name)
226+
else:
227+
raise TypeError(
228+
f"Unsupported input type {human_readable_type_name(type)} for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types."
229+
)
217230

218231
def get_input_type(predictor: BasePredictor) -> Type[BaseInput]:
219232
"""
@@ -238,14 +251,7 @@ class Input(BaseModel):
238251
for name, parameter in signature.parameters.items():
239252
InputType = parameter.annotation
240253

241-
if InputType is inspect.Signature.empty:
242-
raise TypeError(
243-
f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}."
244-
)
245-
elif InputType not in ALLOWED_INPUT_TYPES:
246-
raise TypeError(
247-
f"Unsupported input type {human_readable_type_name(InputType)} for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}."
248-
)
254+
validate_input_type(InputType, name)
249255

250256
# if no default is specified, create an empty, required input
251257
if parameter.default is inspect.Signature.empty:
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from cog import BasePredictor, Path
2+
3+
from typing import List, Union
4+
5+
class Predictor(BasePredictor):
6+
def predict(self, args: Union[int, List[int]]) -> int:
7+
if isinstance(args, int):
8+
return args
9+
else:
10+
return sum(args)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from cog import BasePredictor, Path
2+
3+
from typing import List, Union
4+
5+
class Predictor(BasePredictor):
6+
def predict(self, args: Union[str, List[str]]) -> str:
7+
if isinstance(args, str):
8+
return args
9+
else:
10+
return "".join(args)

python/tests/server/test_http_input.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@ def test_empty_input(client, match):
3030
assert resp.json() == match({"status": "succeeded", "output": "foobar"})
3131

3232

33-
@uses_predictor("input_string")
34-
def test_good_str_input(client, match):
35-
resp = client.post("/predictions", json={"input": {"text": "baz"}})
36-
assert resp.status_code == 200
37-
assert resp.json() == match({"status": "succeeded", "output": "baz"})
38-
39-
4033
@uses_predictor("input_integer")
4134
def test_good_int_input(client, match):
4235
resp = client.post("/predictions", json={"input": {"num": 3}})
@@ -212,6 +205,39 @@ def test_choices_int(client):
212205
assert resp.status_code == 422
213206

214207

208+
@uses_predictor("input_union_string_or_list_of_strings")
209+
def test_union_strings(client):
210+
resp = client.post("/predictions", json={"input": {"args": "abc"}})
211+
assert resp.status_code == 200
212+
assert resp.json()["output"] == "abc"
213+
214+
resp = client.post("/predictions", json={"input": {"args": ["a", "b", "c"]}})
215+
assert resp.status_code == 200
216+
assert resp.json()["output"] == "abc"
217+
218+
# FIXME: Numbers are successfully cast to strings, but maybe shouldn't be
219+
# resp = client.post("/predictions", json={"input": {"args": 123}})
220+
# assert resp.status_code == 422
221+
# resp = client.post("/predictions", json={"input": {"args": [1, 2, 3]}})
222+
# assert resp.status_code == 422
223+
224+
225+
@uses_predictor("input_union_integer_or_list_of_integers")
226+
def test_union_integers(client):
227+
resp = client.post("/predictions", json={"input": {"args": 123}})
228+
assert resp.status_code == 200
229+
assert resp.json()["output"] == 123
230+
231+
resp = client.post("/predictions", json={"input": {"args": [1, 2, 3]}})
232+
assert resp.status_code == 200
233+
assert resp.json()["output"] == 6
234+
235+
resp = client.post("/predictions", json={"input": {"args": "abc"}})
236+
assert resp.status_code == 422
237+
resp = client.post("/predictions", json={"input": {"args": ["a", "b", "c"]}})
238+
assert resp.status_code == 422
239+
240+
215241
def test_untyped_inputs():
216242
with pytest.raises(TypeError):
217243
make_client("input_untyped")

0 commit comments

Comments
 (0)