Skip to content

Commit

Permalink
Simplify the workflow for inference (OpenNMT#2413)
Browse files Browse the repository at this point in the history
* simplify to make workflow easy to understand, remove unused stuff
  • Loading branch information
vince62s authored Jun 16, 2023
1 parent 367ee54 commit eef722a
Show file tree
Hide file tree
Showing 17 changed files with 55 additions and 110 deletions.
31 changes: 7 additions & 24 deletions eval_llm/MMLU/run_mmlu_opennmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
import os
import time
import pandas as pd
import torch
from onmt.utils.logging import init_logger
from onmt.translate.translator import build_translator
from onmt.inputters.dynamic_iterator import DynamicDatasetIter
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.inputters.inputter import IterOnDevice
from onmt.inputters.text_corpus import ParallelCorpus
from onmt.transforms import get_transforms_cls, TransformPipe, make_transforms
from onmt.constants import CorpusTask, CorpusName
from onmt.transforms import get_transforms_cls
from onmt.constants import CorpusTask
import onmt.opts as opts
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed
Expand Down Expand Up @@ -153,9 +151,6 @@ def evaluate(opt):
translator = build_translator(opt, logger=logger, report_score=True)
# Build the transforms (along with the tokenizer)
transforms_cls = get_transforms_cls(opt._all_transform)
transforms = make_transforms(
opt, transforms_cls, translator.vocabs
) # transforms is a dictionary of each transform and its instance

data_dir = "eval_llm/MMLU/data/"
ntrain = 5 # nshots from dev
Expand Down Expand Up @@ -185,27 +180,15 @@ def evaluate(opt):
records.append({"prompt": prompt, "answer": label})
src.append(prompt.replace("\n", "⦅newline⦆"))

corpora = {}
corpora["infer"] = ParallelCorpus(CorpusName.INFER, src, None)
infer_iter = DynamicDatasetIter.from_opt(
corpora, transforms, translator.vocabs, opt, CorpusTask.INFER, False
infer_iter = build_dynamic_dataset_iter(
opt, transforms_cls, translator.vocabs, task=CorpusTask.INFER, src=src
)
infer_iter.num_workers = 0
infer_iter._init_datasets(0)

data_transform = [
infer_iter.transforms[name]
for name in opt.transforms
if name in infer_iter.transforms
]
transform = TransformPipe.build_from(data_transform)

if infer_iter is not None:
infer_iter = IterOnDevice(infer_iter, opt.gpu)
infer_iter = IterOnDevice(infer_iter, opt.gpu)

scores, preds = translator._translate(
infer_iter,
transform=transform,
transform=infer_iter.transform,
attn_debug=opt.attn_debug,
align_debug=opt.align_debug,
)
Expand Down
14 changes: 3 additions & 11 deletions onmt/bin/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from onmt.translate.translator import build_translator
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.inputters.inputter import IterOnDevice
from onmt.transforms import get_transforms_cls, TransformPipe
from onmt.transforms import get_transforms_cls
from onmt.constants import CorpusTask
import onmt.opts as opts
from onmt.utils.parse import ArgumentParser
Expand Down Expand Up @@ -32,19 +32,11 @@ def translate(opt):
copy=translator.copy_attn,
)

data_transform = [
infer_iter.transforms[name]
for name in opt.transforms
if name in infer_iter.transforms
]
transform = TransformPipe.build_from(data_transform)

if infer_iter is not None:
infer_iter = IterOnDevice(infer_iter, opt.gpu)
infer_iter = IterOnDevice(infer_iter, opt.gpu)

_, _ = translator._translate(
infer_iter,
transform=transform,
transform=infer_iter.transform,
attn_debug=opt.attn_debug,
align_debug=opt.align_debug,
)
Expand Down
15 changes: 13 additions & 2 deletions onmt/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,21 @@ def __iter__(self):


def build_dynamic_dataset_iter(
opt, transforms_cls, vocabs, copy=False, task=CorpusTask.TRAIN, stride=1, offset=0
opt,
transforms_cls,
vocabs,
copy=False,
task=CorpusTask.TRAIN,
stride=1,
offset=0,
src=None,
tgt=None,
align=None,
):
"""
Build `DynamicDatasetIter` from opt.
if src, tgt,align are passed then dataset is built from those lists
instead of opt.[src, tgt, align]
Typically this function is called for CorpusTask.[TRAIN,VALID,INFER]
from the main tain / translate scripts
We disable automatic batching in the DataLoader.
Expand All @@ -360,7 +371,7 @@ def build_dynamic_dataset_iter(
advance to avoid the GPU waiting during the refilling of the bucket.
"""
transforms = make_transforms(opt, transforms_cls, vocabs)
corpora = get_corpora(opt, task)
corpora = get_corpora(opt, task, src=src, tgt=tgt, align=align)
if corpora is None:
assert task != CorpusTask.TRAIN, "only valid corpus is ignorable."
return None
Expand Down
5 changes: 5 additions & 0 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def __init__(self, iterable, device_id):
super(IterOnDevice).__init__()
self.iterable = iterable
self.device_id = device_id
# temporary as long as translation_server and scoring_preparator still use lists
if hasattr(iterable, "transforms"):
self.transform = TransformPipe.build_from(
[iterable.transforms[name] for name in iterable.transforms]
)

@staticmethod
def batch_to_device(tensor_batch, device_id):
Expand Down
7 changes: 4 additions & 3 deletions onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __str__(self):
)


def get_corpora(opts, task=CorpusTask.TRAIN):
def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None):
corpora_dict = {}
if task == CorpusTask.TRAIN:
for corpus_id, corpus_dict in opts.data.items():
Expand Down Expand Up @@ -140,8 +140,9 @@ def get_corpora(opts, task=CorpusTask.TRAIN):
else:
corpora_dict[CorpusName.INFER] = ParallelCorpus(
CorpusName.INFER,
opts.src,
opts.tgt,
src if src else opts.src,
tgt if tgt else opts.tgt,
align if align else None,
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults,
)
Expand Down
5 changes: 0 additions & 5 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
This file is for models creation, which consults options
and creates each encoder and decoder accordingly.
"""
import re
import os
import importlib
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_, zeros_, uniform_
Expand Down Expand Up @@ -268,8 +265,6 @@ def build_base_model(model_opt, vocabs):
"%s compression of layer %s" % (model_opt.quant_type, nonlora_to_quant)
)
try:
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
import bitsandbytes as bnb
from onmt.modules.bnb_linear import replace_bnb_linear
except ImportError:
raise ImportError("Install bitsandbytes to use 4/8bit compression")
Expand Down
2 changes: 0 additions & 2 deletions onmt/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
""" Attention and normalization modules """
import importlib
import os
from onmt.modules.util_class import Elementwise
from onmt.modules.gate import context_gate_factory, ContextGate
from onmt.modules.global_attention import GlobalAttention
Expand Down
5 changes: 0 additions & 5 deletions onmt/modules/bnb_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,6 @@ def replace_bnb_linear(
threshold=6.0,
compute_dtype=torch.float16, # we could also use bfloat16 when available
):
try:
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
import bitsandbytes as bnb
except ImportError:
raise ImportError("Install bitsandbytes to use 4/8bit compression")
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_bnb_linear(
Expand Down
3 changes: 1 addition & 2 deletions onmt/modules/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import torch.nn as nn
import torch.nn.functional as F
import math
import importlib
from torch.utils.checkpoint import checkpoint
from typing import List, Dict
from typing import Dict
import os


Expand Down
2 changes: 1 addition & 1 deletion onmt/tests/test_data_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, *args, **kwargs):

def dataset_build(self, opt):
try:
prepare_transforms_vocabs(opt)
prepare_transforms_vocabs(opt, {})
except SystemExit as err:
print(err)
except IOError as err:
Expand Down
53 changes: 19 additions & 34 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@
from onmt.modules.embeddings import prepare_pretrained_embeddings


def prepare_transforms_vocabs(opt):
def prepare_transforms_vocabs(opt, transforms_cls):
"""Prepare or dump transforms before training."""
transforms_cls = get_transforms_cls(opt._all_transform)

# if transform + options set in 'valid' we need to copy in main
# transform / options for scoring considered as inference
validset_transforms = opt.data.get("valid", {}).get("transforms", None)
Expand Down Expand Up @@ -66,18 +64,22 @@ def prepare_transforms_vocabs(opt):
f"{vocabs_to_dict(vocabs)['src'][0:10]}"
)
logger.info(f"The decoder start token is: {opt.decoder_start_token}")
return vocabs, transforms_cls
return vocabs


def _init_train(opt):
"""Common initilization stuff for all training process."""
"""Common initilization stuff for all training process.
We need to build or rebuild the vocab in 3 cases:
- training from scratch (train_from is false)
- resume training but transforms have changed
- resume training but vocab file has been modified
"""
ArgumentParser.validate_prepare_opts(opt)

transforms_cls = get_transforms_cls(opt._all_transform)
if opt.train_from:
# Load checkpoint if we resume from a previous training.
checkpoint = load_checkpoint(ckpt_path=opt.train_from)
vocabs = dict_to_vocabs(checkpoint["vocab"])
transforms_cls = get_transforms_cls(opt._all_transform)
if (
hasattr(checkpoint["opt"], "_all_transform")
and len(
Expand All @@ -95,13 +97,13 @@ def _init_train(opt):
if len(old_transf) != 0:
_msg += f" -{old_transf}."
logger.warning(_msg)
vocabs, transforms_cls = prepare_transforms_vocabs(opt)
vocabs = prepare_transforms_vocabs(opt, transforms_cls)
if opt.update_vocab:
logger.info("Updating checkpoint vocabulary with new vocabulary")
vocabs, transforms_cls = prepare_transforms_vocabs(opt)
vocabs = prepare_transforms_vocabs(opt, transforms_cls)
else:
checkpoint = None
vocabs, transforms_cls = prepare_transforms_vocabs(opt)
vocabs = prepare_transforms_vocabs(opt, transforms_cls)

return checkpoint, vocabs, transforms_cls

Expand Down Expand Up @@ -149,28 +151,6 @@ def _get_model_opts(opt, checkpoint=None):
return model_opt


def _build_valid_iter(opt, transforms_cls, vocabs):
"""Build iterator used for validation."""
valid_iter = build_dynamic_dataset_iter(
opt, transforms_cls, vocabs, task=CorpusTask.VALID, copy=opt.copy_attn
)
return valid_iter


def _build_train_iter(opt, transforms_cls, vocabs, stride=1, offset=0):
"""Build training iterator."""
train_iter = build_dynamic_dataset_iter(
opt,
transforms_cls,
vocabs,
task=CorpusTask.TRAIN,
copy=opt.copy_attn,
stride=stride,
offset=offset,
)
return train_iter


def main(opt, device_id):
"""Start training on `device_id`."""
# NOTE: It's important that ``opt`` has been validated and updated
Expand Down Expand Up @@ -222,16 +202,21 @@ def main(opt, device_id):
opt, device_id, model, vocabs, optim, model_saver=model_saver
)

_train_iter = _build_train_iter(
_train_iter = build_dynamic_dataset_iter(
opt,
transforms_cls,
vocabs,
task=CorpusTask.TRAIN,
copy=opt.copy_attn,
stride=max(1, len(opt.gpu_ranks)),
offset=max(0, device_id),
)
train_iter = IterOnDevice(_train_iter, device_id)

valid_iter = _build_valid_iter(opt, transforms_cls, vocabs)
valid_iter = build_dynamic_dataset_iter(
opt, transforms_cls, vocabs, task=CorpusTask.VALID, copy=opt.copy_attn
)

if valid_iter is not None:
valid_iter = IterOnDevice(valid_iter, device_id)

Expand Down
2 changes: 0 additions & 2 deletions onmt/transforms/inlinetags.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform

import random
import ahocorasick
import string
from typing import Tuple


class InlineTagger(object):
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ max-line-length = 100
ignore =
E203,
E731,
F401,
W503,
14 changes: 2 additions & 12 deletions tools/LM_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from onmt.inputters.inputter import IterOnDevice
from onmt.utils.loss import LossCompute
from onmt.constants import DefaultTokens, CorpusTask
from onmt.transforms import get_transforms_cls, TransformPipe
from onmt.transforms import get_transforms_cls
from onmt.model_builder import load_test_model

"""
Expand Down Expand Up @@ -71,20 +71,10 @@ def main():
valid_loss.to(device)

transforms_cls = get_transforms_cls(opt._all_transform)

infer_iter = build_dynamic_dataset_iter(
opt, transforms_cls, vocabs, task=CorpusTask.INFER, copy=False
)

data_transform = [
infer_iter.transforms[name]
for name in opt.transforms
if name in infer_iter.transforms
]
_ = TransformPipe.build_from(data_transform)

if infer_iter is not None:
infer_iter = IterOnDevice(infer_iter, opt.gpu)
infer_iter = IterOnDevice(infer_iter, opt.gpu)

model.to(device)
model.eval()
Expand Down
2 changes: 0 additions & 2 deletions tools/convert_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import pyonmttok
from argparse import Namespace
from onmt.inputters.inputter import vocabs_to_dict
from onmt.constants import DefaultTokens
from sentencepiece import SentencePieceProcessor
import os
from safetensors.torch import save_file

Expand Down
2 changes: 0 additions & 2 deletions tools/convert_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import pyonmttok
from argparse import Namespace
from onmt.inputters.inputter import vocabs_to_dict
from onmt.constants import DefaultTokens
from sentencepiece import SentencePieceProcessor
import os
from transformers import AutoModelForCausalLM
from safetensors.torch import save_file
Expand Down
2 changes: 0 additions & 2 deletions tools/convert_redpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import pyonmttok
from argparse import Namespace
from onmt.inputters.inputter import vocabs_to_dict
from onmt.constants import DefaultTokens
from sentencepiece import SentencePieceProcessor
import os
from transformers import AutoModelForCausalLM
from safetensors.torch import save_file
Expand Down

0 comments on commit eef722a

Please sign in to comment.