Skip to content

PALU MLRD (Feature) #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: torchao
Choose a base branch
from
Open

PALU MLRD (Feature) #4

wants to merge 8 commits into from

Conversation

KeremTurgutlu
Copy link

@KeremTurgutlu KeremTurgutlu commented Oct 17, 2024

This PR implements PALU based on the existing XFormers(CLA) attn backend decode and prefill kernels:

Screenshot 2024-10-18 at 6 01 18 PM

Our implementation follows Figure 2 from the paper and implements MLRD (Multi-head low rank decomposition) version from the paper to make implementation easier with the existing paged attention kernels. For example, Grid: (num_heads, num_seqs, max_num_partitions) - is the launch parameter for the paged attention kernel meaning that blockDim.x corresponds to a single head so during up projection it will be easier to work with a single head.

Kernel implementations below are responsible only for the (QK^T) @ V portion of the computation, and fused output projection will be handled in the model layer.

query - This will have the original head_size without compression as it is computed every time.

key - This will be down projected by the fused Kd_proj at the model layer before caching. Inside the attention kernel it will be up projected on the fly inside and RoPE will be applied.

value - Similar to the key, value will be also down projected by the fused Vd_proj at the model layer before caching, but it won't require an up projection inside the kernel since we will be using a fused output projection layer O_proj at the model layer.

1) PALU Paged Attention Decode CUDA Kernel

  • Implemented csrc/attention/attention_kernels_palu.cu based on csrc/attention/attention_kernels.cu.

  • Followed implementation details from the docs.

  • Only support BLOCK_SIZE=32 (this is paged attn block size not CUDA grid!) to make it equal to WARP_SIZE, to ensure THREAD_GROUP_SIZE=1 in which case each thread will process all the elements of 1 key token of 1 head at a given time. This way we can up project elements of a single key token of a given head using one thread. This is also to make implementation easier and to avoid dealing with synching across multiple threads during dot product of the up projection.

  • Added initial tests in a notebook which currently fail.

  • Fix implementation and pass the tests.

  • Add RoPE.

Here we modify

2) PALU Paged Attention Prefill Triton Kernel

TODO.

3) Remaining changes required at higher level:

Such as handling paged attention KV cache allocation based on palu_head_size which can be passed as a config param. Also, other model related code changes as needed.

TODO.

@KeremTurgutlu KeremTurgutlu changed the base branch from kv_cache_sharing to torchao October 18, 2024 14:50
@KeremTurgutlu KeremTurgutlu changed the title PALU (Feature) PALU MLRD (Feature) Oct 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant