Skip to content

Commit 83bdfa6

Browse files
authored
FQLoRA: use last checkpoint with AWQ+SE as initialization (#3577)
### Changes - Fixed support for more advanced compression methods (AWQ+Scale Estimation) as initialization for QAT with absorbable LoRA. Previously RTN baseline was used, but now with data-aware AWQ + Scale Estimation tuning starts with more accurate model and achieves superior accuracy: ![image](https://github.com/user-attachments/assets/24c5d912-8563-47c1-85a4-62d17222fa1a) - Removed selection of the best checkpoint based on validation set. It significantly reduces overall tuning time and max allocated memory. The best results are slightly worse on wikitext, but it should provide more fair and faster tuning pipeline (from 32 min to 25 min for SmoLM-1.7B). ### Reason for changes Improvements in QAT + absorbable LoRA. ### Related tickets ticket - 169140 ### Tests test_examples - https://github.com/openvinotoolkit/nncf/actions/runs/16223789950
1 parent de884e0 commit 83bdfa6

File tree

8 files changed

+112
-151
lines changed

8 files changed

+112
-151
lines changed

examples/llm_compression/torch/distillation_qat_with_lora/README.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,37 +55,37 @@ Where:
5555

5656
The training dataset comprises 1024 samples (each 1024 tokens long) from the training split of the `wikitext-2-raw-v1` dataset. Validation occurs after each epoch using [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) on the validation split of the same dataset.
5757
Final perplexity measurements are conducted using OpenVINO with dynamic quantization and float16 KV-cache enabled on the test split of `wikitext-2-raw-v1`.
58-
For `HuggingFaceTB/SmolLM-1.7B-Instruct` model training with evaluation for 10 epochs requires about 30 minutes on a single A100 GPU or approximately 60 minutes using three RTX 3090 GPUs.
58+
For `HuggingFaceTB/SmolLM-1.7B-Instruct` model training with evaluation for 10 epochs requires about 25 minutes on a single A100 GPU or approximately 50 minutes using three RTX 3090 GPUs.
5959

6060
All quantization methods compressed the models to `INT4_ASYM` precision with a group size of `64`.
6161

6262
| Model | Precision | Wikitext,<br>word_ppl | Improvement |
6363
|-------------------------------------|-------------------|-----------------------|-------------|
64-
| google/gemma-2-2b-it | BF16 | 15.02 | |
65-
| google/gemma-2-2b-it | INT4 (QAT + LoRA) | 15.09 | 91% |
64+
| google/gemma-2-2b-it | BF16 | 15.05 | |
65+
| google/gemma-2-2b-it | INT4 (QAT + LoRA) | 15.28 | 69% |
6666
| google/gemma-2-2b-it | INT4 (best PTWC) | 15.80 | |
6767
| HuggingFaceTB/SmolLM-1.7B-Instruct | BF16 | 19.11 | |
68-
| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (QAT + LoRA) | 19.25 | 79% |
68+
| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (QAT + LoRA) | 19.57 | 30% |
6969
| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (best PTWC) | 19.77 | |
70-
| meta-llama/Llama-3.2-1B-Instruct | BF16 | 16.30 | |
71-
| meta-llama/Llama-3.2-1B-Instruct | INT4 (QAT + LoRA) | 17.01 | 41% |
70+
| meta-llama/Llama-3.2-1B-Instruct | BF16 | 16.29 | |
71+
| meta-llama/Llama-3.2-1B-Instruct | INT4 (QAT + LoRA) | 17.02 | 40% |
7272
| meta-llama/Llama-3.2-1B-Instruct | INT4 (best PTWC) | 17.51 | |
7373
| meta-llama/Llama-3.2-3B-Instruct | BF16 | 12.67 | |
74-
| meta-llama/Llama-3.2-3B-Instruct | INT4 (QAT + LoRA) | 13.03 | 33% |
74+
| meta-llama/Llama-3.2-3B-Instruct | INT4 (QAT + LoRA) | 13.00 | 40% |
7575
| meta-llama/Llama-3.2-3B-Instruct | INT4 (best PTWC) | 13.22 | |
76-
| meta-llama/Meta-Llama-3-8B-Instruct | BF16 | 10.22 | |
77-
| meta-llama/Meta-Llama-3-8B-Instruct | INT4 (QAT + LoRA) | 10.30 | 64% |
76+
| meta-llama/Meta-Llama-3-8B-Instruct | BF16 | 10.20 | |
77+
| meta-llama/Meta-Llama-3-8B-Instruct | INT4 (QAT + LoRA) | 10.34 | 44% |
7878
| meta-llama/Meta-Llama-3-8B-Instruct | INT4 (best PTWC) | 10.45 | |
79-
| microsoft/phi3.5-mini-instruct | BF16 | 10.00 | |
80-
| microsoft/phi3.5-mini-instruct | INT4 (QAT + LoRA) | 10.52 | 26% |
79+
| microsoft/phi3.5-mini-instruct | BF16 | 9.98 | |
80+
| microsoft/phi3.5-mini-instruct | INT4 (QAT + LoRA) | 10.46 | 34% |
8181
| microsoft/phi3.5-mini-instruct | INT4 (best PTWC) | 10.71 | |
82-
| microsoft/phi3-mini-4k-instruct | BF16 | 9.49 | |
83-
| microsoft/phi3-mini-4k-instruct | INT4 (QAT + LoRA) | 10.04 | 26% |
82+
| microsoft/phi3-mini-4k-instruct | BF16 | 9.48 | |
83+
| microsoft/phi3-mini-4k-instruct | INT4 (QAT + LoRA) | 10.03 | 28% |
8484
| microsoft/phi3-mini-4k-instruct | INT4 (best PTWC) | 10.24 | |
8585
| mistralai/Mistral-7B-v0.3 | BF16 | 8.21 | |
86-
| mistralai/Mistral-7B-v0.3 | INT4 (QAT + LoRA) | 8.36 | 21% |
86+
| mistralai/Mistral-7B-v0.3 | INT4 (QAT + LoRA) | 8.35 | 23% |
8787
| mistralai/Mistral-7B-v0.3 | INT4 (best PTWC) | 8.40 | |
88-
| Qwen/Qwen2.5-3B-Instruct | BF16 | 11.01 | |
89-
| Qwen/Qwen2.5-3B-Instruct | INT4 (QAT + LoRA) | 11.45 | 29% |
88+
| Qwen/Qwen2.5-3B-Instruct | BF16 | 11.02 | |
89+
| Qwen/Qwen2.5-3B-Instruct | INT4 (QAT + LoRA) | 11.48 | 27% |
9090
| Qwen/Qwen2.5-3B-Instruct | INT4 (best PTWC) | 11.64 | |
91-
| | | Average | 46% |
91+
| | | Average | 39% |

examples/llm_compression/torch/distillation_qat_with_lora/main.py

Lines changed: 48 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,17 @@
1212
import shutil
1313
import sys
1414
import warnings
15-
from contextlib import contextmanager
1615
from datetime import datetime
1716
from pathlib import Path
1817
from pprint import pprint
19-
from typing import Any, Generator, Optional, Union
18+
from typing import Any, Optional, Union
2019

2120
import torch
2221
import torch.nn.functional as F
2322
import transformers
2423
from datasets import load_dataset
2524
from lm_eval import simple_evaluate
2625
from lm_eval.models.optimum_lm import OptimumLM
27-
from lm_eval.tasks import TaskManager
2826
from optimum.exporters.openvino.convert import export_from_model
2927
from optimum.intel.openvino import OVModelForCausalLM
3028
from optimum.modeling_base import OptimizedModel
@@ -42,6 +40,7 @@
4240
from nncf.parameters import CompressionFormat
4341
from nncf.parameters import CompressWeightsMode
4442
from nncf.parameters import StripFormat
43+
from nncf.quantization.advanced_parameters import AdvancedAWQParameters
4544
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
4645
from nncf.quantization.quantize_model import compress_weights
4746
from nncf.torch.function_hook.wrapper import get_hook_storage
@@ -76,64 +75,24 @@ def get_wikitext2(num_samples: int, seqlen: int, tokenizer: Any, device: torch.d
7675
return trainloader
7776

7877

79-
@contextmanager
80-
def create_eval_model(
81-
model: AutoModelForCausalLM,
82-
fast_eval: bool,
83-
pretrained: str,
84-
torch_dtype: torch.dtype,
85-
ckpt_file: Path,
86-
) -> Generator[AutoModelForCausalLM, None, None]:
87-
"""
88-
Context manager for creating an evaluation model with appropriate cleanup.
89-
90-
If fast_eval is True, creates a new model for evaluation that will be
91-
automatically deleted when the context exits. Otherwise, uses the provided model.
92-
93-
:param model: Original model to use if fast_eval is False.
94-
:param fast_eval: Whether to create a new optimized model for evaluation.
95-
:param pretrained: Pretrained model identifier or path for AutoModelForCausalLM.
96-
:param torch_dtype: PyTorch data type to use for the model (e.g., torch.bfloat16).
97-
:param ckpt_file: Path to the checkpoint file to load weights from.
98-
:yields: Model to use for evaluation, either the new loaded model or the given one.
99-
"""
100-
if fast_eval:
101-
eval_model = AutoModelForCausalLM.from_pretrained(pretrained, torch_dtype=torch_dtype, device_map="auto")
102-
eval_model = load_checkpoint(eval_model, ckpt_file)
103-
device = next(model.parameters()).device
104-
example_input = {k: v.to(device) for k, v in eval_model.dummy_inputs.items()}
105-
eval_model = nncf.strip(
106-
eval_model, do_copy=False, strip_format=StripFormat.IN_PLACE, example_input=example_input
107-
)
108-
try:
109-
yield eval_model
110-
finally:
111-
del eval_model
112-
else:
113-
yield model
114-
115-
11678
def measure_perplexity(
11779
optimum_model: OptimizedModel,
118-
task_manager: TaskManager,
11980
max_length: Optional[int] = None,
12081
limit: Optional[Union[int, float]] = None,
121-
task="wikitext_validation",
12282
) -> float:
12383
"""
12484
Measure perplexity on the Wikitext dataset, via rolling loglikelihoods for a given model.
12585
12686
:param optimum_model: A model to be evaluated.
127-
:param task_manager: The TaskManager instance that handles dataset loading and processing.
12887
:param max_length: The maximum sequence length for evaluation.
12988
:param limit: Limit the number of examples per task (only use this for testing).
13089
If <1, limit is a percentage of the total number of examples.
131-
:param task: The evaluation task name to use, defaults to "wikitext_validation".
13290
:return: The similarity score as a float.
13391
"""
92+
task = "wikitext"
13493
print("#" * 50 + " Evaluate via lm-eval-harness " + "#" * 50)
13594
lm_obj = OptimumLM(pretrained=optimum_model, max_length=max_length)
136-
results = simple_evaluate(lm_obj, tasks=[task], limit=limit, task_manager=task_manager, log_samples=False)
95+
results = simple_evaluate(lm_obj, tasks=[task], limit=limit, log_samples=False)
13796
return results["results"][task]["word_perplexity,none"]
13897

13998

@@ -223,15 +182,20 @@ def set_trainable(model: nn.Module, lora_lr: float, fq_lr: float) -> list[dict[s
223182
return [{"params": adapters_to_train, "lr": lora_lr}, {"params": scales_to_train, "lr": fq_lr}]
224183

225184

226-
def save_checkpoint(model: nn.Module, ckpt_file: Path) -> None:
185+
def save_checkpoint(model: nn.Module, ckpt_file: Path, model_state: bool = True) -> None:
227186
"""
228-
Saves the state of a tuned model from a checkpoint.
187+
Stores the current state of a quantized model to a checkpoint file.
229188
230-
:param model: The model to load the checkpoint into.
231-
:param ckpt_file: Path to the checkpoint file.
189+
:param model: The model whose state will be saved to checkpoint.
190+
:param ckpt_file: Path to store the checkpoint file.
191+
:param model_state: Whether to save the complete model weights in addition to NNCF state. Required when using
192+
AWQ method which fuses scaling factors into weights. When False, only NNCF configuration and state are saved,
193+
as they're maintained separately from the model's weights.
232194
"""
233195
hook_storage = get_hook_storage(model)
234196
ckpt = {"nncf_state_dict": hook_storage.state_dict(), "nncf_config": nncf.torch.get_config(model)}
197+
if model_state:
198+
ckpt["model_state"] = model.state_dict()
235199
torch.save(ckpt, ckpt_file)
236200

237201

@@ -246,6 +210,8 @@ def load_checkpoint(model: nn.Module, ckpt_file: Path) -> nn.Module:
246210
"""
247211
ckpt = torch.load(ckpt_file, weights_only=False, map_location="cpu")
248212
model = load_from_config(model, ckpt["nncf_config"])
213+
if "model_state" in ckpt:
214+
model.load_state_dict(ckpt["model_state"])
249215
hook_storage = get_hook_storage(model)
250216
hook_storage.load_state_dict(ckpt["nncf_state_dict"])
251217
return model
@@ -306,16 +272,17 @@ def get_argument_parser() -> argparse.ArgumentParser:
306272
)
307273
parser.add_argument("--lora_rank", type=int, default=256, help="Rank of lora adapters")
308274
parser.add_argument(
309-
"--fast_eval",
275+
"--basic_init",
310276
action="store_true",
311-
help="Enable faster evaluation by applying in-place quantization to the model weights. "
312-
"This method uses additional GPU memory for memory copying. By default, evaluation is slower "
313-
"but conserves GPU memory.",
277+
help="Whether to initialize quantization with basic min-max round-to-nearest schema. By default, advanced "
278+
"data-aware post-training methods are used: AWQ + Scale Estimation. These methods typically provide better "
279+
"accuracy, but require a calibration dataset and additional initialization time "
280+
"(~20 sec for 1B and ~80 sec for 8B models).",
314281
)
315282

316283
# Data params
317284
parser.add_argument("--num_train_samples", type=int, default=1024, help="Number of training samples")
318-
parser.add_argument("--calib_seqlen", type=int, default=1024, help="Calibration data context length.")
285+
parser.add_argument("--train_seqlen", type=int, default=1024, help="Train data context length.")
319286
parser.add_argument("--eval_seqlen", type=int, default=2048, help="Evaluation data context length.")
320287
parser.add_argument(
321288
"--limit",
@@ -338,7 +305,7 @@ def get_argument_parser() -> argparse.ArgumentParser:
338305
parser.add_argument(
339306
"--microbatch_size",
340307
type=int,
341-
default=8,
308+
default=2,
342309
help="Size of each training microbatch. Gradients will be accumulated until the batch size is reached.",
343310
)
344311
return parser
@@ -351,61 +318,63 @@ def main(argv) -> float:
351318
"""
352319
parser = get_argument_parser()
353320
args = parser.parse_args(argv)
354-
pprint(vars(args))
355321
assert torch.cuda.is_available()
356322
transformers.set_seed(42)
357323
device = "cuda"
358324
torch_dtype = torch.bfloat16
359325
compression_config = dict(
360326
mode=CompressWeightsMode.INT4_ASYM,
361327
group_size=64,
328+
awq=not args.basic_init,
329+
scale_estimation=not args.basic_init,
362330
compression_format=CompressionFormat.FQ_LORA,
363-
advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=args.lora_rank),
364331
)
365-
332+
pprint({"CLI arguments": vars(args), "Major compression parameters": compression_config})
333+
compression_config["advanced_parameters"] = AdvancedCompressionParameters(
334+
awq_params=AdvancedAWQParameters(prefer_data_aware_scaling=not args.basic_init),
335+
lora_adapter_rank=args.lora_rank,
336+
)
366337
# Configure output and log files.
367338
output_dir = Path(args.output_dir)
368339
tensorboard_dir = output_dir / "tb" / datetime.now().strftime("%Y-%m-%d__%H-%M-%S")
369340
last_dir = output_dir / "last"
370-
best_dir = output_dir / "best"
371341
if not args.resume:
372-
shutil.rmtree(output_dir, ignore_errors=True)
373-
for path in [output_dir, tensorboard_dir, last_dir, best_dir]:
342+
shutil.rmtree(last_dir, ignore_errors=True)
343+
for path in [output_dir, tensorboard_dir, last_dir]:
374344
path.mkdir(exist_ok=True, parents=True)
375345
ckpt_file = last_dir / "nncf_checkpoint.pth"
376346
print(f"To visualize the loss and validation metrics, open Tensorboard using the logs from: {tensorboard_dir}")
377347
tb = SummaryWriter(tensorboard_dir, "QAT with absorbable LoRA")
378-
task_manager = TaskManager(include_path=str(Path(__file__).resolve().parent / "custom_eval_tasks"))
379348

380349
# Load original model and tokenizer.
381350
model = AutoModelForCausalLM.from_pretrained(args.pretrained, torch_dtype=torch_dtype, device_map="auto")
382351
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
383352

384-
# Prepare training data and pre-compute hiddens of teacher model for distillation loss.
353+
# Prepare training and calibration data
385354
train_loader = get_wikitext2(
386-
num_samples=args.num_train_samples, seqlen=args.calib_seqlen, tokenizer=tokenizer, device=device
355+
num_samples=args.num_train_samples, seqlen=args.train_seqlen, tokenizer=tokenizer, device=device
387356
)
357+
if args.basic_init:
358+
calib_loader = get_wikitext2(num_samples=128, seqlen=128, tokenizer=tokenizer, device=device)
359+
dataset = Dataset(map(get_model_input, calib_loader))
360+
else:
361+
example_input = {k: v.to(device) for k, v in model.dummy_inputs.items()}
362+
dataset = Dataset([example_input])
363+
364+
# Pre-compute hiddens of teacher model for distillation loss.
388365
orig_hiddens = calc_hiddens(model, train_loader)
389366

390367
# Create or load model to tune with Fake Quantizers and absorbable LoRA adapters.
391-
example_input = {k: v.to(device) for k, v in model.dummy_inputs.items()}
392368
if args.resume and ckpt_file.exists():
393369
model = load_checkpoint(model, ckpt_file)
394370
else:
395-
model = compress_weights(model, dataset=Dataset([example_input]), **compression_config)
396-
save_checkpoint(model, ckpt_file)
371+
model = compress_weights(model, dataset=dataset, **compression_config)
372+
save_checkpoint(model, ckpt_file, model_state=not args.basic_init)
397373
fq_lr = args.lr / 10
398374
weight_decay = args.lr
399375
param_to_train = set_trainable(model, lora_lr=args.lr, fq_lr=fq_lr)
400376
opt = torch.optim.AdamW(param_to_train, weight_decay=weight_decay)
401377

402-
with create_eval_model(model, args.fast_eval, args.pretrained, torch_dtype, ckpt_file) as eval_model:
403-
initial_perplexity = best_perplexity = measure_perplexity(
404-
eval_model, task_manager, args.eval_seqlen, args.limit
405-
)
406-
tb.add_scalar("perplexity", best_perplexity, 0)
407-
print(f"Initial word perplexity on wikitext (validation) = {best_perplexity:.4f}")
408-
409378
# Run tuning with distillation loss and validation after each epoch.
410379
grad_accumulation_steps = args.batch_size // args.microbatch_size
411380
num_samples = len(train_loader)
@@ -450,28 +419,18 @@ def form_batch(inputs: list[Tensor], model_input: bool):
450419
total_steps += 1
451420
tb.add_scalar("loss", aggregated_loss, total_steps)
452421

453-
# Keep the best checkpoint with the lowest perplexity.
454-
save_checkpoint(model, ckpt_file)
455-
with create_eval_model(model, args.fast_eval, args.pretrained, torch_dtype, ckpt_file) as eval_model:
456-
perplexity = measure_perplexity(eval_model, task_manager, args.eval_seqlen, args.limit)
457-
tb.add_scalar("perplexity", perplexity, total_steps)
458-
print(f"[Epoch {epoch}], word perplexity on wikitext (validation) = {perplexity:.4f}")
459-
if perplexity < best_perplexity:
460-
print(f"New best word perplexity = {perplexity:.4f}")
461-
best_perplexity = perplexity
462-
shutil.copytree(last_dir, best_dir, dirs_exist_ok=True)
422+
save_checkpoint(model, ckpt_file, model_state=not args.basic_init)
463423

464424
del model
465425
# Export the best tuned model to OpenVINO and evaluate it using LM-Evaluation-Harness.
466-
best_ckpt_file = best_dir / "nncf_checkpoint.pth"
467-
model_for_eval = export_to_openvino(args.pretrained, best_ckpt_file, best_dir)
468-
ov_perplexity = measure_perplexity(model_for_eval, task_manager, args.eval_seqlen, args.limit, task="wikitext")
426+
model_for_eval = export_to_openvino(args.pretrained, ckpt_file, ckpt_file.parent)
427+
ov_perplexity = measure_perplexity(model_for_eval, args.eval_seqlen, args.limit)
469428
tb.add_scalar("ov_perplexity", ov_perplexity, 0)
470429
print(
471-
f"The finetuned model has been exported to OpenVINO and saved to: {best_dir}\n"
430+
f"The finetuned model has been exported to OpenVINO and saved to: {last_dir}\n"
472431
f"The word perplexity on wikitext (test) = {ov_perplexity:.4f}"
473432
)
474-
return initial_perplexity - best_perplexity, ov_perplexity
433+
return ov_perplexity
475434

476435

477436
if __name__ == "__main__":

0 commit comments

Comments
 (0)