Skip to content

Commit cfe14e2

Browse files
committed
refine the device
1 parent 9c8e66b commit cfe14e2

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

test/quantization/test_gptq.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
class TestGPTQ(TestCase):
2929
@unittest.skip("skipping until we get checkpoints for gpt-fast")
30+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
3031
def test_gptq_quantizer_int4_weight_only(self):
3132
from torchao._models._eval import (
3233
LMEvalInputRecorder,
@@ -105,6 +106,7 @@ def test_gptq_quantizer_int4_weight_only(self):
105106

106107

107108
class TestMultiTensorFlow(TestCase):
109+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
108110
def test_multitensor_add_tensors(self):
109111
from torchao.quantization.GPTQ import MultiTensor
110112

@@ -116,6 +118,7 @@ def test_multitensor_add_tensors(self):
116118
self.assertTrue(torch.equal(mt.values[0], tensor1))
117119
self.assertTrue(torch.equal(mt.values[1], tensor2))
118120

121+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
119122
def test_multitensor_pad_unpad(self):
120123
from torchao.quantization.GPTQ import MultiTensor
121124

@@ -126,6 +129,7 @@ def test_multitensor_pad_unpad(self):
126129
mt.unpad()
127130
self.assertEqual(mt.count, 1)
128131

132+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
129133
def test_multitensor_inplace_operation(self):
130134
from torchao.quantization.GPTQ import MultiTensor
131135

@@ -136,6 +140,7 @@ def test_multitensor_inplace_operation(self):
136140

137141

138142
class TestMultiTensorInputRecorder(TestCase):
143+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
139144
def test_multitensor_input_recorder(self):
140145
from torchao.quantization.GPTQ import MultiTensor, MultiTensorInputRecorder
141146

torchao/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,12 @@ def get_available_devices():
149149

150150

151151
def auto_detect_device():
152-
if torch.accelerator.is_available():
153-
return torch.accelerator.current_accelerator()
152+
if torch.cuda.is_available():
153+
return "cuda"
154+
elif torch.xpu.is_available():
155+
return "xpu"
154156
else:
155-
return "cpu"
157+
return None
156158

157159

158160
def get_compute_capability():

0 commit comments

Comments
 (0)