Skip to content

Commit

Permalink
[PT FE]: extend logical operations support (openvinotoolkit#19981)
Browse files Browse the repository at this point in the history
* [PT FE]: extend logical operations support

* tests

* more tests
  • Loading branch information
eaidova authored and alvoron committed Nov 6, 2023
1 parent b7ac034 commit dd99d0f
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 4 deletions.
42 changes: 38 additions & 4 deletions src/frontends/pytorch/src/op/logical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/logical_and.hpp"
#include "openvino/op/logical_not.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/logical_xor.hpp"
#include "utils.hpp"

namespace ov {
Expand All @@ -15,25 +17,57 @@ namespace op {
using namespace ov::op;

OutputVector translate_or(const NodeContext& context) {
num_inputs_check(context, 2, 2);
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
// TODO: use bitwise op here when will be supported by openvino
auto or_node = context.mark_node(std::make_shared<v1::LogicalOr>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, or_node);
}
return {or_node};
};

OutputVector translate_and(const NodeContext& context) {
num_inputs_check(context, 2, 2);
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
// TODO: use bitwise op here when will be supported by openvino
auto or_node = context.mark_node(std::make_shared<v1::LogicalAnd>(x, y));
return {or_node};
auto and_node = context.mark_node(std::make_shared<v1::LogicalAnd>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, and_node);
}
return {and_node};
};

OutputVector translate_not(const NodeContext& context) {
num_inputs_check(context, 1, 2);
auto x = context.get_input(0);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
// TODO: use bitwise op here when will be supported by openvino
auto not_node = context.mark_node(std::make_shared<v1::LogicalNot>(x));
if (!context.input_is_none(1)) {
context.mutate_input(1, not_node);
}
return {not_node};
};

OutputVector translate_xor(const NodeContext& context) {
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
// TODO: use bitwise op here when will be supported by openvino
auto xor_node = context.mark_node(std::make_shared<v1::LogicalXor>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, xor_node);
}
return {xor_node};
};

} // namespace op
Expand Down
6 changes: 6 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ OP_CONVERTER(translate_new_zeros);
OP_CONVERTER(translate_nms);
OP_CONVERTER(translate_nonzero);
OP_CONVERTER(translate_norm);
OP_CONVERTER(translate_not);
OP_CONVERTER(translate_numel);
OP_CONVERTER(translate_one_hot);
OP_CONVERTER(translate_ones);
Expand Down Expand Up @@ -188,6 +189,7 @@ OP_CONVERTER(translate_quantized_cat);
OP_CONVERTER(translate_quantized_convnd);
OP_CONVERTER(translate_quantized_convnd_relu);
OP_CONVERTER(translate_quantized_linear);
OP_CONVERTER(translate_xor);
// Torch FX Translations
OP_CONVERTER(translate_arange_fx);
OP_CONVERTER(translate_batch_norm_fx);
Expand Down Expand Up @@ -343,6 +345,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::linspace", op::translate_linspace},
{"aten::log", op::translate_log},
{"aten::log_", op::inplace_op<op::translate_log>},
{"aten::logical_and", op::translate_and},
{"aten::logical_or", op::translate_or},
{"aten::logical_not", op::translate_not},
{"aten::logical_xor", op::translate_xor},
{"aten::log_softmax", op::translate_log_softmax},
{"aten::log2", op::translate_log2},
{"aten::log2_", op::inplace_op<op::translate_log2>},
Expand Down
64 changes: 64 additions & 0 deletions tests/layer_tests/pytorch_tests/test_logical_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np
import pytest
from pytorch_layer_test_class import PytorchLayerTest

class TestLogicalOp(PytorchLayerTest):

def _prepare_input(self, out, unary, first_dtype, second_dtype):
x = np.random.randint(1, 5, (1, 10)).astype(first_dtype)
if unary:
return (x, ) if not out else (x, np.zeros_like(x).astype(bool))
y = np.random.randint(1, 5, (1, 10)).astype(second_dtype)
if not out:
return x, y
return x, y, np.zeros_like(x).astype(bool)

def create_model(self, op_name, out):
import torch

ops = {
"and": torch.logical_and,
"or": torch.logical_or,
"xor": torch.logical_xor,
"not": torch.logical_not
}
op = ops[op_name]
class aten_logical(torch.nn.Module):

def __init__(self, op, out) -> None:
super().__init__()
self.op = op
if op == torch.logical_not:
self.forward = self.forward_not
if out:
self.forward = self.forward_out if not op == torch.logical_not else self.forward_not_out

def forward(self, tensor_a, tensor_b):
return self.op(tensor_a, tensor_b)


def forward_out(self, tensor_a, tensor_b, out):
return self.op(tensor_a, tensor_b, out=out), out

def forward_not(self, tensor_a):
return self.op(tensor_a)

def forward_not_out(self, tensor_a, out):
return self.op(tensor_a, out=out), out

ref_net = None

return aten_logical(op, out), ref_net, f"aten::logical_{op_name}"


@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("op_type", ["and", "or", "not", "xor"])
@pytest.mark.parametrize("first_dtype", ["bool", "int32", 'int8', 'float32'])
@pytest.mark.parametrize("second_dtype", ["bool", "int32", 'int8', 'float32'])
@pytest.mark.parametrize("out", [True, False])
def test_logical(self, op_type, out, first_dtype, second_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(op_type, out),
ie_device, precision, ir_version,
kwargs_to_prepare_input={"out": out, "unary": op_type == "not",
"first_dtype": first_dtype, "second_dtype": second_dtype})

0 comments on commit dd99d0f

Please sign in to comment.