Skip to content

Commit 4562cf4

Browse files
authored
fix gguf fp8 input model and vram issue (#844)
1 parent 162424d commit 4562cf4

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

auto_round/compressors/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,9 +1195,6 @@ def _quant_rtn_with_imatrix(self, all_to_quantized_module_names: list[str]) -> N
11951195
# Load dataset
11961196
from auto_round.calib_dataset import get_dataloader
11971197

1198-
if _is_fp8_model(self.model):
1199-
convert_fp8_model_to_16b_model(self.model, self.amp_dtype)
1200-
12011198
if isinstance(self.dataset, str):
12021199
if self.tokenizer is None:
12031200
raise ValueError("A tokenizer must be set for the model when using a dataset string.")
@@ -1244,6 +1241,8 @@ def get_imatrix_hook(module, input, output):
12441241
dispatch_model(self.model, self.model.hf_device_map)
12451242
else:
12461243
model = model.to(self.device)
1244+
if _is_fp8_model(self.model):
1245+
convert_fp8_model_to_16b_model(self.model, self.amp_dtype)
12471246
cnt = 0
12481247

12491248
# Run forward pass to accumulate imatrix
@@ -1422,7 +1421,6 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
14221421
"""
14231422
m = get_module(self.model, name)
14241423

1425-
# if m.__class__.__name__ == "FP8Linear":
14261424
if _is_fp8_linear(m):
14271425
m = convert_fp8_layer_to_linear(m, self.amp_dtype)
14281426
set_module(self.model, name, m)

auto_round/export/export_to_gguf/convert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,10 @@ def prepare_tensors(cls):
394394
clean_weight_list = []
395395

396396
modify_name = _special_name_handle(cls, name)
397+
orig_device = data_torch.device
398+
data_torch = data_torch.to("cpu")
397399
for new_name, data_torch in cls.modify_tensors(data_torch, modify_name, bid):
400+
data_torch.to(orig_device)
398401
skip = False
399402
for tensor_info in cls.gguf_writer.tensors:
400403
if new_name in tensor_info:
@@ -545,7 +548,7 @@ def prepare_tensors(cls):
545548
attr_tensor = getattr(module, attr)
546549
else:
547550
attr_tensor = getattr(module, "w_" + attr)
548-
if attr_tensor is None:
551+
if attr_tensor is None or not isinstance(attr_tensor, torch.Tensor):
549552
continue
550553
kv_b = attr_tensor.view(n_head_kv, v_head_dim + qk_nope_head_dim, -1)
551554
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)

auto_round/export/export_to_gguf/convert_hf_to_gguf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@
4141

4242
if "NO_LOCAL_GGUF" not in os.environ:
4343
sys.path.insert(1, str(Path(__file__).parent / "gguf-py"))
44+
4445
from auto_round.utils import LazyImport
4546

4647
gguf = LazyImport("gguf")
48+
4749
MistralTokenizerType = LazyImport("gguf.vocab.MistralTokenizerType")
4850
MistralVocab = LazyImport("gguf.vocab.MistralVocab")
4951
DATASET_MEAN = LazyImport("mistral_common.tokens.tokenizers.multimodal.DATASET_MEAN")

auto_round/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def __getitem__(self, key):
7979

8080
# Changed to str as it relies on triton or others lib to load this
8181
INNER_SUPPORTED_LAYER_TYPES = ("FP8Linear",)
82-
# INNER_SUPPORTED_LAYER_TYPES = (transformers.integrations.finegrained_fp8.FP8Linear,)
83-
82+
# transformers.integrations.finegrained_fp8.FP8Linear
8483
if deepspeed_exists:
8584
from deepspeed.module_inject import LinearAllreduce, LinearLayer
8685

@@ -2298,10 +2297,14 @@ def convert_fp8_model_to_16b_model(model, dtype=torch.bfloat16):
22982297
Convert a model with FP8 quantized layers to a model with 16-bit linear layers.
22992298
This is useful for compatibility with other frameworks or for further processing.
23002299
"""
2300+
cnt = 0
23012301
for n, m in model.named_modules():
23022302
if m.__class__.__name__ == "FP8Linear":
23032303
new_module = convert_fp8_layer_to_linear(m, dtype=dtype)
23042304
set_module(model, n, new_module)
2305+
cnt += 1
2306+
if cnt % 10 == 0: # Tricky setting
2307+
clear_memory()
23052308
return model
23062309

23072310

@@ -2344,7 +2347,6 @@ def download_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None):
23442347
"""Download hugging face model from hf hub."""
23452348
from huggingface_hub.constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE
23462349
from huggingface_hub.file_download import REGEX_COMMIT_HASH, repo_folder_name
2347-
from huggingface_hub.utils import EntryNotFoundError
23482350

23492351
if cache_dir is None:
23502352
cache_dir = HUGGINGFACE_HUB_CACHE

0 commit comments

Comments
 (0)