Skip to content

Commit

Permalink
aten.full_like.default (#3843)
Browse files Browse the repository at this point in the history
Summary:

Implement aten.full_like.default, which is required in OCR full model. Reuse the implementation of aten.full.default

```
func: full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
func: full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
```

The major difference between full and full_like is the first argument, which full is an integer list and full_like is an input tensor. We can reuse lots of code here. And to support dynamic reshape, just add a condition in resize_full_node to determine the out_sizes.

Reviewed By: yipjustin

Differential Revision: D58121891
  • Loading branch information
Yujie Hui authored and facebook-github-bot committed Jun 5, 2024
1 parent f184329 commit e6daf87
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 3 deletions.
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __contains__(self, op):
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
]


Expand Down
16 changes: 13 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@ void resize_full_node(
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
std::vector<int64_t> out_sizes = *graph->get_int_list(extra_args[0]);
std::vector<int64_t> out_sizes;
if (graph->val_is_tensor(extra_args[0])) {
out_sizes = graph->get_tensor(extra_args[0])->sizes();
} else {
out_sizes = *graph->get_int_list(extra_args[0]);
}

out->virtual_resize(out_sizes);
}

void add_full_node(
ComputeGraph& graph,
const ValueRef size,
const ValueRef size_or_in, //IntListPtr when op is full and vTensorPtr if is full_like
const ValueRef fill_value,
const ValueRef out) {
float fill_value_val = graph.extract_scalar<float>(fill_value);
Expand All @@ -54,15 +59,20 @@ void add_full_node(
{SV(t_out->packed_dim_whcn_idx())},
// Resizing Logic
resize_full_node,
{size}));
{size_or_in}));
}

void full(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_full_node(graph, args[0], args[1], args[6]);
}

void full_like(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_full_node(graph, args[0], args[1], args[7]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.full.default, full);
VK_REGISTER_OP(aten.full_like.default, full_like);
}

} // namespace vkcompute
12 changes: 12 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,17 @@ def get_full_inputs():
return test_suite


def get_full_like_inputs():
test_suite = VkTestSuite(
[
((S1, S2), 4.0),
((M, M1, M2), -3.5),
((L, M, M1, M2), 9.876),
]
)
return test_suite


def get_select_int_inputs():
test_suite = VkTestSuite(
[
Expand Down Expand Up @@ -909,6 +920,7 @@ def get_arange_inputs():
"aten.convolution.default": get_conv_inputs(),
"aten.native_layer_norm.default": get_native_layer_norm_inputs(),
"aten.full.default": get_full_inputs(),
"aten.full_like.default": get_full_like_inputs(),
"aten.select.int": get_select_int_inputs(),
"aten.select_copy.int": get_select_int_inputs(),
"aten.permute.default": get_permute_inputs(),
Expand Down
16 changes: 16 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,22 @@ def forward(self, x):
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_full_like(self):
class FullLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.full_like(x, 42.0)

sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)

self.lower_module_and_test_output(
FullLikeModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_reshape(self):
class ReshapeModule(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit e6daf87

Please sign in to comment.