Skip to content

Add Classifier-Free Guidance sampling #24536

Closed
@Vermeille

Description

@Vermeille

EDIT: ===========================
As I see many people copy pasting this initial code that was meant to be a basis for discussion, here is a cleaner version (yet not perfect! We're still doing improvement rounds with the huggingface team to improve it! Check the state of the PR until it's not merged! #24654 ).

from transformers import (GPT2Tokenizer, AutoModelForCausalLM,
                          GPTNeoXForCausalLM, AutoTokenizer)
import numpy as np
import torch
from transformers import (LogitsProcessor, LogitsProcessorList,
                          MinLengthLogitsProcessor, TemperatureLogitsWarper,
                          TopKLogitsWarper, TopPLogitsWarper,
                          TypicalLogitsWarper)
from transformers.generation import LogitNormalization
import torch.nn.functional as F

class CFGLogits(LogitsProcessor):
    r"""Logits processor for Classifier-Free Guidance (CFG). The processors
    computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits,
    parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with
    the `uncond` branch. Finally, according to CFG Rescale, the reweighted logits are interpolated back with weight
    `rescale_factor` the conditional ones to smooth the effect and increase output quality.

    See [the paper](https://arxiv.org/abs/2306.17806) for more information.

    Args:
        guidance_scale (float):
            The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
            Higher guidance scale encourages the model to generate samples that are more closely linked to the input
            prompt, usually at the expense of poorer quality.
        uncond (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary for the unconditional branch.
        model:
            The LM computing the unconditional scores. Supposedly the same as the one computing the conditional scores.
            Both models must use the same tokenizer.
    """

    def __init__(self, guidance_scale, uncond, model):
        self.guidance_scale = guidance_scale
        self.uncond = uncond
        self.model = model
        self.out = None
        self.rescale_factor = rescale_factor

    def __call__(self, input_ids, scores):
        scores = F.log_softmax(scores, dim=-1)
        if self.guidance_scale == 1:
            return scores

        if self.out is None:
            self.out = self.model(self.uncond, use_cache=True)
        else:
            self.out = self.model(
                input_ids[:, -1:],
                use_cache=True,
                past_key_values=self.out.past_key_values,
            )
        unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
        out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
        return out
        

# paper usage: (copying and editing @grantCelley 's answer)
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopPLogitsWarper

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")

model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")

prompt = tokenizer("Today a dragon flew over Paris, France,", return_tensors='pt')
# either provide a negative prompt:
neg_prompt = tokenizer("A sad event happened,", return_tensors='pt')['input_ids']
# or don't:
# neg_prompt = prompt['input_ids'][:, -1:]

device='cuda:0'
model.to(device)
outputs = model.generate(
    input_ids=prompt['input_ids'].to(device),
    attention_mask=prompt['attention_mask'].to(device),
    max_new_tokens=125,
    logits_processor=LogitsProcessorList([
        # inputs_cfg usually is the last token of the prompt but there are
        # possibilities of negative prompting that are explored in the paper
        CFGLogits(1.5, neg_prompt.to(device), model),
        TemperatureLogitsWarper(0.8),
        TopPLogitsWarper(0.95),
    ]),
    do_sample=True,
)

print(tokenizer.decode(outputs[0]))

===============================

Feature request

Hello!
I wish to contribute CFG sampling. I'm working with EleutherAI and @StellaAthena and will have a paper about it by Friday. CFG brings non trivial improvements on many standard benchmarks. It contrast the logits for the next token $P(w_t|w_{..t}, prompt)$ to that of the input deprived of the prompt $P(w_t|w_{..t})$, by defining

$$ \log P_{\text{cfg}}(w|w_{..t}, prompt) = \log P(w|w_{..t}) + \text{cfg} * (\log P(w|w_{..t}, prompt) - \log P(w|w_{..t}) $$

And then we can blend $\log P_{\text{cfg}}$ with $\log P(w|w_{..t}, prompt)$ to smoothen that distribution a bit, but it's optional.

Motivation

My current implementation is:

class CFGLogits(LogitsWarper):

    def __init__(self, cfg, inputs, model, verbose=True):
        self.cfg = cfg
        self.inputs = inputs
        self.model = model
        self.out = None
        self.verbose = verbose

    def __call__(self, input_ids, scores):
        if self.cfg == 1:
            return F.log_softmax(scores, dim=-1)
        scores = F.log_softmax(scores, dim=-1)
        if self.out is None:
            self.out = self.model(self.inputs.to(device), use_cache=True)
        else:
            self.out = self.model(input_ids[:, -1:],
                                  use_cache=True,
                                  past_key_values=self.out.past_key_values)
        unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
        out = self.cfg * (scores - unconditional_logits) + unconditional_logits
        out = F.log_softmax(out, dim=-1)
        return 0.7 * out + 0.3 * scores

# usage:

outputs = model.generate(
    input_ids=inputs['input_ids'].to(device),
    attention_mask=inputs['attention_mask'].to(device),
    max_new_tokens=l,
    logits_processor=LogitsProcessorList([
        # inputs_cfg usually is the last token of the prompt but there are
        # possibilities of negative prompting that are explored in the paper
        CFGLogits(cfg, inputs_cfg, model),
        TemperatureLogitsWarper(0.8),
        TopPLogitsWarper(0.95),
    ]),
    do_sample=True,
)

I am not familiar enough with the design guidelines of HF to know if this implementation as a LogitsWarper is satisfactory.

just a few figures supporting the claims:
flops
image
image

image
image

Your contribution

I can contribute the code but I need to be guided as I don't know the exact design guidelines and overall architecture of HF.

Thank you for your time!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions