@@ -989,20 +989,22 @@ def set_auto_device_map_for_block_with_tuning(
989989 Note:
990990 This function is intended for internal use in device memory management and tuning.
991991 """
992- if not (device_map == "auto" or ((isinstance (device_map , str ) and "," in device_map ))):
993- block = block .to (output_device )
994- card_0_in_high_risk = False # card 0 contains weight, clear_memory will not help much
995- loss_device = output_device
996- return card_0_in_high_risk , loss_device
997-
992+ card_0_in_high_risk , loss_device = False , output_device
998993 if torch .cuda .is_available ():
999994 num_devices = torch .cuda .device_count ()
1000995 device_name = "cuda"
1001996 elif torch .xpu .is_available ():
1002997 num_devices = torch .xpu .device_count ()
1003998 device_name = "xpu"
1004999 else :
1005- return
1000+ return card_0_in_high_risk , loss_device
1001+
1002+ if not (
1003+ device_map == "auto" or ((isinstance (device_map , str ) and "," in device_map )) or num_devices > 1
1004+ ): # Only 1 card is available or non-auto device map
1005+ block = block .to (output_device )
1006+ return card_0_in_high_risk , loss_device
1007+
10061008 device_list = None
10071009 if isinstance (device_map , str ) and "," in device_map :
10081010 device_list = [int (dev ) for dev in device_map .split ("," ) if dev .isdigit ()]
0 commit comments