Description
EDIT: ===========================
As I see many people copy pasting this initial code that was meant to be a basis for discussion, here is a cleaner version (yet not perfect! We're still doing improvement rounds with the huggingface team to improve it! Check the state of the PR until it's not merged! #24654 ).
from transformers import (GPT2Tokenizer, AutoModelForCausalLM,
GPTNeoXForCausalLM, AutoTokenizer)
import numpy as np
import torch
from transformers import (LogitsProcessor, LogitsProcessorList,
MinLengthLogitsProcessor, TemperatureLogitsWarper,
TopKLogitsWarper, TopPLogitsWarper,
TypicalLogitsWarper)
from transformers.generation import LogitNormalization
import torch.nn.functional as F
class CFGLogits(LogitsProcessor):
r"""Logits processor for Classifier-Free Guidance (CFG). The processors
computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits,
parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with
the `uncond` branch. Finally, according to CFG Rescale, the reweighted logits are interpolated back with weight
`rescale_factor` the conditional ones to smooth the effect and increase output quality.
See [the paper](https://arxiv.org/abs/2306.17806) for more information.
Args:
guidance_scale (float):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
uncond (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the unconditional branch.
model:
The LM computing the unconditional scores. Supposedly the same as the one computing the conditional scores.
Both models must use the same tokenizer.
"""
def __init__(self, guidance_scale, uncond, model):
self.guidance_scale = guidance_scale
self.uncond = uncond
self.model = model
self.out = None
self.rescale_factor = rescale_factor
def __call__(self, input_ids, scores):
scores = F.log_softmax(scores, dim=-1)
if self.guidance_scale == 1:
return scores
if self.out is None:
self.out = self.model(self.uncond, use_cache=True)
else:
self.out = self.model(
input_ids[:, -1:],
use_cache=True,
past_key_values=self.out.past_key_values,
)
unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
return out
# paper usage: (copying and editing @grantCelley 's answer)
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopPLogitsWarper
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")
prompt = tokenizer("Today a dragon flew over Paris, France,", return_tensors='pt')
# either provide a negative prompt:
neg_prompt = tokenizer("A sad event happened,", return_tensors='pt')['input_ids']
# or don't:
# neg_prompt = prompt['input_ids'][:, -1:]
device='cuda:0'
model.to(device)
outputs = model.generate(
input_ids=prompt['input_ids'].to(device),
attention_mask=prompt['attention_mask'].to(device),
max_new_tokens=125,
logits_processor=LogitsProcessorList([
# inputs_cfg usually is the last token of the prompt but there are
# possibilities of negative prompting that are explored in the paper
CFGLogits(1.5, neg_prompt.to(device), model),
TemperatureLogitsWarper(0.8),
TopPLogitsWarper(0.95),
]),
do_sample=True,
)
print(tokenizer.decode(outputs[0]))
===============================
Feature request
Hello!
I wish to contribute CFG sampling. I'm working with EleutherAI and @StellaAthena and will have a paper about it by Friday. CFG brings non trivial improvements on many standard benchmarks. It contrast the logits for the next token
And then we can blend
Motivation
My current implementation is:
class CFGLogits(LogitsWarper):
def __init__(self, cfg, inputs, model, verbose=True):
self.cfg = cfg
self.inputs = inputs
self.model = model
self.out = None
self.verbose = verbose
def __call__(self, input_ids, scores):
if self.cfg == 1:
return F.log_softmax(scores, dim=-1)
scores = F.log_softmax(scores, dim=-1)
if self.out is None:
self.out = self.model(self.inputs.to(device), use_cache=True)
else:
self.out = self.model(input_ids[:, -1:],
use_cache=True,
past_key_values=self.out.past_key_values)
unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
out = self.cfg * (scores - unconditional_logits) + unconditional_logits
out = F.log_softmax(out, dim=-1)
return 0.7 * out + 0.3 * scores
# usage:
outputs = model.generate(
input_ids=inputs['input_ids'].to(device),
attention_mask=inputs['attention_mask'].to(device),
max_new_tokens=l,
logits_processor=LogitsProcessorList([
# inputs_cfg usually is the last token of the prompt but there are
# possibilities of negative prompting that are explored in the paper
CFGLogits(cfg, inputs_cfg, model),
TemperatureLogitsWarper(0.8),
TopPLogitsWarper(0.95),
]),
do_sample=True,
)
I am not familiar enough with the design guidelines of HF to know if this implementation as a LogitsWarper is satisfactory.
just a few figures supporting the claims:
Your contribution
I can contribute the code but I need to be guided as I don't know the exact design guidelines and overall architecture of HF.
Thank you for your time!