Skip to content

Commit

Permalink
Calculate log-probabilities in mbrs-generate (#12)
Browse files Browse the repository at this point in the history
* Calculate lprobs in mbrs-generate

* Reduce the GPU memory usage when calculating lprobs

* Prenormalize the score tensors

* Implement memory efficient compute_transition_scores

* Add --report argument to mbrs-generate
  • Loading branch information
de9uch1 authored Jul 27, 2024
1 parent 060d4a3 commit 2aa1c21
Showing 1 changed file with 231 additions and 16 deletions.
247 changes: 231 additions & 16 deletions mbrs/cli/generate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/usr/bin/env python3

from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, FileType, Namespace
from dataclasses import dataclass
from itertools import chain
from typing import Any, Iterable
from typing import Any, Generator, Iterable, Optional

import torch
from tabulate import tabulate, tabulate_formats
Expand All @@ -11,7 +12,9 @@
AutoModelForSeq2SeqLM,
AutoTokenizer,
M2M100ForConditionalGeneration,
set_seed,
)
from transformers.generation.utils import GenerateOutput, GenerationMixin

from mbrs import timer

Expand All @@ -35,6 +38,12 @@ def get_argparser() -> ArgumentParser:
help="Input file. If not specified, read from stdin.")
parser.add_argument("--output", "-o", default="-", type=FileType("w"),
help="Output file.")
parser.add_argument("--lprobs", default=None, type=FileType("w"),
help="Reference log-probabilities file. "
"This option is useful for the model-based estimation.")
parser.add_argument("--length_normalized_lprobs", default=None, type=FileType("w"),
help="Length-normalized reference log-probabilities file. "
"This option is useful for the model-based estimation.")
parser.add_argument("--model", "-m", type=str, default="facebook/m2m100_418M",
help="Model name or path.")
parser.add_argument("--num_candidates", "-n", type=int, default=1,
Expand All @@ -52,18 +61,26 @@ def get_argparser() -> ArgumentParser:
help="Maximum length of an output sentence.")
parser.add_argument("--min_length", type=int, default=1,
help="Minimum length of an output sentence.")
parser.add_argument("--length_penalty", type=float, default=None,
help="Length penalty.")
parser.add_argument("--batch_size", "-b", type=int, default=8,
help="Batch size.")
parser.add_argument("--sampling_size", type=int, default=8,
help="Sampling size in a single inference.")
help="Sampling size in a single inference. "
"The model generates this number of samples at a time "
"until the total number of samples reaches `--num_candidates`.")
parser.add_argument("--fp16", action="store_true",
help="Use float16.")
parser.add_argument("--bf16", action="store_true",
help="Use bfloat16.")
parser.add_argument("--cpu", action="store_true",
help="Force to use CPU.")
parser.add_argument("--seed", type=int, default=0,
help="Random number seed.")
parser.add_argument("--quiet", "-q", action="store_true",
help="No report statistics.")
parser.add_argument("--report", default="-", type=FileType("w"),
help="Report file.")
parser.add_argument("--report_format", type=str, default="rounded_outline",
choices=tabulate_formats,
help="Report runtime statistics.")
Expand All @@ -77,7 +94,142 @@ def parse_args() -> Namespace:
return get_argparser().parse_args()


@dataclass
class Sample:
"""A sample generated by a model."""

text: str
lprob: Optional[float] = None
length_normalized_lprob: Optional[float] = None


def memory_efficient_compute_transition_scores(
model: GenerationMixin,
sequences: torch.Tensor,
scores: tuple[torch.Tensor, ...],
beam_indices: Optional[torch.Tensor] = None,
normalize_logits: bool = False,
) -> torch.Tensor:
"""
Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was
used). This is a convenient method to quicky obtain the scores of the selected tokens at generation time.
This is another implementation for memory efficiency.
Parameters:
model (`GenerationMixin`): Generation model.
sequences (`torch.LongTensor`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or
shorter if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)`):
Transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
beam_indices (`torch.LongTensor`, *optional*):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at
generate-time.
normalize_logits (`bool`, *optional*, defaults to `False`):
Whether to normalize the logits (which, for legacy reasons, may be unnormalized).
Return:
`torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing
the transition scores (logits)
Examples:
```python
>>> from transformers import GPT2Tokenizer, AutoModelForCausalLM
>>> import numpy as np
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer.pad_token_id = tokenizer.eos_token_id
>>> inputs = tokenizer(["Today is"], return_tensors="pt")
>>> # Example 1: Print the scores for each token generated with Greedy Search
>>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True)
>>> transition_scores = memory_efficient_compute_transition_scores(
... model, outputs.sequences, outputs.scores, normalize_logits=True
... )
>>> # input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for
>>> # encoder-decoder models, like BART or T5.
>>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
>>> generated_tokens = outputs.sequences[:, input_length:]
>>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
... # | token | token string | log probability | probability
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
| 262 | the | -1.414 | 24.33%
| 1110 | day | -2.609 | 7.36%
| 618 | when | -2.010 | 13.40%
| 356 | we | -1.859 | 15.58%
| 460 | can | -2.508 | 8.14%
>>> # Example 2: Reconstruct the sequence scores from Beam Search
>>> outputs = model.generate(
... **inputs,
... max_new_tokens=5,
... num_beams=4,
... num_return_sequences=4,
... return_dict_in_generate=True,
... output_scores=True,
... )
>>> transition_scores = memory_efficient_compute_transition_scores(
... model, outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
... )
>>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
>>> # Tip 1: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the
>>> # use case, you might want to recompute it with `normalize_logits=True`.
>>> # Tip 2: the output length does NOT include the input length
>>> output_length = np.sum(transition_scores.numpy() < 0, axis=1)
>>> length_penalty = model.generation_config.length_penalty
>>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty)
>>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
True
```"""
# 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent
# to a beam search approach were the first (and only) beam is always selected
if beam_indices is None:
beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device)
beam_indices = beam_indices.expand(-1, len(scores))

# 2. Optionally normalize the logits (across the vocab dimension)
if normalize_logits:
scores = tuple(s.log_softmax(dim=1) for s in scores)

# 3. cut beam_indices to longest beam length
beam_indices_mask = beam_indices < 0
max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
beam_indices = beam_indices.clone()[:, :max_beam_length]
beam_indices_mask = beam_indices_mask[:, :max_beam_length]

# 4. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards
beam_indices[beam_indices_mask] = 0

# 5. multiply beam_indices with vocab size to gather correctly from scores
beam_sequence_indices = beam_indices * model.config.vocab_size

# 6. Define which indices contributed to scores
cut_idx = sequences.shape[-1] - max_beam_length
indices = sequences[:, cut_idx:] + beam_sequence_indices

# 7. Compute scores
transition_scores = torch.zeros(
indices.size(), dtype=torch.float32, device=scores[0].device
)
for step, step_scores in enumerate(scores):
transition_scores[:, step] = step_scores.view(-1).gather(0, indices[:, step])

# 8. Mask out transition_scores of beams that stopped early
transition_scores[beam_indices_mask] = 0

return transition_scores


def main(args: Namespace) -> None:
set_seed(args.seed)

src_lang, tgt_lang = tuple(args.lang_pair.split("-"))
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModelForSeq2SeqLM.from_pretrained(args.model)
Expand All @@ -86,13 +238,24 @@ def main(args: Namespace) -> None:
for param in model.parameters():
param.requires_grad = False
if torch.cuda.is_available() and not args.cpu:
model.cuda()
if args.fp16:
model.half()
elif args.bf16:
model.bfloat16()
model.cuda()

generation_kwargs = {"max_length": args.max_length, "min_length": args.min_length}
generation_kwargs = {
"max_length": args.max_length,
"min_length": args.min_length,
"return_dict_in_generate": True,
}
length_penalty = getattr(model.generation_config, "length_penalty", 1.0)
if args.length_penalty is not None:
generation_kwargs["length_penalty"] = args.length_penalty
length_penalty = args.length_penalty

if args.lprobs is not None or args.length_normalized_lprobs is not None:
generation_kwargs["output_scores"] = True

if isinstance(model, M2M100ForConditionalGeneration):
tokenizer.src_lang = src_lang
Expand All @@ -102,40 +265,92 @@ def main(args: Namespace) -> None:
generation_kwargs["do_sample"] = True
generation_kwargs["epsilon_cutoff"] = args.epsilon
generation_kwargs["num_beams"] = 1
generation_kwargs["early_stopping"] = False
else:
generation_kwargs["num_beams"] = max(args.beam_size, args.num_candidates)

def decode(
inputs: list[str], num_candidates: int, generation_kwargs: dict[str, Any]
) -> list[str]:
) -> list[Sample]:
model_inputs = tokenizer(inputs, return_tensors="pt", padding=True).to(
device=model.device
)
with timer.measure("generate"):
model_outputs = model.generate(
model_outputs: GenerateOutput = model.generate(
**model_inputs, **generation_kwargs, num_return_sequences=num_candidates
)
return tokenizer.batch_decode(model_outputs, skip_special_tokens=True)

def generate(inputs: list[str]) -> list[str]:
if model_outputs.scores is None:
return [
Sample(s)
for s in tokenizer.batch_decode(
model_outputs.sequences, skip_special_tokens=True
)
]

sequences = model_outputs.sequences
scores = model_outputs.scores
max_length = len(scores)
if hasattr(tokenizer, "pad_token_id"):
pad_token_id = tokenizer.pad_token_id
for s in scores:
s[:, pad_token_id] = float("-inf")
scores = tuple(s.log_softmax(dim=1) for s in scores)
for s in scores:
s[:, pad_token_id] = 0.0
sequence_lengths = max_length - (sequences.eq(pad_token_id)).sum(dim=-1)
else:
scores = tuple(s.log_softmax(dim=1) for s in scores)
sequence_lengths = torch.full(
(sequences.size(0),), fill_value=max_length, device=sequences.device
)

transition_scores = memory_efficient_compute_transition_scores(
model,
sequences,
scores,
beam_indices=model_outputs.get("beam_indices", None),
)
lprobs = transition_scores.sum(dim=-1)
length_normalized_lprobs = lprobs / (sequence_lengths**length_penalty)
return [
Sample(text, lprob=lprob, length_normalized_lprob=length_normalized_lprob)
for text, lprob, length_normalized_lprob in zip(
tokenizer.batch_decode(sequences, skip_special_tokens=True),
lprobs.cpu().tolist(),
length_normalized_lprobs.cpu().tolist(),
)
]

def generate(inputs: list[str]) -> Generator[Sample, None, None]:
if (
not generation_kwargs.get("do_sample", False)
or generation_kwargs.get("num_beams", 1) != 1
):
return decode(inputs, args.num_candidates, generation_kwargs)
yield from decode(inputs, args.num_candidates, generation_kwargs)
else:
outputs: list[list[str]] = [[] for _ in range(args.batch_size)]
samples: list[list[Sample]] = [[] for _ in range(args.batch_size)]
for n in range(0, args.num_candidates, args.sampling_size):
sampling_size = min(args.sampling_size, args.num_candidates - n)
samples = decode(inputs, sampling_size, generation_kwargs)
shards = decode(inputs, sampling_size, generation_kwargs)
for i in range(args.batch_size):
outputs[i] += samples[i * sampling_size : (i + 1) * sampling_size]
return list(chain.from_iterable(outputs))
samples[i] += shards[i * sampling_size : (i + 1) * sampling_size]
yield from chain.from_iterable(samples)

num_sentences = 0
for lines in buffer_lines(args.input, buffer_size=args.batch_size):
for output in generate(lines):
print(output.strip(), file=args.output)
for sample in generate(lines):
print(sample.text.strip(), file=args.output)
if sample.lprob is not None and args.lprobs is not None:
print(str(sample.lprob), file=args.lprobs)
if (
sample.length_normalized_lprob is not None
and args.length_normalized_lprobs is not None
):
print(
str(sample.length_normalized_lprob),
file=args.length_normalized_lprobs,
)
num_sentences += 1

if not args.quiet:
Expand All @@ -146,7 +361,7 @@ def generate(inputs: list[str]) -> list[str]:
tablefmt=args.report_format,
floatfmt=f".{args.width}f",
)
print(table)
print(table, file=args.report)


def cli_main():
Expand Down

0 comments on commit 2aa1c21

Please sign in to comment.