Commit 3613a5b
authored
feat: JIT compilation (#507)
This PR implements the JIT compilation (#170 ) of flashinfer, after this
PR, flashinfer will compile kernels just-in-time for different input
data types and shapes, and cached the kernels at the disk, instead of
pre-compile a set of kernels in the wheel.
# Motivation
The pip wheel size is exploding as we add support to more data types,
more head dimensions, more attention variants and more kernel
implementation. Pre-compile everything is not sustainable, and impedes
development speed.
This PR refactors the codebase to use torch's [JIT Compiling
Extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html#jit-compiling-extensions)
feature instead of pre-compile kernels in the wheel.
## Attention Variants
We learned from [FlexAttention](https://pytorch.org/blog/flexattention/)
and describes every attention variant as a template class, each instance
of the struct can carry some closure variable defined in local memory or
shared memory, below are two examples (logits soft cap and alibi
attention, the programming interface is tentative and will be updated as
we improve the programmability of the JIT template):
```cuda
template <typename ParamsT>
struct LogitsSoftCap {
using DTypeQ = typename ParamsT::DTypeQ;
using DTypeKV = typename ParamsT::DTypeKV;
using DTypeO = typename ParamsT::DTypeO;
uint32_t qo_len, kv_len;
uint32_t window_left;
__device__ __host__ LogitsSoftCap(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {
qo_len = params.get_qo_len(batch_idx);
kv_len = params.get_kv_len(batch_idx);
window_left = kv_len;
}
template <typename T>
__device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) {
return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap);
}
template <typename T>
__device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx,
uint32_t qo_idx, uint32_t kv_idx,
uint32_t qo_head_idx, uint32_t kv_head_idx) {
return params.logits_soft_cap * math::log2e * float(math::tanh(logits));
}
__device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx,
uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx,
uint32_t kv_head_idx) {
return true;
}
};
template <typename ParamsT>
struct ALIBIAttention {
using DTypeQ = typename ParamsT::DTypeQ;
using DTypeKV = typename ParamsT::DTypeKV;
using DTypeO = typename ParamsT::DTypeO;
using IdType = typename ParamsT::IdType;
uint32_t qo_len, kv_len;
uint32_t window_left;
__device__ __host__ ALIBIAttention(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {
qo_len = params.get_qo_len(batch_idx);
kv_len = params.get_kv_len(batch_idx);
window_left = kv_len;
}
template <typename T>
__device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) {
return float(q) * params.sm_scale * math::log2e;
}
template <typename T>
__device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx,
uint32_t qo_idx, uint32_t kv_idx,
uint32_t qo_head_idx, uint32_t kv_head_idx) {
return logits + params.alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx));
}
__device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx,
uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx,
uint32_t kv_head_idx) {
return true;
}
};
```
User can customize their own `ParamsT` class and variants class to
define their own attention variants, we hope such refactor will make the
codebase more concise and extensive.
# Roadmap
After this PR, we will add support for:
1. PyPI wheels #153
2. fp8 tensor cores attention: #502
3. different head dimensions: #142 #454 #455
4. flashattention3 #369
5. multi-head latency attention #237
6. Generate ParamsT and Attention variants description from python dsl
The development of this features have been blocked by the limitation of
wheel size (binary size >= 2GB will trigger some linking issues), I hope
this PR will make development easier in the future.1 parent 2043692 commit 3613a5b
File tree
137 files changed
+6986
-6122
lines changed- 3rdparty
- cmake
- flashinfer-aot
- csrc_aot
- include/flashinfer
- attention
- gemm
- group_gemm
- python
- csrc
- flashinfer
- jit
- src
- tests
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
137 files changed
+6986
-6122
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
| 16 | + | |
15 | 17 | | |
16 | 18 | | |
17 | 19 | | |
| |||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
27 | | - | |
28 | 27 | | |
29 | 28 | | |
30 | 29 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
Lines changed: 18 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
21 | | - | |
22 | 21 | | |
23 | 22 | | |
24 | 23 | | |
25 | 24 | | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
26 | 40 | | |
27 | 41 | | |
28 | 42 | | |
| |||
33 | 47 | | |
34 | 48 | | |
35 | 49 | | |
36 | | - | |
| 50 | + | |
37 | 51 | | |
38 | 52 | | |
39 | 53 | | |
| |||
51 | 65 | | |
52 | 66 | | |
53 | 67 | | |
54 | | - | |
| 68 | + | |
55 | 69 | | |
56 | 70 | | |
57 | 71 | | |
| |||
69 | 83 | | |
70 | 84 | | |
71 | 85 | | |
72 | | - | |
| 86 | + | |
73 | 87 | | |
74 | 88 | | |
75 | 89 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
0 commit comments