Skip to content

ggml : add ggml_gelu_erf() #13667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,14 +528,15 @@ extern "C" {
GGML_UNARY_OP_STEP,
GGML_UNARY_OP_TANH,
GGML_UNARY_OP_ELU,
GGML_UNARY_OP_RELU,
GGML_UNARY_OP_SIGMOID,
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_ERF,
GGML_UNARY_OP_GELU_QUICK,
GGML_UNARY_OP_SILU,
GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_RELU,

GGML_UNARY_OP_COUNT,
};
Expand Down Expand Up @@ -1024,6 +1025,16 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

// GELU using erf (error function) when possible
// some backends may fallback to approximation based on Abramowitz and Stegun formula
GGML_API struct ggml_tensor * ggml_gelu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_gelu_erf_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_gelu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;

case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
{
Expand Down
107 changes: 107 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2691,6 +2691,109 @@ static void ggml_compute_forward_gelu(
}
}

// ggml_compute_forward_gelu_erf

static void ggml_compute_forward_gelu_erf_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_are_same_shape(src0, dst));

const int ith = params->ith;
const int nth = params->nth;

const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_gelu_erf_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));

#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}

static void ggml_compute_forward_gelu_erf_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_are_same_shape(src0, dst));

const int ith = params->ith;
const int nth = params->nth;

const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_gelu_erf_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));

#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
const float v = GGML_FP16_TO_FP32(x);
GGML_UNUSED(v);
assert(!isnan(v));
assert(!isinf(v));
}
#endif
}
}

static void ggml_compute_forward_gelu_erf(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_gelu_erf_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_gelu_erf_f16(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_gelu_quick

static void ggml_compute_forward_gelu_quick_f32(
Expand Down Expand Up @@ -7749,6 +7852,10 @@ void ggml_compute_forward_unary(
{
ggml_compute_forward_gelu(params, dst);
} break;
case GGML_UNARY_OP_GELU_ERF:
{
ggml_compute_forward_gelu_erf(params, dst);
} break;
case GGML_UNARY_OP_GELU_QUICK:
{
ggml_compute_forward_gelu_quick(params, dst);
Expand Down
16 changes: 16 additions & 0 deletions ggml/src/ggml-cpu/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp
static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
static const float SQRT_2_INV = 0.70710678118654752440084436210484f;

inline static float ggml_gelu_f32(float x) {
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
Expand All @@ -440,6 +441,14 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
}
}

inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {
float xi = GGML_FP16_TO_FP32(x[i]);
float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
y[i] = GGML_FP32_TO_FP16(res);
}
}

#ifdef GGML_GELU_FP16
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
uint16_t t;
Expand All @@ -463,6 +472,13 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
}
#endif

inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
float xi = x[i];
y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
}
}

inline static float ggml_gelu_quick_f32(float x) {
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
}
Expand Down
24 changes: 24 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_SIGMOID,
GGML_METAL_KERNEL_TYPE_GELU,
GGML_METAL_KERNEL_TYPE_GELU_4,
GGML_METAL_KERNEL_TYPE_GELU_ERF,
GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
GGML_METAL_KERNEL_TYPE_SILU,
Expand Down Expand Up @@ -1103,6 +1105,8 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
Expand Down Expand Up @@ -1613,6 +1617,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_ELU:
Expand Down Expand Up @@ -2251,6 +2256,25 @@ static bool ggml_metal_encode_node(

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_GELU_ERF:
{
int64_t n = ggml_nelements(dst);

id<MTLComputePipelineState> pipeline = nil;

if (n % 4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
n /= 4;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
}

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_GELU_QUICK:
{
int64_t n = ggml_nelements(dst);
Expand Down
37 changes: 37 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,7 @@ kernel void kernel_tanh(
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;

kernel void kernel_gelu(
device const float * src0,
Expand Down Expand Up @@ -897,6 +898,42 @@ kernel void kernel_gelu_quick_4(
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}

// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;

template<typename T>
T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}

kernel void kernel_gelu_erf(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];

dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
}

kernel void kernel_gelu_erf_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];

dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
}

kernel void kernel_silu(
device const float * src0,
device float * dst,
Expand Down
17 changes: 16 additions & 1 deletion ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1099,9 +1099,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSWISH",
"HARDSIGMOID",
"EXP",
"GELU_ERF",
};

static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");


static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
Expand Down Expand Up @@ -2501,6 +2502,20 @@ struct ggml_tensor * ggml_gelu_inplace(
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
}

// ggml_gelu_erf

struct ggml_tensor * ggml_gelu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);
}

struct ggml_tensor * ggml_gelu_erf_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);
}

// ggml_gelu_quick

struct ggml_tensor * ggml_gelu_quick(
Expand Down
Loading