From 2aa1c215ea2ce896bc1ec836d02d633a8c22b102 Mon Sep 17 00:00:00 2001 From: Hiroyuki Deguchi Date: Sat, 27 Jul 2024 14:41:30 +0900 Subject: [PATCH] Calculate log-probabilities in mbrs-generate (#12) * Calculate lprobs in mbrs-generate * Reduce the GPU memory usage when calculating lprobs * Prenormalize the score tensors * Implement memory efficient compute_transition_scores * Add --report argument to mbrs-generate --- mbrs/cli/generate.py | 247 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 231 insertions(+), 16 deletions(-) diff --git a/mbrs/cli/generate.py b/mbrs/cli/generate.py index 0bed6c6..5b1082e 100644 --- a/mbrs/cli/generate.py +++ b/mbrs/cli/generate.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, FileType, Namespace +from dataclasses import dataclass from itertools import chain -from typing import Any, Iterable +from typing import Any, Generator, Iterable, Optional import torch from tabulate import tabulate, tabulate_formats @@ -11,7 +12,9 @@ AutoModelForSeq2SeqLM, AutoTokenizer, M2M100ForConditionalGeneration, + set_seed, ) +from transformers.generation.utils import GenerateOutput, GenerationMixin from mbrs import timer @@ -35,6 +38,12 @@ def get_argparser() -> ArgumentParser: help="Input file. If not specified, read from stdin.") parser.add_argument("--output", "-o", default="-", type=FileType("w"), help="Output file.") + parser.add_argument("--lprobs", default=None, type=FileType("w"), + help="Reference log-probabilities file. " + "This option is useful for the model-based estimation.") + parser.add_argument("--length_normalized_lprobs", default=None, type=FileType("w"), + help="Length-normalized reference log-probabilities file. " + "This option is useful for the model-based estimation.") parser.add_argument("--model", "-m", type=str, default="facebook/m2m100_418M", help="Model name or path.") parser.add_argument("--num_candidates", "-n", type=int, default=1, @@ -52,18 +61,26 @@ def get_argparser() -> ArgumentParser: help="Maximum length of an output sentence.") parser.add_argument("--min_length", type=int, default=1, help="Minimum length of an output sentence.") + parser.add_argument("--length_penalty", type=float, default=None, + help="Length penalty.") parser.add_argument("--batch_size", "-b", type=int, default=8, help="Batch size.") parser.add_argument("--sampling_size", type=int, default=8, - help="Sampling size in a single inference.") + help="Sampling size in a single inference. " + "The model generates this number of samples at a time " + "until the total number of samples reaches `--num_candidates`.") parser.add_argument("--fp16", action="store_true", help="Use float16.") parser.add_argument("--bf16", action="store_true", help="Use bfloat16.") parser.add_argument("--cpu", action="store_true", help="Force to use CPU.") + parser.add_argument("--seed", type=int, default=0, + help="Random number seed.") parser.add_argument("--quiet", "-q", action="store_true", help="No report statistics.") + parser.add_argument("--report", default="-", type=FileType("w"), + help="Report file.") parser.add_argument("--report_format", type=str, default="rounded_outline", choices=tabulate_formats, help="Report runtime statistics.") @@ -77,7 +94,142 @@ def parse_args() -> Namespace: return get_argparser().parse_args() +@dataclass +class Sample: + """A sample generated by a model.""" + + text: str + lprob: Optional[float] = None + length_normalized_lprob: Optional[float] = None + + +def memory_efficient_compute_transition_scores( + model: GenerationMixin, + sequences: torch.Tensor, + scores: tuple[torch.Tensor, ...], + beam_indices: Optional[torch.Tensor] = None, + normalize_logits: bool = False, +) -> torch.Tensor: + """ + Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was + used). This is a convenient method to quicky obtain the scores of the selected tokens at generation time. + + This is another implementation for memory efficiency. + + Parameters: + model (`GenerationMixin`): Generation model. + sequences (`torch.LongTensor`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or + shorter if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)`): + Transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + beam_indices (`torch.LongTensor`, *optional*): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at + generate-time. + normalize_logits (`bool`, *optional*, defaults to `False`): + Whether to normalize the logits (which, for legacy reasons, may be unnormalized). + + Return: + `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing + the transition scores (logits) + + Examples: + + ```python + >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM + >>> import numpy as np + + >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer.pad_token_id = tokenizer.eos_token_id + >>> inputs = tokenizer(["Today is"], return_tensors="pt") + + >>> # Example 1: Print the scores for each token generated with Greedy Search + >>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True) + >>> transition_scores = memory_efficient_compute_transition_scores( + ... model, outputs.sequences, outputs.scores, normalize_logits=True + ... ) + >>> # input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for + >>> # encoder-decoder models, like BART or T5. + >>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] + >>> generated_tokens = outputs.sequences[:, input_length:] + >>> for tok, score in zip(generated_tokens[0], transition_scores[0]): + ... # | token | token string | log probability | probability + ... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}") + | 262 | the | -1.414 | 24.33% + | 1110 | day | -2.609 | 7.36% + | 618 | when | -2.010 | 13.40% + | 356 | we | -1.859 | 15.58% + | 460 | can | -2.508 | 8.14% + + >>> # Example 2: Reconstruct the sequence scores from Beam Search + >>> outputs = model.generate( + ... **inputs, + ... max_new_tokens=5, + ... num_beams=4, + ... num_return_sequences=4, + ... return_dict_in_generate=True, + ... output_scores=True, + ... ) + >>> transition_scores = memory_efficient_compute_transition_scores( + ... model, outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False + ... ) + >>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores. + >>> # Tip 1: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the + >>> # use case, you might want to recompute it with `normalize_logits=True`. + >>> # Tip 2: the output length does NOT include the input length + >>> output_length = np.sum(transition_scores.numpy() < 0, axis=1) + >>> length_penalty = model.generation_config.length_penalty + >>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty) + >>> print(np.allclose(outputs.sequences_scores, reconstructed_scores)) + True + ```""" + # 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent + # to a beam search approach were the first (and only) beam is always selected + if beam_indices is None: + beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device) + beam_indices = beam_indices.expand(-1, len(scores)) + + # 2. Optionally normalize the logits (across the vocab dimension) + if normalize_logits: + scores = tuple(s.log_softmax(dim=1) for s in scores) + + # 3. cut beam_indices to longest beam length + beam_indices_mask = beam_indices < 0 + max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() + beam_indices = beam_indices.clone()[:, :max_beam_length] + beam_indices_mask = beam_indices_mask[:, :max_beam_length] + + # 4. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards + beam_indices[beam_indices_mask] = 0 + + # 5. multiply beam_indices with vocab size to gather correctly from scores + beam_sequence_indices = beam_indices * model.config.vocab_size + + # 6. Define which indices contributed to scores + cut_idx = sequences.shape[-1] - max_beam_length + indices = sequences[:, cut_idx:] + beam_sequence_indices + + # 7. Compute scores + transition_scores = torch.zeros( + indices.size(), dtype=torch.float32, device=scores[0].device + ) + for step, step_scores in enumerate(scores): + transition_scores[:, step] = step_scores.view(-1).gather(0, indices[:, step]) + + # 8. Mask out transition_scores of beams that stopped early + transition_scores[beam_indices_mask] = 0 + + return transition_scores + + def main(args: Namespace) -> None: + set_seed(args.seed) + src_lang, tgt_lang = tuple(args.lang_pair.split("-")) tokenizer = AutoTokenizer.from_pretrained(args.model) model = AutoModelForSeq2SeqLM.from_pretrained(args.model) @@ -86,13 +238,24 @@ def main(args: Namespace) -> None: for param in model.parameters(): param.requires_grad = False if torch.cuda.is_available() and not args.cpu: - model.cuda() if args.fp16: model.half() elif args.bf16: model.bfloat16() + model.cuda() - generation_kwargs = {"max_length": args.max_length, "min_length": args.min_length} + generation_kwargs = { + "max_length": args.max_length, + "min_length": args.min_length, + "return_dict_in_generate": True, + } + length_penalty = getattr(model.generation_config, "length_penalty", 1.0) + if args.length_penalty is not None: + generation_kwargs["length_penalty"] = args.length_penalty + length_penalty = args.length_penalty + + if args.lprobs is not None or args.length_normalized_lprobs is not None: + generation_kwargs["output_scores"] = True if isinstance(model, M2M100ForConditionalGeneration): tokenizer.src_lang = src_lang @@ -102,40 +265,92 @@ def main(args: Namespace) -> None: generation_kwargs["do_sample"] = True generation_kwargs["epsilon_cutoff"] = args.epsilon generation_kwargs["num_beams"] = 1 + generation_kwargs["early_stopping"] = False else: generation_kwargs["num_beams"] = max(args.beam_size, args.num_candidates) def decode( inputs: list[str], num_candidates: int, generation_kwargs: dict[str, Any] - ) -> list[str]: + ) -> list[Sample]: model_inputs = tokenizer(inputs, return_tensors="pt", padding=True).to( device=model.device ) with timer.measure("generate"): - model_outputs = model.generate( + model_outputs: GenerateOutput = model.generate( **model_inputs, **generation_kwargs, num_return_sequences=num_candidates ) - return tokenizer.batch_decode(model_outputs, skip_special_tokens=True) - def generate(inputs: list[str]) -> list[str]: + if model_outputs.scores is None: + return [ + Sample(s) + for s in tokenizer.batch_decode( + model_outputs.sequences, skip_special_tokens=True + ) + ] + + sequences = model_outputs.sequences + scores = model_outputs.scores + max_length = len(scores) + if hasattr(tokenizer, "pad_token_id"): + pad_token_id = tokenizer.pad_token_id + for s in scores: + s[:, pad_token_id] = float("-inf") + scores = tuple(s.log_softmax(dim=1) for s in scores) + for s in scores: + s[:, pad_token_id] = 0.0 + sequence_lengths = max_length - (sequences.eq(pad_token_id)).sum(dim=-1) + else: + scores = tuple(s.log_softmax(dim=1) for s in scores) + sequence_lengths = torch.full( + (sequences.size(0),), fill_value=max_length, device=sequences.device + ) + + transition_scores = memory_efficient_compute_transition_scores( + model, + sequences, + scores, + beam_indices=model_outputs.get("beam_indices", None), + ) + lprobs = transition_scores.sum(dim=-1) + length_normalized_lprobs = lprobs / (sequence_lengths**length_penalty) + return [ + Sample(text, lprob=lprob, length_normalized_lprob=length_normalized_lprob) + for text, lprob, length_normalized_lprob in zip( + tokenizer.batch_decode(sequences, skip_special_tokens=True), + lprobs.cpu().tolist(), + length_normalized_lprobs.cpu().tolist(), + ) + ] + + def generate(inputs: list[str]) -> Generator[Sample, None, None]: if ( not generation_kwargs.get("do_sample", False) or generation_kwargs.get("num_beams", 1) != 1 ): - return decode(inputs, args.num_candidates, generation_kwargs) + yield from decode(inputs, args.num_candidates, generation_kwargs) else: - outputs: list[list[str]] = [[] for _ in range(args.batch_size)] + samples: list[list[Sample]] = [[] for _ in range(args.batch_size)] for n in range(0, args.num_candidates, args.sampling_size): sampling_size = min(args.sampling_size, args.num_candidates - n) - samples = decode(inputs, sampling_size, generation_kwargs) + shards = decode(inputs, sampling_size, generation_kwargs) for i in range(args.batch_size): - outputs[i] += samples[i * sampling_size : (i + 1) * sampling_size] - return list(chain.from_iterable(outputs)) + samples[i] += shards[i * sampling_size : (i + 1) * sampling_size] + yield from chain.from_iterable(samples) num_sentences = 0 for lines in buffer_lines(args.input, buffer_size=args.batch_size): - for output in generate(lines): - print(output.strip(), file=args.output) + for sample in generate(lines): + print(sample.text.strip(), file=args.output) + if sample.lprob is not None and args.lprobs is not None: + print(str(sample.lprob), file=args.lprobs) + if ( + sample.length_normalized_lprob is not None + and args.length_normalized_lprobs is not None + ): + print( + str(sample.length_normalized_lprob), + file=args.length_normalized_lprobs, + ) num_sentences += 1 if not args.quiet: @@ -146,7 +361,7 @@ def generate(inputs: list[str]) -> list[str]: tablefmt=args.report_format, floatfmt=f".{args.width}f", ) - print(table) + print(table, file=args.report) def cli_main():