Skip to content

Commit d094d8d

Browse files
voidismgante
andauthored
Generate: Add new decoding strategy "DoLa" in .generate() (huggingface#29619)
Co-authored-by: Joao Gante <joao@huggingface.co>
1 parent 99c0e55 commit d094d8d

File tree

7 files changed

+530
-5
lines changed

7 files changed

+530
-5
lines changed

docs/source/en/generation_strategies.md

+60-4
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te
178178

179179
The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value
180180
cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models.
181-
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
181+
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
182182

183183
KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache]
184184
(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper.
@@ -213,11 +213,11 @@ I like rock music because it's loud and energetic. I like to listen to it when I
213213

214214
## Watermarking
215215

216-
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
216+
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
217217
When generating the "green" will have a small 'bias' value added to their logits, thus having a higher chance to be generated.
218218
The watermarked text can be detected by calculating the proportion of "green" tokens in the text and estimating how likely it is
219-
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
220-
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
219+
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
220+
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
221221
the inner functioning of watermarking, it is recommended to refer to the paper.
222222

223223
The watermarking can be used with any generative model in `tranformers` and does not require an extra classification model
@@ -484,3 +484,59 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
484484

485485
Alternativelly, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
486486
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
487+
### DoLa Decoding
488+
489+
**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the
490+
hallucinations of LLMs, as described in this paper of ICLR 2024 [DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models](https://arxiv.org/abs/2309.03883).
491+
492+
DoLa is achieved by contrasting the differences in logits obtained from final
493+
layers versus earlier layers, thus amplify the factual knowledge localized to particular part of transformer layers.
494+
495+
Do the following two steps to activate DoLa decoding when calling the `model.generate` function:
496+
1. Set the `dola_layers` argument, which can be either a string or a list of integers.
497+
- If set to a string, it can be one of `low`, `high`.
498+
- If set to a list of integers, it should be a list of layer indices between 0 and the total number of layers in the model. The 0-th layer is word embedding, and the 1st layer is the first transformer layer, and so on.
499+
2. Set `repetition_penalty = 1.2` is suggested to reduce repetition in DoLa decoding.
500+
501+
See the following examples for DoLa decoding with the 32-layer LLaMA-7B model.
502+
503+
```python
504+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
505+
>>> import torch
506+
507+
>>> tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
508+
>>> model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", torch_dtype=torch.float16)
509+
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
510+
>>> model.to(device)
511+
>>> set_seed(42)
512+
513+
>>> text = "On what date was the Declaration of Independence officially signed?"
514+
>>> inputs = tokenizer(text, return_tensors="pt").to(device)
515+
516+
# Vanilla greddy decoding
517+
>>> vanilla_output = model.generate(**inputs, do_sample=False, max_new_tokens=50)
518+
>>> tokenizer.batch_decode(vanilla_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
519+
['\nThe Declaration of Independence was signed on July 4, 1776.\nWhat was the date of the signing of the Declaration of Independence?\nThe Declaration of Independence was signed on July 4,']
520+
521+
# DoLa decoding with contrasting higher part of layers (layers 16,18,...,30)
522+
>>> dola_high_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers='high')
523+
>>> tokenizer.batch_decode(dola_high_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
524+
['\nJuly 4, 1776, when the Continental Congress voted to separate from Great Britain. The 56 delegates to the Continental Congress signed the Declaration on August 2, 1776.']
525+
526+
# DoLa decoding with contrasting specific layers (layers 28 and 30)
527+
>>> dola_custom_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers=[28,30], repetition_penalty=1.2)
528+
>>> tokenizer.batch_decode(dola_custom_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
529+
['\nIt was officially signed on 2 August 1776, when 56 members of the Second Continental Congress, representing the original 13 American colonies, voted unanimously for the resolution for independence. The 2']
530+
```
531+
532+
#### Understanding the `dola_layers` argument
533+
534+
`dola_layers` stands for the candidate layers in premature layer selection, as described in the DoLa paper. The selected premature layer will be contrasted with the final layer.
535+
536+
Setting `dola_layers` to `'low'` or `'high'` will select the lower or higher part of the layers to contrast, respectively.
537+
- For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` are used for `'low'` and `'high'` layers, respectively.
538+
- For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for `'low'` and `'high'` layers, respectively.
539+
- If the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, as the early exit from word embeddings will become identity function.
540+
- Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. For example, setting `dola_layers=[28,30]` will contrast the final layer (32-th layer) with the 28-th and 30-th layers.
541+
542+
The paper suggested that contrasting `'high'` layers to improve short-answer tasks like TruthfulQA, and contrasting `'low'` layers to improve all the other long-answer reasoning tasks, such as GSM8K, StrategyQA, FACTOR, and VicunaQA. Applying DoLa to smaller models like GPT-2 is not recommended, as the results shown in the Appendix N of the paper.

src/transformers/generation/configuration_utils.py

+38
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class GenerationMode(ExplicitEnum):
6060
GREEDY_SEARCH = "greedy_search"
6161
SAMPLE = "sample"
6262
ASSISTED_GENERATION = "assisted_generation"
63+
DOLA_GENERATION = "dola_generation"
6364
# Beam methods
6465
BEAM_SEARCH = "beam_search"
6566
BEAM_SAMPLE = "beam_sample"
@@ -81,6 +82,7 @@ class GenerationConfig(PushToHubMixin):
8182
- *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
8283
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
8384
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
85+
- *dola decoding* if `dola_layers` is passed to `.generate()`
8486
8587
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
8688
@@ -305,6 +307,18 @@ class GenerationConfig(PushToHubMixin):
305307
max_matching_ngram_size (`int`, *optional*, default to `None`):
306308
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
307309
310+
> Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883)
311+
312+
dola_layers (`str` or `List[int]`, *optional*):
313+
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
314+
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
315+
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
316+
layers up to the last 20 layers.
317+
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
318+
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
319+
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
320+
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
321+
308322
> Parameters specific to the caching mechanism:
309323
310324
cache_implementation (`str`, *optional*, default to `None`):
@@ -397,6 +411,9 @@ def __init__(self, **kwargs):
397411
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
398412
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
399413

414+
# DoLa generation
415+
self.dola_layers = kwargs.pop("dola_layers", None)
416+
400417
# Cache implementation
401418
self.cache_implementation = kwargs.pop("cache_implementation", None)
402419
self.cache_config = kwargs.pop("cache_config", None)
@@ -495,6 +512,16 @@ def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = Non
495512
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
496513
"is only supported with Greedy Search and Sample."
497514
)
515+
516+
# DoLa generation may extend some generation modes
517+
if self.dola_layers is not None:
518+
if generation_mode in ("greedy_search", "sample"):
519+
generation_mode = GenerationMode.DOLA_GENERATION
520+
else:
521+
raise ValueError(
522+
"You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate "
523+
"is only supported with Greedy Search and Sample."
524+
)
498525
return generation_mode
499526

500527
def validate(self, is_init=False):
@@ -700,6 +727,17 @@ def validate(self, is_init=False):
700727
"`generate()` (or a pipeline) directly."
701728
)
702729

730+
# 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2
731+
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
732+
dola_decoding_wrong_parameter_msg = (
733+
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, "
734+
"which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`."
735+
)
736+
warnings.warn(
737+
dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty),
738+
UserWarning,
739+
)
740+
703741
def save_pretrained(
704742
self,
705743
save_directory: Union[str, os.PathLike],

0 commit comments

Comments
 (0)