Skip to content

Commit 95ce3e6

Browse files
authored
[ET-VK] support biases in buffer-based linear shader
Differential Revision: D69247282 Pull Request resolved: #8284
1 parent dad2ba0 commit 95ce3e6

File tree

5 files changed

+116
-36
lines changed

5 files changed

+116
-36
lines changed

backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.glsl

+17-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
$if HAS_BIAS:
14+
#define HAS_BIAS
15+
1316
#define T ${buffer_scalar_type(DTYPE)}
1417

1518
${define_required_extensions(DTYPE)}
@@ -19,13 +22,17 @@ layout(std430) buffer;
1922
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
2023
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer")}
2124
${layout_declare_tensor(B, "r", "t_mat2", DTYPE, "buffer")}
25+
$if HAS_BIAS:
26+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer")}
2227
${layout_declare_ubo(B, "ivec4", "out_sizes")}
2328
${layout_declare_ubo(B, "ivec4", "out_strides")}
2429
${layout_declare_ubo(B, "ivec4", "mat1_sizes")}
2530
${layout_declare_ubo(B, "ivec4", "mat1_strides")}
2631
${layout_declare_ubo(B, "ivec4", "mat2_sizes")}
2732
${layout_declare_ubo(B, "ivec4", "mat2_strides")}
2833
${layout_declare_ubo(B, "int", "out_numel")}
34+
$if HAS_BIAS:
35+
${layout_declare_ubo(B, "float", "alpha", "float", "beta")}
2936

3037
#include "indexing_utils.h"
3138

@@ -34,25 +41,25 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3441
${layout_declare_spec_const(C, "int", "mat2_is_transposed", "0")}
3542

3643
void main() {
37-
const ivec4 out_bufix = ivec4(
44+
const ivec4 out_tidx = ivec4(
3845
gl_GlobalInvocationID.x,
3946
gl_GlobalInvocationID.y,
4047
gl_GlobalInvocationID.z % out_sizes.z,
4148
gl_GlobalInvocationID.z / out_sizes.z);
4249

43-
if (any(greaterThanEqual(out_bufix, out_sizes))) {
50+
if (any(greaterThanEqual(out_tidx, out_sizes))) {
4451
return;
4552
}
4653

4754
int mat1_bufi = tidx_to_bufi(
48-
ivec4(0, out_bufix.y, out_bufix.z, out_bufix.w), mat1_strides);
55+
ivec4(0, out_tidx.y, out_tidx.z, out_tidx.w), mat1_strides);
4956
int mat2_bufi;
5057
if (mat2_is_transposed > 0) {
5158
mat2_bufi = tidx_to_bufi(
52-
ivec4(0, out_bufix.x, 0, 0), mat2_strides);
59+
ivec4(0, out_tidx.x, 0, 0), mat2_strides);
5360
} else {
5461
mat2_bufi = tidx_to_bufi(
55-
ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides);
62+
ivec4(out_tidx.x, 0, out_tidx.z, out_tidx.w), mat2_strides);
5663
}
5764

5865
int mat2_stride;
@@ -70,6 +77,10 @@ void main() {
7077
mat2_bufi += mat2_stride;
7178
}
7279

73-
const int out_bufi = tidx_to_bufi(out_bufix, out_strides);
80+
const int out_bufi = tidx_to_bufi(out_tidx, out_strides);
81+
#ifdef HAS_BIAS
82+
t_out[out_bufi] = T(alpha) * T(sum) + T(beta) * t_bias[out_tidx.x];
83+
#else
7484
t_out[out_bufi] = T(sum);
85+
#endif // HAS_BIAS
7586
}

backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.yaml

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
matmul_naive_buffer:
7+
addmm_naive_buffer:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
STORAGE: buffer
11+
HAS_BIAS: false
1112
generate_variant_forall:
1213
DTYPE:
1314
- VALUE: float
1415
- VALUE: half
1516
shader_variants:
1617
- NAME: matmul_naive_buffer
18+
- NAME: addmm_naive_buffer
19+
HAS_BIAS: true

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

+69-5
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ struct Params final {
8484
float beta;
8585
};
8686

87-
void add_addmm_naive_node(
87+
void add_addmm_naive_texture_node(
8888
ComputeGraph& graph,
8989
const ValueRef self_data,
9090
const ValueRef mat1,
@@ -134,6 +134,69 @@ void add_addmm_naive_node(
134134
{mat2_is_transposed}));
135135
}
136136

137+
void add_addmm_naive_buffer_node(
138+
ComputeGraph& graph,
139+
const ValueRef self_data,
140+
const ValueRef mat1,
141+
const ValueRef mat2_data,
142+
const ValueRef beta,
143+
const ValueRef alpha,
144+
const ValueRef out,
145+
const Params& params,
146+
const ValueRef mat2_is_transposed) {
147+
(void)beta;
148+
(void)alpha;
149+
ValueRef mat2 = prepack_standard(
150+
graph,
151+
mat2_data,
152+
graph.storage_type_of(out),
153+
utils::kHeightPacked,
154+
/*passthrough = */ true);
155+
ValueRef self = prepack_standard(
156+
graph,
157+
self_data,
158+
graph.storage_type_of(out),
159+
utils::kWidthPacked,
160+
/*passthrough = */ true);
161+
162+
std::string kernel_name = "addmm_naive_buffer";
163+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
164+
165+
utils::uvec3 global_size = {
166+
graph.size_at<uint32_t>(-1, out),
167+
graph.size_at<uint32_t>(-2, out),
168+
graph.size_at<uint32_t>(-3, out) * graph.size_at<uint32_t>(-4, out)};
169+
170+
int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef &&
171+
graph.get_bool(mat2_is_transposed))
172+
? 1
173+
: 0;
174+
175+
graph.execute_nodes().emplace_back(new DispatchNode(
176+
graph,
177+
VK_KERNEL_FROM_STR(kernel_name),
178+
global_size,
179+
graph.create_local_wg_size(global_size),
180+
// Inputs and Outputs
181+
{{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}},
182+
// Shader params buffers
183+
{
184+
graph.sizes_ubo(out),
185+
graph.strides_ubo(out),
186+
graph.sizes_ubo(mat1),
187+
graph.strides_ubo(mat1),
188+
graph.sizes_ubo(mat2),
189+
graph.strides_ubo(mat2),
190+
graph.numel_ubo(out),
191+
graph.create_params_buffer(params),
192+
},
193+
// Specialization Constants
194+
{mat2_is_transposed_val},
195+
// Resizing Logic
196+
resize_addmm_node,
197+
{mat2_is_transposed}));
198+
}
199+
137200
void add_addmm_optimized_node(
138201
ComputeGraph& graph,
139202
const ValueRef self_data,
@@ -246,11 +309,14 @@ void add_addmm_node(
246309
}
247310

248311
Params params = {alpha_val, beta_val};
249-
if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) {
312+
if (graph.is_buffer_storage(out)) {
313+
add_addmm_naive_buffer_node(
314+
graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
315+
} else if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) {
250316
add_addmm_optimized_node(
251317
graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
252318
} else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) {
253-
add_addmm_naive_node(
319+
add_addmm_naive_texture_node(
254320
graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
255321
} else {
256322
VK_THROW("Input should be channel packed or width packed.");
@@ -283,8 +349,6 @@ void linear(ComputeGraph& graph, const std::vector<ValueRef>& args) {
283349
if (graph.val_is_none(bias)) {
284350
return add_matmul_node(graph, input, weight, out, mat2_is_transposed);
285351
} else {
286-
// Buffer implementation does not yet support biases
287-
VK_CHECK_COND(!graph.is_buffer_storage(out));
288352
return add_addmm_node(
289353
graph,
290354
bias,

backends/vulkan/test/op_tests/cases.py

+3-24
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def get_addmm_inputs():
126126
]
127127

128128

129-
def get_linear_texture_inputs():
129+
@register_test_suite("aten.linear.default")
130+
def get_linear_inputs():
130131
MKN_list = common_MKN_list
131132

132133
inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list]
@@ -141,32 +142,10 @@ def get_linear_texture_inputs():
141142
"utils::kWidthPacked",
142143
"utils::kChannelsPacked",
143144
]
144-
test_suite.test_name_suffix = "texture"
145-
return test_suite
146-
147-
148-
def get_linear_buffer_inputs():
149-
MKN_list = common_MKN_list
150-
151-
inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list]
152-
inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list]
153-
154-
test_suite = VkTestSuite(inputs_list)
155-
test_suite.dtypes = ["at::kFloat"]
156-
test_suite.layouts = [
157-
"utils::kWidthPacked",
158-
"utils::kChannelsPacked",
159-
]
160-
test_suite.storage_types = ["utils::kBuffer"]
161-
test_suite.test_name_suffix = "buffer"
145+
test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
162146
return test_suite
163147

164148

165-
@register_test_suite("aten.linear.default")
166-
def get_linear_test_suites():
167-
return [get_linear_texture_inputs(), get_linear_buffer_inputs()]
168-
169-
170149
@register_test_suite("aten._weight_int8pack_mm.default")
171150
def get_weight_int8pack_mm_inputs():
172151
MKN_list = [

backends/vulkan/test/test_vulkan_delegate.py

+23
Original file line numberDiff line numberDiff line change
@@ -1711,3 +1711,26 @@ def forward(self, x):
17111711
(torch.ones(size=[5, 4, 1, 2, 6]),),
17121712
expect_no_delegates=True,
17131713
)
1714+
1715+
def test_vulkan_backend_large_linear_layer(self):
1716+
class LinearModel(torch.nn.Module):
1717+
def __init__(
1718+
self, n_pca_basis: int, n_sh_basis: int, n_gaussians: int
1719+
) -> None:
1720+
super(LinearModel, self).__init__()
1721+
self.fc1 = torch.nn.Linear(
1722+
n_pca_basis, (n_sh_basis + 3 + 3 + 4) * n_gaussians
1723+
)
1724+
1725+
def forward(self, x: torch.Tensor):
1726+
out = self.fc1(x)
1727+
return out
1728+
1729+
n_pca_basis = 64
1730+
n_sh_basis = 6
1731+
n_gaussians = 2**16
1732+
1733+
self.lower_module_and_test_output(
1734+
LinearModel(n_pca_basis, n_sh_basis, n_gaussians),
1735+
(torch.ones(n_pca_basis),),
1736+
)

0 commit comments

Comments
 (0)