Skip to content

Commit 95f4121

Browse files
larekrowEduardoPach
authored andcommitted
Update logits_process.py docstrings to clarify penalty and reward cases (attempt #2) (huggingface#26784)
* Update logits_process.py docstrings + match arg fields to __init__'s * Ran `make style`
1 parent 1f23591 commit 95f4121

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

src/transformers/generation/logits_process.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,14 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
276276
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
277277
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
278278
279+
This technique can also be used to reward and thus encourage repetition in a similar manner. To penalize and reduce
280+
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
281+
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
282+
279283
Args:
280-
repetition_penalty (`float`):
281-
The parameter for repetition penalty. 1.0 means no penalty. See [this
284+
penalty (`float`):
285+
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
286+
tokens. Between 0.0 and 1.0 rewards previously generated tokens. See [this
282287
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
283288
284289
Examples:
@@ -313,7 +318,7 @@ def __init__(self, penalty: float):
313318
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
314319
score = torch.gather(scores, 1, input_ids)
315320

316-
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
321+
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
317322
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
318323

319324
scores.scatter_(1, input_ids, score)
@@ -322,11 +327,18 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
322327

323328
class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
324329
r"""
325-
[`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input.
330+
[`LogitsProcessor`] that avoids hallucination by boosting the probabilities of tokens found within the original
331+
input.
332+
333+
This technique can also be used to reward and thus encourage hallucination (or creativity) in a similar manner. To
334+
penalize and reduce hallucination, use `penalty` values above 1.0, where a higher value penalizes more strongly. To
335+
reward and encourage hallucination, use `penalty` values between 0.0 and 1.0, where a lower value rewards more
336+
strongly.
326337
327338
Args:
328-
hallucination_penalty (`float`):
329-
The parameter for hallucination penalty. 1.0 means no penalty.
339+
penalty (`float`):
340+
The parameter for hallucination penalty. 1.0 means no penalty. Above 1.0 penalizes hallucination. Between
341+
0.0 and 1.0 rewards hallucination.
330342
encoder_input_ids (`torch.LongTensor`):
331343
The encoder_input_ids that should be repeated within the decoder ids.
332344
"""
@@ -342,7 +354,7 @@ def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
342354
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
343355
score = torch.gather(scores, 1, self.encoder_input_ids)
344356

345-
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
357+
# if score < 0 then hallucination penalty has to be multiplied to increase the token probabilities
346358
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
347359

348360
scores.scatter_(1, self.encoder_input_ids, score)

0 commit comments

Comments
 (0)