Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/flashinfer/activation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef FLASHINFER_ACTIVATION_CUH_
#define FLASHINFER_ACTIVATION_CUH_

#include "math.cuh"
#include "utils.cuh"
#include "vec_dtypes.cuh"

Expand All @@ -30,6 +31,13 @@ __device__ __forceinline__ float silu_kernel(const float& val) {
return val / (1.0f + __expf(-val));
}

template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& val) {
const float cdf =
0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val))));
return val * cdf;
}

template <typename T, float (*Activation)(const float&)>
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
constexpr uint32_t vec_size = 16 / sizeof(T);
Expand Down
19 changes: 19 additions & 0 deletions python/csrc/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,22 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
return true;
});
}

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type,
flashinfer::activation::gelu_tanh_kernel>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
static_cast<c_type*>(input.data_ptr()), d);

return true;
});
}
1 change: 1 addition & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul");
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
Expand Down
2 changes: 2 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tenso

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);

Expand Down
2 changes: 1 addition & 1 deletion python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache,
)
from .activation import silu_and_mul
from .activation import gelu_tanh_and_mul, silu_and_mul
from .group_gemm import SegmentGEMMWrapper
from .norm import fused_add_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
Expand Down
33 changes: 32 additions & 1 deletion python/flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
limitations under the License.
"""

import torch
from typing import Optional

import torch

# mypy: disable-error-code="attr-defined"
try:
from . import _kernels
Expand Down Expand Up @@ -69,3 +70,33 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
)
_kernels.silu_and_mul(out, input)
return out


def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
r"""Fused GeLU Tanh and Mul operation.

Parameters
----------
input: torch.Tensor
Input tensor, shape (..., 2 * hidden_size).

out: Optional[torch.Tensor]
The the output tensor, if specified, the kernel will update this tensor inplace.

Returns
-------
output: torch.Tensor
Output tensor, shape (..., hidden_size).
"""
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
_kernels.gelu_tanh_and_mul(out, input)
return out
12 changes: 12 additions & 0 deletions python/tests/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,15 @@ def test_fused_silu_mul(dim, batch_size, seq_len):
numpy.testing.assert_allclose(
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
)


@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
def test_fused_gelu_tanh_mul(dim, batch_size, seq_len):
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh")
y = flashinfer.activation.gelu_tanh_and_mul(x)
numpy.testing.assert_allclose(
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
)