Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Sep 18, 2024
1 parent 087d8ac commit a44fc4c
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5370,6 +5370,35 @@ def forward(self, query, key, value):
@register_test_case(
module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule()
)
def ScaledDotProductAttentionDifferentDynamicCausalModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
module.forward(query, key, value)


class ScaledDotProductAttentionDifferentDynamicCausalModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 3, -1, 16], torch.float32, True),
([2, 3, -1, 16], torch.float32, True),
([2, 3, -1, 20], torch.float32, True),
]
)
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(
query, key, value, is_causal=True
)


@register_test_case(
module_factory=lambda: ScaledDotProductAttentionDifferentDynamicCausalModule()
)
def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
Expand Down

0 comments on commit a44fc4c

Please sign in to comment.