Skip to content

Commit 38ece90

Browse files
committed
refine the device_module
1 parent f10b37b commit 38ece90

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

test/quantization/test_quant_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
)
7575

7676
_DEVICE = auto_detect_device()
77-
device_module = torch.get_device_module(_DEVICE)
7877

7978
try:
8079
import gemlite # noqa: F401
@@ -499,6 +498,7 @@ def test_quantized_tensor_subclass_save_load_map_location(self):
499498

500499
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
501500
def test_quantized_model_streaming(self):
501+
device_module = torch.get_device_module(_DEVICE)
502502
def reset_memory():
503503
gc.collect()
504504
device_module.empty_cache()
@@ -1109,6 +1109,7 @@ def test_non_fqn_config_filter_fn_none(self):
11091109

11101110
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
11111111
def test_quantized_model_streaming_fqn_config(self):
1112+
device_module = torch.get_device_module(_DEVICE)
11121113
def reset_memory():
11131114
gc.collect()
11141115
device_module.empty_cache()

test/quantization/test_quant_primitives.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@
4242
torch.manual_seed(_SEED)
4343

4444
_DEVICE = auto_detect_device()
45-
device_module = torch.get_device_module(_DEVICE)
46-
4745

4846
# Helper function to run a function twice
4947
# and verify that the result is the same.
@@ -599,6 +597,7 @@ def test_choose_qparams_tensor_asym_eps(self):
599597
def test_get_group_qparams_symmetric_memory(self):
600598
"""Check the memory usage of the op"""
601599
weight = torch.randn(1024, 1024).to(device=_DEVICE)
600+
device_module = torch.get_device_module(_DEVICE)
602601
original_mem_use = device_module.memory_allocated()
603602
n_bit = 4
604603
groupsize = 128

0 commit comments

Comments
 (0)