Skip to content

Commit 26e790d

Browse files
authored
clean up device checks in float8 unit test files (#923)
Summary: While working on rowwise scaling I noticed that some of the CUDA device capability checks we had in the test files did not make sense, cleaning this up. Test Plan: tests pass on my H100 CI, it should skip less tests now since CI only has CUDA capability 8, 9 Reviewers: Subscribers: Tasks: Tags:
1 parent 858205a commit 26e790d

File tree

2 files changed

+2
-24
lines changed

2 files changed

+2
-24
lines changed

test/float8/test_base.py

-23
Original file line numberDiff line numberDiff line change
@@ -231,15 +231,6 @@ def test_linear(
231231
linear_dtype: torch.dtype,
232232
linear_bias: bool,
233233
):
234-
if not emulate:
235-
if not torch.cuda.is_available():
236-
warnings.warn("CUDA not available")
237-
pytest.skip()
238-
elif torch.cuda.get_device_capability() < (9, 0):
239-
warnings.warn(
240-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
241-
)
242-
pytest.skip()
243234
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
244235
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
245236

@@ -287,16 +278,6 @@ def test_autocast_outputs(
287278
emulate: bool,
288279
linear_dtype: torch.dtype,
289280
):
290-
if not emulate:
291-
if not torch.cuda.is_available():
292-
warnings.warn("CUDA not available")
293-
pytest.skip()
294-
elif torch.cuda.get_device_capability() < (9, 0):
295-
warnings.warn(
296-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
297-
)
298-
pytest.skip()
299-
300281
m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
301282
config = Float8LinearConfig(
302283
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
@@ -334,10 +315,6 @@ def test_autocast_outputs(
334315
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
335316
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
336317
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
337-
emulate = (
338-
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)
339-
)
340-
341318
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
342319
config = Float8LinearConfig(emulate=emulate)
343320
m = Float8Linear.from_float(copy.deepcopy(m), config)

test/float8/test_compile.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def forward(self, x):
224224
return x_hp
225225
return x_fp8
226226

227-
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
227+
# TODO(future): figure out why the test below fails on CUDA capability 8.9
228+
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with capability 9.0 or greater not available")
228229
def test_float8_with_graph_break_in_the_middle(self):
229230
"""Test that having Float8Tensor object at the boundary of a subgraph"""
230231
cnts = CompileCounterWithBackend("inductor")

0 commit comments

Comments
 (0)