Skip to content

Commit e743cdd

Browse files
authored
cuda : add ELU support (#14657)
1 parent 05fec5b commit e743cdd

File tree

3 files changed

+13
-0
lines changed

3 files changed

+13
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,6 +2303,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23032303
case GGML_UNARY_OP_EXP:
23042304
ggml_cuda_op_exp(ctx, dst);
23052305
break;
2306+
case GGML_UNARY_OP_ELU:
2307+
ggml_cuda_op_elu(ctx, dst);
2308+
break;
23062309
default:
23072310
return false;
23082311
}
@@ -3116,6 +3119,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31163119
case GGML_UNARY_OP_GELU_QUICK:
31173120
case GGML_UNARY_OP_TANH:
31183121
case GGML_UNARY_OP_EXP:
3122+
case GGML_UNARY_OP_ELU:
31193123
return ggml_is_contiguous(op->src[0]);
31203124
default:
31213125
return false;

ggml/src/ggml-cuda/unary.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) {
8383
return logf(x);
8484
}
8585

86+
static __device__ __forceinline__ float op_elu(float x) {
87+
return (x > 0.f) ? x : expm1f(x);
88+
}
89+
8690
template <float (*op)(float), typename T>
8791
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
8892
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
196200
ggml_cuda_op_unary<op_log>(ctx, dst);
197201
}
198202

203+
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
204+
ggml_cuda_op_unary<op_elu>(ctx, dst);
205+
}
199206
/* gated ops */
200207

201208
template <float (*op)(float), typename T>

ggml/src/ggml-cuda/unary.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
5959

6060
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6161

62+
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
63+
6264
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6365

6466
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)