Skip to content

attention kernel does not work for sequence length = 1 #6

@adamomainz

Description

@adamomainz

Currently the attention kernel does not work well in special cases. An example of this is with the following shapes
q.shape = torch.Size([4, 32, 1, 128])
k.shape = torch.Size([4, 32, 20, 128])
v.shape = torch.Size([4, 32, 20, 128])

https://github.com/triton-lang/kernels/blob/main/kernels/flash_attention.py#L23

Repro steps:

  1. add triton kernel here https://github.com/triton-lang/kernels/blob/main/models/llama/llama/math_ops.py#L64
  2. run
    CUDA_LAUNCH_BLOCKING=1 python3.9 -m main llama_chat_completion --profile=False --benchmark=False --ckpt_dir="models/llama/meta-llama/Meta-Llama-3-8B-Instruct/original" --tokenizer_path="models/llama/meta-llama/Meta-Llama-3-8B-Instruct/original/tokenizer.model" --use_triton=True

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions