Skip to content

Commit e6daf87

Browse files
Yujie Huifacebook-github-bot
authored andcommitted
aten.full_like.default (#3843)
Summary: Implement aten.full_like.default, which is required in OCR full model. Reuse the implementation of aten.full.default ``` func: full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor func: full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor ``` The major difference between full and full_like is the first argument, which full is an integer list and full_like is an input tensor. We can reuse lots of code here. And to support dynamic reshape, just add a condition in resize_full_node to determine the out_sizes. Reviewed By: yipjustin Differential Revision: D58121891
1 parent f184329 commit e6daf87

File tree

4 files changed

+42
-3
lines changed

4 files changed

+42
-3
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __contains__(self, op):
117117
exir_ops.edge.aten.arange.start_step,
118118
exir_ops.edge.aten.clone.default,
119119
exir_ops.edge.aten.full.default,
120+
exir_ops.edge.aten.full_like.default,
120121
]
121122

122123

backends/vulkan/runtime/graph/ops/impl/Full.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,19 @@ void resize_full_node(
2020
const std::vector<ArgGroup>& args,
2121
const std::vector<ValueRef>& extra_args) {
2222
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
23-
std::vector<int64_t> out_sizes = *graph->get_int_list(extra_args[0]);
23+
std::vector<int64_t> out_sizes;
24+
if (graph->val_is_tensor(extra_args[0])) {
25+
out_sizes = graph->get_tensor(extra_args[0])->sizes();
26+
} else {
27+
out_sizes = *graph->get_int_list(extra_args[0]);
28+
}
2429

2530
out->virtual_resize(out_sizes);
2631
}
2732

2833
void add_full_node(
2934
ComputeGraph& graph,
30-
const ValueRef size,
35+
const ValueRef size_or_in, //IntListPtr when op is full and vTensorPtr if is full_like
3136
const ValueRef fill_value,
3237
const ValueRef out) {
3338
float fill_value_val = graph.extract_scalar<float>(fill_value);
@@ -54,15 +59,20 @@ void add_full_node(
5459
{SV(t_out->packed_dim_whcn_idx())},
5560
// Resizing Logic
5661
resize_full_node,
57-
{size}));
62+
{size_or_in}));
5863
}
5964

6065
void full(ComputeGraph& graph, const std::vector<ValueRef>& args) {
6166
return add_full_node(graph, args[0], args[1], args[6]);
6267
}
6368

69+
void full_like(ComputeGraph& graph, const std::vector<ValueRef>& args) {
70+
return add_full_node(graph, args[0], args[1], args[7]);
71+
}
72+
6473
REGISTER_OPERATORS {
6574
VK_REGISTER_OP(aten.full.default, full);
75+
VK_REGISTER_OP(aten.full_like.default, full_like);
6676
}
6777

6878
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,17 @@ def get_full_inputs():
303303
return test_suite
304304

305305

306+
def get_full_like_inputs():
307+
test_suite = VkTestSuite(
308+
[
309+
((S1, S2), 4.0),
310+
((M, M1, M2), -3.5),
311+
((L, M, M1, M2), 9.876),
312+
]
313+
)
314+
return test_suite
315+
316+
306317
def get_select_int_inputs():
307318
test_suite = VkTestSuite(
308319
[
@@ -909,6 +920,7 @@ def get_arange_inputs():
909920
"aten.convolution.default": get_conv_inputs(),
910921
"aten.native_layer_norm.default": get_native_layer_norm_inputs(),
911922
"aten.full.default": get_full_inputs(),
923+
"aten.full_like.default": get_full_like_inputs(),
912924
"aten.select.int": get_select_int_inputs(),
913925
"aten.select_copy.int": get_select_int_inputs(),
914926
"aten.permute.default": get_permute_inputs(),

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,22 @@ def forward(self, x):
971971
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
972972
)
973973

974+
def test_vulkan_backend_full_like(self):
975+
class FullLikeModule(torch.nn.Module):
976+
def __init__(self):
977+
super().__init__()
978+
979+
def forward(self, x):
980+
return torch.full_like(x, 42.0)
981+
982+
sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)
983+
984+
self.lower_module_and_test_output(
985+
FullLikeModule(),
986+
sample_inputs,
987+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
988+
)
989+
974990
def test_vulkan_backend_reshape(self):
975991
class ReshapeModule(torch.nn.Module):
976992
def __init__(self):

0 commit comments

Comments
 (0)