Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions csrc/cpu/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
return w3 * x * (ones + t);
}

FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(1.702f);
return x / (ones + (zeros - w1 * x).exp());
}

FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2);
Expand Down Expand Up @@ -142,3 +149,15 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
});
}

void gelu_quick(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);

VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_quick_impl)
activation_kernel<scalar_t, gelu_quick_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_quick_impl)
});
}
4 changes: 4 additions & 0 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);

// Quick GELU implementation.
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCPU, &gelu_quick);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
3 changes: 3 additions & 0 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))

# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:

def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
ops.gelu_quick(out, x)
return out

# TODO implement forward_xpu for QuickGELU
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
Expand Down