diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 32c6c093e34b63..041231cc8c40fb 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -263,10 +263,12 @@ const std::map get_supported_ops() { {"Switch", CreatorFunction(translate_switch_op)}, {"TensorListFromTensor", CreatorFunction(translate_tensor_list_from_tensor_op)}, {"TensorListGetItem", CreatorFunction(translate_tensor_list_get_item_op)}, + {"TensorListLength", CreatorFunction(translate_tensor_list_length_op)}, {"TensorListPushBack", CreatorFunction(translate_tensor_list_push_back_op)}, {"TensorListSetItem", CreatorFunction(translate_tensor_list_set_item_op)}, {"TensorListStack", CreatorFunction(translate_tensor_list_stack_op)}, {"TensorListReserve", CreatorFunction(translate_tensor_list_reserve_op)}, + {"TensorListResize", CreatorFunction(translate_tensor_list_resize_op)}, {"Tile", CreatorFunction(translate_tile_op)}, {"TopK", CreatorFunction(translate_top_k_op)}, {"TopKV2", CreatorFunction(translate_top_k_v2_op)}, diff --git a/src/frontends/tensorflow_common/include/common_op_table.hpp b/src/frontends/tensorflow_common/include/common_op_table.hpp index 2baab49b74a1cc..0f0fe5b1840462 100644 --- a/src/frontends/tensorflow_common/include/common_op_table.hpp +++ b/src/frontends/tensorflow_common/include/common_op_table.hpp @@ -132,10 +132,12 @@ OP_CONVERTER(translate_strided_slice_op); OP_CONVERTER(translate_sqrt_op); OP_CONVERTER(translate_tensor_list_from_tensor_op); OP_CONVERTER(translate_tensor_list_get_item_op); +OP_CONVERTER(translate_tensor_list_length_op); OP_CONVERTER(translate_tensor_list_push_back_op); OP_CONVERTER(translate_tensor_list_reserve_op); OP_CONVERTER(translate_tensor_list_set_item_op); OP_CONVERTER(translate_tensor_list_stack_op); +OP_CONVERTER(translate_tensor_list_resize_op); OP_CONVERTER(translate_tile_op); OP_CONVERTER_NAMED(translate_top_k_op); OP_CONVERTER_NAMED(translate_top_k_v2_op); diff --git a/src/frontends/tensorflow_common/src/op/tensor_list_operations.cpp b/src/frontends/tensorflow_common/src/op/tensor_list_operations.cpp index cfd62b8e5fc7e4..57976994094b17 100644 --- a/src/frontends/tensorflow_common/src/op/tensor_list_operations.cpp +++ b/src/frontends/tensorflow_common/src/op/tensor_list_operations.cpp @@ -3,12 +3,26 @@ // #include "common_op_table.hpp" -#include "openvino/opsets/opset10.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/maximum.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_update.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/unsqueeze.hpp" #include "utils.hpp" using namespace std; using namespace ov; -using namespace opset10; +using namespace ov::op; namespace ov { namespace frontend { @@ -22,7 +36,7 @@ OutputVector translate_tensor_list_reserve_op(const NodeContext& node) { // all tensor elements will be saved in the flatten form in the list // because we want to cover a case of dynamic rank tensor list // the real shape of the tensor elements will be restored by TensorListStack operations - auto empty_constant = make_shared(element_dtype, Shape{0, 0}); + auto empty_constant = make_shared(element_dtype, Shape{0, 0}); set_node_name(node.get_name(), empty_constant); return {empty_constant}; } @@ -41,14 +55,14 @@ OutputVector translate_tensor_list_stack_op(const NodeContext& node) { auto element_shape = node.get_input(1); // compute number of tensor elements in the list - Output num_elements = make_shared(input_handle, element::i32); - auto zero_const = make_shared(element::i32, Shape{1}, 0); - auto one_const = make_shared(element::i32, Shape{1}, 1); - num_elements = make_shared(num_elements, zero_const, one_const, one_const); + Output num_elements = make_shared(input_handle, element::i32); + auto zero_const = make_shared(element::i32, Shape{1}, 0); + auto one_const = make_shared(element::i32, Shape{1}, 1); + num_elements = make_shared(num_elements, zero_const, one_const, one_const); // restore the real shape of tensor elements - auto new_shape = make_shared(OutputVector{num_elements, element_shape}, 0); - auto reshape = make_shared(input_handle, new_shape, false); + auto new_shape = make_shared(OutputVector{num_elements, element_shape}, 0); + auto reshape = make_shared(input_handle, new_shape, false); set_node_name(node.get_name(), reshape); return {reshape}; @@ -62,12 +76,12 @@ OutputVector translate_tensor_list_get_item_op(const NodeContext& node) { auto element_dtype = node.get_attribute("element_dtype"); // squeeze index tensor to have a scalar - index = make_shared(index); + index = make_shared(index); // gather tensor element by the required position - auto gather_axis = make_shared(element::i32, Shape{1}, 0); - Output tensor_element = make_shared(input_handle, index, gather_axis); - tensor_element = make_shared(tensor_element, element_dtype); + auto gather_axis = make_shared(element::i32, Shape{1}, 0); + Output tensor_element = make_shared(input_handle, index, gather_axis); + tensor_element = make_shared(tensor_element, element_dtype); set_node_name(node.get_name(), tensor_element.get_node_shared_ptr()); return {tensor_element}; @@ -80,44 +94,44 @@ OutputVector translate_tensor_list_set_item_op(const NodeContext& node) { auto item = node.get_input(2); // squeeze index tensor to have a scalar - index = make_shared(index); + index = make_shared(index); // flatten item to be inserted since // the tensor list saves elements in the flatten form - auto new_item_shape = make_shared(element::i32, Shape{1}, -1); - item = make_shared(item, new_item_shape, false); - auto item_shape = make_shared(item, element::i32); + auto new_item_shape = make_shared(element::i32, Shape{1}, -1); + item = make_shared(item, new_item_shape, false); + auto item_shape = make_shared(item, element::i32); // reshape the tensor list to the shape [num_elements, -1] // that is because in the first iteration we have empty constant of a shape [0,0] - auto minus_one = make_shared(element::i32, Shape{1}, -1); - auto new_input_handle_shape = make_shared(OutputVector{minus_one, item_shape}, 0); - input_handle = make_shared(input_handle, new_input_handle_shape, false); - input_handle = make_shared(input_handle, item); + auto minus_one = make_shared(element::i32, Shape{1}, -1); + auto new_input_handle_shape = make_shared(OutputVector{minus_one, item_shape}, 0); + input_handle = make_shared(input_handle, new_input_handle_shape, false); + input_handle = make_shared(input_handle, item); // compute the current length of the list - Output list_length = make_shared(input_handle, element::i32); - auto zero_const = make_shared(element::i32, Shape{1}, 0); - auto one_const = make_shared(element::i32, Shape{1}, 1); - list_length = make_shared(list_length, zero_const, one_const, one_const); + Output list_length = make_shared(input_handle, element::i32); + auto zero_const = make_shared(element::i32, Shape{1}, 0); + auto one_const = make_shared(element::i32, Shape{1}, 1); + list_length = make_shared(list_length, zero_const, one_const, one_const); // compute a size of the dummy tensor that serves to fill holes in the list // if no tensor is inserted at this position - auto one_const_scalar = make_shared(element::i32, Shape{1}, 1); - auto index_plus_one = make_shared(index, one_const_scalar); - Output max_length = make_shared(list_length, index_plus_one); - Output dummy_tensor_size = make_shared(max_length, list_length); + auto one_const_scalar = make_shared(element::i32, Shape{1}, 1); + auto index_plus_one = make_shared(index, one_const_scalar); + Output max_length = make_shared(list_length, index_plus_one); + Output dummy_tensor_size = make_shared(max_length, list_length); // create dummy tensor and concatenate it auto zero_element = create_same_type_const_scalar(item, 0); - auto dummy_tensor_shape = make_shared(OutputVector{dummy_tensor_size, item_shape}, 0); - auto dummy_tensor = make_shared(zero_element, dummy_tensor_shape); - input_handle = make_shared(OutputVector{input_handle, dummy_tensor}, 0); + auto dummy_tensor_shape = make_shared(OutputVector{dummy_tensor_size, item_shape}, 0); + auto dummy_tensor = make_shared(zero_element, dummy_tensor_shape); + input_handle = make_shared(OutputVector{input_handle, dummy_tensor}, 0); // update the resulted tensor using ScatterUpdate - index = make_shared(index, zero_const); - item = make_shared(item, zero_const); - auto scatter_update = make_shared(input_handle, index, item, zero_const); + index = make_shared(index, zero_const); + item = make_shared(item, zero_const); + auto scatter_update = make_shared(input_handle, index, item, zero_const); set_node_name(node.get_name(), scatter_update); return {scatter_update}; @@ -132,29 +146,82 @@ OutputVector translate_tensor_list_push_back_op(const NodeContext& node) { // the tensor list saves elements in the flatten form // because we want to cover a case of dynamic rank tensor list // the real shape of the tensor elements will be restored by TensorListStack operations - auto new_tensor_shape = make_shared(element::i32, Shape{1}, -1); - tensor = make_shared(tensor, new_tensor_shape, false); - auto tensor_shape = make_shared(tensor, element::i32); + auto new_tensor_shape = make_shared(element::i32, Shape{1}, -1); + tensor = make_shared(tensor, new_tensor_shape, false); + auto tensor_shape = make_shared(tensor, element::i32); // reshape the tensor list to the shape [num_elements, -1] // that is because in the first iteration we have empty constant of a shape [0,0] - Output num_elements = make_shared(input_handle, element::i32); - auto zero_const = make_shared(element::i32, Shape{1}, 0); - auto one_const = make_shared(element::i32, Shape{1}, 1); - num_elements = make_shared(num_elements, zero_const, one_const, one_const); - auto new_input_handle_shape = make_shared(OutputVector{num_elements, tensor_shape}, 0); - input_handle = make_shared(input_handle, new_input_handle_shape, false); + Output num_elements = make_shared(input_handle, element::i32); + auto zero_const = make_shared(element::i32, Shape{1}, 0); + auto one_const = make_shared(element::i32, Shape{1}, 1); + num_elements = make_shared(num_elements, zero_const, one_const, one_const); + auto new_input_handle_shape = make_shared(OutputVector{num_elements, tensor_shape}, 0); + input_handle = make_shared(input_handle, new_input_handle_shape, false); // unsqueeze tensor to be inserted into the list - tensor = make_shared(tensor, zero_const); + tensor = make_shared(tensor, zero_const); // insert the tensor into the end - auto updated_list = make_shared(OutputVector{input_handle, tensor}, 0); + auto updated_list = make_shared(OutputVector{input_handle, tensor}, 0); set_node_name(node.get_name(), updated_list); return {updated_list}; } +OutputVector translate_tensor_list_resize_op(const NodeContext& node) { + default_op_checks(node, 2, {"TensorListResize"}); + auto input_handle = node.get_input(0); + auto size = node.get_input(1); + + // create auxiliary constants + auto zero_const = make_shared(element::i32, Shape{1}, 0); + auto one_const = make_shared(element::i32, Shape{1}, 1); + auto max_const = make_shared(element::i32, Shape{1}, numeric_limits::max()); + + // compute the current length of the list and item shape + auto tensor_list_shape = make_shared(input_handle, element::i32); + auto list_length = make_shared(tensor_list_shape, zero_const, one_const, one_const); + auto item_shape = make_shared(tensor_list_shape, one_const, max_const, one_const); + + // compute a size of the dummy tensor to resize + // and clip it by zero if it is negative + Output dummy_tensor_size = make_shared(size, list_length); + dummy_tensor_size = make_shared(dummy_tensor_size, zero_const); + + // create dummy tensor and concatenate it + auto zero_const_same_type = create_same_type_const(input_handle, vector{0.0f}, Shape{}); + auto dummy_tensor_shape = make_shared(OutputVector{dummy_tensor_size, item_shape}, 0); + auto dummy_tensor = make_shared(zero_const_same_type, dummy_tensor_shape); + input_handle = make_shared(OutputVector{input_handle, dummy_tensor}, 0); + + // reshape size to have 1D tensor with one element + auto new_size_shape = make_shared(element::i32, Shape{1}, 1); + size = make_shared(size, new_size_shape, false); + + // resize can also shrink the input tensor list + input_handle = make_shared(input_handle, zero_const, size, one_const); + + set_node_name(node.get_name(), input_handle.get_node_shared_ptr()); + return {input_handle}; +} + +OutputVector translate_tensor_list_length_op(const NodeContext& node) { + default_op_checks(node, 1, {"TensorListLength"}); + auto input_handle = node.get_input(0); + + // create auxiliary constants + auto zero_const = make_shared(element::i32, Shape{1}, 0); + auto one_const = make_shared(element::i32, Shape{1}, 1); + + // compute the current length of the list + auto tensor_list_shape = make_shared(input_handle, element::i32); + auto list_length = make_shared(tensor_list_shape, zero_const, one_const, one_const); + + set_node_name(node.get_name(), list_length); + return {list_length}; +} + } // namespace op } // namespace tensorflow } // namespace frontend diff --git a/tests/layer_tests/tensorflow_tests/test_tf_TensorListLength.py b/tests/layer_tests/tensorflow_tests/test_tf_TensorListLength.py new file mode 100644 index 00000000000000..57b17b3341d750 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_TensorListLength.py @@ -0,0 +1,83 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestTensorListLength(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'x' in inputs_info + x_shape = inputs_info['x'] + inputs_data = {} + inputs_data['x'] = np.random.randint(-10, 10, x_shape).astype(self.input_type) + return inputs_data + + def create_tensor_list_length(self, input_shape, input_type): + self.input_type = input_type + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(input_type, input_shape, 'x') + tensor_list = tf.raw_ops.TensorListFromTensor(tensor=x, + element_shape=tf.constant(input_shape[1:], dtype=tf.int32)) + tf.raw_ops.TensorListLength(input_handle=tensor_list) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(input_shape=[7], input_type=np.float32), + dict(input_shape=[10, 20], input_type=np.float32), + dict(input_shape=[2, 3, 4], input_type=np.int32), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_tensor_list_length_basic(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_tensor_list_length(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) + + +class TestTensorListLengthEmptyList(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + inputs_data = {} + inputs_data['tensor_list_size'] = np.array([self.tensor_list_size], dtype=np.int32) + return inputs_data + + def create_tensor_list_length_empty_list(self, tensor_list_size, element_shape): + self.tensor_list_size = tensor_list_size + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + tensor_list_size = tf.compat.v1.placeholder(tf.int32, [1], 'tensor_list_size') + tf_element_shape = tf.constant(element_shape, dtype=tf.int32) + tensor_shape = tf.concat([tensor_list_size, tf_element_shape], 0) + tensor = tf.broadcast_to(tf.constant(0.0, dtype=tf.float32), tensor_shape) + tensor_list = tf.raw_ops.TensorListFromTensor(tensor=tensor, + element_shape=tf_element_shape) + tf.raw_ops.TensorListLength(input_handle=tensor_list) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_tensor_list_length_empty_list = [ + dict(tensor_list_size=0, element_shape=[]), + dict(tensor_list_size=0, element_shape=[2, 3]), + ] + + @pytest.mark.parametrize("params", test_data_tensor_list_length_empty_list) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_tensor_list_length_empty_list(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_tensor_list_length_empty_list(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_TensorListResize.py b/tests/layer_tests/tensorflow_tests/test_tf_TensorListResize.py new file mode 100644 index 00000000000000..39bdca06dee004 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_TensorListResize.py @@ -0,0 +1,49 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestTensorListResize(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'x' in inputs_info + x_shape = inputs_info['x'] + inputs_data = {} + inputs_data['x'] = np.random.randint(-10, 10, x_shape).astype(self.input_type) + return inputs_data + + def create_tensor_list_resize(self, input_shape, input_type, new_size): + self.input_type = input_type + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(input_type, input_shape, 'x') + tensor_list = tf.raw_ops.TensorListFromTensor(tensor=x, + element_shape=tf.constant(input_shape[1:], dtype=tf.int32)) + tf_new_size = tf.constant(new_size, dtype=tf.int32) + tensor_list_resize = tf.raw_ops.TensorListResize(input_handle=tensor_list, size=tf_new_size) + element_shape = tf.constant(input_shape[1:], dtype=tf.int32) + tf.raw_ops.TensorListStack(input_handle=tensor_list_resize, element_shape=element_shape, + element_dtype=input_type) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(input_shape=[7], input_type=np.float32, new_size=3), + dict(input_shape=[10, 20], input_type=np.float32, new_size=20), + dict(input_shape=[2, 3, 4], input_type=np.int32, new_size=5), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_tensor_list_resize_basic(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_tensor_list_resize(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api)