Skip to content

{executorch][llama] support mqa #3080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/models/llama2/custom_ops/TARGETS
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()

runtime.python_test(
name = "test_sdpa_with_kv_cache",
srcs = [
"test_sdpa_with_kv_cache.py",
],
preload_deps = [
":custom_ops_aot_lib",
],
deps = [
"//caffe2:torch",
],
)
21 changes: 19 additions & 2 deletions examples/models/llama2/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,29 @@ void cpu_flash_attention(
int64_t qSize = query.size(2);
int64_t headSize = query.size(3);
int64_t kvSize = value.size(2);
int64_t num_heads_kv = key.size(1);

if (is_with_kv_cache) {
num_head = query.size(2);
num_heads_kv = key.size(2);
qSize = query.size(1);
kvSize = value.size(1);
}

ET_CHECK_MSG(
num_heads_kv <= num_head,
"FlashAttention does not support num kv heads > num query heads.Got num query heads=%" PRId64
" num key heads:%" PRId64,
num_head,
num_heads_kv);
ET_CHECK_MSG(
num_head % num_heads_kv == 0,
"FlashAttention: num qyery heads must be divisible by num kv heads but got num query heads=%" PRId64
" and num kv heads=%" PRId64,
num_head,
num_heads_kv);
int64_t num_reps = num_head / num_heads_kv;

bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
if (has_attn_mask) {
/*
Expand Down Expand Up @@ -365,6 +381,7 @@ void cpu_flash_attention(
fill_stub(
qk_max_data, -std::numeric_limits<accum_t>::infinity(), qBlockSize);
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
auto j_kv = j / num_reps;
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
// Calculate scale * q @ k.T
Expand All @@ -376,7 +393,7 @@ void cpu_flash_attention(
qBlockSize,
headSize,
static_cast<accum_t>(1),
k_data + i * kStrideB + j * kStrideH + n * kStrideN,
k_data + i * kStrideB + j_kv * kStrideH + n * kStrideN,
kStrideN,
q_data + i * qStrideB + j * qStrideH + m * qStrideM,
qStrideM,
Expand Down Expand Up @@ -460,7 +477,7 @@ void cpu_flash_attention(
qBlockSize,
kvBlockSize,
static_cast<accum_t>(1),
v_data + i * vStrideB + j * vStrideH + n * vStrideN,
v_data + i * vStrideB + j_kv * vStrideH + n * vStrideN,
vStrideN,
conditional_data_ptr(qk_data, qk_reduced_data),
kvBlockSize,
Expand Down
203 changes: 203 additions & 0 deletions examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
import torch.nn.functional as F


class SDPATest(unittest.TestCase):

def setUp(self):
torch.manual_seed(42)
self.k_cache = torch.zeros((1, 5, 8, 4))
self.v_cache = torch.zeros((1, 5, 8, 4))
self.mask = torch.full(
(5, 5),
float("-inf"),
)
self.mask = torch.triu(self.mask, diagonal=1)

def _sdpa_with_kv_cache_ref(self, q, k, v, k_cache, v_cache, mask, start_pos):
print(f"at start_pos:{start_pos}")
print(q)
print(k)
print(v)
attn_mask = mask[start_pos].view((1, -1))
attn_mask = attn_mask[:, : start_pos + 1]
q = q.transpose(1, 2)
k_cache[:, start_pos] = k
v_cache[:, start_pos] = v
sliced_k_cache = k_cache[:, : start_pos + 1, :, :]
sliced_v_cache = v_cache[:, : start_pos + 1, :, :]
sliced_k_cache = sliced_k_cache.transpose(1, 2)
sliced_v_cache = sliced_v_cache.transpose(1, 2)
# print(sliced_k_cache.size())
# print(torch.matmul(q, sliced_k_cache.transpose(2, 3)))
# print("q @ k")
# qk = torch.matmul(q, sliced_k_cache.transpose(2, 3))
# qk_softmax = torch.softmax(qk, dim=-1)
# qkv = torch.matmul(qk_softmax, sliced_v_cache)
# print(qk)
# print(qk_softmax)
# print(qkv)
out = F.scaled_dot_product_attention(
q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask
)
out = out.transpose(1, 2)
print(out)
print(f"-------- start pos {start_pos} done -----")
return out

def test_sdpa_with_cache_no_mqa_1(self):
q = torch.rand((1, 1, 8, 4))
k = torch.rand((1, 1, 8, 4))
v = torch.rand((1, 1, 8, 4))
ref_output = self._sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 0
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, 0, 1, None, 0, False
)
self.assertTrue(torch.allclose(ref_output, op_output))

def test_sdpa_with_cache_no_mqa_2(self):
q = torch.rand((1, 1, 8, 4))
k = torch.rand((1, 1, 8, 4))
v = torch.rand((1, 1, 8, 4))

ref_output = self._sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 1
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False
)
self.assertTrue(torch.allclose(ref_output, op_output))

def test_sdpa_with_cache_no_mqa_3(self):
q = torch.rand((1, 1, 8, 4))
k = torch.rand((1, 1, 8, 4))
v = torch.rand((1, 1, 8, 4))

ref_output = self._sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 2
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, 2, 1, None, 0, False
)
self.assertTrue(torch.allclose(ref_output, op_output))

def test_sdpa_with_cache_no_mqa_4(self):
q = torch.rand((1, 1, 8, 4))
k = torch.rand((1, 1, 8, 4))
v = torch.rand((1, 1, 8, 4))

ref_output = self._sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 3
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, 3, 1, None, 0, False
)
self.assertTrue(torch.allclose(ref_output, op_output))


class SDPATestWithMQA(unittest.TestCase):

def setup_caches(self):
self.k_cache = torch.zeros((1, 5, self.n_heads_kv, 4))
self.v_cache = torch.zeros((1, 5, self.n_heads_kv, 4))

def setUp(self):
torch.manual_seed(42)
self.n_heads_kv = 4
self.n_heads_q = 8
self.setup_caches()
self.mask = torch.full(
(5, 5),
float("-inf"),
)
self.mask = torch.triu(self.mask, diagonal=1)

def _sdpa_with_kv_cache_ref(self, q, k, v, k_cache, v_cache, mask, start_pos):
print(f"at start_pos:{start_pos}")
print(q)
print(k)
print(v)
attn_mask = mask[start_pos].view((1, -1))
attn_mask = attn_mask[:, : start_pos + 1]
q = q.transpose(1, 2)
k_cache[:, start_pos] = k
v_cache[:, start_pos] = v
sliced_k_cache = k_cache[:, : start_pos + 1, :, :]
sliced_v_cache = v_cache[:, : start_pos + 1, :, :]
sliced_k_cache = sliced_k_cache.transpose(1, 2)
sliced_v_cache = sliced_v_cache.transpose(1, 2)
# print(sliced_k_cache.size())
# print(torch.matmul(q, sliced_k_cache.transpose(2, 3)))
# print("q @ k")
# qk = torch.matmul(q, sliced_k_cache.transpose(2, 3))
# qk_softmax = torch.softmax(qk, dim=-1)
# qkv = torch.matmul(qk_softmax, sliced_v_cache)
# print(qk)
# print(qk_softmax)
# print(qkv)
num_heads_q = q.size(1)
num_heads_kv = sliced_k_cache.size(1)
if num_heads_q != num_heads_kv:
assert (
num_heads_q % num_heads_kv == 0
), f"{num_heads_q} not divisible by {num_heads_kv}"
n_reps = num_heads_q // num_heads_kv
if n_reps > 1:
sliced_k_cache = sliced_k_cache.repeat_interleave(n_reps, dim=1)
sliced_v_cache = sliced_v_cache.repeat_interleave(n_reps, dim=1)
out = F.scaled_dot_product_attention(
q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask
)
out = out.transpose(1, 2)
print(out)
print(f"-------- start pos {start_pos} done -----")
return out

def test_sdpa_with_cache_mqa_1(self):
q = torch.rand((1, 1, self.n_heads_q, 4))
k = torch.rand((1, 1, self.n_heads_kv, 4))
v = torch.rand((1, 1, self.n_heads_kv, 4))
ref_output = self._sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 0
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, 0, 1, None, 0, False
)
self.assertTrue(torch.allclose(ref_output, op_output))

def test_sdpa_with_cache_mqa_2(self):
q = torch.rand((1, 1, self.n_heads_q, 4))
k = torch.rand((1, 1, self.n_heads_kv, 4))
v = torch.rand((1, 1, self.n_heads_kv, 4))
ref_output = self._sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 1
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False
)
self.assertTrue(torch.allclose(ref_output, op_output))

def test_sdpa_with_cache_mqa_3(self):
self.n_heads_q = 14
self.n_heads_kv = 7
self.setup_caches()
q = torch.rand((1, 1, self.n_heads_q, 4))
k = torch.rand((1, 1, self.n_heads_kv, 4))
v = torch.rand((1, 1, self.n_heads_kv, 4))
ref_output = self._sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 1
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False
)
self.assertTrue(torch.allclose(ref_output, op_output))