@@ -193,12 +193,38 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
193193
194194class RepetitionPenaltyLogitsProcessor (LogitsProcessor ):
195195 r"""
196- [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
196+ [`LogitsProcessor`] that prevents the repetition of previous tokens through an exponential penalty. This technique
197+ shares some similarities with coverage mechanisms and other aimed at reducing repetition. During the text
198+ generation process, the probability distribution for the next token is determined using a formula that incorporates
199+ token scores based on their occurrence in the generated sequence. Tokens with higher scores are less likely to be
200+ selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
201+ paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
197202
198203 Args:
199204 repetition_penalty (`float`):
200205 The parameter for repetition penalty. 1.0 means no penalty. See [this
201206 paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
207+
208+ Examples:
209+
210+ ```py
211+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
212+
213+ >>> # Initializing the model and tokenizer for it
214+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
215+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
216+ >>> inputs = tokenizer(["I'm not going to"], return_tensors="pt")
217+
218+ >>> # This shows a normal generate without any specific parameters
219+ >>> summary_ids = model.generate(inputs["input_ids"], max_length=20)
220+ >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
221+ I'm not going to lie, I'm not going to lie. I'm not going to lie
222+
223+ >>> # This generates a penalty for repeated tokens
224+ >>> penalized_ids = model.generate(inputs["input_ids"], max_length=20, repetition_penalty=1.2)
225+ >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
226+ I'm not going to lie, I was really excited about this. It's a great game
227+ ```
202228 """
203229
204230 def __init__ (self , penalty : float ):
0 commit comments