Skip to content

Commit

Permalink
[PT FE]: extend logical operations support
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Sep 21, 2023
1 parent 4a0e3fc commit a24c7ec
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
42 changes: 38 additions & 4 deletions src/frontends/pytorch/src/op/logical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/logical_and.hpp"
#include "openvino/op/logical_not.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/logical_xor.hpp"
#include "utils.hpp"

namespace ov {
Expand All @@ -15,25 +17,57 @@ namespace op {
using namespace ov::op;

OutputVector translate_or(const NodeContext& context) {
num_inputs_check(context, 2, 2);
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
// TODO: use bitwise op here when will be supported by openvino
auto or_node = context.mark_node(std::make_shared<v1::LogicalOr>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, or_node);
}
return {or_node};
};

OutputVector translate_and(const NodeContext& context) {
num_inputs_check(context, 2, 2);
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
// TODO: use bitwise op here when will be supported by openvino
auto or_node = context.mark_node(std::make_shared<v1::LogicalAnd>(x, y));
return {or_node};
auto and_node = context.mark_node(std::make_shared<v1::LogicalAnd>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, and_node);
}
return {and_node};
};

OutputVector translate_not(const NodeContext& context) {
num_inputs_check(context, 1, 2);
auto x = context.get_input(0);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
// TODO: use bitwise op here when will be supported by openvino
auto not_node = context.mark_node(std::make_shared<v1::LogicalNot>(x));
if (!context.input_is_none(1)) {
context.mutate_input(1, not_node);
}
return {not_node};
};

OutputVector translate_xor(const NodeContext& context) {
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
x = context.mark_node(std::make_shared<v0::Convert>(x, element::boolean));
y = context.mark_node(std::make_shared<v0::Convert>(y, element::boolean));
// TODO: use bitwise op here when will be supported by openvino
auto xor_node = context.mark_node(std::make_shared<v1::LogicalXor>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, xor_node);
}
return {xor_node};
};

} // namespace op
Expand Down
6 changes: 6 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ OP_CONVERTER(translate_new_zeros);
OP_CONVERTER(translate_nms);
OP_CONVERTER(translate_nonzero);
OP_CONVERTER(translate_norm);
OP_CONVERTER(translate_not);
OP_CONVERTER(translate_numel);
OP_CONVERTER(translate_one_hot);
OP_CONVERTER(translate_ones);
Expand Down Expand Up @@ -185,6 +186,7 @@ OP_CONVERTER(translate_quantized_cat);
OP_CONVERTER(translate_quantized_convnd);
OP_CONVERTER(translate_quantized_convnd_relu);
OP_CONVERTER(translate_quantized_linear);
OP_CONVERTER(translate_xor);
// Torch FX Translations
OP_CONVERTER(translate_arange_fx);
OP_CONVERTER(translate_batch_norm_fx);
Expand Down Expand Up @@ -340,6 +342,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::linspace", op::translate_linspace},
{"aten::log", op::translate_log},
{"aten::log_", op::inplace_op<op::translate_log>},
{"aten::logical_and", op::translate_and},
{"aten::logical_or", op::translate_or},
{"aten::logical_not", op::translate_not},
{"aten::logical_xor", op::translate_xor},
{"aten::log_softmax", op::translate_log_softmax},
{"aten::log2", op::translate_log2},
{"aten::log2_", op::inplace_op<op::translate_log2>},
Expand Down

0 comments on commit a24c7ec

Please sign in to comment.