From fb3d4e971b1acfa53f175f8736f8f570ba586c89 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Fri, 22 Sep 2023 13:54:44 +0400 Subject: [PATCH] [PT FE]: support aten::broadcast_tensors (#19994) * broadcast tensors * [PT FE]: support aten::broadcast_tensors * apply review comments * remove add --- .../transforms/prim_list_unpack_replacer.cpp | 23 ++++++++++ .../pytorch_tests/test_broadcast_tensors.py | 45 +++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 tests/layer_tests/pytorch_tests/test_broadcast_tensors.py diff --git a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp index 312ed1457c985d..e5fa463af31d00 100644 --- a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp @@ -173,6 +173,29 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { } } + if (auto broadcast_tensors = cast_fw_node(input_node, "aten::broadcast_tensors")) { + auto tensors = cast_fw_node(broadcast_tensors->input_value(0).get_node_shared_ptr(), "prim::ListConstruct"); + if (!tensors) { + add_exception_to_fw_node(input_node, + "aten::broadcast_tensors: only prim::ListConstruct supported as input."); + return false; + } + Output final_shape_t = opset10::Constant::create(element::i32, Shape{}, {0}); + for (auto input : tensors->inputs()) { + auto tensor_shape = rg.make(input.get_source_output(), element::i32); + final_shape_t = + rg.make(final_shape_t, tensor_shape, ov::op::BroadcastType::BIDIRECTIONAL); + } + auto final_shape = rg.make(final_shape_t, element::i32); + OutputVector outputs; + for (auto input : tensors->inputs()) { + outputs.push_back(rg.make(input.get_source_output(), final_shape)); + } + copy_runtime_info_and_name(list_unpack, rg.get(), {input_node}); + replace_node(list_unpack, outputs); + return true; + } + if (auto unbind = cast_fw_node(input_node, "aten::unbind")) { const auto input = unbind->get_input_source_output(0); const auto axis = unbind->get_input_source_output(1); diff --git a/tests/layer_tests/pytorch_tests/test_broadcast_tensors.py b/tests/layer_tests/pytorch_tests/test_broadcast_tensors.py new file mode 100644 index 00000000000000..b405f0114b5a4b --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_broadcast_tensors.py @@ -0,0 +1,45 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestBroadcastTensors(PytorchLayerTest): + def _prepare_input(self, x_shape, y_shape, z_shape, x_dtype, y_dtype, z_dtype): + import numpy as np + return ( + np.random.randn(*x_shape).astype(x_dtype), + np.random.randn(*y_shape).astype(y_dtype), + np.random.randn(*z_shape).astype(z_dtype)) + + def create_model(self): + import torch + + class aten_broadcast_tensors(torch.nn.Module): + def __init__(self): + super(aten_broadcast_tensors, self).__init__() + + def forward(self, x, y, z): + x1, y1, z1 = torch.broadcast_tensors(x, y, z) + return x1, y1, z1 + + ref_net = None + + return aten_broadcast_tensors(), ref_net, ("prim::ListConstruct", "aten::broadcast_tensors", "prim::ListUnpack") + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("x_shape", [[1, ], [2, 1], [2, 2, 1]]) + @pytest.mark.parametrize("y_shape", [[2, ], [1, 2], [1, 2, 1]]) + @pytest.mark.parametrize("z_shape", [[1, 2], [2, 2], [1, 2, 1, 1]]) + @pytest.mark.parametrize("x_dtype", ["float32", "int32"]) + @pytest.mark.parametrize("y_dtype", ["float32", "int32"]) + @pytest.mark.parametrize("z_dtype", ["float32", "int32"]) + def test_broadcast_tensors(self, x_shape, y_shape, z_shape, x_dtype, y_dtype, z_dtype, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={ + "x_shape": x_shape, "x_dtype": x_dtype, + "y_shape": y_shape, "y_dtype": y_dtype, + "z_shape": z_shape, "z_dtype": z_dtype, + })