Skip to content

Commit 7b5d2c4

Browse files
committed
refine the device
1 parent 9c8e66b commit 7b5d2c4

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

test/quantization/test_gptq.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
class TestGPTQ(TestCase):
2929
@unittest.skip("skipping until we get checkpoints for gpt-fast")
30+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
3031
def test_gptq_quantizer_int4_weight_only(self):
3132
from torchao._models._eval import (
3233
LMEvalInputRecorder,
@@ -105,6 +106,7 @@ def test_gptq_quantizer_int4_weight_only(self):
105106

106107

107108
class TestMultiTensorFlow(TestCase):
109+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
108110
def test_multitensor_add_tensors(self):
109111
from torchao.quantization.GPTQ import MultiTensor
110112

@@ -116,6 +118,7 @@ def test_multitensor_add_tensors(self):
116118
self.assertTrue(torch.equal(mt.values[0], tensor1))
117119
self.assertTrue(torch.equal(mt.values[1], tensor2))
118120

121+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
119122
def test_multitensor_pad_unpad(self):
120123
from torchao.quantization.GPTQ import MultiTensor
121124

@@ -126,6 +129,7 @@ def test_multitensor_pad_unpad(self):
126129
mt.unpad()
127130
self.assertEqual(mt.count, 1)
128131

132+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
129133
def test_multitensor_inplace_operation(self):
130134
from torchao.quantization.GPTQ import MultiTensor
131135

@@ -136,6 +140,7 @@ def test_multitensor_inplace_operation(self):
136140

137141

138142
class TestMultiTensorInputRecorder(TestCase):
143+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
139144
def test_multitensor_input_recorder(self):
140145
from torchao.quantization.GPTQ import MultiTensor, MultiTensorInputRecorder
141146

torchao/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def auto_detect_device():
152152
if torch.accelerator.is_available():
153153
return torch.accelerator.current_accelerator()
154154
else:
155-
return "cpu"
155+
return None
156156

157157

158158
def get_compute_capability():

0 commit comments

Comments
 (0)