-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Labels
Description
Currently the attention kernel does not work well in special cases. An example of this is with the following shapes
q.shape = torch.Size([4, 32, 1, 128])
k.shape = torch.Size([4, 32, 20, 128])
v.shape = torch.Size([4, 32, 20, 128])
https://github.com/triton-lang/kernels/blob/main/kernels/flash_attention.py#L23
Repro steps:
- add triton kernel here https://github.com/triton-lang/kernels/blob/main/models/llama/llama/math_ops.py#L64
- run
CUDA_LAUNCH_BLOCKING=1 python3.9 -m main llama_chat_completion --profile=False --benchmark=False --ckpt_dir="models/llama/meta-llama/Meta-Llama-3-8B-Instruct/original" --tokenizer_path="models/llama/meta-llama/Meta-Llama-3-8B-Instruct/original/tokenizer.model" --use_triton=True