Skip to content

Commit

Permalink
add torch.compile support.
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Oct 29, 2024
1 parent 52c0ab0 commit f846f71
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 1 deletion.
96 changes: 96 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,17 @@ def test_paged_attention_multi_queries_wrapper(self):
num_queries_per_compute_block=num_queries_per_compute_block,
)

nonkernel_output = multi_queries_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=False,
)

q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
Expand All @@ -623,6 +634,91 @@ def test_paged_attention_multi_queries_wrapper(self):
self.assertTrue(
torch.allclose(
output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5))
self.assertTrue(
torch.allclose(
output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_multi_queries_wrapper_with_dynamo(self):
from torch_xla.experimental.custom_kernel import multi_queries_paged_attention
from torch_xla.experimental.pallas_kernels.multi_queries_paged_attention_kernel import paged_attention as jax_multi_queries_paged_attention

dtype = torch.float32
page_size = 16
num_kv_heads = 8
q_kv_head_ratio = 4
head_dim = 256
num_queries_per_compute_block = 32
block_kv_size = 256

max_kv_len = 2048
query_len = 64
kv_seq_lens = torch.randint(query_len, max_kv_len, (3,), dtype=torch.int32)
assert query_len <= max_kv_len
for cur_kv_seq in kv_seq_lens:
assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.'
batch_size = len(kv_seq_lens)
pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
kv_seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
dtype=dtype,
query_len=query_len,
)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
kv_seq_lens_xla = kv_seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, page_indices, num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel):
return torch.ops.xla.multi_queries_paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=use_kernel,
)

compiled_paged_attention = torch.compile(multi_queries_paged_attention_wrapper, backend="openxla")

output = compiled_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=True,
)

nonkernel_output = compiled_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=False,
)

self.assertTrue(
torch.allclose(
output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() != 4,
"This test only works on TPUv4 and TPUv5p.")
Expand Down
31 changes: 30 additions & 1 deletion torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _multi_queries_paged_attention_nonkernel(
attn = torch.einsum("qhd,khd->hqk", q[i],
k) # [num_query_heads, query_len, kv_len]
attn = attn.float()
empty_mask = torch.ones(query_len, kv_len)
empty_mask = torch.ones(query_len, kv_len, device=attn.device)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(
Expand Down Expand Up @@ -1083,6 +1083,35 @@ def paged_attention_non_xla(q: torch.Tensor,
attn_logits_soft_cap: float = None):
return non_xla_attetion(q, k_pages, v_pages, "paged")

XLA_LIB.define(
"multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int num_kv_pages_per_compute_block, int num_queries_per_compute_block, bool use_kernel) -> Tensor",
)


@impl(XLA_LIB, "multi_queries_paged_attention", "XLA")
def multi_queries_paged_attention_xla(q: torch.Tensor,
k_pages: torch.Tensor,
v_pages: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int,
use_kernel: bool):
return multi_queries_paged_attention(q, k_pages, v_pages, lengths, page_indices,
num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel)


@impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd")
def multi_queries_paged_attention_non_xla(q: torch.Tensor,
k_pages: torch.Tensor,
v_pages: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int,
use_kernel: bool):
return non_xla_attetion(q, k_pages, v_pages, "paged")


XLA_LIB.define(
"gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None) -> Tensor",
Expand Down

0 comments on commit f846f71

Please sign in to comment.