Skip to content

Commit cba6848

Browse files
committed
device
1 parent 3db5d9a commit cba6848

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

torchao/testing/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,9 @@ def test_linear_compile(self, device, dtype):
223223
NUM_DEVICES,
224224
)
225225

226-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
227226
class TorchAOTensorParallelTestCase(DTensorTestBase):
228227
"""Basic test case for tensor subclasses
229228
"""
230-
COMMON_DEVICES = (["cuda"] if torch.cuda.is_available() else [])
231229
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
232230

233231
TENSOR_SUBCLASS = AffineQuantizedTensor
@@ -279,10 +277,11 @@ def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
279277
quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS))
280278
return m
281279

282-
@common_utils.parametrize("device", COMMON_DEVICES)
283280
@common_utils.parametrize("dtype", COMMON_DTYPES)
284281
@with_comms
285-
def test_tp(self, device, dtype):
282+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
283+
def test_tp(self, dtype):
284+
device = "cuda"
286285
# To make sure different ranks create the same module
287286
torch.manual_seed(5)
288287

0 commit comments

Comments
 (0)