@@ -249,6 +249,7 @@ struct webgpu_context_struct {
249249 webgpu_pipeline memset_pipeline;
250250 webgpu_pipeline mul_mat_pipeline[30 ][2 ];
251251 webgpu_pipeline set_rows_pipeline;
252+ webgpu_pipeline set_rows_f32_no_vec_pipeline;
252253 webgpu_pipeline get_rows_pipeline[30 ];
253254 webgpu_pipeline get_rows_f32_no_vec_pipeline;
254255 webgpu_pipeline cpy_pipeline[2 ][2 ]; // src type, dst type
@@ -767,7 +768,12 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
767768 size_t max_wg_size = ctx->max_wg_size_x ;
768769 uint32_t wg_x = (src->ne [1 ] * src->ne [2 ] * src->ne [3 ] + max_wg_size - 1 ) / max_wg_size;
769770
770- return ggml_backend_webgpu_build (ctx, ctx->set_rows_pipeline , params, entries, wg_x, error_bufs);
771+ webgpu_pipeline pipeline = ctx->set_rows_pipeline ;
772+ // if not evenly divisble by 4, use the non-vectorized version
773+ if (src->type == GGML_TYPE_F32 && dst->ne [0 ] % 4 != 0 ) {
774+ pipeline = ctx->set_rows_f32_no_vec_pipeline ;
775+ }
776+ return ggml_backend_webgpu_build (ctx, pipeline, params, entries, wg_x, error_bufs);
771777}
772778
773779static webgpu_command ggml_webgpu_get_rows (webgpu_context & ctx,
@@ -1613,7 +1619,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
16131619}
16141620
16151621static void ggml_webgpu_init_set_rows_pipeline (webgpu_context & webgpu_ctx) {
1616- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->set_rows_pipeline , wgsl_set_rows, " set_rows" ,
1622+ // create_pipeline(device, pipeline, shader_code, label, constants)
1623+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->set_rows_f32_no_vec_pipeline , wgsl_set_rows_f32, " set_rows_f32" ,
1624+ ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x ));
1625+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->set_rows_pipeline , wgsl_set_rows_f32_vec, " set_rows_f32_vec" ,
16171626 ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x ));
16181627}
16191628
0 commit comments