Skip to content

Integrate ragged paged attention v2 #8791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 85 additions & 122 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def _pagedattention_generate_qkv(
q = torch.randn(batch_size, query_len, num_heads, head_dim, dtype=dtype)
return q, k_pages, v_pages, page_indices

def _round_up_closest_multiple_of(self, x, base):
return (x + base - 1) // base * base
def _ceil_div(self, a, b):
assert b != 0
return (a + b - 1) // b

def _ragged_pagedattention_generate_qkv(
self,
Expand All @@ -97,64 +98,50 @@ def _ragged_pagedattention_generate_qkv(
head_dim,
page_size,
num_pages,
dtype=torch.float32,
num_queries_per_block=None,
pad_num_q_tokens=False,
dtype,
*,
num_kv_pages_per_block=None,
max_num_batched_tokens=None,
max_num_seqs=16,
):
num_seqs = len(seq_lens)
# Make sure the q_len is no longer than the kv_len. For example,
# seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because
# the 3rd sequence has q_len(506) > kv_len(463).
for i in range(num_seqs):
cur_q_len = seq_lens[i][0]
cur_kv_len = seq_lens[i][1]
assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}"

query_lens = [seq_len[0] for seq_len in seq_lens]
actual_num_q_tokens = sum(query_lens)
num_q_tokens = self._round_up_closest_multiple_of(
actual_num_q_tokens,
num_queries_per_block) if pad_num_q_tokens else actual_num_q_tokens
kv_lens = torch.tensor([seq_len[1] for seq_len in seq_lens],
dtype=torch.int32)
num_q_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0."
queries = torch.randn((num_q_tokens, num_q_heads, head_dim), dtype=dtype)
k_pages = torch.randn((num_kv_heads, num_pages, page_size, head_dim),
cu_q_lens = [0]
kv_lens = []
for q_len, kv_len in seq_lens:
assert q_len <= kv_len
cu_q_lens.append(cu_q_lens[-1] + q_len)
kv_lens.append(kv_len)

if max_num_batched_tokens is None:
max_num_batched_tokens = cu_q_lens[-1]
else:
max_num_batched_tokens = max(cu_q_lens[-1], max_num_batched_tokens)
if max_num_seqs is None:
max_num_seqs = len(seq_lens)
else:
max_num_seqs = max(len(seq_lens), max_num_seqs)
max_kv_len = max(kv_lens)
pages_per_seq = self._ceil_div(max_kv_len, page_size)
pages_per_seq = (
self._ceil_div(pages_per_seq, num_kv_pages_per_block) *
num_kv_pages_per_block)

num_q_heads, num_kv_heads = num_heads
cu_q_lens = torch.tensor(cu_q_lens, dtype=torch.int32)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
cu_q_lens = torch.nn.functional.pad(
cu_q_lens, (0, max_num_seqs + 1 - cu_q_lens.shape[0]), "constant", 0)
kv_lens = torch.nn.functional.pad(kv_lens,
(0, max_num_seqs - kv_lens.shape[0]),
"constant", 0)
q = torch.randn((max_num_batched_tokens, num_q_heads, head_dim),
dtype=dtype)
k_pages = torch.randn((num_pages, page_size, num_kv_heads, head_dim),
dtype=dtype)
v_pages = torch.randn((num_kv_heads, num_pages, page_size, head_dim),
v_pages = torch.randn((num_pages, page_size, num_kv_heads, head_dim),
dtype=dtype)

# Create a kv_lens: i32[num_tokens]
kv_lens_with_paddings = [0] * num_q_tokens
for i in range(num_seqs):
kv_lens_with_paddings[i] = kv_lens[i]
kv_lens_ = torch.tensor(kv_lens_with_paddings, dtype=torch.int32)

# Create a page_indices i32[num_tokens, pages_per_sequence]
max_kv_len = max([seq_len[1] for seq_len in seq_lens])
max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size

# The reason why we need to pad max_num_pages_per_seq is that
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
max_num_pages_per_seq = 2**int(np.ceil(np.log2(max_num_pages_per_seq)))

# The assert below mimics the reality that each page get a unique index.
# But for testing, the assert could be omitted.
# assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
page_indices = torch.randint(
0, num_pages, (num_q_tokens, max_num_pages_per_seq), dtype=torch.int32)

# Create a cu_q_lens i32[num_tokens + 1]
q_lens_with_paddings = [0] * num_q_tokens
for i in range(num_seqs):
q_lens_with_paddings[i] = query_lens[i]
cu_q_lens = torch.cumsum(
torch.tensor([0] + q_lens_with_paddings, dtype=torch.int32),
dim=0,
dtype=torch.int32)
return queries, k_pages, v_pages, page_indices, cu_q_lens, kv_lens_
0, num_pages, (max_num_seqs, pages_per_seq), dtype=torch.int32)
return q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_add(self):
Expand Down Expand Up @@ -648,7 +635,7 @@ def test_paged_attention_wrapper(self):
"This test only works on TPUv4+.")
def test_ragged_paged_attention_wrapper_without_dynamo(self):
from torch_xla.experimental.custom_kernel import ragged_paged_attention
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention

seq_lens = [
(1, 1328),
Expand All @@ -663,18 +650,25 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
(1, 17),
(99, 123)
] # last 3 physical q blocks [(q_len, kv_len),...]
num_heads = (4, 4)
num_heads = (32, 8)
head_dim = 128
dtype = torch.float32
page_size = 16
num_pages = 32768
num_seqs = len(seq_lens)
num_kv_pages_per_block = 128
num_kv_pages_per_block = 16
num_queries_per_block = 8
block_kv_size = 256

q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv(
seq_lens, num_heads, head_dim, page_size, num_pages, dtype=dtype)
seq_lens,
num_heads,
head_dim,
page_size,
num_pages,
dtype,
num_kv_pages_per_block=num_kv_pages_per_block,
max_num_batched_tokens=1024,
max_num_seqs=16)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
Expand All @@ -693,7 +687,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
num_seqs=num_seqs,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
use_kernel=True)
use_kernel=True)[:cu_q_lens[num_seqs]]

nonkernel_output = ragged_paged_attention(
q_xla,
Expand Down Expand Up @@ -726,7 +720,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
num_seqs=num_seqs,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)[1]))
)[:cu_q_lens[num_seqs]]))

self.assertTrue(
torch.allclose(
Expand All @@ -745,19 +739,25 @@ def _verify_ragged_paged_attention_with_dynamo(
dtype,
num_kv_pages_per_block,
num_queries_per_block,
pad_num_q_tokens=False,
pad_tokens_and_seqs=False,
sm_scale=1.0,
):
num_seqs = len(seq_lens)
max_num_batched_tokens = None
max_num_seqs = None
if pad_tokens_and_seqs:
max_num_batched_tokens = 1024
max_num_seqs = 16
q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv(
seq_lens,
num_heads,
head_dim,
page_size,
num_pages,
dtype=dtype,
num_queries_per_block=num_queries_per_block,
pad_num_q_tokens=pad_num_q_tokens)
dtype,
num_kv_pages_per_block=num_kv_pages_per_block,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
Expand All @@ -766,29 +766,7 @@ def _verify_ragged_paged_attention_with_dynamo(
page_indices_xla = page_indices.to("xla")
cu_q_lens_xla = cu_q_lens.to("xla")

def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
page_indices, cu_q_lens, num_seqs,
num_kv_pages_per_block,
num_queries_per_block, use_kernel,
sm_scale):
return torch.ops.xla.ragged_paged_attention(
q,
k_pages,
v_pages,
kv_lens,
page_indices,
cu_q_lens,
num_seqs,
num_kv_pages_per_block,
num_queries_per_block,
use_kernel=use_kernel,
sm_scale=sm_scale,
)

compiled_paged_attention = torch.compile(
ragged_paged_attention_wrapper, backend="openxla")

kernel_output = compiled_paged_attention(
kernel_output = torch.ops.xla.ragged_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
Expand All @@ -800,9 +778,9 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
num_queries_per_block=num_queries_per_block,
use_kernel=True,
sm_scale=sm_scale,
)
)[:cu_q_lens[num_seqs]]

nonkernel_output = compiled_paged_attention(
nonkernel_output = torch.ops.xla.ragged_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
Expand All @@ -828,7 +806,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
cu_q_lens_jax = jnp.array(cu_q_lens.numpy(), dtype=jnp.int32)

from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
jax_kernel_output = torch.from_numpy(
np.array(
jax_ragged_paged_attention(
Expand All @@ -842,34 +820,19 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
sm_scale=sm_scale,
)[1]))
)[:cu_q_lens[num_seqs]]))
jax_kernel_output_cpu = jax_kernel_output.cpu()

if pad_num_q_tokens:
actual_num_q_tokens = cu_q_lens[num_seqs]
self.assertTrue(
torch.allclose(
kernel_output_cpu[:actual_num_q_tokens],
nonkernel_output_cpu[:actual_num_q_tokens],
atol=2e-2,
rtol=1e-2))
self.assertTrue(
torch.allclose(
kernel_output_cpu[:actual_num_q_tokens],
jax_kernel_output_cpu[:actual_num_q_tokens],
atol=2e-2,
rtol=1e-2))
else:
self.assertTrue(
torch.allclose(
kernel_output_cpu, nonkernel_output_cpu, atol=2e-2, rtol=1e-2))
self.assertTrue(
torch.allclose(
kernel_output_cpu, jax_kernel_output_cpu, atol=2e-2, rtol=1e-2))
self.assertTrue(
torch.allclose(
kernel_output_cpu, nonkernel_output_cpu, atol=2e-2, rtol=1e-2))
self.assertTrue(
torch.allclose(
kernel_output_cpu, jax_kernel_output_cpu, atol=2e-2, rtol=1e-2))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
def test_ragged_paged_attention_wrapper_no_padding_with_dynamo(self):
seq_lens = [
(1, 1328),
(5, 18),
Expand All @@ -883,7 +846,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
(1, 17),
(99, 123)
] # last 3 physical q blocks [(q_len, kv_len),...]
num_heads = (4, 4)
num_heads = (32, 8)
head_dim = 128
dtype = torch.float32
page_size = 16
Expand All @@ -897,7 +860,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
page_size,
num_pages,
dtype,
num_kv_pages_per_block=128,
num_kv_pages_per_block=16,
num_queries_per_block=8,
sm_scale=sm_scale,
)
Expand All @@ -908,12 +871,12 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
)
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo(
def test_ragged_paged_attention_wrapper_with_padding_with_dynamo(
self,
seq_lens,
num_queries_per_block,
):
num_heads = (4, 4)
num_heads = (32, 8)
head_dim = 128
dtype = torch.float32
page_size = 16
Expand All @@ -927,9 +890,9 @@ def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo(
page_size,
num_pages,
dtype,
num_kv_pages_per_block=128,
num_kv_pages_per_block=16,
num_queries_per_block=num_queries_per_block,
pad_num_q_tokens=True,
pad_tokens_and_seqs=True,
sm_scale=sm_scale,
)

Expand Down
Loading
Loading