Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE]: support aten::minimum aten::maximum #19996

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/frontends/pytorch/src/op/min_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,38 @@ OutputVector translate_min(const NodeContext& context) {
return {values, indicies};
};

OutputVector translate_maximum(const NodeContext& context) {
// aten::maximum(Tensor self, Tensor other) -> Tensor

// aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)

num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y, true);
auto res = context.mark_node(std::make_shared<v1::Maximum>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, res);
}
return {res};
}

OutputVector translate_minimum(const NodeContext& context) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can be a generic function in utils

// aten::minimum(Tensor self, Tensor other) -> Tensor

// aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)

num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y, true);
auto res = context.mark_node(std::make_shared<v1::Minimum>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, res);
}
return {res};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
4 changes: 4 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ OP_CONVERTER(translate_loop);
OP_CONVERTER(translate_masked_fill);
OP_CONVERTER(translate_masked_scatter);
OP_CONVERTER(translate_max);
OP_CONVERTER(translate_maximum);
OP_CONVERTER(translate_max_poolnd);
OP_CONVERTER(translate_mean);
OP_CONVERTER(translate_meshgrid);
OP_CONVERTER(translate_min);
OP_CONVERTER(translate_minimum);
OP_CONVERTER(translate_narrow);
OP_CONVERTER(translate_native_multi_head_attention);
OP_CONVERTER(translate_neg);
Expand Down Expand Up @@ -350,12 +352,14 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::masked_scatter_", op::inplace_op<op::translate_masked_scatter>},
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::max", op::translate_max},
{"aten::maximum", op::translate_maximum},
{"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool2d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool3d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::mean", op::quantizable_op<op::translate_mean>},
{"aten::meshgrid", op::translate_meshgrid},
{"aten::min", op::translate_min},
{"aten::minimum", op::translate_minimum},
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::mul", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>>},
Expand Down
74 changes: 74 additions & 0 deletions tests/layer_tests/pytorch_tests/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,77 @@ def forward(self, x: float, y: float):
def test_min(self, case, kwargs_to_prepare_input, ie_device, precision, ir_version):
self._test(*self.create_model(case),
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, use_mo_convert=False)


class TestMinimumMaximum(PytorchLayerTest):
def _prepare_input(self, input_dtype="float32", second_input_dtype="float32", out=False):
import numpy as np
x = np.random.randn(1, 3, 10, 10).astype(input_dtype)
y = np.random.randn(1, 3, 10, 10).astype(second_input_dtype)
if not out:
return x, y
return (x, y, np.zeros_like(x).astype(input_dtype))

def create_model(self, op_type, dtypes=("float32", "float32"), out=False):
import torch
op_types = {
"maximum": torch.maximum,
"minimum": torch.minimum
}

dtypes_map = {
"float32": torch.float32,
"int32": torch.int32,
"int64": torch.int64,
"float64": torch.float64
}

op = op_types[op_type]

class aten_minimum_maximum(torch.nn.Module):
def __init__(self, op, l_dtype, r_dtype, out):
super(aten_minimum_maximum, self).__init__()
self.op = op
self.l_dtype = l_dtype
self.r_dtype = r_dtype
if out:
self.forward = self.forward_out

def forward_out(self, x, y, z):
return self.op(x.to(self.l_dtype), y.to(self.r_dtype), out=z), z

def forward(self, x, y):
return self.op(x.to(self.l_dtype), y.to(self.r_dtype))

l_dtype = dtypes_map[dtypes[0]]
r_dtype = dtypes_map[dtypes[1]]
model_cls = aten_minimum_maximum(op, l_dtype, r_dtype, out)

return model_cls, None, f"aten::{op_type}"

@pytest.mark.parametrize("op_type", ["minimum", "maximum"])
@pytest.mark.parametrize("second_input_dtype", ["float32", "int32", "int64", "float64"])
@pytest.mark.parametrize("first_input_dtype", ["float32", "int32", "int64", "float64"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_minimum_maximum(
self, op_type, first_input_dtype, second_input_dtype, ie_device, precision, ir_version
):
self._test(*self.create_model(op_type, dtypes=(first_input_dtype, second_input_dtype), out=False),
ie_device, precision, ir_version, kwargs_to_prepare_input=
{"input_dtype": first_input_dtype, "second_input_dtype": second_input_dtype, "out": False}
)


@pytest.mark.parametrize("op_type", ['minimum', 'maximum'])
@pytest.mark.parametrize("input_dtype", ["float32", "int32", "int64", "float64"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_minimum_maximum_out(
self, op_type, input_dtype, ie_device, precision, ir_version
):
self._test(*self.create_model(op_type, dtypes=(input_dtype, input_dtype), out=True),
ie_device, precision, ir_version, kwargs_to_prepare_input=
{"input_dtype": input_dtype, "second_input_dtype": input_dtype,
"out": True}
)