From e6daf8792a547668aee8475a0775efc4d0e51583 Mon Sep 17 00:00:00 2001 From: Yujie Hui Date: Tue, 4 Jun 2024 17:22:17 -0700 Subject: [PATCH] aten.full_like.default (#3843) Summary: Implement aten.full_like.default, which is required in OCR full model. Reuse the implementation of aten.full.default ``` func: full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor func: full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor ``` The major difference between full and full_like is the first argument, which full is an integer list and full_like is an input tensor. We can reuse lots of code here. And to support dynamic reshape, just add a condition in resize_full_node to determine the out_sizes. Reviewed By: yipjustin Differential Revision: D58121891 --- backends/vulkan/partitioner/supported_ops.py | 1 + backends/vulkan/runtime/graph/ops/impl/Full.cpp | 16 +++++++++++++--- backends/vulkan/test/op_tests/cases.py | 12 ++++++++++++ backends/vulkan/test/test_vulkan_delegate.py | 16 ++++++++++++++++ 4 files changed, 42 insertions(+), 3 deletions(-) diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 8dba0a33d94..8893c26947c 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -117,6 +117,7 @@ def __contains__(self, op): exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.full_like.default, ] diff --git a/backends/vulkan/runtime/graph/ops/impl/Full.cpp b/backends/vulkan/runtime/graph/ops/impl/Full.cpp index 032aab88bb4..2077fb1b9a8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Full.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Full.cpp @@ -20,14 +20,19 @@ void resize_full_node( const std::vector& args, const std::vector& extra_args) { vTensorPtr out = graph->get_tensor(args[0].refs[0]); - std::vector out_sizes = *graph->get_int_list(extra_args[0]); + std::vector out_sizes; + if (graph->val_is_tensor(extra_args[0])) { + out_sizes = graph->get_tensor(extra_args[0])->sizes(); + } else { + out_sizes = *graph->get_int_list(extra_args[0]); + } out->virtual_resize(out_sizes); } void add_full_node( ComputeGraph& graph, - const ValueRef size, + const ValueRef size_or_in, //IntListPtr when op is full and vTensorPtr if is full_like const ValueRef fill_value, const ValueRef out) { float fill_value_val = graph.extract_scalar(fill_value); @@ -54,15 +59,20 @@ void add_full_node( {SV(t_out->packed_dim_whcn_idx())}, // Resizing Logic resize_full_node, - {size})); + {size_or_in})); } void full(ComputeGraph& graph, const std::vector& args) { return add_full_node(graph, args[0], args[1], args[6]); } +void full_like(ComputeGraph& graph, const std::vector& args) { + return add_full_node(graph, args[0], args[1], args[7]); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten.full.default, full); + VK_REGISTER_OP(aten.full_like.default, full_like); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 3803f73a602..0361e390874 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -303,6 +303,17 @@ def get_full_inputs(): return test_suite +def get_full_like_inputs(): + test_suite = VkTestSuite( + [ + ((S1, S2), 4.0), + ((M, M1, M2), -3.5), + ((L, M, M1, M2), 9.876), + ] + ) + return test_suite + + def get_select_int_inputs(): test_suite = VkTestSuite( [ @@ -909,6 +920,7 @@ def get_arange_inputs(): "aten.convolution.default": get_conv_inputs(), "aten.native_layer_norm.default": get_native_layer_norm_inputs(), "aten.full.default": get_full_inputs(), + "aten.full_like.default": get_full_like_inputs(), "aten.select.int": get_select_int_inputs(), "aten.select_copy.int": get_select_int_inputs(), "aten.permute.default": get_permute_inputs(), diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 8210f66a9ac..decc602086b 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -971,6 +971,22 @@ def forward(self, x): memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + def test_vulkan_backend_full_like(self): + class FullLikeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.full_like(x, 42.0) + + sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),) + + self.lower_module_and_test_output( + FullLikeModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + def test_vulkan_backend_reshape(self): class ReshapeModule(torch.nn.Module): def __init__(self):