Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 54 additions & 35 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,57 +14,76 @@
limitations under the License.
"""

from .activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from .activation import (
gelu_and_mul as gelu_and_mul,
gelu_tanh_and_mul as gelu_tanh_and_mul,
silu_and_mul as silu_and_mul,
)
from .cascade import (
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper,
merge_state,
merge_state_in_place,
merge_states,
BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper,
merge_state as merge_state,
merge_state_in_place as merge_state_in_place,
merge_states as merge_states,
)
from .decode import (
BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache,
BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache as single_decode_with_kv_cache,
)
from .gemm import (
SegmentGEMMWrapper as SegmentGEMMWrapper,
bmm_fp8 as bmm_fp8,
)
from .norm import (
fused_add_rmsnorm as fused_add_rmsnorm,
gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm,
gemma_rmsnorm as gemma_rmsnorm,
rmsnorm as rmsnorm,
)
from .page import (
append_paged_kv_cache as append_paged_kv_cache,
)
from .gemm import SegmentGEMMWrapper, bmm_fp8
from .norm import fused_add_rmsnorm, gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
from .prefill import (
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
single_prefill_with_kv_cache,
single_prefill_with_kv_cache_return_lse,
BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper as BatchPrefillWithRaggedKVCacheWrapper,
single_prefill_with_kv_cache as single_prefill_with_kv_cache,
single_prefill_with_kv_cache_return_lse as single_prefill_with_kv_cache_return_lse,
)
from .quantization import (
packbits as packbits,
segment_packbits as segment_packbits,
)
from .quantization import packbits, segment_packbits
from .rope import (
apply_llama31_rope,
apply_llama31_rope_inplace,
apply_rope,
apply_rope_inplace,
apply_llama31_rope as apply_llama31_rope,
apply_llama31_rope_inplace as apply_llama31_rope_inplace,
apply_rope as apply_rope,
apply_rope_inplace as apply_rope_inplace,
)
from .sampling import (
chain_speculative_sampling,
min_p_sampling_from_probs,
sampling_from_probs,
top_k_mask_logits,
top_k_renorm_probs,
top_k_sampling_from_probs,
top_k_top_p_sampling_from_logits,
top_k_top_p_sampling_from_probs,
top_p_renorm_probs,
top_p_sampling_from_probs,
chain_speculative_sampling as chain_speculative_sampling,
min_p_sampling_from_probs as min_p_sampling_from_probs,
sampling_from_probs as sampling_from_probs,
top_k_mask_logits as top_k_mask_logits,
top_k_renorm_probs as top_k_renorm_probs,
top_k_sampling_from_probs as top_k_sampling_from_probs,
top_k_top_p_sampling_from_logits as top_k_top_p_sampling_from_logits,
top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs,
top_p_renorm_probs as top_p_renorm_probs,
top_p_sampling_from_probs as top_p_sampling_from_probs,
)
from .sparse import (
BlockSparseAttentionWrapper as BlockSparseAttentionWrapper,
)
from .sparse import BlockSparseAttentionWrapper

try:
from ._build_meta import __version__
from ._build_meta import __version__ as __version__
except ImportError:
with open("version.txt") as f:
__version__ = f.read().strip()

try:
import aot_config
import aot_config as aot_config # type: ignore[import]
except ImportError:
aot_config = None
2 changes: 1 addition & 1 deletion python/flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_act_and_mul_module(act_func_name: str):
global _jit_modules
if act_func_name not in _jit_modules:
if has_prebuilt_ops:
from . import _kernels
from . import _kernels # type: ignore[attr-defined]

module = _kernels
else:
Expand Down
50 changes: 28 additions & 22 deletions python/flashinfer/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,44 @@
limitations under the License.
"""

import logging
import os
import re
import logging
import subprocess
from pathlib import Path
from typing import List, Union

import torch.utils.cpp_extension as torch_cpp_ext
from filelock import FileLock
from typing import List, Tuple
from .env import (
FLASHINFER_WORKSPACE_DIR,
FLASHINFER_JIT_DIR,
FLASHINFER_GEN_SRC_DIR,
FLASHINFER_INCLUDE_DIR,
FLASHINFER_CSRC_DIR,
CUTLASS_INCLUDE_DIRS,

# Re-export
from .activation import (
gen_act_and_mul_cu as gen_act_and_mul_cu,
get_act_and_mul_cu_str as get_act_and_mul_cu_str,
)
from .activation import get_act_and_mul_cu_str, gen_act_and_mul_cu
from .attention import (
gen_single_decode_cu,
get_single_decode_uri,
gen_batch_decode_cu,
get_batch_decode_uri,
gen_single_prefill_cu,
get_single_prefill_uri,
gen_batch_prefill_cu,
get_batch_prefill_uri,
gen_batch_decode_cu as gen_batch_decode_cu,
gen_batch_prefill_cu as gen_batch_prefill_cu,
gen_single_decode_cu as gen_single_decode_cu,
gen_single_prefill_cu as gen_single_prefill_cu,
get_batch_decode_uri as get_batch_decode_uri,
get_batch_prefill_uri as get_batch_prefill_uri,
get_single_decode_uri as get_single_decode_uri,
get_single_prefill_uri as get_single_prefill_uri,
)
from .env import (
CUTLASS_INCLUDE_DIRS as CUTLASS_INCLUDE_DIRS,
FLASHINFER_CSRC_DIR as FLASHINFER_CSRC_DIR,
FLASHINFER_GEN_SRC_DIR as FLASHINFER_GEN_SRC_DIR,
FLASHINFER_INCLUDE_DIR as FLASHINFER_INCLUDE_DIR,
FLASHINFER_JIT_DIR as FLASHINFER_JIT_DIR,
FLASHINFER_WORKSPACE_DIR as FLASHINFER_WORKSPACE_DIR,
)

try:
from .aot_config import prebuilt_ops_uri
from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri # type: ignore[import]

has_prebuilt_ops = True
except ImportError as e:
except ImportError:
prebuilt_ops_uri = set()
has_prebuilt_ops = False

Expand Down Expand Up @@ -112,7 +118,7 @@ def remove_unwanted_pytorch_nvcc_flags():

def load_cuda_ops(
name: str,
sources: List[str],
sources: List[Union[str, Path]],
extra_cflags: List[str] = [],
extra_cuda_cflags: List[str] = [],
extra_ldflags=None,
Expand Down
2 changes: 1 addition & 1 deletion python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)

if has_prebuilt_ops:
from . import _prefill_kernels
from . import _prefill_kernels # type: ignore[attr-defined]


def compile_single_prefill_module(
Expand Down
Empty file added python/flashinfer/py.typed
Empty file.
2 changes: 1 addition & 1 deletion python/flashinfer/triton/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def merge_state_in_place(
if mask is not None:
check_dim(1, mask)
assert v.size(0) == mask.size(0)
assert mask.device == device
assert mask.device == v.device
seq_len = v.size(0)
num_heads = v.size(1)
head_dim = v.size(2)
Expand Down
4 changes: 2 additions & 2 deletions python/flashinfer/triton/kernels/cascade.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import triton
import triton.language as tl
import triton # type: ignore[import]
import triton.language as tl # type: ignore[import]


@triton.jit
Expand Down