Skip to content

Commit 9e10936

Browse files
authored
misc: typing improvement (#555)
1. Make mypy happy 2. Adopt the `from .x import y as y` idiom of re-exporting things in `__init__.py`. See [PEP 484](https://peps.python.org/pep-0484/#stub-files) ``` $ cd python/ $ mypy flashinfer/ Success: no issues found in 27 source files ```
1 parent 9bf916f commit 9e10936

File tree

7 files changed

+87
-62
lines changed

7 files changed

+87
-62
lines changed

python/flashinfer/__init__.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,57 +14,76 @@
1414
limitations under the License.
1515
"""
1616

17-
from .activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
17+
from .activation import (
18+
gelu_and_mul as gelu_and_mul,
19+
gelu_tanh_and_mul as gelu_tanh_and_mul,
20+
silu_and_mul as silu_and_mul,
21+
)
1822
from .cascade import (
19-
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
20-
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
21-
MultiLevelCascadeAttentionWrapper,
22-
merge_state,
23-
merge_state_in_place,
24-
merge_states,
23+
BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
24+
BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
25+
MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper,
26+
merge_state as merge_state,
27+
merge_state_in_place as merge_state_in_place,
28+
merge_states as merge_states,
2529
)
2630
from .decode import (
27-
BatchDecodeWithPagedKVCacheWrapper,
28-
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
29-
single_decode_with_kv_cache,
31+
BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper,
32+
CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
33+
single_decode_with_kv_cache as single_decode_with_kv_cache,
34+
)
35+
from .gemm import (
36+
SegmentGEMMWrapper as SegmentGEMMWrapper,
37+
bmm_fp8 as bmm_fp8,
38+
)
39+
from .norm import (
40+
fused_add_rmsnorm as fused_add_rmsnorm,
41+
gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm,
42+
gemma_rmsnorm as gemma_rmsnorm,
43+
rmsnorm as rmsnorm,
44+
)
45+
from .page import (
46+
append_paged_kv_cache as append_paged_kv_cache,
3047
)
31-
from .gemm import SegmentGEMMWrapper, bmm_fp8
32-
from .norm import fused_add_rmsnorm, gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
33-
from .page import append_paged_kv_cache
3448
from .prefill import (
35-
BatchPrefillWithPagedKVCacheWrapper,
36-
BatchPrefillWithRaggedKVCacheWrapper,
37-
single_prefill_with_kv_cache,
38-
single_prefill_with_kv_cache_return_lse,
49+
BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper,
50+
BatchPrefillWithRaggedKVCacheWrapper as BatchPrefillWithRaggedKVCacheWrapper,
51+
single_prefill_with_kv_cache as single_prefill_with_kv_cache,
52+
single_prefill_with_kv_cache_return_lse as single_prefill_with_kv_cache_return_lse,
53+
)
54+
from .quantization import (
55+
packbits as packbits,
56+
segment_packbits as segment_packbits,
3957
)
40-
from .quantization import packbits, segment_packbits
4158
from .rope import (
42-
apply_llama31_rope,
43-
apply_llama31_rope_inplace,
44-
apply_rope,
45-
apply_rope_inplace,
59+
apply_llama31_rope as apply_llama31_rope,
60+
apply_llama31_rope_inplace as apply_llama31_rope_inplace,
61+
apply_rope as apply_rope,
62+
apply_rope_inplace as apply_rope_inplace,
4663
)
4764
from .sampling import (
48-
chain_speculative_sampling,
49-
min_p_sampling_from_probs,
50-
sampling_from_probs,
51-
top_k_mask_logits,
52-
top_k_renorm_probs,
53-
top_k_sampling_from_probs,
54-
top_k_top_p_sampling_from_logits,
55-
top_k_top_p_sampling_from_probs,
56-
top_p_renorm_probs,
57-
top_p_sampling_from_probs,
65+
chain_speculative_sampling as chain_speculative_sampling,
66+
min_p_sampling_from_probs as min_p_sampling_from_probs,
67+
sampling_from_probs as sampling_from_probs,
68+
top_k_mask_logits as top_k_mask_logits,
69+
top_k_renorm_probs as top_k_renorm_probs,
70+
top_k_sampling_from_probs as top_k_sampling_from_probs,
71+
top_k_top_p_sampling_from_logits as top_k_top_p_sampling_from_logits,
72+
top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs,
73+
top_p_renorm_probs as top_p_renorm_probs,
74+
top_p_sampling_from_probs as top_p_sampling_from_probs,
75+
)
76+
from .sparse import (
77+
BlockSparseAttentionWrapper as BlockSparseAttentionWrapper,
5878
)
59-
from .sparse import BlockSparseAttentionWrapper
6079

6180
try:
62-
from ._build_meta import __version__
81+
from ._build_meta import __version__ as __version__
6382
except ImportError:
6483
with open("version.txt") as f:
6584
__version__ = f.read().strip()
6685

6786
try:
68-
import aot_config
87+
import aot_config as aot_config # type: ignore[import]
6988
except ImportError:
7089
aot_config = None

python/flashinfer/activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_act_and_mul_module(act_func_name: str):
7272
global _jit_modules
7373
if act_func_name not in _jit_modules:
7474
if has_prebuilt_ops:
75-
from . import _kernels
75+
from . import _kernels # type: ignore[attr-defined]
7676

7777
module = _kernels
7878
else:

python/flashinfer/jit/__init__.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,44 @@
1414
limitations under the License.
1515
"""
1616

17+
import logging
1718
import os
1819
import re
19-
import logging
20-
import subprocess
20+
from pathlib import Path
21+
from typing import List, Union
22+
2123
import torch.utils.cpp_extension as torch_cpp_ext
2224
from filelock import FileLock
23-
from typing import List, Tuple
24-
from .env import (
25-
FLASHINFER_WORKSPACE_DIR,
26-
FLASHINFER_JIT_DIR,
27-
FLASHINFER_GEN_SRC_DIR,
28-
FLASHINFER_INCLUDE_DIR,
29-
FLASHINFER_CSRC_DIR,
30-
CUTLASS_INCLUDE_DIRS,
25+
26+
# Re-export
27+
from .activation import (
28+
gen_act_and_mul_cu as gen_act_and_mul_cu,
29+
get_act_and_mul_cu_str as get_act_and_mul_cu_str,
3130
)
32-
from .activation import get_act_and_mul_cu_str, gen_act_and_mul_cu
3331
from .attention import (
34-
gen_single_decode_cu,
35-
get_single_decode_uri,
36-
gen_batch_decode_cu,
37-
get_batch_decode_uri,
38-
gen_single_prefill_cu,
39-
get_single_prefill_uri,
40-
gen_batch_prefill_cu,
41-
get_batch_prefill_uri,
32+
gen_batch_decode_cu as gen_batch_decode_cu,
33+
gen_batch_prefill_cu as gen_batch_prefill_cu,
34+
gen_single_decode_cu as gen_single_decode_cu,
35+
gen_single_prefill_cu as gen_single_prefill_cu,
36+
get_batch_decode_uri as get_batch_decode_uri,
37+
get_batch_prefill_uri as get_batch_prefill_uri,
38+
get_single_decode_uri as get_single_decode_uri,
39+
get_single_prefill_uri as get_single_prefill_uri,
40+
)
41+
from .env import (
42+
CUTLASS_INCLUDE_DIRS as CUTLASS_INCLUDE_DIRS,
43+
FLASHINFER_CSRC_DIR as FLASHINFER_CSRC_DIR,
44+
FLASHINFER_GEN_SRC_DIR as FLASHINFER_GEN_SRC_DIR,
45+
FLASHINFER_INCLUDE_DIR as FLASHINFER_INCLUDE_DIR,
46+
FLASHINFER_JIT_DIR as FLASHINFER_JIT_DIR,
47+
FLASHINFER_WORKSPACE_DIR as FLASHINFER_WORKSPACE_DIR,
4248
)
4349

4450
try:
45-
from .aot_config import prebuilt_ops_uri
51+
from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri # type: ignore[import]
4652

4753
has_prebuilt_ops = True
48-
except ImportError as e:
54+
except ImportError:
4955
prebuilt_ops_uri = set()
5056
has_prebuilt_ops = False
5157

@@ -112,7 +118,7 @@ def remove_unwanted_pytorch_nvcc_flags():
112118

113119
def load_cuda_ops(
114120
name: str,
115-
sources: List[str],
121+
sources: List[Union[str, Path]],
116122
extra_cflags: List[str] = [],
117123
extra_cuda_cflags: List[str] = [],
118124
extra_ldflags=None,

python/flashinfer/prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
)
5050

5151
if has_prebuilt_ops:
52-
from . import _prefill_kernels
52+
from . import _prefill_kernels # type: ignore[attr-defined]
5353

5454

5555
def compile_single_prefill_module(

python/flashinfer/py.typed

Whitespace-only changes.

python/flashinfer/triton/cascade.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def merge_state_in_place(
6969
if mask is not None:
7070
check_dim(1, mask)
7171
assert v.size(0) == mask.size(0)
72-
assert mask.device == device
72+
assert mask.device == v.device
7373
seq_len = v.size(0)
7474
num_heads = v.size(1)
7575
head_dim = v.size(2)

python/flashinfer/triton/kernels/cascade.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import triton
2-
import triton.language as tl
1+
import triton # type: ignore[import]
2+
import triton.language as tl # type: ignore[import]
33

44

55
@triton.jit

0 commit comments

Comments
 (0)