12
12
import shutil
13
13
import sys
14
14
import warnings
15
- from contextlib import contextmanager
16
15
from datetime import datetime
17
16
from pathlib import Path
18
17
from pprint import pprint
19
- from typing import Any , Generator , Optional , Union
18
+ from typing import Any , Optional , Union
20
19
21
20
import torch
22
21
import torch .nn .functional as F
23
22
import transformers
24
23
from datasets import load_dataset
25
24
from lm_eval import simple_evaluate
26
25
from lm_eval .models .optimum_lm import OptimumLM
27
- from lm_eval .tasks import TaskManager
28
26
from optimum .exporters .openvino .convert import export_from_model
29
27
from optimum .intel .openvino import OVModelForCausalLM
30
28
from optimum .modeling_base import OptimizedModel
42
40
from nncf .parameters import CompressionFormat
43
41
from nncf .parameters import CompressWeightsMode
44
42
from nncf .parameters import StripFormat
43
+ from nncf .quantization .advanced_parameters import AdvancedAWQParameters
45
44
from nncf .quantization .advanced_parameters import AdvancedCompressionParameters
46
45
from nncf .quantization .quantize_model import compress_weights
47
46
from nncf .torch .function_hook .wrapper import get_hook_storage
@@ -76,64 +75,24 @@ def get_wikitext2(num_samples: int, seqlen: int, tokenizer: Any, device: torch.d
76
75
return trainloader
77
76
78
77
79
- @contextmanager
80
- def create_eval_model (
81
- model : AutoModelForCausalLM ,
82
- fast_eval : bool ,
83
- pretrained : str ,
84
- torch_dtype : torch .dtype ,
85
- ckpt_file : Path ,
86
- ) -> Generator [AutoModelForCausalLM , None , None ]:
87
- """
88
- Context manager for creating an evaluation model with appropriate cleanup.
89
-
90
- If fast_eval is True, creates a new model for evaluation that will be
91
- automatically deleted when the context exits. Otherwise, uses the provided model.
92
-
93
- :param model: Original model to use if fast_eval is False.
94
- :param fast_eval: Whether to create a new optimized model for evaluation.
95
- :param pretrained: Pretrained model identifier or path for AutoModelForCausalLM.
96
- :param torch_dtype: PyTorch data type to use for the model (e.g., torch.bfloat16).
97
- :param ckpt_file: Path to the checkpoint file to load weights from.
98
- :yields: Model to use for evaluation, either the new loaded model or the given one.
99
- """
100
- if fast_eval :
101
- eval_model = AutoModelForCausalLM .from_pretrained (pretrained , torch_dtype = torch_dtype , device_map = "auto" )
102
- eval_model = load_checkpoint (eval_model , ckpt_file )
103
- device = next (model .parameters ()).device
104
- example_input = {k : v .to (device ) for k , v in eval_model .dummy_inputs .items ()}
105
- eval_model = nncf .strip (
106
- eval_model , do_copy = False , strip_format = StripFormat .IN_PLACE , example_input = example_input
107
- )
108
- try :
109
- yield eval_model
110
- finally :
111
- del eval_model
112
- else :
113
- yield model
114
-
115
-
116
78
def measure_perplexity (
117
79
optimum_model : OptimizedModel ,
118
- task_manager : TaskManager ,
119
80
max_length : Optional [int ] = None ,
120
81
limit : Optional [Union [int , float ]] = None ,
121
- task = "wikitext_validation" ,
122
82
) -> float :
123
83
"""
124
84
Measure perplexity on the Wikitext dataset, via rolling loglikelihoods for a given model.
125
85
126
86
:param optimum_model: A model to be evaluated.
127
- :param task_manager: The TaskManager instance that handles dataset loading and processing.
128
87
:param max_length: The maximum sequence length for evaluation.
129
88
:param limit: Limit the number of examples per task (only use this for testing).
130
89
If <1, limit is a percentage of the total number of examples.
131
- :param task: The evaluation task name to use, defaults to "wikitext_validation".
132
90
:return: The similarity score as a float.
133
91
"""
92
+ task = "wikitext"
134
93
print ("#" * 50 + " Evaluate via lm-eval-harness " + "#" * 50 )
135
94
lm_obj = OptimumLM (pretrained = optimum_model , max_length = max_length )
136
- results = simple_evaluate (lm_obj , tasks = [task ], limit = limit , task_manager = task_manager , log_samples = False )
95
+ results = simple_evaluate (lm_obj , tasks = [task ], limit = limit , log_samples = False )
137
96
return results ["results" ][task ]["word_perplexity,none" ]
138
97
139
98
@@ -223,15 +182,20 @@ def set_trainable(model: nn.Module, lora_lr: float, fq_lr: float) -> list[dict[s
223
182
return [{"params" : adapters_to_train , "lr" : lora_lr }, {"params" : scales_to_train , "lr" : fq_lr }]
224
183
225
184
226
- def save_checkpoint (model : nn .Module , ckpt_file : Path ) -> None :
185
+ def save_checkpoint (model : nn .Module , ckpt_file : Path , model_state : bool = True ) -> None :
227
186
"""
228
- Saves the state of a tuned model from a checkpoint.
187
+ Stores the current state of a quantized model to a checkpoint file .
229
188
230
- :param model: The model to load the checkpoint into.
231
- :param ckpt_file: Path to the checkpoint file.
189
+ :param model: The model whose state will be saved to checkpoint.
190
+ :param ckpt_file: Path to store the checkpoint file.
191
+ :param model_state: Whether to save the complete model weights in addition to NNCF state. Required when using
192
+ AWQ method which fuses scaling factors into weights. When False, only NNCF configuration and state are saved,
193
+ as they're maintained separately from the model's weights.
232
194
"""
233
195
hook_storage = get_hook_storage (model )
234
196
ckpt = {"nncf_state_dict" : hook_storage .state_dict (), "nncf_config" : nncf .torch .get_config (model )}
197
+ if model_state :
198
+ ckpt ["model_state" ] = model .state_dict ()
235
199
torch .save (ckpt , ckpt_file )
236
200
237
201
@@ -246,6 +210,8 @@ def load_checkpoint(model: nn.Module, ckpt_file: Path) -> nn.Module:
246
210
"""
247
211
ckpt = torch .load (ckpt_file , weights_only = False , map_location = "cpu" )
248
212
model = load_from_config (model , ckpt ["nncf_config" ])
213
+ if "model_state" in ckpt :
214
+ model .load_state_dict (ckpt ["model_state" ])
249
215
hook_storage = get_hook_storage (model )
250
216
hook_storage .load_state_dict (ckpt ["nncf_state_dict" ])
251
217
return model
@@ -306,16 +272,17 @@ def get_argument_parser() -> argparse.ArgumentParser:
306
272
)
307
273
parser .add_argument ("--lora_rank" , type = int , default = 256 , help = "Rank of lora adapters" )
308
274
parser .add_argument (
309
- "--fast_eval " ,
275
+ "--basic_init " ,
310
276
action = "store_true" ,
311
- help = "Enable faster evaluation by applying in-place quantization to the model weights. "
312
- "This method uses additional GPU memory for memory copying. By default, evaluation is slower "
313
- "but conserves GPU memory." ,
277
+ help = "Whether to initialize quantization with basic min-max round-to-nearest schema. By default, advanced "
278
+ "data-aware post-training methods are used: AWQ + Scale Estimation. These methods typically provide better "
279
+ "accuracy, but require a calibration dataset and additional initialization time "
280
+ "(~20 sec for 1B and ~80 sec for 8B models)." ,
314
281
)
315
282
316
283
# Data params
317
284
parser .add_argument ("--num_train_samples" , type = int , default = 1024 , help = "Number of training samples" )
318
- parser .add_argument ("--calib_seqlen " , type = int , default = 1024 , help = "Calibration data context length." )
285
+ parser .add_argument ("--train_seqlen " , type = int , default = 1024 , help = "Train data context length." )
319
286
parser .add_argument ("--eval_seqlen" , type = int , default = 2048 , help = "Evaluation data context length." )
320
287
parser .add_argument (
321
288
"--limit" ,
@@ -338,7 +305,7 @@ def get_argument_parser() -> argparse.ArgumentParser:
338
305
parser .add_argument (
339
306
"--microbatch_size" ,
340
307
type = int ,
341
- default = 8 ,
308
+ default = 2 ,
342
309
help = "Size of each training microbatch. Gradients will be accumulated until the batch size is reached." ,
343
310
)
344
311
return parser
@@ -351,61 +318,63 @@ def main(argv) -> float:
351
318
"""
352
319
parser = get_argument_parser ()
353
320
args = parser .parse_args (argv )
354
- pprint (vars (args ))
355
321
assert torch .cuda .is_available ()
356
322
transformers .set_seed (42 )
357
323
device = "cuda"
358
324
torch_dtype = torch .bfloat16
359
325
compression_config = dict (
360
326
mode = CompressWeightsMode .INT4_ASYM ,
361
327
group_size = 64 ,
328
+ awq = not args .basic_init ,
329
+ scale_estimation = not args .basic_init ,
362
330
compression_format = CompressionFormat .FQ_LORA ,
363
- advanced_parameters = AdvancedCompressionParameters (lora_adapter_rank = args .lora_rank ),
364
331
)
365
-
332
+ pprint ({"CLI arguments" : vars (args ), "Major compression parameters" : compression_config })
333
+ compression_config ["advanced_parameters" ] = AdvancedCompressionParameters (
334
+ awq_params = AdvancedAWQParameters (prefer_data_aware_scaling = not args .basic_init ),
335
+ lora_adapter_rank = args .lora_rank ,
336
+ )
366
337
# Configure output and log files.
367
338
output_dir = Path (args .output_dir )
368
339
tensorboard_dir = output_dir / "tb" / datetime .now ().strftime ("%Y-%m-%d__%H-%M-%S" )
369
340
last_dir = output_dir / "last"
370
- best_dir = output_dir / "best"
371
341
if not args .resume :
372
- shutil .rmtree (output_dir , ignore_errors = True )
373
- for path in [output_dir , tensorboard_dir , last_dir , best_dir ]:
342
+ shutil .rmtree (last_dir , ignore_errors = True )
343
+ for path in [output_dir , tensorboard_dir , last_dir ]:
374
344
path .mkdir (exist_ok = True , parents = True )
375
345
ckpt_file = last_dir / "nncf_checkpoint.pth"
376
346
print (f"To visualize the loss and validation metrics, open Tensorboard using the logs from: { tensorboard_dir } " )
377
347
tb = SummaryWriter (tensorboard_dir , "QAT with absorbable LoRA" )
378
- task_manager = TaskManager (include_path = str (Path (__file__ ).resolve ().parent / "custom_eval_tasks" ))
379
348
380
349
# Load original model and tokenizer.
381
350
model = AutoModelForCausalLM .from_pretrained (args .pretrained , torch_dtype = torch_dtype , device_map = "auto" )
382
351
tokenizer = AutoTokenizer .from_pretrained (args .pretrained )
383
352
384
- # Prepare training data and pre-compute hiddens of teacher model for distillation loss.
353
+ # Prepare training and calibration data
385
354
train_loader = get_wikitext2 (
386
- num_samples = args .num_train_samples , seqlen = args .calib_seqlen , tokenizer = tokenizer , device = device
355
+ num_samples = args .num_train_samples , seqlen = args .train_seqlen , tokenizer = tokenizer , device = device
387
356
)
357
+ if args .basic_init :
358
+ calib_loader = get_wikitext2 (num_samples = 128 , seqlen = 128 , tokenizer = tokenizer , device = device )
359
+ dataset = Dataset (map (get_model_input , calib_loader ))
360
+ else :
361
+ example_input = {k : v .to (device ) for k , v in model .dummy_inputs .items ()}
362
+ dataset = Dataset ([example_input ])
363
+
364
+ # Pre-compute hiddens of teacher model for distillation loss.
388
365
orig_hiddens = calc_hiddens (model , train_loader )
389
366
390
367
# Create or load model to tune with Fake Quantizers and absorbable LoRA adapters.
391
- example_input = {k : v .to (device ) for k , v in model .dummy_inputs .items ()}
392
368
if args .resume and ckpt_file .exists ():
393
369
model = load_checkpoint (model , ckpt_file )
394
370
else :
395
- model = compress_weights (model , dataset = Dataset ([ example_input ]) , ** compression_config )
396
- save_checkpoint (model , ckpt_file )
371
+ model = compress_weights (model , dataset = dataset , ** compression_config )
372
+ save_checkpoint (model , ckpt_file , model_state = not args . basic_init )
397
373
fq_lr = args .lr / 10
398
374
weight_decay = args .lr
399
375
param_to_train = set_trainable (model , lora_lr = args .lr , fq_lr = fq_lr )
400
376
opt = torch .optim .AdamW (param_to_train , weight_decay = weight_decay )
401
377
402
- with create_eval_model (model , args .fast_eval , args .pretrained , torch_dtype , ckpt_file ) as eval_model :
403
- initial_perplexity = best_perplexity = measure_perplexity (
404
- eval_model , task_manager , args .eval_seqlen , args .limit
405
- )
406
- tb .add_scalar ("perplexity" , best_perplexity , 0 )
407
- print (f"Initial word perplexity on wikitext (validation) = { best_perplexity :.4f} " )
408
-
409
378
# Run tuning with distillation loss and validation after each epoch.
410
379
grad_accumulation_steps = args .batch_size // args .microbatch_size
411
380
num_samples = len (train_loader )
@@ -450,28 +419,18 @@ def form_batch(inputs: list[Tensor], model_input: bool):
450
419
total_steps += 1
451
420
tb .add_scalar ("loss" , aggregated_loss , total_steps )
452
421
453
- # Keep the best checkpoint with the lowest perplexity.
454
- save_checkpoint (model , ckpt_file )
455
- with create_eval_model (model , args .fast_eval , args .pretrained , torch_dtype , ckpt_file ) as eval_model :
456
- perplexity = measure_perplexity (eval_model , task_manager , args .eval_seqlen , args .limit )
457
- tb .add_scalar ("perplexity" , perplexity , total_steps )
458
- print (f"[Epoch { epoch } ], word perplexity on wikitext (validation) = { perplexity :.4f} " )
459
- if perplexity < best_perplexity :
460
- print (f"New best word perplexity = { perplexity :.4f} " )
461
- best_perplexity = perplexity
462
- shutil .copytree (last_dir , best_dir , dirs_exist_ok = True )
422
+ save_checkpoint (model , ckpt_file , model_state = not args .basic_init )
463
423
464
424
del model
465
425
# Export the best tuned model to OpenVINO and evaluate it using LM-Evaluation-Harness.
466
- best_ckpt_file = best_dir / "nncf_checkpoint.pth"
467
- model_for_eval = export_to_openvino (args .pretrained , best_ckpt_file , best_dir )
468
- ov_perplexity = measure_perplexity (model_for_eval , task_manager , args .eval_seqlen , args .limit , task = "wikitext" )
426
+ model_for_eval = export_to_openvino (args .pretrained , ckpt_file , ckpt_file .parent )
427
+ ov_perplexity = measure_perplexity (model_for_eval , args .eval_seqlen , args .limit )
469
428
tb .add_scalar ("ov_perplexity" , ov_perplexity , 0 )
470
429
print (
471
- f"The finetuned model has been exported to OpenVINO and saved to: { best_dir } \n "
430
+ f"The finetuned model has been exported to OpenVINO and saved to: { last_dir } \n "
472
431
f"The word perplexity on wikitext (test) = { ov_perplexity :.4f} "
473
432
)
474
- return initial_perplexity - best_perplexity , ov_perplexity
433
+ return ov_perplexity
475
434
476
435
477
436
if __name__ == "__main__" :
0 commit comments