PALU MLRD (Feature) #4
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR implements PALU based on the existing XFormers(CLA) attn backend decode and prefill kernels:
Our implementation follows Figure 2 from the paper and implements MLRD (Multi-head low rank decomposition) version from the paper to make implementation easier with the existing paged attention kernels. For example,
Grid: (num_heads, num_seqs, max_num_partitions)
- is the launch parameter for the paged attention kernel meaning that blockDim.x corresponds to a single head so during up projection it will be easier to work with a single head.Kernel implementations below are responsible only for the (QK^T) @ V portion of the computation, and fused output projection will be handled in the model layer.
query - This will have the original head_size without compression as it is computed every time.
key - This will be down projected by the fused
Kd_proj
at the model layer before caching. Inside the attention kernel it will be up projected on the fly inside and RoPE will be applied.value - Similar to the key, value will be also down projected by the fused
Vd_proj
at the model layer before caching, but it won't require an up projection inside the kernel since we will be using a fused output projection layerO_proj
at the model layer.1) PALU Paged Attention Decode CUDA Kernel
Implemented
csrc/attention/attention_kernels_palu.cu
based oncsrc/attention/attention_kernels.cu
.Followed implementation details from the docs.
Only support BLOCK_SIZE=32 (this is paged attn block size not CUDA grid!) to make it equal to WARP_SIZE, to ensure
THREAD_GROUP_SIZE=1
in which case each thread will process all the elements of 1 key token of 1 head at a given time. This way we can up project elements of a single key token of a given head using one thread. This is also to make implementation easier and to avoid dealing with synching across multiple threads during dot product of the up projection.Added initial tests in a notebook which currently fail.
Fix implementation and pass the tests.
Add RoPE.
Here we modify
2) PALU Paged Attention Prefill Triton Kernel
TODO.
3) Remaining changes required at higher level:
Such as handling paged attention KV cache allocation based on
palu_head_size
which can be passed as a config param. Also, other model related code changes as needed.TODO.