@@ -245,6 +245,7 @@ struct vk_device_struct {
245
245
vk_pipeline pipeline_gelu_f32;
246
246
vk_pipeline pipeline_gelu_quick_f32;
247
247
vk_pipeline pipeline_silu_f32;
248
+ vk_pipeline pipeline_silu_back_f32;
248
249
vk_pipeline pipeline_relu_f32;
249
250
vk_pipeline pipeline_leaky_relu_f32;
250
251
vk_pipeline pipeline_tanh_f32;
@@ -2183,6 +2184,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2183
2184
ggml_vk_create_pipeline (device, device->pipeline_gelu_f32 , " gelu_f32" , gelu_f32_len, gelu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2184
2185
ggml_vk_create_pipeline (device, device->pipeline_gelu_quick_f32 , " gelu_quick_f32" , gelu_quick_f32_len, gelu_quick_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2185
2186
ggml_vk_create_pipeline (device, device->pipeline_silu_f32 , " silu_f32" , silu_f32_len, silu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2187
+ ggml_vk_create_pipeline (device, device->pipeline_silu_back_f32 , " silu_back_f32" , silu_back_f32_len, silu_back_f32_data, " main" , 3 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2186
2188
ggml_vk_create_pipeline (device, device->pipeline_relu_f32 , " relu_f32" , relu_f32_len, relu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2187
2189
ggml_vk_create_pipeline (device, device->pipeline_leaky_relu_f32 , " leaky_relu_f32" , leaky_relu_f32_len, leaky_relu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2188
2190
ggml_vk_create_pipeline (device, device->pipeline_tanh_f32 , " tanh_f32" , tanh_f32_len, tanh_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
@@ -5286,6 +5288,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5286
5288
case GGML_OP_CONT:
5287
5289
case GGML_OP_DUP:
5288
5290
return ggml_vk_get_cpy_pipeline (ctx, src0, dst, dst->type );
5291
+ case GGML_OP_SILU_BACK:
5292
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5293
+ return ctx->device ->pipeline_silu_back_f32 ;
5294
+ }
5295
+ return nullptr ;
5289
5296
case GGML_OP_NORM:
5290
5297
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5291
5298
return ctx->device ->pipeline_norm_f32 ;
@@ -6324,6 +6331,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
6324
6331
}, dryrun);
6325
6332
}
6326
6333
6334
+ static void ggml_vk_silu_back (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
6335
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SILU_BACK, { (uint32_t )ggml_nelements (src0), 0 , 0 .0f , 0 .0f }, dryrun);
6336
+ }
6337
+
6327
6338
static void ggml_vk_norm (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6328
6339
float * op_params = (float *)dst->op_params ;
6329
6340
@@ -7335,6 +7346,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7335
7346
case GGML_OP_CPY:
7336
7347
case GGML_OP_CONT:
7337
7348
case GGML_OP_DUP:
7349
+ case GGML_OP_SILU_BACK:
7338
7350
case GGML_OP_NORM:
7339
7351
case GGML_OP_GROUP_NORM:
7340
7352
case GGML_OP_RMS_NORM:
@@ -7395,6 +7407,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7395
7407
case GGML_OP_CPY:
7396
7408
case GGML_OP_CONT:
7397
7409
case GGML_OP_DUP:
7410
+ case GGML_OP_SILU_BACK:
7398
7411
case GGML_OP_NORM:
7399
7412
case GGML_OP_GROUP_NORM:
7400
7413
case GGML_OP_RMS_NORM:
@@ -7495,6 +7508,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7495
7508
case GGML_OP_DUP:
7496
7509
ggml_vk_cpy (ctx, compute_ctx, src0, node, dryrun);
7497
7510
7511
+ break ;
7512
+ case GGML_OP_SILU_BACK:
7513
+ ggml_vk_silu_back (ctx, compute_ctx, src0, src1, node, dryrun);
7514
+
7498
7515
break ;
7499
7516
case GGML_OP_NORM:
7500
7517
ggml_vk_norm (ctx, compute_ctx, src0, node, dryrun);
@@ -7664,6 +7681,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7664
7681
case GGML_OP_CPY:
7665
7682
case GGML_OP_CONT:
7666
7683
case GGML_OP_DUP:
7684
+ case GGML_OP_SILU_BACK:
7667
7685
case GGML_OP_NORM:
7668
7686
case GGML_OP_GROUP_NORM:
7669
7687
case GGML_OP_RMS_NORM:
@@ -8607,6 +8625,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8607
8625
case GGML_OP_MUL:
8608
8626
case GGML_OP_DIV:
8609
8627
case GGML_OP_CONCAT:
8628
+ case GGML_OP_SILU_BACK:
8610
8629
case GGML_OP_RMS_NORM_BACK:
8611
8630
case GGML_OP_UPSCALE:
8612
8631
case GGML_OP_SCALE:
@@ -9011,6 +9030,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9011
9030
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
9012
9031
const float eps = ((float *) tensor->op_params )[0 ];
9013
9032
tensor_clone = ggml_rms_norm_back (ggml_ctx, src_clone[0 ], src_clone[1 ], eps);
9033
+ } else if (tensor->op == GGML_OP_SILU_BACK) {
9034
+ tensor_clone = ggml_silu_back (ggml_ctx, src_clone[0 ], src_clone[1 ]);
9014
9035
} else if (tensor->op == GGML_OP_SOFT_MAX) {
9015
9036
if (src1 != nullptr ) {
9016
9037
tensor_clone = ggml_soft_max_ext (ggml_ctx, src_clone[0 ], src_clone[1 ], ((float *)tensor->op_params )[0 ], ((float *)tensor->op_params )[1 ]);
0 commit comments