Skip to content

Commit 03cd23e

Browse files
authored
Supports complex data type for the add op. (#2061)
Supports complex data type for the `add` op.
1 parent 4480c25 commit 03cd23e

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,10 +869,17 @@ def add(context, node):
869869
# rdar://60175736
870870
if len(add_inputs) > 2 and add_inputs[2].val != 1:
871871
raise ValueError("ADD does not support scale factor param")
872-
x, y = promote_input_dtypes(add_inputs[:2])
872+
x, y = add_inputs[:2]
873873
if types.is_bool(x.dtype) and types.is_bool(y.dtype):
874874
add_node = mb.logical_or(x=x, y=y, name=node.name)
875+
elif types.is_complex(x.dtype) or types.is_complex(y.dtype):
876+
x_real = mb.complex_real(data=x) if types.is_complex(x.dtype) else x
877+
x_imag = mb.complex_imag(data=x) if types.is_complex(x.dtype) else 0.0
878+
y_real = mb.complex_real(data=y) if types.is_complex(y.dtype) else y
879+
y_imag = mb.complex_imag(data=y) if types.is_complex(y.dtype) else 0.0
880+
add_node = mb.complex(real_data=mb.add(x=x_real, y=y_real), imag_data=mb.add(x=x_imag, y=y_imag), name=node.name)
875881
else:
882+
x, y = promote_input_dtypes([x, y])
876883
add_node = mb.add(x=x, y=y, name=node.name)
877884
context.add(add_node)
878885

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3752,6 +3752,31 @@ def forward(self, x, y):
37523752
input_as_shape=False,
37533753
)
37543754

3755+
@pytest.mark.parametrize(
3756+
"compute_unit, backend, x_complex, y_complex",
3757+
itertools.product(
3758+
compute_units,
3759+
backends,
3760+
(True, False),
3761+
(True, False),
3762+
),
3763+
)
3764+
def test_add_complex(self, compute_unit, backend, x_complex, y_complex):
3765+
class TestAddComplexModel(nn.Module):
3766+
def forward(self, x, y):
3767+
if x_complex:
3768+
x = torch.complex(x, x)
3769+
if y_complex:
3770+
y = torch.complex(y, y)
3771+
return torch.add(x, y).abs()
3772+
3773+
self.run_compare_torch(
3774+
[(2, 3), (2, 3)],
3775+
TestAddComplexModel(),
3776+
compute_unit=compute_unit,
3777+
backend=backend,
3778+
)
3779+
37553780

37563781
class TestFull(TorchBaseTest):
37573782
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)