Skip to content

Commit 7beac1b

Browse files
authored
Fix rtn tuning_device issue (#893)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent 747b7af commit 7beac1b

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

auto_round/compressors/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1422,7 +1422,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
14221422
m.zp = None
14231423
else:
14241424
try:
1425-
m = m.to(self.device)
1425+
m = m.to(m.tuning_device if hasattr(m, "tuning_device") else self.device)
14261426
m = WrapperLinear(
14271427
m,
14281428
enable_minmax_tuning=False,

test/test_cuda/test_multiple_card.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,19 @@ def test_device_map_dict(self):
242242
)
243243
autoround.quantize()
244244

245+
# test rtn
246+
autoround = AutoRound(
247+
model_name,
248+
tokenizer,
249+
bits=bits,
250+
group_size=group_size,
251+
sym=sym,
252+
iters=0,
253+
seqlen=2,
254+
device_map=device_map,
255+
)
256+
autoround.quantize()
257+
245258
@multi_card
246259
@require_greater_than_050
247260
def test_device_map_for_triton(self):

0 commit comments

Comments
 (0)