Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Oct 29, 2024
1 parent f846f71 commit f0ca56d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 29 deletions.
27 changes: 16 additions & 11 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,19 +680,24 @@ def test_paged_attention_multi_queries_wrapper_with_dynamo(self):
kv_seq_lens_xla = kv_seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, page_indices, num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel):
def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
page_indices,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel):
return torch.ops.xla.multi_queries_paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=use_kernel,
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=use_kernel,
)

compiled_paged_attention = torch.compile(multi_queries_paged_attention_wrapper, backend="openxla")

compiled_paged_attention = torch.compile(
multi_queries_paged_attention_wrapper, backend="openxla")

output = compiled_paged_attention(
q_xla,
Expand Down
2 changes: 2 additions & 0 deletions test/test_tpu_paged_attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class PagedAttentionKernelTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()


# def test_paged_attention(
# self,
# ):
Expand All @@ -101,6 +102,7 @@ def setUp(self):
# head_dim = 256
# num_queries_per_compute_block = 32
# block_kv_size = 256

@parameterized.product(
dtype=(jnp.float32, jnp.bfloat16),
page_size=(16, 32, 64),
Expand Down
34 changes: 16 additions & 18 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,33 +1083,31 @@ def paged_attention_non_xla(q: torch.Tensor,
attn_logits_soft_cap: float = None):
return non_xla_attetion(q, k_pages, v_pages, "paged")


XLA_LIB.define(
"multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int num_kv_pages_per_compute_block, int num_queries_per_compute_block, bool use_kernel) -> Tensor",
)


@impl(XLA_LIB, "multi_queries_paged_attention", "XLA")
def multi_queries_paged_attention_xla(q: torch.Tensor,
k_pages: torch.Tensor,
v_pages: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int,
use_kernel: bool):
return multi_queries_paged_attention(q, k_pages, v_pages, lengths, page_indices,
num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel)
def multi_queries_paged_attention_xla(
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
lengths: torch.Tensor, page_indices: torch.Tensor,
num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int,
use_kernel: bool):
return multi_queries_paged_attention(q, k_pages, v_pages, lengths,
page_indices,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel)


@impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd")
def multi_queries_paged_attention_non_xla(q: torch.Tensor,
k_pages: torch.Tensor,
v_pages: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int,
use_kernel: bool):
def multi_queries_paged_attention_non_xla(
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
lengths: torch.Tensor, page_indices: torch.Tensor,
num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int,
use_kernel: bool):
return non_xla_attetion(q, k_pages, v_pages, "paged")


Expand Down

0 comments on commit f0ca56d

Please sign in to comment.