Skip to content

Commit 18ee10e

Browse files
ywang96jimpang
authored andcommitted
[Kernel][CPU] Add Quick gelu to CPU (vllm-project#5717)
1 parent c08f3c5 commit 18ee10e

File tree

4 files changed

+29
-0
lines changed

4 files changed

+29
-0
lines changed

csrc/cpu/activation.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
5959
return w3 * x * (ones + t);
6060
}
6161

62+
FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) {
63+
const vec_op::FP32Vec8 zeros(0.0);
64+
const vec_op::FP32Vec8 ones(1.0);
65+
const vec_op::FP32Vec8 w1(1.702f);
66+
return x / (ones + (zeros - w1 * x).exp());
67+
}
68+
6269
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
6370
const vec_op::FP32Vec8 ones(1.0);
6471
const vec_op::FP32Vec8 w1(M_SQRT1_2);
@@ -142,3 +149,15 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
142149
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
143150
});
144151
}
152+
153+
void gelu_quick(torch::Tensor& out, torch::Tensor& input) {
154+
int num_tokens = input.numel() / input.size(-1);
155+
int d = input.size(-1);
156+
157+
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] {
158+
CPU_KERNEL_GUARD_IN(gelu_quick_impl)
159+
activation_kernel<scalar_t, gelu_quick_act, false>(
160+
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
161+
CPU_KERNEL_GUARD_OUT(gelu_quick_impl)
162+
});
163+
}

csrc/cpu/torch_bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
5858
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
5959
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);
6060

61+
// Quick GELU implementation.
62+
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
63+
ops.impl("gelu_quick", torch::kCPU, &gelu_quick);
64+
6165
// Layernorm
6266
// Apply Root Mean Square (RMS) Normalization to the input tensor.
6367
ops.def(

vllm/_ipex_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
4343
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
4444
out.copy_(torch.nn.functional.gelu(x))
4545

46+
# TODO add implementation of gelu_quick here
47+
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
48+
4649
def paged_attention_v1(
4750
out: torch.Tensor,
4851
query: torch.Tensor,

vllm/model_executor/layers/activation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
155155
ops.gelu_quick(out, x)
156156
return out
157157

158+
# TODO implement forward_xpu for QuickGELU
159+
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
160+
158161

159162
class ScaledActivation(nn.Module):
160163
"""An activation function with post-scale parameters.

0 commit comments

Comments
 (0)