Description
Feature request
Hello, I would like to propose a feature that allows you to set a list of tokens, or even token sequences, that can be excluded from repetition penalty calculations.
The reasoning for this being that, given a prompt format for multiturn, such as:
user: abc
assistant: def
user: ghi
assistant: jkl
Or even worse, a format like ChatML, where it is in now standard case using <|im_end|> as a stopping token and included in every turn, it seems only logical that given these tokens all appear in every turn, that, especially in short token turn sequences, repetition penalty will destroy the validity of these prompt formats.
Motivation
Your contribution
I will soon attempt a solution but I really think I will get it wrong, nonetheless, here is my theory on how to do it in TGI:
This seems to be the rep penalty code:
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
repetition_penalty (`List[float]`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty
self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
scores.scatter_(1, input_ids, score)
return scores
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]
if any([x != 1.0 for x in self.penalty]):
self.penalty_tensor = self.penalty_tensor[indices]
return self
return None
I'm thinking we take the input_id's before getting scored and simply replacing it with input_id's that remove any from some variable setting a list of token ids or token strings->id's.