Skip to content

Commit 5f1879b

Browse files
committed
clean up device checks in float8 unit test files
Summary: While working on rowwise scaling I noticed that some of the CUDA device capability checks we had in the test files did not make sense, cleaning this up. Test Plan: tests pass on my H100 CI, it should skip less tests now since CI only has CUDA capability 8, 9 Reviewers: Subscribers: Tasks: Tags:
1 parent 53b6b78 commit 5f1879b

File tree

2 files changed

+1
-25
lines changed

2 files changed

+1
-25
lines changed

test/float8/test_base.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -231,15 +231,6 @@ def test_linear(
231231
linear_dtype: torch.dtype,
232232
linear_bias: bool,
233233
):
234-
if not emulate:
235-
if not torch.cuda.is_available():
236-
warnings.warn("CUDA not available")
237-
pytest.skip()
238-
elif torch.cuda.get_device_capability() < (9, 0):
239-
warnings.warn(
240-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
241-
)
242-
pytest.skip()
243234
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
244235
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
245236

@@ -287,16 +278,6 @@ def test_autocast_outputs(
287278
emulate: bool,
288279
linear_dtype: torch.dtype,
289280
):
290-
if not emulate:
291-
if not torch.cuda.is_available():
292-
warnings.warn("CUDA not available")
293-
pytest.skip()
294-
elif torch.cuda.get_device_capability() < (9, 0):
295-
warnings.warn(
296-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
297-
)
298-
pytest.skip()
299-
300281
m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
301282
config = Float8LinearConfig(
302283
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
@@ -334,10 +315,6 @@ def test_autocast_outputs(
334315
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
335316
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
336317
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
337-
emulate = (
338-
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)
339-
)
340-
341318
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
342319
config = Float8LinearConfig(emulate=emulate)
343320
m = Float8Linear.from_float(copy.deepcopy(m), config)

test/float8/test_compile.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from torch._dynamo.test_case import TestCase as DynamoTestCase
3333
from torch._dynamo.testing import CompileCounterWithBackend
3434

35-
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
3635
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
3736

3837
def _test_compile_base(
@@ -224,7 +223,7 @@ def forward(self, x):
224223
return x_hp
225224
return x_fp8
226225

227-
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
226+
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
228227
def test_float8_with_graph_break_in_the_middle(self):
229228
"""Test that having Float8Tensor object at the boundary of a subgraph"""
230229
cnts = CompileCounterWithBackend("inductor")

0 commit comments

Comments
 (0)