Skip to content

aten.full.default #3013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.native_layer_norm.default,
# Other
operator.getitem,
exir_ops.edge.aten.full.default,
]
return supported

Expand Down
61 changes: 61 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/full.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}

#define to_tensor_idx to_tensor_idx_${PACKING}
#define get_packed_dim get_packed_dim_${PACKING}

#include "broadcasting_utils.h"
#include "indexing_utils.h"

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;

layout(set = 0, binding = 1) uniform PRECISION restrict GpuSizes {
ivec4 data;
}
gpu_sizes;

layout(set = 0, binding = 2) uniform PRECISION restrict CpuSizes {
ivec4 data;
}
cpu_sizes;

layout(set = 0, binding = 3) uniform PRECISION restrict FillVal {
float data;
}
fill_value;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data);

if (any(greaterThanEqual(idx, gpu_sizes.data))) {
return;
}

VEC4_T outtex = VEC4_T(fill_value.data);
const int packed_dim_size = get_packed_dim(cpu_sizes.data);
int packed_idx = get_packed_dim(idx);

if (packed_idx + 3 >= packed_dim_size) {
ivec4 packed_ind = ivec4(packed_idx) + ivec4(0, 1, 2, 3);
VEC4_T valid_idx = VEC4_T(lessThan(packed_ind, ivec4(packed_dim_size)));
outtex = outtex * valid_idx;
}

imageStore(image_out, ${get_pos[NDIM]("pos")}, outtex);
}
17 changes: 17 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

full:
parameter_names_with_default_values:
NDIM: 3
DTYPE: float
PACKING: C_packed
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: full
68 changes: 68 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Full.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void resize_full_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
std::vector<int64_t> out_sizes = *graph->get_int_list(extra_args[0]);

out->virtual_resize(out_sizes);
}

void add_full_node(
ComputeGraph& graph,
const ValueRef size,
const ValueRef fill_value,
const ValueRef out) {
float fill_value_val = graph.extract_scalar<float>(fill_value);
vTensorPtr t_out = graph.get_tensor(out);

api::utils::uvec3 global_size = t_out->extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

std::string kernel_name("full");
kernel_name.reserve(kShaderNameReserve);

add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_size,
local_size,
// Inputs and Outputs
{{out, api::MemoryAccessType::WRITE}},
// Shader params buffers
{t_out->gpu_sizes_ubo(),
t_out->cpu_sizes_ubo(),
graph.create_params_buffer(fill_value_val)},
// Resizing
resize_full_node,
{size}));
}

void full(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_full_node(graph, args[0], args[1], args[6]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.full.default, full);
}

} // namespace vkcompute
7 changes: 6 additions & 1 deletion backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,12 @@ def get_or_create_value_for(self, arg: _Argument):
if arg in self.node_to_value_ids:
return self.node_to_value_ids[arg]
return self.create_node_value(arg)
elif isinstance(arg, NoneType):
elif (
isinstance(arg, NoneType)
or isinstance(arg, torch.device)
or isinstance(arg, torch.dtype)
or isinstance(arg, torch.layout)
):
return self.create_null_value()
elif isinstance(arg, _ScalarType):
return self.create_scalar_value(arg)
Expand Down
13 changes: 13 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ def get_native_layer_norm_inputs():
return test_suite


def get_full_inputs():
test_suite = VkTestSuite(
[
([S1, S2], 42.0),
([M, M1, M2], 3.14),
([L, M, M1, M2], 2.72),
]
)
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
return test_suite


test_suites = {
"aten.add.Tensor": get_binary_elementwise_inputs(),
"aten.sub.Tensor": get_binary_elementwise_inputs(),
Expand All @@ -139,6 +151,7 @@ def get_native_layer_norm_inputs():
"aten.max_pool2d_with_indices.default": get_pool2d_inputs(),
"aten.convolution.default": get_conv2d_inputs(),
"aten.native_layer_norm.default": get_native_layer_norm_inputs(),
"aten.full.default": get_full_inputs(),
}

prepacked_args = {"aten.mm.default": {"mat2"}}
Expand Down
17 changes: 14 additions & 3 deletions backends/vulkan/test/op_tests/utils/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
AT_INT_ARRAY_REF,
AT_SCALAR,
AT_TENSOR,
AT_TENSOR_OPT,
BOOL,
CppTestFileGen,
DOUBLE,
INT,
OPT_AT_TENSOR,
OPT_BOOL,
OPT_DEVICE,
OPT_LAYOUT,
OPT_SCALARTYPE,
TestSuite,
TestSuiteGen,
THREE_TENSOR_TUPLE,
Expand Down Expand Up @@ -180,7 +184,6 @@ def create_aten_fn_call(self) -> str:
func_call = generate_static_dispatch_backend_call(
self.f_sig, self.f, TestSuiteGen.backend_key
)[7:].replace("::cpu", "")

return func_call

def create_out_src(self) -> str:
Expand All @@ -205,7 +208,7 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901

cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef"

if ref.src_cpp_type == AT_TENSOR_OPT:
if ref.src_cpp_type == OPT_AT_TENSOR:
ret_str = f"{cpp_type} {ref.name} = "
ret_str += f"!{ref.src_cpp_name}.has_value() ? "
ret_str += f"{self.graph}{self.dot}add_none() : "
Expand Down Expand Up @@ -241,6 +244,13 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
ret_str += f"add_scalar<int64_t>({ref.src_cpp_name}); \n"
elif ref.src_cpp_type == DOUBLE:
ret_str += f"add_scalar<double>({ref.src_cpp_name}); \n"
elif (
ref.src_cpp_type == OPT_SCALARTYPE
or ref.src_cpp_type == OPT_LAYOUT
or ref.src_cpp_type == OPT_DEVICE
or ref.src_cpp_type == OPT_BOOL
):
ret_str += "add_none(); \n"
elif ref.src_cpp_type == TWO_TENSOR_TUPLE:
ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second}}); \n"
elif ref.src_cpp_type == THREE_TENSOR_TUPLE:
Expand Down Expand Up @@ -457,6 +467,7 @@ def gen_parameterization(self) -> str:
#include <tuple>

using namespace vkcompute;
using TensorOptions = at::TensorOptions;

api::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
switch(at_scalartype) {
Expand Down
17 changes: 14 additions & 3 deletions backends/vulkan/test/op_tests/utils/codegen_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
AT_INT_ARRAY_REF = "at::IntArrayRef"
AT_SCALAR = "at::Scalar"
AT_TENSOR = "at::Tensor"
AT_TENSOR_OPT = "::std::optional<at::Tensor>"
BOOL = "bool"
INT = "int64_t"
DOUBLE = "double"
INT = "int64_t"
OPT_AT_TENSOR = "::std::optional<at::Tensor>"
OPT_BOOL = "::std::optional<bool>"
OPT_DEVICE = "::std::optional<at::Device>"
OPT_LAYOUT = "::std::optional<at::Layout>"
OPT_SCALARTYPE = "::std::optional<at::ScalarType>"
TWO_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor>"
THREE_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor,at::Tensor>"

Expand Down Expand Up @@ -120,7 +124,7 @@ def create_input_data(self, arg: Argument, data: Any) -> str:

if cpp_type == AT_TENSOR:
ret_str += f"make_rand_tensor({init_list_str(data)}, test_dtype);"
elif cpp_type == AT_TENSOR_OPT:
elif cpp_type == OPT_AT_TENSOR:
if str(data) == "None":
ret_str += "std::nullopt;"
else:
Expand All @@ -135,6 +139,13 @@ def create_input_data(self, arg: Argument, data: Any) -> str:
ret_str += f"{str(data).lower()};"
elif cpp_type == DOUBLE:
ret_str += f"{str(data).lower()};"
elif (
cpp_type == OPT_SCALARTYPE
or cpp_type == OPT_LAYOUT
or cpp_type == OPT_DEVICE
or cpp_type == OPT_BOOL
):
ret_str += "std::nullopt;"
else:
raise RuntimeError(f"Unsupported cpp type {cpp_type}")
return ret_str + "\n"
Expand Down
16 changes: 16 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,3 +665,19 @@ def forward(self, x):
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_full(self):
class FullModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.full(x.shape, 42.0)

sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)

self.lower_module_and_test_output(
FullModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)