From a24c7ec6e06d3360ea4f0837910d62eb79274652 Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 21 Sep 2023 11:11:59 +0400 Subject: [PATCH] [PT FE]: extend logical operations support --- src/frontends/pytorch/src/op/logical.cpp | 42 +++++++++++++++++++++--- src/frontends/pytorch/src/op_table.cpp | 6 ++++ 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/frontends/pytorch/src/op/logical.cpp b/src/frontends/pytorch/src/op/logical.cpp index 0c5a93e2c91933..b094067dbbd05d 100644 --- a/src/frontends/pytorch/src/op/logical.cpp +++ b/src/frontends/pytorch/src/op/logical.cpp @@ -4,7 +4,9 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/logical_and.hpp" +#include "openvino/op/logical_not.hpp" #include "openvino/op/logical_or.hpp" +#include "openvino/op/logical_xor.hpp" #include "utils.hpp" namespace ov { @@ -15,25 +17,57 @@ namespace op { using namespace ov::op; OutputVector translate_or(const NodeContext& context) { - num_inputs_check(context, 2, 2); + num_inputs_check(context, 2, 3); auto x = context.get_input(0); auto y = context.get_input(1); x = context.mark_node(std::make_shared(x, element::boolean)); y = context.mark_node(std::make_shared(y, element::boolean)); // TODO: use bitwise op here when will be supported by openvino auto or_node = context.mark_node(std::make_shared(x, y)); + if (!context.input_is_none(2)) { + context.mutate_input(2, or_node); + } return {or_node}; }; OutputVector translate_and(const NodeContext& context) { - num_inputs_check(context, 2, 2); + num_inputs_check(context, 2, 3); auto x = context.get_input(0); auto y = context.get_input(1); x = context.mark_node(std::make_shared(x, element::boolean)); y = context.mark_node(std::make_shared(y, element::boolean)); // TODO: use bitwise op here when will be supported by openvino - auto or_node = context.mark_node(std::make_shared(x, y)); - return {or_node}; + auto and_node = context.mark_node(std::make_shared(x, y)); + if (!context.input_is_none(2)) { + context.mutate_input(2, and_node); + } + return {and_node}; +}; + +OutputVector translate_not(const NodeContext& context) { + num_inputs_check(context, 1, 2); + auto x = context.get_input(0); + x = context.mark_node(std::make_shared(x, element::boolean)); + // TODO: use bitwise op here when will be supported by openvino + auto not_node = context.mark_node(std::make_shared(x)); + if (!context.input_is_none(1)) { + context.mutate_input(1, not_node); + } + return {not_node}; +}; + +OutputVector translate_xor(const NodeContext& context) { + num_inputs_check(context, 2, 3); + auto x = context.get_input(0); + auto y = context.get_input(1); + x = context.mark_node(std::make_shared(x, element::boolean)); + y = context.mark_node(std::make_shared(y, element::boolean)); + // TODO: use bitwise op here when will be supported by openvino + auto xor_node = context.mark_node(std::make_shared(x, y)); + if (!context.input_is_none(2)) { + context.mutate_input(2, xor_node); + } + return {xor_node}; }; } // namespace op diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 4dd60e01b71d7f..8e552bc0184c81 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -109,6 +109,7 @@ OP_CONVERTER(translate_new_zeros); OP_CONVERTER(translate_nms); OP_CONVERTER(translate_nonzero); OP_CONVERTER(translate_norm); +OP_CONVERTER(translate_not); OP_CONVERTER(translate_numel); OP_CONVERTER(translate_one_hot); OP_CONVERTER(translate_ones); @@ -185,6 +186,7 @@ OP_CONVERTER(translate_quantized_cat); OP_CONVERTER(translate_quantized_convnd); OP_CONVERTER(translate_quantized_convnd_relu); OP_CONVERTER(translate_quantized_linear); +OP_CONVERTER(translate_xor); // Torch FX Translations OP_CONVERTER(translate_arange_fx); OP_CONVERTER(translate_batch_norm_fx); @@ -340,6 +342,10 @@ const std::map get_supported_ops_ts() { {"aten::linspace", op::translate_linspace}, {"aten::log", op::translate_log}, {"aten::log_", op::inplace_op}, + {"aten::logical_and", op::translate_and}, + {"aten::logical_or", op::translate_or}, + {"aten::logical_not", op::translate_not}, + {"aten::logical_xor", op::translate_xor}, {"aten::log_softmax", op::translate_log_softmax}, {"aten::log2", op::translate_log2}, {"aten::log2_", op::inplace_op},