Skip to content

Commit 0c41765

Browse files
authored
[DOCS] Example for LogitsProcessor class (#24848)
* make docs * fixup * resolved * remove debugs * Revert "fixup" This reverts commit 5e0f636. * prev (ignore) * fixup broke some files * remove files * reverting modeling_reformer * lang fix
1 parent 35c0459 commit 0c41765

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

src/transformers/generation/logits_process.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,38 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
193193

194194
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
195195
r"""
196-
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
196+
[`LogitsProcessor`] that prevents the repetition of previous tokens through an exponential penalty. This technique
197+
shares some similarities with coverage mechanisms and other aimed at reducing repetition. During the text
198+
generation process, the probability distribution for the next token is determined using a formula that incorporates
199+
token scores based on their occurrence in the generated sequence. Tokens with higher scores are less likely to be
200+
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
201+
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
197202
198203
Args:
199204
repetition_penalty (`float`):
200205
The parameter for repetition penalty. 1.0 means no penalty. See [this
201206
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
207+
208+
Examples:
209+
210+
```py
211+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
212+
213+
>>> # Initializing the model and tokenizer for it
214+
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
215+
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
216+
>>> inputs = tokenizer(["I'm not going to"], return_tensors="pt")
217+
218+
>>> # This shows a normal generate without any specific parameters
219+
>>> summary_ids = model.generate(inputs["input_ids"], max_length=20)
220+
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
221+
I'm not going to lie, I'm not going to lie. I'm not going to lie
222+
223+
>>> # This generates a penalty for repeated tokens
224+
>>> penalized_ids = model.generate(inputs["input_ids"], max_length=20, repetition_penalty=1.2)
225+
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
226+
I'm not going to lie, I was really excited about this. It's a great game
227+
```
202228
"""
203229

204230
def __init__(self, penalty: float):

0 commit comments

Comments
 (0)