|
14 | 14 | limitations under the License. |
15 | 15 | """ |
16 | 16 |
|
17 | | -from .activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul |
| 17 | +from .activation import ( |
| 18 | + gelu_and_mul as gelu_and_mul, |
| 19 | + gelu_tanh_and_mul as gelu_tanh_and_mul, |
| 20 | + silu_and_mul as silu_and_mul, |
| 21 | +) |
18 | 22 | from .cascade import ( |
19 | | - BatchDecodeWithSharedPrefixPagedKVCacheWrapper, |
20 | | - BatchPrefillWithSharedPrefixPagedKVCacheWrapper, |
21 | | - MultiLevelCascadeAttentionWrapper, |
22 | | - merge_state, |
23 | | - merge_state_in_place, |
24 | | - merge_states, |
| 23 | + BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper, |
| 24 | + BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper, |
| 25 | + MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper, |
| 26 | + merge_state as merge_state, |
| 27 | + merge_state_in_place as merge_state_in_place, |
| 28 | + merge_states as merge_states, |
25 | 29 | ) |
26 | 30 | from .decode import ( |
27 | | - BatchDecodeWithPagedKVCacheWrapper, |
28 | | - CUDAGraphBatchDecodeWithPagedKVCacheWrapper, |
29 | | - single_decode_with_kv_cache, |
| 31 | + BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper, |
| 32 | + CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper, |
| 33 | + single_decode_with_kv_cache as single_decode_with_kv_cache, |
| 34 | +) |
| 35 | +from .gemm import ( |
| 36 | + SegmentGEMMWrapper as SegmentGEMMWrapper, |
| 37 | + bmm_fp8 as bmm_fp8, |
| 38 | +) |
| 39 | +from .norm import ( |
| 40 | + fused_add_rmsnorm as fused_add_rmsnorm, |
| 41 | + gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm, |
| 42 | + gemma_rmsnorm as gemma_rmsnorm, |
| 43 | + rmsnorm as rmsnorm, |
| 44 | +) |
| 45 | +from .page import ( |
| 46 | + append_paged_kv_cache as append_paged_kv_cache, |
30 | 47 | ) |
31 | | -from .gemm import SegmentGEMMWrapper, bmm_fp8 |
32 | | -from .norm import fused_add_rmsnorm, gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm |
33 | | -from .page import append_paged_kv_cache |
34 | 48 | from .prefill import ( |
35 | | - BatchPrefillWithPagedKVCacheWrapper, |
36 | | - BatchPrefillWithRaggedKVCacheWrapper, |
37 | | - single_prefill_with_kv_cache, |
38 | | - single_prefill_with_kv_cache_return_lse, |
| 49 | + BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper, |
| 50 | + BatchPrefillWithRaggedKVCacheWrapper as BatchPrefillWithRaggedKVCacheWrapper, |
| 51 | + single_prefill_with_kv_cache as single_prefill_with_kv_cache, |
| 52 | + single_prefill_with_kv_cache_return_lse as single_prefill_with_kv_cache_return_lse, |
| 53 | +) |
| 54 | +from .quantization import ( |
| 55 | + packbits as packbits, |
| 56 | + segment_packbits as segment_packbits, |
39 | 57 | ) |
40 | | -from .quantization import packbits, segment_packbits |
41 | 58 | from .rope import ( |
42 | | - apply_llama31_rope, |
43 | | - apply_llama31_rope_inplace, |
44 | | - apply_rope, |
45 | | - apply_rope_inplace, |
| 59 | + apply_llama31_rope as apply_llama31_rope, |
| 60 | + apply_llama31_rope_inplace as apply_llama31_rope_inplace, |
| 61 | + apply_rope as apply_rope, |
| 62 | + apply_rope_inplace as apply_rope_inplace, |
46 | 63 | ) |
47 | 64 | from .sampling import ( |
48 | | - chain_speculative_sampling, |
49 | | - min_p_sampling_from_probs, |
50 | | - sampling_from_probs, |
51 | | - top_k_mask_logits, |
52 | | - top_k_renorm_probs, |
53 | | - top_k_sampling_from_probs, |
54 | | - top_k_top_p_sampling_from_logits, |
55 | | - top_k_top_p_sampling_from_probs, |
56 | | - top_p_renorm_probs, |
57 | | - top_p_sampling_from_probs, |
| 65 | + chain_speculative_sampling as chain_speculative_sampling, |
| 66 | + min_p_sampling_from_probs as min_p_sampling_from_probs, |
| 67 | + sampling_from_probs as sampling_from_probs, |
| 68 | + top_k_mask_logits as top_k_mask_logits, |
| 69 | + top_k_renorm_probs as top_k_renorm_probs, |
| 70 | + top_k_sampling_from_probs as top_k_sampling_from_probs, |
| 71 | + top_k_top_p_sampling_from_logits as top_k_top_p_sampling_from_logits, |
| 72 | + top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs, |
| 73 | + top_p_renorm_probs as top_p_renorm_probs, |
| 74 | + top_p_sampling_from_probs as top_p_sampling_from_probs, |
| 75 | +) |
| 76 | +from .sparse import ( |
| 77 | + BlockSparseAttentionWrapper as BlockSparseAttentionWrapper, |
58 | 78 | ) |
59 | | -from .sparse import BlockSparseAttentionWrapper |
60 | 79 |
|
61 | 80 | try: |
62 | | - from ._build_meta import __version__ |
| 81 | + from ._build_meta import __version__ as __version__ |
63 | 82 | except ImportError: |
64 | 83 | with open("version.txt") as f: |
65 | 84 | __version__ = f.read().strip() |
66 | 85 |
|
67 | 86 | try: |
68 | | - import aot_config |
| 87 | + import aot_config as aot_config # type: ignore[import] |
69 | 88 | except ImportError: |
70 | 89 | aot_config = None |
0 commit comments