Skip to content

Commit 69a75f7

Browse files
committed
Refactor field repetition
* Support `<T>`, `Optional[<T>]` and `List[<T>]` as repetition levels * `default=None` or absence means `None` default value for optional fields * Otherwise the field is required
1 parent 78919cc commit 69a75f7

File tree

9 files changed

+366
-250
lines changed

9 files changed

+366
-250
lines changed

python/coglet/adt.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from coglet import api
77

88

9-
class Type(Enum):
9+
class PrimitiveType(Enum):
1010
BOOL = 1
1111
FLOAT = 2
1212
INTEGER = 3
@@ -15,10 +15,22 @@ class Type(Enum):
1515
SECRET = 6
1616

1717

18-
NUMERIC_TYPES = {Type.FLOAT, Type.INTEGER}
19-
PATH_TYPES = {Type.STRING, Type.PATH}
20-
SECRET_TYPES = {Type.STRING, Type.SECRET}
21-
CHOICE_TYPES = {Type.INTEGER, Type.STRING}
18+
class Repetition(Enum):
19+
REQUIRED = 1
20+
OPTIONAL = 2
21+
REPEATED = 3
22+
23+
24+
@dataclass(frozen=True)
25+
class FieldType:
26+
primitive: PrimitiveType
27+
repetition: Repetition
28+
29+
30+
NUMERIC_TYPES = {PrimitiveType.FLOAT, PrimitiveType.INTEGER}
31+
PATH_TYPES = {PrimitiveType.STRING, PrimitiveType.PATH}
32+
SECRET_TYPES = {PrimitiveType.STRING, PrimitiveType.SECRET}
33+
CHOICE_TYPES = {PrimitiveType.INTEGER, PrimitiveType.STRING}
2234

2335

2436
class Kind(Enum):
@@ -37,31 +49,31 @@ class Kind(Enum):
3749

3850
# Python types to Cog types
3951
PYTHON_TO_COG = {
40-
bool: Type.BOOL,
41-
float: Type.FLOAT,
42-
int: Type.INTEGER,
43-
str: Type.STRING,
44-
api.Path: Type.PATH,
45-
api.Secret: Type.SECRET,
52+
bool: PrimitiveType.BOOL,
53+
float: PrimitiveType.FLOAT,
54+
int: PrimitiveType.INTEGER,
55+
str: PrimitiveType.STRING,
56+
api.Path: PrimitiveType.PATH,
57+
api.Secret: PrimitiveType.SECRET,
4658
}
4759

4860
# Cog types to JSON types
4961
COG_TO_JSON = {
50-
Type.BOOL: 'boolean',
51-
Type.FLOAT: 'number',
52-
Type.INTEGER: 'integer',
53-
Type.STRING: 'string',
54-
Type.PATH: 'string',
55-
Type.SECRET: 'string',
62+
PrimitiveType.BOOL: 'boolean',
63+
PrimitiveType.FLOAT: 'number',
64+
PrimitiveType.INTEGER: 'integer',
65+
PrimitiveType.STRING: 'string',
66+
PrimitiveType.PATH: 'string',
67+
PrimitiveType.SECRET: 'string',
5668
}
5769

5870
# JSON types to Cog types
5971
# PATH and SECRET depend on format field
6072
JSON_TO_COG = {
61-
'boolean': Type.BOOL,
62-
'number': Type.FLOAT,
63-
'integer': Type.INTEGER,
64-
'string': Type.STRING,
73+
'boolean': PrimitiveType.BOOL,
74+
'number': PrimitiveType.FLOAT,
75+
'integer': PrimitiveType.INTEGER,
76+
'string': PrimitiveType.STRING,
6577
}
6678

6779
# Python container types to Cog types
@@ -76,8 +88,7 @@ class Kind(Enum):
7688
class Input:
7789
name: str
7890
order: int
79-
type: Type
80-
is_list: bool
91+
type: FieldType
8192
default: Any = None
8293
description: Optional[str] = None
8394
ge: Optional[Union[int, float]] = None
@@ -91,8 +102,8 @@ class Input:
91102
@dataclass(frozen=True)
92103
class Output:
93104
kind: Kind
94-
type: Optional[Type] = None
95-
fields: Optional[Dict[str, Type]] = None
105+
type: Optional[PrimitiveType] = None
106+
fields: Optional[Dict[str, FieldType]] = None
96107

97108

98109
@dataclass(frozen=True)

python/coglet/inspector.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,11 @@ def _validate_predict(f: Callable) -> None:
3636
assert spec.annotations.get('return') is not None, 'predict() must not return None'
3737

3838

39-
def _validate_input(
40-
name: str, cog_t: adt.Type, is_list: bool, cog_in: api.Input
41-
) -> None:
39+
def _validate_input(name: str, ft: adt.FieldType, cog_in: api.Input) -> None:
4240
defaults = []
41+
cog_t = ft.primitive
4342
if cog_in.default is not None:
44-
if is_list:
43+
if ft.repetition is adt.Repetition.REPEATED:
4544
assert type(cog_in.default) is list, (
4645
f'default must be a list for input: {name}'
4746
)
@@ -67,7 +66,7 @@ def _validate_input(
6766
)
6867

6968
if cog_in.min_length is not None or cog_in.max_length is not None:
70-
assert cog_t is adt.Type.STRING, (
69+
assert cog_t is adt.PrimitiveType.STRING, (
7170
f'incompatible input type for min_length/max_length: {name}'
7271
)
7372
if cog_in.min_length is not None:
@@ -80,7 +79,9 @@ def _validate_input(
8079
)
8180

8281
if cog_in.regex is not None:
83-
assert cog_t is adt.Type.STRING, f'incompatible input type for regex: {name}'
82+
assert cog_t is adt.PrimitiveType.STRING, (
83+
f'incompatible input type for regex: {name}'
84+
)
8485
regex = re.compile(cog_in.regex)
8586
assert all(regex.match(x) for x in defaults), (
8687
f'not all defaults match regex for input: {name}'
@@ -103,29 +104,28 @@ def _validate_input(
103104
def _input_adt(
104105
order: int, name: str, tpe: type, cog_in: Optional[api.Input]
105106
) -> adt.Input:
106-
cog_t, is_list = util.check_cog_type(tpe)
107-
assert cog_t is not None, f'unsupported input type for {name}'
107+
ft = util.get_field_type(tpe)
108108
if cog_in is None:
109109
return adt.Input(
110110
name=name,
111111
order=order,
112-
type=cog_t,
113-
is_list=is_list,
112+
type=ft,
114113
)
115114
else:
116-
_validate_input(name, cog_t, is_list, cog_in)
115+
_validate_input(name, ft, cog_in)
117116
if cog_in.default is None:
118117
default = None
119118
else:
120-
if is_list:
121-
default = [util.normalize_value(cog_t, x) for x in cog_in.default]
119+
if ft.repetition is adt.Repetition.REPEATED:
120+
default = [
121+
util.normalize_value(ft.primitive, x) for x in cog_in.default
122+
]
122123
else:
123-
default = util.normalize_value(cog_t, cog_in.default)
124+
default = util.normalize_value(ft.primitive, cog_in.default)
124125
return adt.Input(
125126
name=name,
126127
order=order,
127-
type=cog_t,
128-
is_list=is_list,
128+
type=ft,
129129
default=default,
130130
description=cog_in.description,
131131
ge=float(cog_in.ge) if cog_in.ge is not None else None,
@@ -142,9 +142,11 @@ def _output_adt(tpe: type) -> adt.Output:
142142
assert tpe.__name__ == 'Output', 'output type must be named Output'
143143
fields = {}
144144
for name, t in tpe.__annotations__.items():
145-
cog_t, is_list = util.check_cog_type(t)
146-
assert not is_list, f'output field must not be list: {name}'
147-
fields[name] = cog_t
145+
ft = util.get_field_type(t)
146+
assert ft.repetition is not adt.Repetition.REPEATED, (
147+
f'output field must not be list: {name}'
148+
)
149+
fields[name] = ft
148150
return adt.Output(kind=adt.Kind.OBJECT, fields=fields)
149151

150152
kind = adt.CONTAINER_TO_COG.get(typing.get_origin(tpe)) or adt.Kind.SINGLE
@@ -201,8 +203,8 @@ def check_input(
201203
for name, value in inputs.items():
202204
assert name in adt_ins, f'unknown field: {name}'
203205
adt_in = adt_ins[name]
204-
cog_t = adt_in.type
205-
if adt_in.is_list:
206+
cog_t = adt_in.type.primitive
207+
if adt_in.type.repetition is adt.Repetition.REPEATED:
206208
assert all(util.check_value(cog_t, v) for v in value), (
207209
f'incompatible value for field: {name}={value}'
208210
)
@@ -215,12 +217,18 @@ def check_input(
215217
kwargs[name] = value
216218
for name, adt_in in adt_ins.items():
217219
if name not in kwargs:
218-
assert adt_in.default is not None, (
219-
f'missing default value for field: {name}'
220-
)
220+
# default=None is only allowed on `Optional[<type>]`
221+
if adt_in.type.repetition is not adt.Repetition.OPTIONAL:
222+
assert adt_in.default is not None or adt_in, (
223+
f'missing default value for field: {name}'
224+
)
221225
kwargs[name] = adt_in.default
222226

223-
values = kwargs[name] if adt_in.is_list else [kwargs[name]]
227+
values = (
228+
kwargs[name]
229+
if adt_in.type.repetition is adt.Repetition.REPEATED
230+
else [kwargs[name]]
231+
)
224232
v = kwargs[name]
225233
if adt_in.ge is not None:
226234
assert (x >= adt_in.ge for x in values), (
@@ -264,15 +272,20 @@ def check_output(adt_out: adt.Output, output: Any) -> Any:
264272
)
265273
output[i] = util.normalize_value(adt_out.type, x)
266274
return output
267-
elif adt_out.kind == adt.Kind.OBJECT:
275+
elif adt_out.kind is adt.Kind.OBJECT:
268276
assert adt_out.fields is not None, 'missing output fields'
269277
for name, tpe in adt_out.fields.items():
270278
assert hasattr(output, name), f'missing output field: {name}'
271279
value = getattr(output, name)
272-
assert util.check_value(tpe, value), (
273-
f'incompatible output for field: {name}={value}'
274-
)
275-
setattr(output, name, util.normalize_value(tpe, value))
280+
if value is None:
281+
assert tpe.repetition is adt.Repetition.OPTIONAL, (
282+
f'missing value for output field: {name}'
283+
)
284+
else:
285+
assert util.check_value(tpe.primitive, value), (
286+
f'incompatible output for field: {name}={value}'
287+
)
288+
setattr(output, name, util.normalize_value(tpe.primitive, value))
276289
return output
277290

278291

0 commit comments

Comments
 (0)