Skip to content

Commit 2c9d1c3

Browse files
zhyncsyzh119
andauthored
feat: support fused gelu tanh mul (#434)
cc @yzh119 ``` pytest python/tests/test_activation.py =================================================================== test session starts =================================================================== platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /flashinfer/python plugins: anyio-4.2.0 collected 630 items python/tests/test_activation.py ................................................................................................................... [ 18%] ................................................................................................................................................... [ 41%] ................................................................................................................................................... [ 64%] ................................................................................................................................................... [ 88%] .......................................................................... [100%] ============================================================= 630 passed in 146.89s (0:02:26) ============================================================= ``` --------- Co-authored-by: Zihao Ye <expye@outlook.com>
1 parent 949c328 commit 2c9d1c3

File tree

7 files changed

+75
-2
lines changed

7 files changed

+75
-2
lines changed

include/flashinfer/activation.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef FLASHINFER_ACTIVATION_CUH_
1818
#define FLASHINFER_ACTIVATION_CUH_
1919

20+
#include "math.cuh"
2021
#include "utils.cuh"
2122
#include "vec_dtypes.cuh"
2223

@@ -30,6 +31,13 @@ __device__ __forceinline__ float silu_kernel(const float& val) {
3031
return val / (1.0f + __expf(-val));
3132
}
3233

34+
template <typename T>
35+
__device__ __forceinline__ T gelu_tanh_kernel(const T& val) {
36+
const float cdf =
37+
0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val))));
38+
return val * cdf;
39+
}
40+
3341
template <typename T, float (*Activation)(const float&)>
3442
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
3543
constexpr uint32_t vec_size = 16 / sizeof(T);

python/csrc/activation.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,22 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
4040
return true;
4141
});
4242
}
43+
44+
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) {
45+
int d = input.size(-1) / 2;
46+
int64_t num_tokens = input.numel() / input.size(-1);
47+
dim3 grid(num_tokens);
48+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
49+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
50+
51+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
52+
uint32_t vec_size = 16 / sizeof(c_type);
53+
dim3 block(std::min(d / vec_size, 1024U));
54+
flashinfer::activation::act_and_mul_kernel<c_type,
55+
flashinfer::activation::gelu_tanh_kernel>
56+
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
57+
static_cast<c_type*>(input.data_ptr()), d);
58+
59+
return true;
60+
});
61+
}

python/csrc/flashinfer_ops.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4040
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
4141
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
4242
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
43+
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul");
4344
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
4445
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
4546
"Apply Llama 3.1 style RoPE in-place");

python/csrc/flashinfer_ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tenso
7878

7979
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
8080

81+
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
82+
8183
void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
8284
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);
8385

python/flashinfer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
2727
single_decode_with_kv_cache,
2828
)
29-
from .activation import silu_and_mul
29+
from .activation import gelu_tanh_and_mul, silu_and_mul
3030
from .group_gemm import SegmentGEMMWrapper
3131
from .norm import fused_add_rmsnorm, rmsnorm
3232
from .page import append_paged_kv_cache

python/flashinfer/activation.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
limitations under the License.
1515
"""
1616

17-
import torch
1817
from typing import Optional
1918

19+
import torch
20+
2021
# mypy: disable-error-code="attr-defined"
2122
try:
2223
from . import _kernels
@@ -69,3 +70,33 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
6970
)
7071
_kernels.silu_and_mul(out, input)
7172
return out
73+
74+
75+
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
76+
r"""Fused GeLU Tanh and Mul operation.
77+
78+
Parameters
79+
----------
80+
input: torch.Tensor
81+
Input tensor, shape (..., 2 * hidden_size).
82+
83+
out: Optional[torch.Tensor]
84+
The the output tensor, if specified, the kernel will update this tensor inplace.
85+
86+
Returns
87+
-------
88+
output: torch.Tensor
89+
Output tensor, shape (..., hidden_size).
90+
"""
91+
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
92+
raise ValueError("The pointers must be multiple of 16 bytes.")
93+
if out is not None:
94+
_check_shape(input, out)
95+
else:
96+
out = torch.empty(
97+
input.shape[:-1] + (input.shape[-1] // 2,),
98+
device=input.device,
99+
dtype=input.dtype,
100+
)
101+
_kernels.gelu_tanh_and_mul(out, input)
102+
return out

python/tests/test_activation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,15 @@ def test_fused_silu_mul(dim, batch_size, seq_len):
3131
numpy.testing.assert_allclose(
3232
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
3333
)
34+
35+
36+
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
37+
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
38+
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
39+
def test_fused_gelu_tanh_mul(dim, batch_size, seq_len):
40+
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
41+
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh")
42+
y = flashinfer.activation.gelu_tanh_and_mul(x)
43+
numpy.testing.assert_allclose(
44+
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
45+
)

0 commit comments

Comments
 (0)