@@ -36,12 +36,11 @@ def _validate_predict(f: Callable) -> None:
3636 assert spec .annotations .get ('return' ) is not None , 'predict() must not return None'
3737
3838
39- def _validate_input (
40- name : str , cog_t : adt .Type , is_list : bool , cog_in : api .Input
41- ) -> None :
39+ def _validate_input (name : str , ft : adt .FieldType , cog_in : api .Input ) -> None :
4240 defaults = []
41+ cog_t = ft .primitive
4342 if cog_in .default is not None :
44- if is_list :
43+ if ft . repetition is adt . Repetition . REPEATED :
4544 assert type (cog_in .default ) is list , (
4645 f'default must be a list for input: { name } '
4746 )
@@ -67,7 +66,7 @@ def _validate_input(
6766 )
6867
6968 if cog_in .min_length is not None or cog_in .max_length is not None :
70- assert cog_t is adt .Type .STRING , (
69+ assert cog_t is adt .PrimitiveType .STRING , (
7170 f'incompatible input type for min_length/max_length: { name } '
7271 )
7372 if cog_in .min_length is not None :
@@ -80,7 +79,9 @@ def _validate_input(
8079 )
8180
8281 if cog_in .regex is not None :
83- assert cog_t is adt .Type .STRING , f'incompatible input type for regex: { name } '
82+ assert cog_t is adt .PrimitiveType .STRING , (
83+ f'incompatible input type for regex: { name } '
84+ )
8485 regex = re .compile (cog_in .regex )
8586 assert all (regex .match (x ) for x in defaults ), (
8687 f'not all defaults match regex for input: { name } '
@@ -103,29 +104,28 @@ def _validate_input(
103104def _input_adt (
104105 order : int , name : str , tpe : type , cog_in : Optional [api .Input ]
105106) -> adt .Input :
106- cog_t , is_list = util .check_cog_type (tpe )
107- assert cog_t is not None , f'unsupported input type for { name } '
107+ ft = util .get_field_type (tpe )
108108 if cog_in is None :
109109 return adt .Input (
110110 name = name ,
111111 order = order ,
112- type = cog_t ,
113- is_list = is_list ,
112+ type = ft ,
114113 )
115114 else :
116- _validate_input (name , cog_t , is_list , cog_in )
115+ _validate_input (name , ft , cog_in )
117116 if cog_in .default is None :
118117 default = None
119118 else :
120- if is_list :
121- default = [util .normalize_value (cog_t , x ) for x in cog_in .default ]
119+ if ft .repetition is adt .Repetition .REPEATED :
120+ default = [
121+ util .normalize_value (ft .primitive , x ) for x in cog_in .default
122+ ]
122123 else :
123- default = util .normalize_value (cog_t , cog_in .default )
124+ default = util .normalize_value (ft . primitive , cog_in .default )
124125 return adt .Input (
125126 name = name ,
126127 order = order ,
127- type = cog_t ,
128- is_list = is_list ,
128+ type = ft ,
129129 default = default ,
130130 description = cog_in .description ,
131131 ge = float (cog_in .ge ) if cog_in .ge is not None else None ,
@@ -142,9 +142,11 @@ def _output_adt(tpe: type) -> adt.Output:
142142 assert tpe .__name__ == 'Output' , 'output type must be named Output'
143143 fields = {}
144144 for name , t in tpe .__annotations__ .items ():
145- cog_t , is_list = util .check_cog_type (t )
146- assert not is_list , f'output field must not be list: { name } '
147- fields [name ] = cog_t
145+ ft = util .get_field_type (t )
146+ assert ft .repetition is not adt .Repetition .REPEATED , (
147+ f'output field must not be list: { name } '
148+ )
149+ fields [name ] = ft
148150 return adt .Output (kind = adt .Kind .OBJECT , fields = fields )
149151
150152 kind = adt .CONTAINER_TO_COG .get (typing .get_origin (tpe )) or adt .Kind .SINGLE
@@ -201,8 +203,8 @@ def check_input(
201203 for name , value in inputs .items ():
202204 assert name in adt_ins , f'unknown field: { name } '
203205 adt_in = adt_ins [name ]
204- cog_t = adt_in .type
205- if adt_in .is_list :
206+ cog_t = adt_in .type . primitive
207+ if adt_in .type . repetition is adt . Repetition . REPEATED :
206208 assert all (util .check_value (cog_t , v ) for v in value ), (
207209 f'incompatible value for field: { name } ={ value } '
208210 )
@@ -215,12 +217,18 @@ def check_input(
215217 kwargs [name ] = value
216218 for name , adt_in in adt_ins .items ():
217219 if name not in kwargs :
218- assert adt_in .default is not None , (
219- f'missing default value for field: { name } '
220- )
220+ # default=None is only allowed on `Optional[<type>]`
221+ if adt_in .type .repetition is not adt .Repetition .OPTIONAL :
222+ assert adt_in .default is not None or adt_in , (
223+ f'missing default value for field: { name } '
224+ )
221225 kwargs [name ] = adt_in .default
222226
223- values = kwargs [name ] if adt_in .is_list else [kwargs [name ]]
227+ values = (
228+ kwargs [name ]
229+ if adt_in .type .repetition is adt .Repetition .REPEATED
230+ else [kwargs [name ]]
231+ )
224232 v = kwargs [name ]
225233 if adt_in .ge is not None :
226234 assert (x >= adt_in .ge for x in values ), (
@@ -264,15 +272,20 @@ def check_output(adt_out: adt.Output, output: Any) -> Any:
264272 )
265273 output [i ] = util .normalize_value (adt_out .type , x )
266274 return output
267- elif adt_out .kind == adt .Kind .OBJECT :
275+ elif adt_out .kind is adt .Kind .OBJECT :
268276 assert adt_out .fields is not None , 'missing output fields'
269277 for name , tpe in adt_out .fields .items ():
270278 assert hasattr (output , name ), f'missing output field: { name } '
271279 value = getattr (output , name )
272- assert util .check_value (tpe , value ), (
273- f'incompatible output for field: { name } ={ value } '
274- )
275- setattr (output , name , util .normalize_value (tpe , value ))
280+ if value is None :
281+ assert tpe .repetition is adt .Repetition .OPTIONAL , (
282+ f'missing value for output field: { name } '
283+ )
284+ else :
285+ assert util .check_value (tpe .primitive , value ), (
286+ f'incompatible output for field: { name } ={ value } '
287+ )
288+ setattr (output , name , util .normalize_value (tpe .primitive , value ))
276289 return output
277290
278291
0 commit comments