Skip to content

Commit a97069a

Browse files
Neha AbbasNeha Abbas
authored andcommitted
optimized set_rows-- each thread can copy over vec4 at once if possible
1 parent 74b8fc1 commit a97069a

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ struct webgpu_context_struct {
249249
webgpu_pipeline memset_pipeline;
250250
webgpu_pipeline mul_mat_pipeline[30][2];
251251
webgpu_pipeline set_rows_pipeline;
252+
webgpu_pipeline set_rows_f32_no_vec_pipeline;
252253
webgpu_pipeline get_rows_pipeline[30];
253254
webgpu_pipeline get_rows_f32_no_vec_pipeline;
254255
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
@@ -767,7 +768,12 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
767768
size_t max_wg_size = ctx->max_wg_size_x;
768769
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
769770

770-
return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs);
771+
webgpu_pipeline pipeline = ctx->set_rows_pipeline;
772+
// if not evenly divisble by 4, use the non-vectorized version
773+
if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) {
774+
pipeline = ctx->set_rows_f32_no_vec_pipeline;
775+
}
776+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs);
771777
}
772778

773779
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
@@ -1613,7 +1619,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
16131619
}
16141620

16151621
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
1616-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
1622+
// create_pipeline(device, pipeline, shader_code, label, constants)
1623+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_f32_no_vec_pipeline, wgsl_set_rows_f32, "set_rows_f32",
1624+
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
1625+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows_f32_vec, "set_rows_f32_vec",
16171626
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
16181627
}
16191628

ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl renamed to ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,58 @@
1+
2+
#define(VARIANTS)
3+
4+
[
5+
{
6+
"SHADER_SUFFIX": "f32_vec",
7+
"REPLS": {
8+
"TYPE" : "vec4<f32>",
9+
"DST_TYPE": "vec4<f16>",
10+
"BLOCK_SIZE": 4
11+
},
12+
"DECLS": ["F32_VEC"]
13+
},
14+
{
15+
"REPLS": {
16+
"TYPE" : "f32",
17+
"DST_TYPE": "f16",
18+
"BLOCK_SIZE": 1
19+
},
20+
"DECLS": ["F32"]
21+
}
22+
]
23+
24+
#end(VARIANTS)
25+
26+
#define(DECLS)
27+
28+
#decl(F32_VEC)
29+
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
30+
dst[(dst_base / 4) + offset] = vec4<f16>(src[(src_base / 4) + offset]);
31+
}
32+
#enddecl(F32_VEC)
33+
34+
#decl(F32)
35+
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
36+
dst[dst_base + offset] = f16(src[src_base + offset]);
37+
}
38+
#enddecl(F32)
39+
40+
#end(DECLS)
41+
42+
#define(SHADER)
43+
144
enable f16;
245

46+
DECLS
47+
348
@group(0) @binding(0)
4-
var<storage, read_write> src: array<f32>;
49+
var<storage, read_write> src: array<{{TYPE}}>;
550

651
@group(0) @binding(1)
752
var<storage, read_write> idx: array<u32>;
853

954
@group(0) @binding(2)
10-
var<storage, read_write> dst: array<f16>;
55+
var<storage, read_write> dst: array<{{DST_TYPE}}>;
1156

1257
@group(0) @binding(3)
1358
var<storage, read_write> error: atomic<u32>;
@@ -75,7 +120,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
75120
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
76121
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
77122

78-
for (var i: u32 = 0; i < params.ne0; i++) {
79-
dst[i_dst_row + i] = f16(src[i_src_row + i]);
123+
for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
124+
copy_elements(i_src_row, i_dst_row, i);
80125
}
81126
}
127+
128+
#end(SHADER)
129+

0 commit comments

Comments
 (0)