Skip to content

Commit

Permalink
Revert "Update error messages to reflect why test is skipped (pytorch…
Browse files Browse the repository at this point in the history
…#95049)"

This reverts commit 22e797a.
  • Loading branch information
seemethere committed Feb 17, 2023
1 parent 8928e7b commit e44737e
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,7 @@ def ones_tensor(*shape):
_ = mha_f(qkv_f, qkv_f, qkv_f, need_weights=False, is_causal=True)
torch.cuda.synchronize()

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Platform does not supposrt fused SDPA or pre-SM80 hardware")
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "CUDA unavailable")
def test_is_causal_gpu(self):
device = 'cuda'
self.is_causal_kernels(["math", "meff"], device)
Expand Down Expand Up @@ -1473,7 +1473,7 @@ def test_fused_sdp_choice(self, type: str):

assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Platform does not support fused SDPA")
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "CUDA unavailable")
@parametrize("warn_only", [True, False])
def test_sdp_choice_with_determinism(self, warn_only):
# If we are only warning we still expect that efficient_attention will still be called.
Expand All @@ -1487,7 +1487,7 @@ def test_sdp_choice_with_determinism(self, warn_only):
assert torch._fused_sdp_choice(query, key, value) == (
SDPBackend.EFFICIENT_ATTENTION if warn_only else SDPBackend.MATH)

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86Device, "Does not support fused SDPA or not SM86 hardware")
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86Device, "CUDA unavailable")
def test_memory_efficeint_sm86_failure(self):
device = 'cuda'
dtype = torch.float16
Expand All @@ -1499,7 +1499,7 @@ def test_memory_efficeint_sm86_failure(self):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86Device, "Does not support fused SDPA or not SM86 hardware")
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86Device, "CUDA unavailable")
def test_flash_backward_sm86_headdim128(self):
device = 'cuda'
dtype = torch.float16
Expand All @@ -1518,7 +1518,7 @@ def test_flash_backward_sm86_headdim128(self):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Platform does not support fused scaled dot product attention")
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
def test_dispatch_fails_no_backend(self):
dtype = torch.float16
device = "cuda"
Expand Down Expand Up @@ -1619,7 +1619,7 @@ def test_invalid_fused_inputs_attn_mask_present(self, kernel: SDPBackend):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, torch.ones_like(q), 0.0, False))

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support fused SDPA or pre-SM80 hardware")
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "CUDA unavailable")
def test_unaligned_tensors(self):
# The alignment is depdent on arch so we specifiy SM80OrLater
device = 'cuda'
Expand All @@ -1631,7 +1631,7 @@ def test_unaligned_tensors(self):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support fused SDPA or pre-SM80 hardware")
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "CUDA unavailable")
def test_flash_fail_fp32(self):
device = 'cuda'
dtype = torch.float
Expand All @@ -1642,7 +1642,7 @@ def test_flash_fail_fp32(self):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support SDPA or pre-SM80 hardware")
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "CUDA unavailable")
def test_flash_autocast_fp32_float16(self):
device = 'cuda'
dtype = torch.float
Expand All @@ -1654,7 +1654,7 @@ def test_flash_autocast_fp32_float16(self):
_ = torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False)

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support SDPA or pre-SM80 hardware")
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "CUDA unavailable")
def test_flash_autocast_fp32_bfloat16(self):
device = 'cuda'
dtype = torch.float
Expand Down Expand Up @@ -1684,7 +1684,7 @@ def func():

self.assertRaises(RuntimeError, func)

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support SDPA or pre-SM80 hardware")
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "CUDA unavailable")
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048])
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048])
Expand Down Expand Up @@ -1768,7 +1768,7 @@ def test_mem_efficient_attention_vs_math_ref_grads(self, batch_size: int, seq_le
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support SDPA or pre-SM80 hardware")
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "CUDA unavailable")
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048])
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048])
Expand Down

0 comments on commit e44737e

Please sign in to comment.