Skip to content

Commit 255322f

Browse files
authored
dispatch model with real max memory (#1022)
Signed-off-by: He, Xin3 <xin3.he@intel.com>
1 parent a3d422d commit 255322f

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

auto_round/compressors/base.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import accelerate
2626
import torch
2727
from accelerate.big_modeling import dispatch_model, infer_auto_device_map
28-
from accelerate.utils import get_balanced_memory
28+
from accelerate.utils import get_max_memory
2929
from torch import autocast
3030
from tqdm import tqdm
3131
from transformers import set_seed
@@ -1992,11 +1992,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
19921992
if str(self.model.device) == "cpu" and (not self.device.startswith("hpu")):
19931993
no_split_modules = getattr(self.model, "_no_split_modules", [])
19941994
devices = parse_available_devices(self.device_map)
1995-
max_memory = get_balanced_memory(
1996-
self.model,
1997-
max_memory=None,
1998-
no_split_module_classes=no_split_modules,
1999-
)
1995+
max_memory = get_max_memory()
20001996
new_max_memory = {}
20011997
for device in devices:
20021998
if ":" in device:

0 commit comments

Comments
 (0)