Skip to content

Commit 7521b68

Browse files
YifanShenSZyifan_shen3
andauthored
7.2 release (#2196)
Co-authored-by: yifan_shen3 <yifan_shen3@apple.com>
1 parent c8f7e77 commit 7521b68

File tree

16 files changed

+395
-183
lines changed

16 files changed

+395
-183
lines changed

coremltools/converters/mil/backend/mil/helper.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,6 @@ def create_valuetype_list(length, elem_shape, dtype):
4242
update_listtype(v_type.listType, length, elem_shape, dtype)
4343
return v_type
4444

45-
def create_valuetype_dict(key_type, value_type):
46-
"""
47-
Return proto.MIL_pb2.ValueType with dict (dictionaryType) set
48-
"""
49-
v_type = proto.MIL_pb2.ValueType()
50-
v_type.dictionaryType.keyType.CopyFrom(types_to_proto(key_type))
51-
v_type.dictionaryType.valueType.CopyFrom(types_to_proto(value_type))
52-
return v_type
53-
54-
5545
def create_valuetype_tensor(shape, data_type):
5646
"""
5747
Return proto.MIL_pb2.ValueType with tensor (TensorType) set.
@@ -261,40 +251,6 @@ def types_to_proto_primitive(valuetype):
261251
)
262252
return types.BUILTIN_TO_PROTO_TYPES[valuetype]
263253

264-
265-
def types_to_proto(valuetype):
266-
if types.is_tensor(valuetype):
267-
primitive = types_to_proto_primitive(valuetype.get_primitive())
268-
return create_valuetype_tensor(valuetype.get_shape(), primitive)
269-
elif types.is_tuple(valuetype):
270-
v_type = proto.MIL_pb2.ValueType()
271-
t_type = v_type.tupleType
272-
for t in valuetype.T:
273-
new_v_type = t_type.types.add()
274-
new_v_type.CopyFrom(types_to_proto(t))
275-
return v_type
276-
elif types.is_list(valuetype):
277-
elem = valuetype.T[0]
278-
length = valuetype.T[1]
279-
if types.is_tensor(elem):
280-
dtype = types_to_proto_primitive(elem.get_primitive())
281-
elem_shape = elem.get_shape()
282-
elif types.is_scalar(elem):
283-
dtype = types_to_proto_primitive(valuetype)
284-
elem_shape = ()
285-
elif types.is_str(elem):
286-
dtype = types_to_proto_primitive(elem)
287-
elem_shape = ()
288-
else:
289-
raise NotImplementedError("Only list of either tensors or scalars supported. "
290-
"Got element of type {}".format(elem.__type_info__()))
291-
return create_valuetype_list(length=length, elem_shape=elem_shape, dtype=dtype)
292-
elif types.is_dict(valuetype):
293-
return create_valuetype_dict(valuetype.T[0], valuetype.T[1])
294-
else:
295-
return create_valuetype_scalar(types_to_proto_primitive(valuetype))
296-
297-
298254
def _get_offset_by_writing_data(output_var, blob_writer):
299255
if output_var.val.dtype.kind == 'f' and output_var.val.dtype.itemsize == 4:
300256
offset = blob_writer.write_float_data(np.ascontiguousarray(output_var.val.flatten()))

coremltools/converters/mil/backend/mil/load.py

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
create_immediate_value,
2323
create_list_scalarvalue,
2424
create_scalar_value,
25-
types_to_proto,
25+
create_valuetype_list,
26+
create_valuetype_scalar,
27+
create_valuetype_tensor,
2628
types_to_proto_primitive,
2729
)
2830
from coremltools.converters.mil.backend.nn.load import _set_optional_inputs
@@ -158,7 +160,7 @@ def translate_const(self, op: Operation) -> proto.MIL_pb2.Operation:
158160
attributes={"name": create_scalar_value(op.name), "val": value},
159161
outputs=[
160162
proto.MIL_pb2.NamedValueType(
161-
name=output_var.name, type=types_to_proto(output_var.sym_type)
163+
name=output_var.name, type=self.types_to_proto(output_var.sym_type)
162164
)
163165
],
164166
)
@@ -190,12 +192,58 @@ def translate_constexpr(self, op: Operation) -> proto.MIL_pb2.Operation:
190192
attributes=attributes,
191193
outputs=[
192194
proto.MIL_pb2.NamedValueType(
193-
name=output_var.name, type=types_to_proto(output_var.sym_type)
195+
name=output_var.name, type=self.types_to_proto(output_var.sym_type)
194196
)
195197
for output_var in op.outputs
196198
],
197199
)
198200

201+
def create_valuetype_dict(self, key_type: type, value_type: type) -> proto.MIL_pb2.ValueType:
202+
"""
203+
Return proto.MIL_pb2.ValueType with dict (dictionaryType) set
204+
"""
205+
v_type = proto.MIL_pb2.ValueType()
206+
v_type.dictionaryType.keyType.CopyFrom(self.types_to_proto(key_type))
207+
v_type.dictionaryType.valueType.CopyFrom(self.types_to_proto(value_type))
208+
return v_type
209+
210+
def types_to_proto(self, valuetype: type) -> proto.MIL_pb2.ValueType:
211+
"""
212+
Return proto.MIL_pb2.ValueType from PyMIL types.
213+
"""
214+
if types.is_tensor(valuetype):
215+
primitive = types_to_proto_primitive(valuetype.get_primitive())
216+
return create_valuetype_tensor(valuetype.get_shape(), primitive)
217+
elif types.is_tuple(valuetype):
218+
v_type = proto.MIL_pb2.ValueType()
219+
t_type = v_type.tupleType
220+
for t in valuetype.T:
221+
new_v_type = t_type.types.add()
222+
new_v_type.CopyFrom(self.types_to_proto(t))
223+
return v_type
224+
elif types.is_list(valuetype):
225+
elem = valuetype.T[0]
226+
length = valuetype.T[1]
227+
if types.is_tensor(elem):
228+
dtype = types_to_proto_primitive(elem.get_primitive())
229+
elem_shape = elem.get_shape()
230+
elif types.is_scalar(elem):
231+
dtype = types_to_proto_primitive(valuetype)
232+
elem_shape = ()
233+
elif types.is_str(elem):
234+
dtype = types_to_proto_primitive(elem)
235+
elem_shape = ()
236+
else:
237+
raise NotImplementedError(
238+
"Only list of either tensors or scalars supported. "
239+
"Got element of type {}".format(elem.__type_info__())
240+
)
241+
return create_valuetype_list(length=length, elem_shape=elem_shape, dtype=dtype)
242+
elif types.is_dict(valuetype):
243+
return self.create_valuetype_dict(valuetype.T[0], valuetype.T[1])
244+
else:
245+
return create_valuetype_scalar(types_to_proto_primitive(valuetype))
246+
199247
def translate_generic_op(
200248
self, op: Operation, literal_params: Optional[List[str]] = None
201249
) -> proto.MIL_pb2.Operation:
@@ -228,7 +276,7 @@ def translate_generic_op(
228276
inputs[param_name] = args
229277

230278
outputs = [
231-
proto.MIL_pb2.NamedValueType(name=v.name, type=types_to_proto(v.sym_type))
279+
proto.MIL_pb2.NamedValueType(name=v.name, type=self.types_to_proto(v.sym_type))
232280
for v in op.outputs
233281
]
234282
blocks = None
@@ -311,14 +359,18 @@ def feeds_to_only_constexprs(op: Operation) -> bool:
311359
literal_params = ["begins", "ends", "end_masks"]
312360
proto_ops.append(self.translate_generic_op(op, literal_params))
313361
else:
314-
proto_ops.append(self.translate_generic_op(op))
362+
# A single pymil op might be decomposed into multiple ops
363+
ops = self.translate_generic_op(op)
364+
if not isinstance(ops, list):
365+
ops = [ops]
366+
proto_ops.extend(ops)
315367

316368
inputs = []
317369
if not isinstance(block, Function):
318370
# Function is subclass of Block, but function's block has no input,
319371
# and hence skipping reading the block inputs.
320372
for var in block.inputs:
321-
proto_type = types_to_proto(var.sym_type)
373+
proto_type = self.types_to_proto(var.sym_type)
322374
inputs.append(proto.MIL_pb2.NamedValueType(name=var.name, type=proto_type))
323375
output_names = [v.name for v in block.outputs]
324376
return proto.MIL_pb2.Block(inputs=inputs, outputs=output_names, operations=proto_ops)
@@ -331,7 +383,7 @@ def convert_function(self, function: Function, opset: str) -> proto.MIL_pb2.Func
331383

332384
inputs = []
333385
for name, var in function.inputs.items():
334-
proto_type = types_to_proto(var.sym_type)
386+
proto_type = self.types_to_proto(var.sym_type)
335387
inputs.append(proto.MIL_pb2.NamedValueType(name=name, type=proto_type))
336388

337389
return proto.MIL_pb2.Function(
@@ -467,6 +519,15 @@ def get_additional_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
467519
"""
468520
return {}
469521

522+
@staticmethod
523+
def _try_convert_other_input_type(
524+
input_var: Var, input_features: List[proto.Model_pb2.FeatureDescription]
525+
) -> bool:
526+
"""
527+
Try to convert an input var with additional type.
528+
"""
529+
return False
530+
470531
def get_func_input(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDescription]:
471532
"""
472533
Utils to get function input feature description.
@@ -554,7 +615,7 @@ def get_func_input(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDesc
554615
input_features.append(
555616
proto.Model_pb2.FeatureDescription(name=var.name, type=input_feature_type)
556617
)
557-
else:
618+
elif not self._try_convert_other_input_type(var, input_features):
558619
raise NotImplementedError(f"Unsupported input type {var.sym_type}.")
559620

560621
if not is_input_shape_symbolic:
@@ -746,6 +807,16 @@ def get_func_output(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDes
746807

747808
return output_features
748809

810+
def create_model_description(
811+
self,
812+
input_features: List[proto.Model_pb2.FeatureDescription],
813+
output_features: List[proto.Model_pb2.FeatureDescription],
814+
) -> proto.Model_pb2.ModelDescription:
815+
"""
816+
Create model description from input and output features
817+
"""
818+
return proto.Model_pb2.ModelDescription(input=input_features, output=output_features)
819+
749820
def get_coreml_model(
750821
self,
751822
input: Dict[str, List[proto.Model_pb2.FeatureDescription]],
@@ -758,7 +829,7 @@ def get_coreml_model(
758829
# Model description
759830
input_features = input[self._DEFAULT_FUNCTION_NAME]
760831
output_features = output[self._DEFAULT_FUNCTION_NAME]
761-
desc = proto.Model_pb2.ModelDescription(input=input_features, output=output_features)
832+
desc = self.create_model_description(input_features, output_features)
762833

763834
if self.classifier_config is not None:
764835
desc.predictedFeatureName = self.predicted_feature_name

coremltools/converters/mil/backend/nn/load.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ImageType,
1313
RangeDim,
1414
Shape,
15+
TensorType,
1516
)
1617
from coremltools.converters.mil.mil import types
1718
from coremltools.converters.mil.mil.types.symbolic import any_symbolic, any_variadic, is_symbolic
@@ -169,7 +170,7 @@ def _set_optional_inputs(proto, input_types):
169170
# Set default values for optional input_types
170171
default_map = {}
171172
for input_type in input_types:
172-
if isinstance(input_type, ImageType):
173+
if not isinstance(input_type, TensorType):
173174
continue
174175
if input_type.default_value is not None:
175176
default_map[input_type.name] = input_type.default_value

coremltools/converters/mil/frontend/_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -512,33 +512,36 @@ def _concat_dims(dims, none_if_empty=False):
512512
return ab
513513

514514

515-
def _lower_scaled_dot_product_attention(q: Var, k: Var, v: Var, mask: Var, name: str) -> Var:
515+
def _lower_scaled_dot_product_attention(
516+
q: Var, k: Var, v: Var, mask: Var, name: str, before_op: Optional[Operation] = None
517+
) -> Var:
516518
# scale the query input
517519
embed_size = q.shape[-1]
518520
if is_symbolic(embed_size):
519521
raise ValueError(
520522
"The embedding size, i.e. last dimension of the shape of query tensor"
521523
" cannot be symbolic, in scaled_dot_product_attention op"
522524
)
525+
526+
q, k, v = promote_input_dtypes([q, k, v])
523527
multiplicative_scale_factor = 1 / math.sqrt(embed_size)
524-
q, k, v, multiplicative_scale_factor = promote_input_dtypes(
525-
[q, k, v, multiplicative_scale_factor]
526-
)
527-
q = mb.mul(x=q, y=multiplicative_scale_factor)
528+
if types.builtin_to_string(q.dtype) == "fp16":
529+
multiplicative_scale_factor = _np.float16(multiplicative_scale_factor)
530+
q = mb.mul(x=q, y=multiplicative_scale_factor, before_op=before_op)
528531

529532
# multiply query and key input tensors
530533
# shape of output: (target_seq, source_seq) or (B,...,target_seq, source_seq)
531-
attn_weights = mb.matmul(x=q, y=k, transpose_y=True)
534+
attn_weights = mb.matmul(x=q, y=k, transpose_y=True, before_op=before_op)
532535

533536
# add mask if applicable
534537
if mask is not None:
535-
attn_weights = mb.add(x=attn_weights, y=mask)
538+
attn_weights = mb.add(x=attn_weights, y=mask, before_op=before_op)
536539

537540
# do softmax
538-
attn_weights_normalized = mb.softmax(x=attn_weights, axis=-1)
541+
attn_weights_normalized = mb.softmax(x=attn_weights, axis=-1, before_op=before_op)
539542

540543
# multiply attn_weights and value tensor
541-
res = mb.matmul(x=attn_weights_normalized, y=v, name=name)
544+
res = mb.matmul(x=attn_weights_normalized, y=v, name=name, before_op=before_op)
542545
return res
543546

544547

@@ -549,7 +552,7 @@ def _construct_constexpr_affine_op(
549552
axis: Optional[Union[Var, int]] = None,
550553
name: Optional[str] = None,
551554
before_op: Optional[Operation] = None,
552-
) -> Operation:
555+
) -> Var:
553556
"""Constructs the constexpr op to represent the dequantized weight from PyTorch's data."""
554557
# The constexpr_affine_dequantize op requires axis.
555558
if axis is None:

coremltools/converters/mil/frontend/torch/exir_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,27 @@ def extract_inputs_from_exir_program(
8787
val = node.meta["val"]
8888
assert isinstance(val, torch.Tensor), "placeholder val must be a tensor or fake tensor"
8989
user_inputs.append(to_coreml_tensor_type(node.name, val))
90+
9091
elif input_spec.kind == torch.export.graph_signature.InputKind.PARAMETER:
9192
lifted_parameters[input_spec.arg.name] = parameters[input_spec.target]
93+
9294
elif input_spec.kind == torch.export.graph_signature.InputKind.BUFFER:
93-
lifted_buffers[input_spec.arg.name] = buffers[input_spec.target]
95+
# This is a workaround on mutable buffer: Core ML does not support stateful execution,
96+
# so ExecuTorch will pass mutable buffers as inputs/outputs to Core ML delegation,
97+
# then in-place copy Core ML outputs into buffers
98+
# On Core ML side, we do not have to do anything special with outputs,
99+
# but for inputs we will need to identify ExecuTorch lifted mutable buffers
100+
# as Core ML user inputs
101+
if input_spec.target in exported_program.graph_signature.buffers_to_mutate.values():
102+
user_inputs.append(
103+
to_coreml_tensor_type(input_spec.arg.name, buffers[input_spec.target])
104+
)
105+
else:
106+
lifted_buffers[input_spec.arg.name] = buffers[input_spec.target]
107+
94108
elif input_spec.kind == torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
95109
lifted_constants[input_spec.arg.name] = exported_program.constants[input_spec.target]
110+
96111
else:
97112
raise NotImplementedError(
98113
"Only 4 types of inputs handled yet: user input, parameter, buffer, constant. "

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,7 @@ def view(context, node):
16781678
x = inputs[0]
16791679
shape = inputs[1]
16801680

1681-
if np.prod(shape.shape) == 0:
1681+
if isinstance(shape, Var) and np.prod(shape.shape) == 0:
16821682
# Reshape to empty shape (works only for scalar) is a no op
16831683
assert (
16841684
np.prod(x.shape) <= 1
@@ -6694,21 +6694,15 @@ def _get_causal_attn_mask(is_causal: bool, query_var: Var, key_var: Var) -> Var:
66946694

66956695
def _cast_bool_attn_mask(attn_mask: Var, query_var: Var) -> Var:
66966696
"""
6697-
compute float mask as:
6698-
mask = cast(bool_mask) + (1-cast(bool_mask)) * -30k*ones(shape(bool_mask))
6697+
compute float mask as (1 - cast(bool_mask)) * -30k
66996698
"""
67006699
assert is_bool(attn_mask.dtype)
67016700

6702-
shape = mb.shape(x=attn_mask)
6703-
negative_inf = mb.fill(
6704-
shape=shape, value=_np.array([-3e4]).astype(types.nptype_from_builtin(query_var.dtype))
6705-
)
67066701
mask = mb.cast(x=attn_mask, dtype=types.builtin_to_string(query_var.dtype))
67076702
compliment_of_mask = mb.sub(
67086703
x=_np.array([1.0]).astype(types.nptype_from_builtin(mask.dtype)), y=mask
67096704
)
6710-
compliment_of_mask = mb.mul(x=negative_inf, y=compliment_of_mask)
6711-
return mb.add(x=mask, y=compliment_of_mask)
6705+
return mb.mul(x=-3e4, y=compliment_of_mask)
67126706

67136707
@register_torch_op
67146708
def scaled_dot_product_attention(context, node):

0 commit comments

Comments
 (0)