diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 8dba0a33d9..8893c26947 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -117,6 +117,7 @@ def __contains__(self, op): exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.full_like.default, ] diff --git a/backends/vulkan/runtime/graph/ops/impl/Full.cpp b/backends/vulkan/runtime/graph/ops/impl/Full.cpp index 032aab88bb..3c85725e14 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Full.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Full.cpp @@ -20,14 +20,20 @@ void resize_full_node( const std::vector& args, const std::vector& extra_args) { vTensorPtr out = graph->get_tensor(args[0].refs[0]); - std::vector out_sizes = *graph->get_int_list(extra_args[0]); + std::vector out_sizes; + if (graph->val_is_tensor(extra_args[0])) { + out_sizes = graph->get_tensor(extra_args[0])->sizes(); + } else { + out_sizes = *graph->get_int_list(extra_args[0]); + } out->virtual_resize(out_sizes); } +// size_or_in is IntListPtr when op is full and vTensorPtr if op is full_like void add_full_node( ComputeGraph& graph, - const ValueRef size, + const ValueRef size_or_in, const ValueRef fill_value, const ValueRef out) { float fill_value_val = graph.extract_scalar(fill_value); @@ -54,15 +60,20 @@ void add_full_node( {SV(t_out->packed_dim_whcn_idx())}, // Resizing Logic resize_full_node, - {size})); + {size_or_in})); } void full(ComputeGraph& graph, const std::vector& args) { return add_full_node(graph, args[0], args[1], args[6]); } +void full_like(ComputeGraph& graph, const std::vector& args) { + return add_full_node(graph, args[0], args[1], args[7]); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten.full.default, full); + VK_REGISTER_OP(aten.full_like.default, full_like); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 3803f73a60..0361e39087 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -303,6 +303,17 @@ def get_full_inputs(): return test_suite +def get_full_like_inputs(): + test_suite = VkTestSuite( + [ + ((S1, S2), 4.0), + ((M, M1, M2), -3.5), + ((L, M, M1, M2), 9.876), + ] + ) + return test_suite + + def get_select_int_inputs(): test_suite = VkTestSuite( [ @@ -909,6 +920,7 @@ def get_arange_inputs(): "aten.convolution.default": get_conv_inputs(), "aten.native_layer_norm.default": get_native_layer_norm_inputs(), "aten.full.default": get_full_inputs(), + "aten.full_like.default": get_full_like_inputs(), "aten.select.int": get_select_int_inputs(), "aten.select_copy.int": get_select_int_inputs(), "aten.permute.default": get_permute_inputs(), diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 8210f66a9a..decc602086 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -971,6 +971,22 @@ def forward(self, x): memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + def test_vulkan_backend_full_like(self): + class FullLikeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.full_like(x, 42.0) + + sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),) + + self.lower_module_and_test_output( + FullLikeModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + def test_vulkan_backend_reshape(self): class ReshapeModule(torch.nn.Module): def __init__(self):