20
20
import torch
21
21
22
22
from ._quant_common .quant_config import local_rank , world_size
23
- from neural_compressor .torch .utils import get_accelerator
24
-
23
+ from neural_compressor .torch .utils import get_accelerator , is_optimum_habana_available
25
24
26
25
MAX_FILE_SIZE = 5 # GB
27
26
cur_accelerator = get_accelerator ()
@@ -153,12 +152,36 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
153
152
"""Initialize BF16 model with meta tensor."""
154
153
import transformers
155
154
from accelerate import init_empty_weights
155
+ config = transformers .AutoConfig .from_pretrained (model_name_or_path , ** kwargs )
156
+ # fp8 model provided by neuralmagic.
157
+ if (
158
+ "quant_method" in config .quantization_config
159
+ and config .quantization_config ["quant_method" ] in ["fp8" , "compressed-tensors" ]
160
+ ):
161
+ from_neuralmagic = True
162
+ if (
163
+ "kv_cache_scheme" in config .quantization_config
164
+ and config .quantization_config ["kv_cache_scheme" ] is not None
165
+ ):
166
+ from_neuralmagic_with_kv = True
167
+ else :
168
+ from_neuralmagic_with_kv = False
169
+ else :
170
+ from_neuralmagic = False
171
+ from_neuralmagic_with_kv = False
172
+
173
+ if from_neuralmagic_with_kv :
174
+ config .flash_attention_fp8 = True
175
+ if is_optimum_habana_available :
176
+ from optimum .habana .transformers .modeling_utils import adapt_transformers_to_gaudi
177
+ adapt_transformers_to_gaudi ()
178
+ else :
179
+ raise ValueError ("Please install optimum-habana to load fp8 kv cache model." )
180
+
156
181
from neural_compressor .torch .utils import get_non_persistent_buffers , load_non_persistent_buffers
157
182
158
183
if world_size > 1 :
159
184
import deepspeed
160
-
161
- config = transformers .AutoConfig .from_pretrained (model_name_or_path , ** kwargs )
162
185
with init_empty_weights (include_buffers = False ):
163
186
model = transformers .AutoModelForCausalLM .from_config (config , torch_dtype = torch .bfloat16 )
164
187
# TODO: [SW-199728] [DeepSpeed] Buffers initialized by model are not correct after tensor parallel
@@ -172,10 +195,9 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
172
195
model = model .module
173
196
load_non_persistent_buffers (model , non_persistent_buffers )
174
197
else :
175
- config = transformers .AutoConfig .from_pretrained (model_name_or_path , ** kwargs )
176
198
with init_empty_weights (include_buffers = False ):
177
199
model = transformers .AutoModelForCausalLM .from_config (config , torch_dtype = torch .bfloat16 )
178
- return model
200
+ return model , from_neuralmagic , from_neuralmagic_with_kv
179
201
180
202
181
203
def find_safetensors_files (model_name_or_path , ** kwargs ):
@@ -205,6 +227,9 @@ def find_safetensors_files(model_name_or_path, **kwargs):
205
227
resolved_archive_file ,
206
228
** kwargs ,
207
229
)
230
+ # for the model only with 1 model.safetensors file.
231
+ if isinstance (resolved_archive_file , str ):
232
+ resolved_archive_file = [resolved_archive_file ]
208
233
return resolved_archive_file
209
234
210
235
@@ -219,6 +244,57 @@ def shard_state_dict(state_dict):
219
244
rank_state_dict [k ] = v .to ("hpu" )
220
245
return rank_state_dict
221
246
247
+ def split_rank_state_dict (model , gathered_state_dict ):
248
+ """split state_dict for current local_rank."""
249
+ rank_state_dict = {}
250
+ for name , param in model .named_parameters ():
251
+ if name in gathered_state_dict :
252
+ full_weight = gathered_state_dict [name ]
253
+ if len (param .shape ) != 0 and full_weight .shape != param .shape :
254
+ if full_weight .shape [0 ] != param .shape [0 ]:
255
+ split_weight = split_weights (full_weight , world_size , local_rank , split_axis = 0 )
256
+ elif full_weight .shape [1 ] != param .shape [1 ]:
257
+ split_weight = split_weights (full_weight , world_size , local_rank , split_axis = 1 )
258
+ else :
259
+ split_weight = split_weights (full_weight , world_size , local_rank , split_axis = 0 )
260
+ else :
261
+ split_weight = full_weight
262
+ rank_state_dict [name ] = split_weight
263
+
264
+ return rank_state_dict
265
+
266
+
267
+ def get_inc_fp8config (model , from_neuralmagic = False , from_neuralmagic_with_kv = False ):
268
+ """Get INC FP8 Config.
269
+
270
+ Args:
271
+ model: empty model.
272
+ from_neuralmagic(bool, optional): whether provided from nerualmagic modelhub.
273
+ from_neuralmagic_with_kv(bool, optional): whether provided from nerualmagic modelhub and quantized kv_cache.
274
+
275
+ Returns:
276
+ INC FP8 Config.
277
+ """
278
+ from neural_compressor .torch .quantization import FP8Config
279
+ if from_neuralmagic :
280
+ if "ignore" in model .config .quantization_config .keys ():
281
+ blocklist = {"types" : [], "names" : model .config .quantization_config ["ignore" ]}
282
+ elif "ignored_layers" in model .config .quantization_config .keys ():
283
+ blocklist = {"types" : [], "names" : model .config .quantization_config ["ignored_layers" ]}
284
+ else :
285
+ blocklist = {"types" : [], "names" : ["lm_head" ]}
286
+ if "target" in model .config .quantization_config .keys ():
287
+ allowlist = {"types" : model .config .quantization_config ["target" ], "names" : []}
288
+ else :
289
+ if from_neuralmagic_with_kv :
290
+ allowlist = {"types" : ["Linear" , "LinearLayer" , "LinearAllreduce" , "KVCache" ], "names" : []}
291
+ else :
292
+ allowlist = {"types" : ["Linear" , "LinearLayer" , "LinearAllreduce" ], "names" : []}
293
+ qconfig = FP8Config (mode = "LOAD" , allowlist = allowlist , blocklist = blocklist , scale_format = "CONST" )
294
+ else :
295
+ qconfig = FP8Config .from_dict (model .config .quantization_config )
296
+ return qconfig
297
+
222
298
223
299
def load (model_name_or_path , format = "huggingface" , device = "hpu" , ** kwargs ):
224
300
"""Load FP8 model.
@@ -236,12 +312,12 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
236
312
assert device == "hpu" , "Currently, only hpu device is supported for FP8 model."
237
313
from safetensors .torch import load_file as safe_load_file
238
314
239
- model = load_empty_raw_model (model_name_or_path , ** kwargs )
240
315
from neural_compressor .torch .algorithms .fp8_quant import prep_model
241
- from neural_compressor .torch .quantization import FP8Config
242
316
243
- qconfig = FP8Config .from_dict (model .config .quantization_config )
317
+ model , from_neuralmagic , from_neuralmagic_with_kv = load_empty_raw_model (model_name_or_path , ** kwargs )
318
+ qconfig = get_inc_fp8config (model , from_neuralmagic , from_neuralmagic_with_kv )
244
319
qconfig .save_temp_json_file () # generate qconfig.json_file
320
+
245
321
# replace modules to patched modules
246
322
prep_model (model , qconfig .json_file )
247
323
# get the safetensors file list from one folder
@@ -250,15 +326,106 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
250
326
for file_name in files_list :
251
327
cur_file = os .path .join (model_name_or_path , file_name )
252
328
gathered_state_dict = safe_load_file (cur_file )
329
+ if from_neuralmagic or from_neuralmagic_with_kv :
330
+ import habana_frameworks .torch .utils .experimental as htexp
331
+ gathered_state_dict = convert_weight_to_inc (
332
+ state_dict = gathered_state_dict ,
333
+ on_gaudi2 = htexp ._get_device_type () == htexp .synDeviceType .synDeviceGaudi2
334
+ )
253
335
if world_size > 0 :
254
336
# only return state_dict for the current local_rank
255
- rank_state_dict = shard_state_dict (gathered_state_dict )
337
+ if from_neuralmagic or from_neuralmagic_with_kv :
338
+ rank_state_dict = split_rank_state_dict (model , gathered_state_dict )
339
+ else :
340
+ rank_state_dict = shard_state_dict (gathered_state_dict )
256
341
model .load_state_dict (rank_state_dict , assign = True , strict = False )
257
342
else :
258
343
model .load_state_dict (gathered_state_dict , assign = True , strict = False )
344
+
345
+ if from_neuralmagic or from_neuralmagic_with_kv :
346
+ model .tie_weights ()
259
347
model = model .eval ()
260
348
model = model .to (cur_accelerator .name ())
349
+
261
350
cur_accelerator .synchronize ()
262
351
# make sure cpu and hpu memory are all released.
263
352
gc .collect ()
264
353
return model
354
+
355
+
356
+ def convert_weight_to_inc (state_dict , on_gaudi2 = False ):
357
+ """To convert the vllm compatable fp8 model weight to INC format,
358
+ one is operators' name are different, the other is to adapt weight on G2
359
+ due to the torch.float8_e4m3fn scope [-240, 240].
360
+
361
+ Args:
362
+ state_dict (dict): state_dict from modelhub.
363
+ on_gaudi2 (bool, optional): whether is on Gaudi2. Defaults to False.
364
+
365
+ Returns:
366
+ state_dict includes weight and scale adapted to INC format.
367
+ """
368
+ key_name = state_dict .keys ()
369
+ for key in list (key_name ):
370
+ if "weight_scale" in key :
371
+ scale_weight = key .replace ("weight_scale" , "scale_weight" )
372
+ if on_gaudi2 :
373
+ # dequant_weight
374
+ weight_key = key .replace ("weight_scale" , "weight" )
375
+ qweight = state_dict [weight_key ].t ().to (torch .bfloat16 ).to ("hpu" )
376
+ scale = state_dict [key ].to ("hpu" )
377
+ dequant_weight = qweight * scale
378
+ # recompute scale, qweight
379
+ recompute_scale = scale * (torch .finfo (torch .float8_e4m3fn ).max /
380
+ torch .finfo (torch .float8_e4m3fnuz ).max )
381
+ qweight = torch .ops .hpu .cast_to_fp8_v2 (dequant_weight , 1.0 / recompute_scale , False , False , torch .float8_e4m3fn )[0 ]
382
+ state_dict [weight_key ] = qweight
383
+ state_dict [scale_weight ] = recompute_scale
384
+ else :
385
+ state_dict [scale_weight ] = state_dict [key ].to ("hpu" )
386
+ state_dict .pop (key )
387
+ elif "kv_scale" in key :
388
+ k_scale_inv = key .replace ("kv_scale" , "k_cache.quant_input.scale_inv" )
389
+ v_scale_inv = key .replace ("kv_scale" , "v_cache.quant_input.scale_inv" )
390
+ k_scale = key .replace ("kv_scale" , "k_cache.dequant_output.scale" )
391
+ v_scale = key .replace ("kv_scale" , "v_cache.dequant_output.scale" )
392
+ state_dict [k_scale_inv ] = 1 / state_dict [key ].to ("hpu" )
393
+ state_dict [v_scale_inv ] = 1 / state_dict [key ].to ("hpu" )
394
+ state_dict [k_scale ] = state_dict [key ].to ("hpu" )
395
+ state_dict [v_scale ] = state_dict [key ].to ("hpu" )
396
+ state_dict .pop (key )
397
+ elif "input_scale" in key :
398
+ scale_input_inv = key .replace ("input_scale" , "quant_input.scale_inv" )
399
+ scale_input = key .replace ("input_scale" , "scale_input" )
400
+ state_dict [scale_input_inv ] = 1 / state_dict [key ].to ("hpu" )
401
+ state_dict [scale_input ] = state_dict [key ].to ("hpu" )
402
+ state_dict .pop (key )
403
+ elif "proj.weight" in key and not on_gaudi2 :
404
+ state_dict [key ] = state_dict [key ].detach ().t ().to ("hpu" )
405
+ else :
406
+ pass
407
+ return state_dict
408
+
409
+
410
+ def split_weights (weight , tp_size , tp_rank , split_axis = 0 ):
411
+ """
412
+ Args:
413
+ weight (torch.Tensor): weight tensor.
414
+ tp_size (int): tensor parallel size.
415
+ tp_rank (int): tensor parallel rank.
416
+ split_axis (int): split by column or line, 0 or 1.
417
+ Returns:
418
+ torch.Tensor: split weight tensor.
419
+ """
420
+ split_size = weight .shape [split_axis ] // tp_size
421
+ start_idx = tp_rank * split_size
422
+ end_idx = (tp_rank + 1 ) * split_size
423
+
424
+ if len (weight .shape ) == 1 :
425
+ return weight [start_idx :end_idx ]
426
+ elif split_axis == 0 :
427
+ return weight [start_idx :end_idx , :]
428
+ elif split_axis == 1 :
429
+ return weight [:, start_idx :end_idx ]
430
+ else :
431
+ raise ValueError ("split_axis must be 0 (row) or 1 (column)." )
0 commit comments