Skip to content

vulkan : add GELU_ERF #14455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ struct vk_device_struct {

// [src/dst 0=fp32,1=fp16]
vk_pipeline pipeline_gelu[2];
vk_pipeline pipeline_gelu_erf[2];
vk_pipeline pipeline_gelu_quick[2];
vk_pipeline pipeline_silu[2];
vk_pipeline pipeline_relu[2];
Expand Down Expand Up @@ -2761,6 +2762,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);

CREATE_UNARY(gelu)
CREATE_UNARY(gelu_erf)
CREATE_UNARY(gelu_quick)
CREATE_UNARY(silu)
CREATE_UNARY(relu)
Expand Down Expand Up @@ -6481,6 +6483,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_GELU:
return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_GELU_ERF:
return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_GELU_QUICK:
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_RELU:
Expand Down Expand Up @@ -8827,6 +8831,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH:
Expand Down Expand Up @@ -9072,6 +9077,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH:
Expand Down Expand Up @@ -9289,6 +9295,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH:
Expand Down Expand Up @@ -10095,6 +10102,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
Expand Down Expand Up @@ -10835,6 +10843,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
case GGML_UNARY_OP_GELU:
tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_GELU_ERF:
tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_GELU_QUICK:
tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
break;
Expand Down
39 changes: 39 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#version 450

#include "generic_head.comp"
#include "types.comp"

#extension GL_EXT_control_flow_attributes : enable

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

void main() {
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
const float p_erf = 0.3275911f;
const float a1_erf = 0.254829592f;
const float a2_erf = -0.284496736f;
const float a3_erf = 1.421413741f;
const float a4_erf = -1.453152027f;
const float a5_erf = 1.061405429f;

const float SQRT_2_INV = 0.70710678118654752440084436210484f;
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;

if (i >= p.KX) {
return;
}

const float a = float(data_a[i]);
const float a_div_sqr2 = a * SQRT_2_INV;
const float sign_x = sign(a_div_sqr2);
const float x = abs(a_div_sqr2);
const float t = 1.0f / (1.0f + p_erf * x);
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
const float erf_approx = sign_x * y;

data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx));
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,8 @@ void process_shaders() {

string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("gelu_erf_f32", "gelu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
Expand Down
Loading