Skip to content

Commit

Permalink
add CFG for .generate()
Browse files Browse the repository at this point in the history
  • Loading branch information
Vermeille committed Jul 5, 2023
1 parent cd4584e commit 8141f46
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 4 deletions.
12 changes: 11 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,15 @@ class GenerationConfig(PushToHubMixin):
guidance_scale (`float`, *optional*):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
prompt, usually at the expense of poorer quality. If `negative_prompt` is unset, this will set it to the
last input token.
negative_prompt (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
The negative prompt to use for CFG. Will throw an error if set but `cfg_scale` <= 1 or None. The batch size
must match the input batch size.
cfg_rescale (float, *optional*):
Rescale CFG for improved quality. Expected values in range [0,1]. Reducing its value smoothens the CFG
effects and recovers some lost quality. Lower values allows for higher guidance scales. No effect unless
`guidance_scale` is set too. Default: 1.0.
> Parameters that define the output variables of `generate`
Expand Down Expand Up @@ -270,6 +278,8 @@ def __init__(self, **kwargs):
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.sequence_bias = kwargs.pop("sequence_bias", None)
self.guidance_scale = kwargs.pop("guidance_scale", None)
self.negative_prompt = kwargs.pop("negative_prompt", None)
self.cfg_rescale = kwargs.pop("cfg_rescale", 1.0)

# Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
Expand Down
54 changes: 54 additions & 0 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
import torch
import torch.nn.functional as F

from ..utils import add_start_docstrings
from ..utils.logging import get_logger
Expand Down Expand Up @@ -1102,3 +1103,56 @@ def __call__(self, input_ids, scores):
cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0)
scores = uncond_logits + (cond_logits - uncond_logits) * self.guidance_scale
return scores


class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
r"""Logits processor for Classifier-Free Guidance (CFG). The processors
computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits,
parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with
the `uncond` branch. Finally, according to CFG Rescale, the reweighted logits are interpolated back with weight
`rescale_factor` the conditional ones to smooth the effect and increase output quality.
See [the paper](https://arxiv.org/abs/2306.17806) for more information.
Args:
guidance_scale (float):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
uncond (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the unconditional branch.
model:
The LM computing the unconditional scores. Supposedly the same as the one computing the conditional scores.
Both models must use the same tokenizer.
smooth_factor (float):
The interpolation weight for CFG Rescale. 1 means no rescaling, 0 reduces to the conditional scores without
CFG. Turn it lower if the output degenerates.
"""

def __init__(self, guidance_scale, uncond, model, rescale_factor=0.7):
self.guidance_scale = guidance_scale
self.uncond = uncond
self.model = model
self.out = None
self.rescale_factor = rescale_factor

def __call__(self, input_ids, scores):
scores = F.log_softmax(scores, dim=-1)
if self.guidance_scale == 1:
return scores

if self.out is None:
self.out = self.model(self.uncond, attention_mask=torch.ones_like(self.uncond[:, -1:]), use_cache=True)
else:
self.out = self.model(
input_ids[:, -1:],
attention_mask=torch.ones_like(input_ids[:, -1:]),
use_cache=True,
past_key_values=self.out.past_key_values,
)
unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
out = F.log_softmax(out, dim=-1)
if self.rescale_factor == 1:
return out
return self.rescale_factor * out + (1 - self.rescale_factor) * scores
13 changes: 10 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .configuration_utils import GenerationConfig
from .logits_process import (
ClassifierFreeGuidanceLogitsProcessor,
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
Expand All @@ -64,6 +63,7 @@
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from .stopping_criteria import (
MaxLengthCriteria,
Expand Down Expand Up @@ -844,6 +844,15 @@ def _get_logits_processor(
# instantiate processors list
processors = LogitsProcessorList()

if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
neg_prompt = encoder_input_ids[:, -1:]
if generation_config.negative_prompt is not None:
neg_prompt = generation_config.negative_prompt
processors.append(
UnbatchedClassifierFreeGuidanceLogitsProcessor(
generation_config.guidance_scale, neg_prompt, self, generation_config.cfg_rescale
)
)
if generation_config.sequence_bias is not None:
processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))

Expand Down Expand Up @@ -941,8 +950,6 @@ def _get_logits_processor(
)
if generation_config.forced_decoder_ids is not None:
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
processors.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
Expand Down

0 comments on commit 8141f46

Please sign in to comment.