22
22
create_immediate_value ,
23
23
create_list_scalarvalue ,
24
24
create_scalar_value ,
25
- types_to_proto ,
25
+ create_valuetype_list ,
26
+ create_valuetype_scalar ,
27
+ create_valuetype_tensor ,
26
28
types_to_proto_primitive ,
27
29
)
28
30
from coremltools .converters .mil .backend .nn .load import _set_optional_inputs
@@ -158,7 +160,7 @@ def translate_const(self, op: Operation) -> proto.MIL_pb2.Operation:
158
160
attributes = {"name" : create_scalar_value (op .name ), "val" : value },
159
161
outputs = [
160
162
proto .MIL_pb2 .NamedValueType (
161
- name = output_var .name , type = types_to_proto (output_var .sym_type )
163
+ name = output_var .name , type = self . types_to_proto (output_var .sym_type )
162
164
)
163
165
],
164
166
)
@@ -190,12 +192,58 @@ def translate_constexpr(self, op: Operation) -> proto.MIL_pb2.Operation:
190
192
attributes = attributes ,
191
193
outputs = [
192
194
proto .MIL_pb2 .NamedValueType (
193
- name = output_var .name , type = types_to_proto (output_var .sym_type )
195
+ name = output_var .name , type = self . types_to_proto (output_var .sym_type )
194
196
)
195
197
for output_var in op .outputs
196
198
],
197
199
)
198
200
201
+ def create_valuetype_dict (self , key_type : type , value_type : type ) -> proto .MIL_pb2 .ValueType :
202
+ """
203
+ Return proto.MIL_pb2.ValueType with dict (dictionaryType) set
204
+ """
205
+ v_type = proto .MIL_pb2 .ValueType ()
206
+ v_type .dictionaryType .keyType .CopyFrom (self .types_to_proto (key_type ))
207
+ v_type .dictionaryType .valueType .CopyFrom (self .types_to_proto (value_type ))
208
+ return v_type
209
+
210
+ def types_to_proto (self , valuetype : type ) -> proto .MIL_pb2 .ValueType :
211
+ """
212
+ Return proto.MIL_pb2.ValueType from PyMIL types.
213
+ """
214
+ if types .is_tensor (valuetype ):
215
+ primitive = types_to_proto_primitive (valuetype .get_primitive ())
216
+ return create_valuetype_tensor (valuetype .get_shape (), primitive )
217
+ elif types .is_tuple (valuetype ):
218
+ v_type = proto .MIL_pb2 .ValueType ()
219
+ t_type = v_type .tupleType
220
+ for t in valuetype .T :
221
+ new_v_type = t_type .types .add ()
222
+ new_v_type .CopyFrom (self .types_to_proto (t ))
223
+ return v_type
224
+ elif types .is_list (valuetype ):
225
+ elem = valuetype .T [0 ]
226
+ length = valuetype .T [1 ]
227
+ if types .is_tensor (elem ):
228
+ dtype = types_to_proto_primitive (elem .get_primitive ())
229
+ elem_shape = elem .get_shape ()
230
+ elif types .is_scalar (elem ):
231
+ dtype = types_to_proto_primitive (valuetype )
232
+ elem_shape = ()
233
+ elif types .is_str (elem ):
234
+ dtype = types_to_proto_primitive (elem )
235
+ elem_shape = ()
236
+ else :
237
+ raise NotImplementedError (
238
+ "Only list of either tensors or scalars supported. "
239
+ "Got element of type {}" .format (elem .__type_info__ ())
240
+ )
241
+ return create_valuetype_list (length = length , elem_shape = elem_shape , dtype = dtype )
242
+ elif types .is_dict (valuetype ):
243
+ return self .create_valuetype_dict (valuetype .T [0 ], valuetype .T [1 ])
244
+ else :
245
+ return create_valuetype_scalar (types_to_proto_primitive (valuetype ))
246
+
199
247
def translate_generic_op (
200
248
self , op : Operation , literal_params : Optional [List [str ]] = None
201
249
) -> proto .MIL_pb2 .Operation :
@@ -228,7 +276,7 @@ def translate_generic_op(
228
276
inputs [param_name ] = args
229
277
230
278
outputs = [
231
- proto .MIL_pb2 .NamedValueType (name = v .name , type = types_to_proto (v .sym_type ))
279
+ proto .MIL_pb2 .NamedValueType (name = v .name , type = self . types_to_proto (v .sym_type ))
232
280
for v in op .outputs
233
281
]
234
282
blocks = None
@@ -311,14 +359,18 @@ def feeds_to_only_constexprs(op: Operation) -> bool:
311
359
literal_params = ["begins" , "ends" , "end_masks" ]
312
360
proto_ops .append (self .translate_generic_op (op , literal_params ))
313
361
else :
314
- proto_ops .append (self .translate_generic_op (op ))
362
+ # A single pymil op might be decomposed into multiple ops
363
+ ops = self .translate_generic_op (op )
364
+ if not isinstance (ops , list ):
365
+ ops = [ops ]
366
+ proto_ops .extend (ops )
315
367
316
368
inputs = []
317
369
if not isinstance (block , Function ):
318
370
# Function is subclass of Block, but function's block has no input,
319
371
# and hence skipping reading the block inputs.
320
372
for var in block .inputs :
321
- proto_type = types_to_proto (var .sym_type )
373
+ proto_type = self . types_to_proto (var .sym_type )
322
374
inputs .append (proto .MIL_pb2 .NamedValueType (name = var .name , type = proto_type ))
323
375
output_names = [v .name for v in block .outputs ]
324
376
return proto .MIL_pb2 .Block (inputs = inputs , outputs = output_names , operations = proto_ops )
@@ -331,7 +383,7 @@ def convert_function(self, function: Function, opset: str) -> proto.MIL_pb2.Func
331
383
332
384
inputs = []
333
385
for name , var in function .inputs .items ():
334
- proto_type = types_to_proto (var .sym_type )
386
+ proto_type = self . types_to_proto (var .sym_type )
335
387
inputs .append (proto .MIL_pb2 .NamedValueType (name = name , type = proto_type ))
336
388
337
389
return proto .MIL_pb2 .Function (
@@ -467,6 +519,15 @@ def get_additional_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
467
519
"""
468
520
return {}
469
521
522
+ @staticmethod
523
+ def _try_convert_other_input_type (
524
+ input_var : Var , input_features : List [proto .Model_pb2 .FeatureDescription ]
525
+ ) -> bool :
526
+ """
527
+ Try to convert an input var with additional type.
528
+ """
529
+ return False
530
+
470
531
def get_func_input (self , func : mil .Function ) -> List [proto .Model_pb2 .FeatureDescription ]:
471
532
"""
472
533
Utils to get function input feature description.
@@ -554,7 +615,7 @@ def get_func_input(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDesc
554
615
input_features .append (
555
616
proto .Model_pb2 .FeatureDescription (name = var .name , type = input_feature_type )
556
617
)
557
- else :
618
+ elif not self . _try_convert_other_input_type ( var , input_features ) :
558
619
raise NotImplementedError (f"Unsupported input type { var .sym_type } ." )
559
620
560
621
if not is_input_shape_symbolic :
@@ -746,6 +807,16 @@ def get_func_output(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDes
746
807
747
808
return output_features
748
809
810
+ def create_model_description (
811
+ self ,
812
+ input_features : List [proto .Model_pb2 .FeatureDescription ],
813
+ output_features : List [proto .Model_pb2 .FeatureDescription ],
814
+ ) -> proto .Model_pb2 .ModelDescription :
815
+ """
816
+ Create model description from input and output features
817
+ """
818
+ return proto .Model_pb2 .ModelDescription (input = input_features , output = output_features )
819
+
749
820
def get_coreml_model (
750
821
self ,
751
822
input : Dict [str , List [proto .Model_pb2 .FeatureDescription ]],
@@ -758,7 +829,7 @@ def get_coreml_model(
758
829
# Model description
759
830
input_features = input [self ._DEFAULT_FUNCTION_NAME ]
760
831
output_features = output [self ._DEFAULT_FUNCTION_NAME ]
761
- desc = proto . Model_pb2 . ModelDescription ( input = input_features , output = output_features )
832
+ desc = self . create_model_description ( input_features , output_features )
762
833
763
834
if self .classifier_config is not None :
764
835
desc .predictedFeatureName = self .predicted_feature_name
0 commit comments