@@ -84,7 +84,7 @@ struct Params final {
84
84
float beta;
85
85
};
86
86
87
- void add_addmm_naive_node (
87
+ void add_addmm_naive_texture_node (
88
88
ComputeGraph& graph,
89
89
const ValueRef self_data,
90
90
const ValueRef mat1,
@@ -134,6 +134,69 @@ void add_addmm_naive_node(
134
134
{mat2_is_transposed}));
135
135
}
136
136
137
+ void add_addmm_naive_buffer_node (
138
+ ComputeGraph& graph,
139
+ const ValueRef self_data,
140
+ const ValueRef mat1,
141
+ const ValueRef mat2_data,
142
+ const ValueRef beta,
143
+ const ValueRef alpha,
144
+ const ValueRef out,
145
+ const Params& params,
146
+ const ValueRef mat2_is_transposed) {
147
+ (void )beta;
148
+ (void )alpha;
149
+ ValueRef mat2 = prepack_standard (
150
+ graph,
151
+ mat2_data,
152
+ graph.storage_type_of (out),
153
+ utils::kHeightPacked ,
154
+ /* passthrough = */ true );
155
+ ValueRef self = prepack_standard (
156
+ graph,
157
+ self_data,
158
+ graph.storage_type_of (out),
159
+ utils::kWidthPacked ,
160
+ /* passthrough = */ true );
161
+
162
+ std::string kernel_name = " addmm_naive_buffer" ;
163
+ add_dtype_suffix (kernel_name, graph.dtype_of (out));
164
+
165
+ utils::uvec3 global_size = {
166
+ graph.size_at <uint32_t >(-1 , out),
167
+ graph.size_at <uint32_t >(-2 , out),
168
+ graph.size_at <uint32_t >(-3 , out) * graph.size_at <uint32_t >(-4 , out)};
169
+
170
+ int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef &&
171
+ graph.get_bool (mat2_is_transposed))
172
+ ? 1
173
+ : 0 ;
174
+
175
+ graph.execute_nodes ().emplace_back (new DispatchNode (
176
+ graph,
177
+ VK_KERNEL_FROM_STR (kernel_name),
178
+ global_size,
179
+ graph.create_local_wg_size (global_size),
180
+ // Inputs and Outputs
181
+ {{out, vkapi::kWrite }, {{mat1, mat2, self}, vkapi::kRead }},
182
+ // Shader params buffers
183
+ {
184
+ graph.sizes_ubo (out),
185
+ graph.strides_ubo (out),
186
+ graph.sizes_ubo (mat1),
187
+ graph.strides_ubo (mat1),
188
+ graph.sizes_ubo (mat2),
189
+ graph.strides_ubo (mat2),
190
+ graph.numel_ubo (out),
191
+ graph.create_params_buffer (params),
192
+ },
193
+ // Specialization Constants
194
+ {mat2_is_transposed_val},
195
+ // Resizing Logic
196
+ resize_addmm_node,
197
+ {mat2_is_transposed}));
198
+ }
199
+
137
200
void add_addmm_optimized_node (
138
201
ComputeGraph& graph,
139
202
const ValueRef self_data,
@@ -246,11 +309,14 @@ void add_addmm_node(
246
309
}
247
310
248
311
Params params = {alpha_val, beta_val};
249
- if (graph.packed_dim_of (mat1) == WHCN::kChannelsDim ) {
312
+ if (graph.is_buffer_storage (out)) {
313
+ add_addmm_naive_buffer_node (
314
+ graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
315
+ } else if (graph.packed_dim_of (mat1) == WHCN::kChannelsDim ) {
250
316
add_addmm_optimized_node (
251
317
graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
252
318
} else if (graph.packed_dim_of (mat1) == WHCN::kWidthDim ) {
253
- add_addmm_naive_node (
319
+ add_addmm_naive_texture_node (
254
320
graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed);
255
321
} else {
256
322
VK_THROW (" Input should be channel packed or width packed." );
@@ -283,8 +349,6 @@ void linear(ComputeGraph& graph, const std::vector<ValueRef>& args) {
283
349
if (graph.val_is_none (bias)) {
284
350
return add_matmul_node (graph, input, weight, out, mat2_is_transposed);
285
351
} else {
286
- // Buffer implementation does not yet support biases
287
- VK_CHECK_COND (!graph.is_buffer_storage (out));
288
352
return add_addmm_node (
289
353
graph,
290
354
bias,
0 commit comments