Skip to content

Commit 664f2af

Browse files
changwangsschensuyue
authored andcommitted
[SW-207579] support load vLLM compatible FP8 model (#18)
Support load vLLM compatible FP8 model, both G2 and G3, both single card and multi-cards. --------- Signed-off-by: changwang <changwang@habana.ai>
1 parent fd8aed0 commit 664f2af

File tree

3 files changed

+204
-11
lines changed

3 files changed

+204
-11
lines changed

neural_compressor/torch/algorithms/fp8_quant/save_load.py

Lines changed: 177 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
import torch
2121

2222
from ._quant_common.quant_config import local_rank, world_size
23-
from neural_compressor.torch.utils import get_accelerator
24-
23+
from neural_compressor.torch.utils import get_accelerator, is_optimum_habana_available
2524

2625
MAX_FILE_SIZE = 5 # GB
2726
cur_accelerator = get_accelerator()
@@ -153,12 +152,36 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
153152
"""Initialize BF16 model with meta tensor."""
154153
import transformers
155154
from accelerate import init_empty_weights
155+
config = transformers.AutoConfig.from_pretrained(model_name_or_path, **kwargs)
156+
# fp8 model provided by neuralmagic.
157+
if (
158+
"quant_method" in config.quantization_config
159+
and config.quantization_config["quant_method"] in ["fp8", "compressed-tensors"]
160+
):
161+
from_neuralmagic = True
162+
if (
163+
"kv_cache_scheme" in config.quantization_config
164+
and config.quantization_config["kv_cache_scheme"] is not None
165+
):
166+
from_neuralmagic_with_kv = True
167+
else:
168+
from_neuralmagic_with_kv = False
169+
else:
170+
from_neuralmagic = False
171+
from_neuralmagic_with_kv = False
172+
173+
if from_neuralmagic_with_kv:
174+
config.flash_attention_fp8 = True
175+
if is_optimum_habana_available:
176+
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
177+
adapt_transformers_to_gaudi()
178+
else:
179+
raise ValueError("Please install optimum-habana to load fp8 kv cache model.")
180+
156181
from neural_compressor.torch.utils import get_non_persistent_buffers, load_non_persistent_buffers
157182

158183
if world_size > 1:
159184
import deepspeed
160-
161-
config = transformers.AutoConfig.from_pretrained(model_name_or_path, **kwargs)
162185
with init_empty_weights(include_buffers=False):
163186
model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
164187
# TODO: [SW-199728] [DeepSpeed] Buffers initialized by model are not correct after tensor parallel
@@ -172,10 +195,9 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
172195
model = model.module
173196
load_non_persistent_buffers(model, non_persistent_buffers)
174197
else:
175-
config = transformers.AutoConfig.from_pretrained(model_name_or_path, **kwargs)
176198
with init_empty_weights(include_buffers=False):
177199
model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
178-
return model
200+
return model, from_neuralmagic, from_neuralmagic_with_kv
179201

180202

181203
def find_safetensors_files(model_name_or_path, **kwargs):
@@ -205,6 +227,9 @@ def find_safetensors_files(model_name_or_path, **kwargs):
205227
resolved_archive_file,
206228
**kwargs,
207229
)
230+
# for the model only with 1 model.safetensors file.
231+
if isinstance(resolved_archive_file, str):
232+
resolved_archive_file = [resolved_archive_file]
208233
return resolved_archive_file
209234

210235

@@ -219,6 +244,57 @@ def shard_state_dict(state_dict):
219244
rank_state_dict[k] = v.to("hpu")
220245
return rank_state_dict
221246

247+
def split_rank_state_dict(model, gathered_state_dict):
248+
"""split state_dict for current local_rank."""
249+
rank_state_dict = {}
250+
for name, param in model.named_parameters():
251+
if name in gathered_state_dict:
252+
full_weight = gathered_state_dict[name]
253+
if len(param.shape) != 0 and full_weight.shape != param.shape:
254+
if full_weight.shape[0] != param.shape[0]:
255+
split_weight = split_weights(full_weight, world_size, local_rank, split_axis=0)
256+
elif full_weight.shape[1] != param.shape[1]:
257+
split_weight = split_weights(full_weight, world_size, local_rank, split_axis=1)
258+
else:
259+
split_weight = split_weights(full_weight, world_size, local_rank, split_axis=0)
260+
else:
261+
split_weight = full_weight
262+
rank_state_dict[name] = split_weight
263+
264+
return rank_state_dict
265+
266+
267+
def get_inc_fp8config(model, from_neuralmagic=False, from_neuralmagic_with_kv=False):
268+
"""Get INC FP8 Config.
269+
270+
Args:
271+
model: empty model.
272+
from_neuralmagic(bool, optional): whether provided from nerualmagic modelhub.
273+
from_neuralmagic_with_kv(bool, optional): whether provided from nerualmagic modelhub and quantized kv_cache.
274+
275+
Returns:
276+
INC FP8 Config.
277+
"""
278+
from neural_compressor.torch.quantization import FP8Config
279+
if from_neuralmagic:
280+
if "ignore" in model.config.quantization_config.keys():
281+
blocklist = {"types": [], "names": model.config.quantization_config["ignore"]}
282+
elif "ignored_layers" in model.config.quantization_config.keys():
283+
blocklist = {"types": [], "names": model.config.quantization_config["ignored_layers"]}
284+
else:
285+
blocklist = {"types": [], "names": ["lm_head"]}
286+
if "target" in model.config.quantization_config.keys():
287+
allowlist = {"types": model.config.quantization_config["target"], "names": []}
288+
else:
289+
if from_neuralmagic_with_kv:
290+
allowlist = {"types": ["Linear", "LinearLayer", "LinearAllreduce", "KVCache"], "names": []}
291+
else:
292+
allowlist = {"types": ["Linear", "LinearLayer", "LinearAllreduce"], "names": []}
293+
qconfig = FP8Config(mode="LOAD", allowlist=allowlist, blocklist=blocklist, scale_format="CONST")
294+
else:
295+
qconfig = FP8Config.from_dict(model.config.quantization_config)
296+
return qconfig
297+
222298

223299
def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
224300
"""Load FP8 model.
@@ -236,12 +312,12 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
236312
assert device == "hpu", "Currently, only hpu device is supported for FP8 model."
237313
from safetensors.torch import load_file as safe_load_file
238314

239-
model = load_empty_raw_model(model_name_or_path, **kwargs)
240315
from neural_compressor.torch.algorithms.fp8_quant import prep_model
241-
from neural_compressor.torch.quantization import FP8Config
242316

243-
qconfig = FP8Config.from_dict(model.config.quantization_config)
317+
model, from_neuralmagic, from_neuralmagic_with_kv = load_empty_raw_model(model_name_or_path, **kwargs)
318+
qconfig = get_inc_fp8config(model, from_neuralmagic, from_neuralmagic_with_kv)
244319
qconfig.save_temp_json_file() # generate qconfig.json_file
320+
245321
# replace modules to patched modules
246322
prep_model(model, qconfig.json_file)
247323
# get the safetensors file list from one folder
@@ -250,15 +326,106 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
250326
for file_name in files_list:
251327
cur_file = os.path.join(model_name_or_path, file_name)
252328
gathered_state_dict = safe_load_file(cur_file)
329+
if from_neuralmagic or from_neuralmagic_with_kv:
330+
import habana_frameworks.torch.utils.experimental as htexp
331+
gathered_state_dict = convert_weight_to_inc(
332+
state_dict=gathered_state_dict,
333+
on_gaudi2=htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2
334+
)
253335
if world_size > 0:
254336
# only return state_dict for the current local_rank
255-
rank_state_dict = shard_state_dict(gathered_state_dict)
337+
if from_neuralmagic or from_neuralmagic_with_kv:
338+
rank_state_dict = split_rank_state_dict(model, gathered_state_dict)
339+
else:
340+
rank_state_dict = shard_state_dict(gathered_state_dict)
256341
model.load_state_dict(rank_state_dict, assign=True, strict=False)
257342
else:
258343
model.load_state_dict(gathered_state_dict, assign=True, strict=False)
344+
345+
if from_neuralmagic or from_neuralmagic_with_kv:
346+
model.tie_weights()
259347
model = model.eval()
260348
model = model.to(cur_accelerator.name())
349+
261350
cur_accelerator.synchronize()
262351
# make sure cpu and hpu memory are all released.
263352
gc.collect()
264353
return model
354+
355+
356+
def convert_weight_to_inc(state_dict, on_gaudi2=False):
357+
"""To convert the vllm compatable fp8 model weight to INC format,
358+
one is operators' name are different, the other is to adapt weight on G2
359+
due to the torch.float8_e4m3fn scope [-240, 240].
360+
361+
Args:
362+
state_dict (dict): state_dict from modelhub.
363+
on_gaudi2 (bool, optional): whether is on Gaudi2. Defaults to False.
364+
365+
Returns:
366+
state_dict includes weight and scale adapted to INC format.
367+
"""
368+
key_name = state_dict.keys()
369+
for key in list(key_name):
370+
if "weight_scale" in key:
371+
scale_weight = key.replace("weight_scale", "scale_weight")
372+
if on_gaudi2:
373+
# dequant_weight
374+
weight_key = key.replace("weight_scale", "weight")
375+
qweight = state_dict[weight_key].t().to(torch.bfloat16).to("hpu")
376+
scale = state_dict[key].to("hpu")
377+
dequant_weight = qweight * scale
378+
# recompute scale, qweight
379+
recompute_scale = scale * (torch.finfo(torch.float8_e4m3fn).max /
380+
torch.finfo(torch.float8_e4m3fnuz).max)
381+
qweight = torch.ops.hpu.cast_to_fp8_v2(dequant_weight, 1.0 / recompute_scale, False, False, torch.float8_e4m3fn)[0]
382+
state_dict[weight_key] = qweight
383+
state_dict[scale_weight] = recompute_scale
384+
else:
385+
state_dict[scale_weight] = state_dict[key].to("hpu")
386+
state_dict.pop(key)
387+
elif "kv_scale" in key:
388+
k_scale_inv = key.replace("kv_scale", "k_cache.quant_input.scale_inv")
389+
v_scale_inv = key.replace("kv_scale", "v_cache.quant_input.scale_inv")
390+
k_scale = key.replace("kv_scale", "k_cache.dequant_output.scale")
391+
v_scale = key.replace("kv_scale", "v_cache.dequant_output.scale")
392+
state_dict[k_scale_inv] = 1 / state_dict[key].to("hpu")
393+
state_dict[v_scale_inv] = 1 / state_dict[key].to("hpu")
394+
state_dict[k_scale] = state_dict[key].to("hpu")
395+
state_dict[v_scale] = state_dict[key].to("hpu")
396+
state_dict.pop(key)
397+
elif "input_scale" in key:
398+
scale_input_inv = key.replace("input_scale", "quant_input.scale_inv")
399+
scale_input = key.replace("input_scale", "scale_input")
400+
state_dict[scale_input_inv] = 1 / state_dict[key].to("hpu")
401+
state_dict[scale_input] = state_dict[key].to("hpu")
402+
state_dict.pop(key)
403+
elif "proj.weight" in key and not on_gaudi2:
404+
state_dict[key] = state_dict[key].detach().t().to("hpu")
405+
else:
406+
pass
407+
return state_dict
408+
409+
410+
def split_weights(weight, tp_size, tp_rank, split_axis=0):
411+
"""
412+
Args:
413+
weight (torch.Tensor): weight tensor.
414+
tp_size (int): tensor parallel size.
415+
tp_rank (int): tensor parallel rank.
416+
split_axis (int): split by column or line, 0 or 1.
417+
Returns:
418+
torch.Tensor: split weight tensor.
419+
"""
420+
split_size = weight.shape[split_axis] // tp_size
421+
start_idx = tp_rank * split_size
422+
end_idx = (tp_rank + 1) * split_size
423+
424+
if len(weight.shape) == 1:
425+
return weight[start_idx:end_idx]
426+
elif split_axis == 0:
427+
return weight[start_idx:end_idx, :]
428+
elif split_axis == 1:
429+
return weight[:, start_idx:end_idx]
430+
else:
431+
raise ValueError("split_axis must be 0 (row) or 1 (column).")

neural_compressor/torch/quantization/save_load_entry.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,14 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
126126
import transformers
127127
config = transformers.AutoConfig.from_pretrained(model_name_or_path, **kwargs)
128128
# use config to check which algorithm is used.
129-
if "fp8_config" in config.quantization_config:
129+
if (
130+
"fp8_config" in config.quantization_config or
131+
# for FP8 LLMs for vLLM (https://huggingface.co/neuralmagic).
132+
(
133+
"quant_method" in config.quantization_config and
134+
config.quantization_config["quant_method"] in ["fp8", "compressed-tensors"]
135+
)
136+
):
130137
from neural_compressor.torch.algorithms import fp8_quant
131138
return fp8_quant.load(model_name_or_path, format=format, device=device, **kwargs)
132139
else:

test/3x/torch/quantization/fp8_quant/test_multi_device.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import local_rank, world_size
99
from neural_compressor.torch.quantization import FP8Config, convert, load, prepare, save
10+
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import PatchedLinear
1011

1112

1213
def get_hpu_used_mem():
@@ -24,6 +25,23 @@ def calib_func(model):
2425
model(example_inputs)
2526

2627

28+
def test_load_model_provided_by_neuralmagic():
29+
model_name_or_path = "neuralmagic/Qwen2-0.5B-Instruct-FP8"
30+
model = load(model_name_or_path, format="huggingface", device="hpu")
31+
assert isinstance(model, torch.nn.Module)
32+
assert isinstance(model.model.layers[0].self_attn.q_proj, PatchedLinear)
33+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
34+
prompt = "There existed a little girl, who liked to have adventures."
35+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("hpu")
36+
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=1)
37+
gen_ids = model.generate(
38+
input_ids,
39+
max_new_tokens=5,
40+
**generate_kwargs,
41+
)
42+
assert isinstance(gen_ids, torch.Tensor)
43+
44+
2745
def test_multi_cards_save_load():
2846
name = "facebook/opt-350m"
2947
if world_size > 0:
@@ -58,3 +76,4 @@ def test_multi_cards_save_load():
5876

5977
if __name__ == "__main__":
6078
test_multi_cards_save_load()
79+
test_load_model_provided_by_neuralmagic()

0 commit comments

Comments
 (0)