Skip to content

Commit

Permalink
Distributed inference of 70B awq model (#2531)
Browse files Browse the repository at this point in the history
* Distributed inference of 70B awq model
* fix overflow
  • Loading branch information
vince62s authored Dec 4, 2023
1 parent a147137 commit 1e5ed31
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 171 deletions.
29 changes: 4 additions & 25 deletions onmt/bin/translate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from onmt.utils.logging import init_logger
from onmt.translate.translator import build_translator
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.transforms import get_transforms_cls
from onmt.constants import CorpusTask
from onmt.inference_engine import InferenceEnginePY
from onmt.opts import config_opts, translate_opts
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed
Expand All @@ -17,29 +13,12 @@ def translate(opt):
ArgumentParser._get_all_transform_translate(opt)
ArgumentParser._validate_transforms_opts(opt)
ArgumentParser.validate_translate_opts_dynamic(opt)
logger = init_logger(opt.log_file)

set_random_seed(opt.seed, use_gpu(opt))

translator = build_translator(opt, logger=logger, report_score=False)

transforms_cls = get_transforms_cls(opt._all_transform)

infer_iter = build_dynamic_dataset_iter(
opt,
transforms_cls,
translator.vocabs,
task=CorpusTask.INFER,
copy=translator.copy_attn,
device_id=opt.gpu,
)

_, _ = translator._translate(
infer_iter,
transform=infer_iter.transforms,
attn_debug=opt.attn_debug,
align_debug=opt.align_debug,
)
engine = InferenceEnginePY(opt)
_, _ = engine.infer_file()
engine.terminate()


def _get_parser():
Expand Down
16 changes: 6 additions & 10 deletions onmt/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from onmt.constants import CorpusTask, DefaultTokens, ModelTask
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.utils.distributed import ErrorHandler, spawned_infer
from onmt.utils.logging import logger
from onmt.utils.logging import init_logger
from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe


Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(self, opt):

super().__init__(opt)
self.opt = opt
self.logger = init_logger(opt.log_file)

if opt.world_size > 1:
mp = torch.multiprocessing.get_context("spawn")
Expand All @@ -92,10 +93,6 @@ def __init__(self, opt):
self.queue_result = []
self.procs = []

print("world_size: ", opt.world_size)
print("gpu_ranks: ", opt.gpu_ranks)
print("opt.gpu: ", opt.gpu)

for device_id in range(opt.world_size):
self.queue_instruct.append(mp.Queue())
self.queue_result.append(mp.Queue())
Expand All @@ -113,12 +110,11 @@ def __init__(self, opt):
)
)
self.procs[device_id].start()
print(" Starting process pid: %d " % self.procs[device_id].pid)
self.error_handler.add_child(self.procs[device_id].pid)
else:
self.device_id = 0 if opt.world_size == 1 else -1
self.device_id = opt.gpu
self.translator = build_translator(
opt, self.device_id, logger=logger, report_score=True
opt, self.device_id, logger=self.logger, report_score=True
)
self.transforms_cls = get_transforms_cls(opt._all_transform)
self.vocabs = self.translator.vocabs
Expand Down Expand Up @@ -168,9 +164,9 @@ def __init__(self, opt):

super().__init__(opt)
self.opt = opt
self.logger = logger
self.logger = init_logger(opt.log_file)
assert self.opt.world_size <= 1, "World size must be less than 1."
self.device_id = 0 if opt.world_size == 1 else -1
self.device_id = opt.gpu
if opt.world_size == 1:
self.device_index = opt.gpu_ranks
self.device = "cuda"
Expand Down
22 changes: 17 additions & 5 deletions onmt/inputters/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,19 @@ def tensorify(vocabs, minibatch, device, left_pad=False):
)

if minibatch[0][0]["tgt"] is not None:
tbatchtgt = [
torch.tensor(ex["tgt"]["tgt_ids"], dtype=torch.long, device=device)
for ex, indice in minibatch
]
if left_pad:
tbatchtgt = [
torch.tensor(
ex["tgt"]["tgt_ids"], dtype=torch.long, device=device
).flip(dims=[0])
for ex, indice in minibatch
]
else:
tbatchtgt = [
torch.tensor(ex["tgt"]["tgt_ids"], dtype=torch.long, device=device)
for ex, indice in minibatch
]

padidx = vocabs["tgt"][DefaultTokens.PAD]
tbatchtgt = pad_sequence(tbatchtgt, batch_first=True, padding_value=padidx)
tbatchtgt = tbatchtgt[:, :, None]
Expand All @@ -258,7 +267,10 @@ def tensorify(vocabs, minibatch, device, left_pad=False):
dtype=torch.long,
device=device,
)
tensor_batch["tgt"] = tbatchtgt
if left_pad:
tensor_batch["tgt"] = tbatchtgt.flip(dims=[1])
else:
tensor_batch["tgt"] = tbatchtgt
tensor_batch["tgtlen"] = tbatchtgtlen

if "align" in minibatch[0][0].keys() and minibatch[0][0]["align"] is not None:
Expand Down
6 changes: 5 additions & 1 deletion onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ def load_test_model(opt, device_id=0, model_path=None):
"aawq_gemm",
"aawq_gemv",
]: # if the loaded model is a awq quantized one, inference config cannot overwrite this
if hasattr(opt, "quant_type") and opt.quant_type != model_opt.quant_type:
if (
hasattr(opt, "quant_type")
and opt.quant_type != ""
and opt.quant_type != model_opt.quant_type
):
raise ValueError(
"Model is a awq quantized model, cannot overwrite with another quant method"
)
Expand Down
27 changes: 20 additions & 7 deletions onmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def count_parameters(self, log=print):
raise NotImplementedError

def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset):
if module.__class__.__name__ == "WQLinear_GEMM":
# ugly patch because in_feat and out_feat are reversed in WQLinear_GEMM
param.data = param.data.transpose(0, 1)
ckpt_t = ckpt_t.transpose(0, 1)
if name.split(".")[-1] in [
"linear_keys",
"linear_values",
Expand Down Expand Up @@ -73,13 +77,22 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset)
].size()
), "An error in model's partition and checkpoint's slice was detected"
if name + "." + param_name in buf_list:
module.register_buffer(
param_name,
ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
],
)
if module.__class__.__name__ == "WQLinear_GEMM":
module.register_buffer(
param_name,
ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
].transpose(0, 1),
)
else:
module.register_buffer(
param_name,
ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
],
)
else:
param.data = ckpt_t[
col_slice_start:col_slice_end,
Expand Down
12 changes: 10 additions & 2 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,8 +1564,16 @@ def _add_quant_opts(parser):
group.add(
"--quant_type",
"-quant_type",
default="bnb_8bit",
choices=["bnb_8bit", "bnb_FP4", "bnb_NF4", "llm_awq", "aawq_gemm", "aawq_gemv"],
default="",
choices=[
"",
"bnb_8bit",
"bnb_FP4",
"bnb_NF4",
"llm_awq",
"aawq_gemm",
"aawq_gemv",
],
type=str,
help="Type of compression.",
)
Expand Down
5 changes: 4 additions & 1 deletion onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,10 @@ def _report_score(self, name, score_total, nb_sentences):
msg = "%s No translations" % (name,)
else:
score = score_total / nb_sentences
ppl = exp(-score_total / nb_sentences)
try:
ppl = exp(-score_total / nb_sentences)
except OverflowError:
ppl = float("inf")
msg = "%s SCORE: %.4f, %s PPL: %.2f NB SENTENCES: %d" % (
name,
score,
Expand Down
3 changes: 1 addition & 2 deletions onmt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def spawned_infer(opt, device_id, error_queue, queue_instruct, queue_result):
init_logger(opt.log_file)
translator = build_translator(opt, device_id, logger=logger, report_score=True)
transforms_cls = get_transforms_cls(opt._all_transform)
print("Device_id: ", device_id, " translator built")
while True:
instruction = queue_instruct.get()
if instruction[0] == "stop":
Expand Down Expand Up @@ -227,7 +226,7 @@ def spawned_infer(opt, device_id, error_queue, queue_instruct, queue_result):
device_id=device_id,
)
scores, preds = translator._translate(
infer_iter, infer_iter.transform, opt.attn_debug, opt.align_debug
infer_iter, infer_iter.transforms, opt.attn_debug, opt.align_debug
)
queue_result.put(scores)
queue_result.put(preds)
Expand Down
15 changes: 12 additions & 3 deletions tools/LM_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def main():
cumul_length = 0
# Now we can pipe the full file through the model using the Iterator

for i, batch in enumerate(infer_iter):
for i, (batch, bucket_idx) in enumerate(infer_iter):
# reminder a batch includes .src .tgt .indices and it is sorted
batch_size = len(batch["srclen"])
src = batch["src"]
src_len = batch["srclen"]

# print(batch)
outputs, attns = model(src, None, src_len, with_align=False)
# Compute and retrieve the loss for EACH sentence
loss, _ = valid_loss(batch, outputs, attns)
Expand All @@ -102,7 +102,16 @@ def main():
cumul_length += batch["tgt"][:, 1:, 0].ne(padding_idx).sum().cpu()
# Now we need to rearrange the batch of ppl
# in the original order with indices
sent_ppl_orig = ppl.gather(0, batch["cid_line_number"].argsort(0))
sent_ppl_orig = ppl.gather(
0,
torch.tensor(
sorted(
range(len(batch["cid_line_number"])),
key=lambda k: batch["cid_line_number"][k],
),
device=ppl.device,
),
)
for j in range(batch_size):
ppl_file.write(str(sent_ppl_orig[j].item()) + "\n")
logger.info(
Expand Down
Loading

0 comments on commit 1e5ed31

Please sign in to comment.