Skip to content

Commit a3d422d

Browse files
authored
add num_device check for set_auto_device_map_for_block_with_tuning (#1021)
Signed-off-by: He, Xin3 <xin3.he@intel.com>
1 parent 6d03ec8 commit a3d422d

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

auto_round/utils/device.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -989,20 +989,22 @@ def set_auto_device_map_for_block_with_tuning(
989989
Note:
990990
This function is intended for internal use in device memory management and tuning.
991991
"""
992-
if not (device_map == "auto" or ((isinstance(device_map, str) and "," in device_map))):
993-
block = block.to(output_device)
994-
card_0_in_high_risk = False # card 0 contains weight, clear_memory will not help much
995-
loss_device = output_device
996-
return card_0_in_high_risk, loss_device
997-
992+
card_0_in_high_risk, loss_device = False, output_device
998993
if torch.cuda.is_available():
999994
num_devices = torch.cuda.device_count()
1000995
device_name = "cuda"
1001996
elif torch.xpu.is_available():
1002997
num_devices = torch.xpu.device_count()
1003998
device_name = "xpu"
1004999
else:
1005-
return
1000+
return card_0_in_high_risk, loss_device
1001+
1002+
if not (
1003+
device_map == "auto" or ((isinstance(device_map, str) and "," in device_map)) or num_devices > 1
1004+
): # Only 1 card is available or non-auto device map
1005+
block = block.to(output_device)
1006+
return card_0_in_high_risk, loss_device
1007+
10061008
device_list = None
10071009
if isinstance(device_map, str) and "," in device_map:
10081010
device_list = [int(dev) for dev in device_map.split(",") if dev.isdigit()]

0 commit comments

Comments
 (0)