From f846f71fece9d24e2d713867fbc45d068b51f173 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Tue, 29 Oct 2024 04:39:28 +0000 Subject: [PATCH] add torch.compile support. --- test/test_pallas.py | 96 +++++++++++++++++++++++++ torch_xla/experimental/custom_kernel.py | 31 +++++++- 2 files changed, 126 insertions(+), 1 deletion(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 47cb7a05e0b..c13e195e9d1 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -603,6 +603,17 @@ def test_paged_attention_multi_queries_wrapper(self): num_queries_per_compute_block=num_queries_per_compute_block, ) + nonkernel_output = multi_queries_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + kv_seq_lens_xla, + page_indices_xla, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + use_kernel=False, + ) + q_jax = jnp.array(q.numpy(), dtype=jnp.float32) k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32) v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32) @@ -623,6 +634,91 @@ def test_paged_attention_multi_queries_wrapper(self): self.assertTrue( torch.allclose( output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5)) + self.assertTrue( + torch.allclose( + output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2)) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_paged_attention_multi_queries_wrapper_with_dynamo(self): + from torch_xla.experimental.custom_kernel import multi_queries_paged_attention + from torch_xla.experimental.pallas_kernels.multi_queries_paged_attention_kernel import paged_attention as jax_multi_queries_paged_attention + + dtype = torch.float32 + page_size = 16 + num_kv_heads = 8 + q_kv_head_ratio = 4 + head_dim = 256 + num_queries_per_compute_block = 32 + block_kv_size = 256 + + max_kv_len = 2048 + query_len = 64 + kv_seq_lens = torch.randint(query_len, max_kv_len, (3,), dtype=torch.int32) + assert query_len <= max_kv_len + for cur_kv_seq in kv_seq_lens: + assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.' + batch_size = len(kv_seq_lens) + pages_per_sequence = max_kv_len // page_size + total_num_pages = batch_size * pages_per_sequence + assert max_kv_len <= total_num_pages * page_size + + q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( + kv_seq_lens, + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + dtype=dtype, + query_len=query_len, + ) + + q_xla = q.to("xla") + k_pages_xla = k_pages.to("xla") + v_pages_xla = v_pages.to("xla") + kv_seq_lens_xla = kv_seq_lens.to("xla") + page_indices_xla = page_indices.to("xla") + + def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, page_indices, num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel): + return torch.ops.xla.multi_queries_paged_attention( + q, + k_pages, + v_pages, + kv_seq_lens, + page_indices, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=use_kernel, + ) + + compiled_paged_attention = torch.compile(multi_queries_paged_attention_wrapper, backend="openxla") + + output = compiled_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + kv_seq_lens_xla, + page_indices_xla, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + use_kernel=True, + ) + + nonkernel_output = compiled_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + kv_seq_lens_xla, + page_indices_xla, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + use_kernel=False, + ) + + self.assertTrue( + torch.allclose( + output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2)) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() != 4, "This test only works on TPUv4 and TPUv5p.") diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 47ec27ab6f8..dc9b15cb7a9 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -528,7 +528,7 @@ def _multi_queries_paged_attention_nonkernel( attn = torch.einsum("qhd,khd->hqk", q[i], k) # [num_query_heads, query_len, kv_len] attn = attn.float() - empty_mask = torch.ones(query_len, kv_len) + empty_mask = torch.ones(query_len, kv_len, device=attn.device) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() attn.masked_fill_(mask, float("-inf")) attn = torch.softmax( @@ -1083,6 +1083,35 @@ def paged_attention_non_xla(q: torch.Tensor, attn_logits_soft_cap: float = None): return non_xla_attetion(q, k_pages, v_pages, "paged") +XLA_LIB.define( + "multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int num_kv_pages_per_compute_block, int num_queries_per_compute_block, bool use_kernel) -> Tensor", +) + + +@impl(XLA_LIB, "multi_queries_paged_attention", "XLA") +def multi_queries_paged_attention_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + lengths: torch.Tensor, + page_indices: torch.Tensor, + num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, + use_kernel: bool): + return multi_queries_paged_attention(q, k_pages, v_pages, lengths, page_indices, + num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel) + + +@impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd") +def multi_queries_paged_attention_non_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + lengths: torch.Tensor, + page_indices: torch.Tensor, + num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, + use_kernel: bool): + return non_xla_attetion(q, k_pages, v_pages, "paged") + XLA_LIB.define( "gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None) -> Tensor",