Skip to content

Commit a5e72f0

Browse files
vulkan: implement GGML_OP_SILU_BACK
1 parent d510fb6 commit a5e72f0

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ struct vk_device_struct {
245245
vk_pipeline pipeline_gelu_f32;
246246
vk_pipeline pipeline_gelu_quick_f32;
247247
vk_pipeline pipeline_silu_f32;
248+
vk_pipeline pipeline_silu_back_f32;
248249
vk_pipeline pipeline_relu_f32;
249250
vk_pipeline pipeline_leaky_relu_f32;
250251
vk_pipeline pipeline_tanh_f32;
@@ -2183,6 +2184,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21832184
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21842185
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21852186
ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2187+
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21862188
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21872189
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21882190
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -5286,6 +5288,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52865288
case GGML_OP_CONT:
52875289
case GGML_OP_DUP:
52885290
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
5291+
case GGML_OP_SILU_BACK:
5292+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5293+
return ctx->device->pipeline_silu_back_f32;
5294+
}
5295+
return nullptr;
52895296
case GGML_OP_NORM:
52905297
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
52915298
return ctx->device->pipeline_norm_f32;
@@ -6324,6 +6331,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
63246331
}, dryrun);
63256332
}
63266333

6334+
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6335+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6336+
}
6337+
63276338
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
63286339
float * op_params = (float *)dst->op_params;
63296340

@@ -7335,6 +7346,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73357346
case GGML_OP_CPY:
73367347
case GGML_OP_CONT:
73377348
case GGML_OP_DUP:
7349+
case GGML_OP_SILU_BACK:
73387350
case GGML_OP_NORM:
73397351
case GGML_OP_GROUP_NORM:
73407352
case GGML_OP_RMS_NORM:
@@ -7395,6 +7407,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73957407
case GGML_OP_CPY:
73967408
case GGML_OP_CONT:
73977409
case GGML_OP_DUP:
7410+
case GGML_OP_SILU_BACK:
73987411
case GGML_OP_NORM:
73997412
case GGML_OP_GROUP_NORM:
74007413
case GGML_OP_RMS_NORM:
@@ -7495,6 +7508,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
74957508
case GGML_OP_DUP:
74967509
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
74977510

7511+
break;
7512+
case GGML_OP_SILU_BACK:
7513+
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
7514+
74987515
break;
74997516
case GGML_OP_NORM:
75007517
ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -7664,6 +7681,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
76647681
case GGML_OP_CPY:
76657682
case GGML_OP_CONT:
76667683
case GGML_OP_DUP:
7684+
case GGML_OP_SILU_BACK:
76677685
case GGML_OP_NORM:
76687686
case GGML_OP_GROUP_NORM:
76697687
case GGML_OP_RMS_NORM:
@@ -8607,6 +8625,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
86078625
case GGML_OP_MUL:
86088626
case GGML_OP_DIV:
86098627
case GGML_OP_CONCAT:
8628+
case GGML_OP_SILU_BACK:
86108629
case GGML_OP_RMS_NORM_BACK:
86118630
case GGML_OP_UPSCALE:
86128631
case GGML_OP_SCALE:
@@ -9011,6 +9030,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
90119030
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
90129031
const float eps = ((float *) tensor->op_params)[0];
90139032
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
9033+
} else if (tensor->op == GGML_OP_SILU_BACK) {
9034+
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
90149035
} else if (tensor->op == GGML_OP_SOFT_MAX) {
90159036
if (src1 != nullptr) {
90169037
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
11+
layout (binding = 1) readonly buffer X {B_TYPE data_x[];};
12+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
13+
14+
void main() {
15+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
16+
17+
if (i >= p.KX) {
18+
return;
19+
}
20+
21+
// Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2
22+
23+
const float xi = float(data_x[i]);
24+
const float s = 1.0f / (1.0f + exp(-xi));
25+
data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s)));
26+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ void process_shaders() {
478478
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
479479
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
480480
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
481+
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
481482
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
482483
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
483484
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

0 commit comments

Comments
 (0)