Skip to content

Commit 6733681

Browse files
Unify GPTQ dataloader with fixed/unfixed length data (#1212)
Signed-off-by: YIYANGCAI <yiyang.cai@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cca57d3 commit 6733681

File tree

6 files changed

+76
-134
lines changed

6 files changed

+76
-134
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import torch.nn as nn
88
from torch.utils.data import DataLoader
9+
from torch.nn.functional import pad
910

1011
import transformers
1112
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -218,13 +219,6 @@ def skip(*args, **kwargs):
218219
model.eval()
219220

220221
# dataset
221-
# original method of loading data, only load the sequence whose length > model.seqlen
222-
# ================================================
223-
# dataloader, testloader = get_loaders(
224-
# args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model_name_or_path, seqlen=model.seqlen
225-
# )
226-
# dataloader = INCDataloader(dataloader)
227-
# ================================================
228222
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
229223
calib_dataset = load_dataset(args.dataset, split="train") # default
230224
# calib_dataset = datasets.load_from_disk('/your/local/pile-10k/') # use this if trouble with connecting to HF
@@ -244,7 +238,6 @@ def skip(*args, **kwargs):
244238

245239
model = model.to(DEV)
246240

247-
print('Starting ...')
248241
if args.sym:
249242
sym_opt = "sym"
250243
else:
@@ -276,7 +269,8 @@ def skip(*args, **kwargs):
276269
# 'act_order':args.act_order,
277270
# 'block_size': args.block_size,
278271
# 'nsampeles': args.nsamples,
279-
# 'use_max_length': args.use_max_length
272+
# 'use_max_length': args.use_max_length,
273+
# 'pad_max_length': args.pad_max_length
280274
# },
281275
# },
282276
# )
@@ -296,7 +290,8 @@ def skip(*args, **kwargs):
296290
weight_config=conf,
297291
dataloader=calib_dataloader,
298292
nsamples = args.nsamples,
299-
use_max_length = args.use_max_length
293+
use_max_length = args.use_max_length,
294+
pad_max_length = args.pad_max_length
300295
)
301296

302297
results = lm_evaluate(

neural_compressor/adaptor/pytorch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4597,9 +4597,15 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
45974597
}
45984598
nsamples = self.recipes["gptq_args"].get("nsamples", 128)
45994599
use_max_length = self.recipes["gptq_args"].get("use_max_length", False)
4600+
pad_max_length = self.recipes["gptq_args"].get("pad_max_length", 2048)
4601+
if use_max_length and "pad_max_length" not in self.recipes["gptq_args"]:
4602+
logger.warning(
4603+
"You choose to use unified sequence length for calibration, \
4604+
but you have not set length value. Default sequence length is 2048 and this might cause inference error!"
4605+
)
46004606
# tune_cfg => weight_config
46014607
model, quantization_perm = gptq_quantize(
4602-
model, weight_config, dataloader, nsamples, use_max_length, self.device
4608+
model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, self.device
46034609
)
46044610
return model, quantization_perm
46054611

neural_compressor/adaptor/torch_utils/gptq.py

Lines changed: 49 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,16 @@ class GPTQuantizer(object):
166166
url: https://arxiv.org/abs/2210.17323
167167
"""
168168

169-
def __init__(self, model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, device=None):
169+
def __init__(
170+
self,
171+
model,
172+
weight_config={},
173+
dataloader=None,
174+
nsamples=128,
175+
use_max_length=True,
176+
pad_max_length=2048,
177+
device=None,
178+
):
170179
"""
171180
Args:
172181
model: the fp32 model to quantize
@@ -211,44 +220,29 @@ def __init__(self, model, weight_config={}, dataloader=None, nsamples=128, use_m
211220

212221
# dataloader
213222
self.use_max_length = use_max_length
223+
self.pad_max_length = pad_max_length
214224
self.dataloader_original = dataloader
215225
self.dataloader = []
216226
self.nsamples = nsamples
217227
self.prepare_dataloader()
218228

219229
def prepare_dataloader(self):
220230
if self.use_max_length:
221-
# (Recommend) only take sequence whose length exceeds model.seqlen,
231+
# (Recommend) only take sequence whose length exceeds self.pad_max_length,
222232
# which perserves calibration's tokens are all valid
223233
# This is GPTQ official dataloader implementation
224234
self.obtain_first_n_samples_fulllength()
225-
# initialize buffers which are essential for gptq computation.
226-
self.model_hidden_size = 2048
227-
self.initialize_inp_buffersize()
228-
try:
229-
# Since length is unified, we can allocate a continous space to store inputs
230-
self.inp = torch.zeros(
231-
(len(self.dataloader), self.model.seqlen, self.model_hidden_size),
232-
dtype=self.dtype,
233-
device=self.device,
234-
)
235-
self.cache = {"i": 0}
236-
self.out = torch.zeros_like(self.inp)
237-
self.is_ready = True
238-
except:
239-
logger.warning("GPTQ Quantizer initialization failed!")
240-
pass
241235
else:
242236
# general selection, no padding, not GPTQ original implementation.
243237
self.obtain_first_n_samples()
244-
try:
245-
self.inp = [torch.zeros(1) for _ in range(len(self.dataloader))]
246-
self.cache = {"i": 0}
247-
self.out = [torch.zeros(1) for _ in range(len(self.dataloader))]
248-
self.is_ready = True
249-
except:
250-
logger.warning("GPTQ Quantizer initialization failed!")
251-
pass
238+
try:
239+
self.inp = [torch.zeros(1) for _ in range(len(self.dataloader))]
240+
self.cache = {"i": 0}
241+
self.out = [torch.zeros(1) for _ in range(len(self.dataloader))]
242+
self.is_ready = True
243+
except:
244+
logger.warning("GPTQ Quantizer initialization failed!")
245+
pass
252246

253247
def obtain_first_n_samples(self, seed=0):
254248
"""Get first nsample data as the real calibration dataset."""
@@ -257,12 +251,13 @@ def obtain_first_n_samples(self, seed=0):
257251
for batch in self.dataloader_original:
258252
# process data, depends on its data type.
259253
if len(self.dataloader) == self.nsamples:
254+
logger.info(f"Successfully collect {self.nsamples} calibration samples.")
260255
break
261256
# list, tuple
262257
if isinstance(batch, list) or isinstance(batch, tuple):
263-
if batch[0].shape[-1] > self.model.seqlen:
264-
i = random.randint(0, batch[0].shape[-1] - self.model.seqlen - 1)
265-
j = i + self.model.seqlen
258+
if batch[0].shape[-1] > self.pad_max_length:
259+
i = random.randint(0, batch[0].shape[-1] - self.pad_max_length - 1)
260+
j = i + self.pad_max_length
266261
batch_final = batch[0][:, i:j]
267262
else:
268263
batch_final = batch[0]
@@ -274,9 +269,9 @@ def obtain_first_n_samples(self, seed=0):
274269
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
275270
continue
276271
batch_final = {}
277-
if length > self.model.seqlen:
278-
i = random.randint(0, length - self.model.seqlen - 1)
279-
j = i + self.model.seqlen
272+
if length > self.pad_max_length:
273+
i = random.randint(0, length - self.pad_max_length - 1)
274+
j = i + self.pad_max_length
280275
# may have to slice every sequence related data
281276
for key in batch.keys():
282277
if isinstance(batch[key], torch.Tensor):
@@ -287,9 +282,9 @@ def obtain_first_n_samples(self, seed=0):
287282
batch_final = batch
288283
# tensor
289284
else:
290-
if batch.shape[-1] > self.model.seqlen:
291-
i = random.randint(0, batch.shape[-1] - self.model.seqlen - 1)
292-
j = i + self.model.seqlen
285+
if batch.shape[-1] > self.pad_max_length:
286+
i = random.randint(0, batch.shape[-1] - self.pad_max_length - 1)
287+
j = i + self.pad_max_length
293288
batch_final = batch[:, i:j]
294289
else:
295290
batch_final = batch
@@ -301,9 +296,10 @@ def obtain_first_n_samples(self, seed=0):
301296
def obtain_first_n_samples_fulllength(self, seed=0):
302297
self.dataloader.clear()
303298
random.seed(seed)
304-
unified_length = self.model.seqlen
299+
unified_length = self.pad_max_length
305300
for batch in self.dataloader_original:
306301
if len(self.dataloader) == self.nsamples:
302+
logger.info(f"Successfully collect {self.nsamples} calibration samples.")
307303
break
308304
# list & tuple
309305
if isinstance(batch, list) or isinstance(batch, tuple):
@@ -325,11 +321,11 @@ def obtain_first_n_samples_fulllength(self, seed=0):
325321
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
326322
continue
327323
batch_final = {}
328-
if length == self.model.seqlen:
324+
if length == self.pad_max_length:
329325
batch_final = batch
330-
elif length > self.model.seqlen:
331-
i = random.randint(0, length - self.model.seqlen - 1)
332-
j = i + self.model.seqlen
326+
elif length > self.pad_max_length:
327+
i = random.randint(0, length - self.pad_max_length - 1)
328+
j = i + self.pad_max_length
333329
# may have to slice every sequence related data
334330
for key in batch.keys():
335331
if isinstance(batch[key], torch.Tensor):
@@ -354,53 +350,9 @@ def obtain_first_n_samples_fulllength(self, seed=0):
354350
if len(self.dataloader) < self.nsamples: # pragma: no cover
355351
logger.warning(
356352
f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \
357-
but only {len(self.dataloader)} samples satisfy your setting. You may choose smaller 'model.seqlen' value."
353+
but only {len(self.dataloader)} samples are found. Please use smaller 'self.pad_max_length' value."
358354
)
359355

360-
@torch.no_grad()
361-
def initialize_inp_buffersize(self):
362-
# Run a forward and generate proper buffer tensor
363-
# Thus, no need to pass hidden_states dimension parameters of model.config
364-
# e.g. OPT's hidden_states dimension can be called by model.config.hidden_size
365-
# but mpt's hidden_states dimension can be called by model.config.d_model
366-
def forward(layer, hidden_states, **kwargs):
367-
# inputs[inputs_info['idx']] = input_ids # TODO solve the problem of batchsize!=1
368-
logger.info(f"The hidden_states shape along transformers blocks is {hidden_states.shape}.")
369-
self.model_hidden_size = hidden_states.shape[-1]
370-
raise ValueError
371-
372-
# Step1: fetch the embeddings and other layers before the transformer stack.
373-
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
374-
embedding_layer = embedding_layer.to(self.device)
375-
376-
# Step2: modify the first transformer block's forward function to obtain inputs for calibration
377-
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
378-
forward_cache = self.gptq_related_blocks["transformers"][0].forward
379-
self.gptq_related_blocks["transformers"][0].forward = partial(
380-
forward, self.gptq_related_blocks["transformers"][0]
381-
)
382-
383-
# Step3: run forward to obtain calibration datasets
384-
logger.info("Collecting calibration inputs...")
385-
for batch in self.dataloader:
386-
batch = move_input_to_device(batch, self.device)
387-
try:
388-
if isinstance(batch, tuple) or isinstance(batch, list):
389-
self.model(batch[0])
390-
elif isinstance(batch, dict):
391-
self.model(**batch)
392-
else:
393-
self.model(batch.to(self.device))
394-
except ValueError:
395-
break
396-
397-
# Step 4: restore original forward function, relocate layers back to cpu.
398-
self.gptq_related_blocks["transformers"][0].forward = forward_cache
399-
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
400-
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
401-
embedding_layer.to(self.device)
402-
torch.cuda.empty_cache()
403-
404356
def get_full_layer_name(self, sub_layer_name, block_idx):
405357
transformer_name = self.gptq_related_blocks["transformers_name"]
406358
return ".".join([transformer_name, str(block_idx), sub_layer_name])
@@ -459,18 +411,12 @@ def forward(layer, hidden_states, **kwargs):
459411
self.cache["i"] += 1
460412
for arg in kwargs:
461413
# TODO: investigate include parameters
462-
if self.use_max_length:
463-
if isinstance(kwargs[arg], torch.Tensor) or arg == "alibi":
464-
self.cache[arg] = kwargs[arg]
465-
else:
466-
continue
467-
else:
468-
# each outputs can be different shape, hence also use list to store
469-
if isinstance(kwargs[arg], torch.Tensor) or arg == "alibi":
470-
if self.cache.get(arg, None) is None:
471-
self.cache[arg] = []
472-
self.cache[arg].append(kwargs[arg])
473-
continue
414+
# each outputs can be different shape, hence also use list to store
415+
if isinstance(kwargs[arg], torch.Tensor) or arg == "alibi":
416+
if self.cache.get(arg, None) is None:
417+
self.cache[arg] = []
418+
self.cache[arg].append(kwargs[arg])
419+
continue
474420
raise ValueError
475421

476422
# Step1: fetch the embeddings and other layers before the transformer stack.
@@ -572,13 +518,9 @@ def tmp(_, inp, out):
572518
handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name)))
573519
idx = self.cache.pop("i")
574520
for j in range(len(self.dataloader)):
575-
if self.use_max_length:
576-
# self.inp[j] shape: [seq_len, hidden_size]
577-
self.out[j] = transformer_block(self.inp[j].unsqueeze(0), **self.cache)[0]
578-
else:
579-
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
580-
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
581-
self.out[j] = transformer_block(self.inp[j], **cache_batch)[0]
521+
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
522+
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
523+
self.out[j] = transformer_block(self.inp[j], **cache_batch)[0]
582524
self.cache["i"] = idx
583525
for h in handles:
584526
h.remove()
@@ -607,13 +549,9 @@ def tmp(_, inp, out):
607549
# Step 2.5: replace output data with quantized weights
608550
idx = self.cache.pop("i")
609551
for j in range(len(self.dataloader)):
610-
if self.use_max_length:
611-
# self.inp[j] shape: [seq_len, hidden_size]
612-
self.out[j] = transformer_block(self.inp[j].unsqueeze(0), **self.cache)[0]
613-
else:
614-
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
615-
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
616-
self.out[j] = transformer_block(self.inp[j], **cache_batch)[0]
552+
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
553+
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
554+
self.out[j] = transformer_block(self.inp[j], **cache_batch)[0]
617555
self.cache["i"] = idx
618556
self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()
619557
del gptq_for_this_block

neural_compressor/adaptor/torch_utils/weight_only.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,13 +469,15 @@ def rtn_quantize(
469469
return model
470470

471471

472-
def gptq_quantize(model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, device=None):
472+
def gptq_quantize(
473+
model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, pad_max_length=2048, device=None
474+
):
473475
"""Run weight-only quantization with."""
474476
# TODO: unify weight_config keys, add docstring, and support default config
475477
assert isinstance(model, torch.nn.Module), "only support torch module"
476478
from .gptq import GPTQuantizer
477479

478-
gptq_quantizer = GPTQuantizer(model, weight_config, dataloader, nsamples, use_max_length, device)
480+
gptq_quantizer = GPTQuantizer(model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device)
479481
fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization()
480482
logger.info("GPTQ quantizing done.")
481483
return fp32_modified_model, gptq_config

test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def setUpClass(self):
6868
self.gptj_no_jit = transformers.AutoModelForCausalLM.from_pretrained(
6969
"hf-internal-testing/tiny-random-GPTJForCausalLM",
7070
)
71-
self.gptj.seqlen = 512
7271
self.llm_dataloader = LLMDataLoader()
7372
self.lm_input = torch.ones([1, 10], dtype=torch.long)
7473

@@ -502,7 +501,7 @@ def __iter__(self):
502501
},
503502
},
504503
recipes={
505-
"gptq_args": {"percdamp": 0.01, "act_order": False},
504+
"gptq_args": {"percdamp": 0.01, "act_order": False, "use_max_length": True, "pad_max_length": 512},
506505
},
507506
)
508507

@@ -608,7 +607,7 @@ def __iter__(self):
608607
},
609608
},
610609
recipes={
611-
"gptq_args": {"percdamp": 0.01, "act_order": False, "use_max_length": True},
610+
"gptq_args": {"percdamp": 0.01, "act_order": False, "use_max_length": False, "pad_max_length": 512},
612611
},
613612
)
614613

0 commit comments

Comments
 (0)