Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new sequence_bias param type for SequenceBiasLogitsProcessor #33362

8 changes: 4 additions & 4 deletions docs/source/en/chat_templating.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ Not all models require generation prompts. Some models, like LLaMA, don't have a
special tokens before bot responses. In these cases, the `add_generation_prompt` argument will have no effect. The exact
effect that `add_generation_prompt` has will depend on the template being used.

## What does "continue_last_message" do?
## What does "continue_final_message" do?
VladOS95-cyber marked this conversation as resolved.
Show resolved Hide resolved

When passing a list of messages to `apply_chat_template` or `TextGenerationPipeline`, you can choose
to format the chat so the model will continue the final message in the chat instead of starting a new one. This is done
Expand All @@ -211,15 +211,15 @@ chat = [
{"role": "assistant", "content": '{"name": "'},
]

formatted_chat = tokenizer.apply_chat_template(chat, tokenize=True, return_dict=True, continue_last_message=True)
formatted_chat = tokenizer.apply_chat_template(chat, tokenize=True, return_dict=True, continue_final_message=True)
model.generate(**formatted_chat)
```

The model will generate text that continues the JSON string, rather than starting a new message. This approach
can be very useful for improving the accuracy of the model's instruction-following when you know how you want
it to start its replies.

Because `add_generation_prompt` adds the tokens that start a new message, and `continue_last_message` removes any
Because `add_generation_prompt` adds the tokens that start a new message, and `continue_final_message` removes any
end-of-message tokens from the final message, it does not make sense to use them together. As a result, you'll
get an error if you try!

Expand All @@ -228,7 +228,7 @@ get an error if you try!
The default behaviour of `TextGenerationPipeline` is to set `add_generation_prompt=True` so that it starts a new
message. However, if the final message in the input chat has the "assistant" role, it will assume that this message is
a prefill and switch to `continue_final_message=True` instead, because most models do not support multiple
consecutive assistant messages. You can override this behaviour by explicitly passing the `continue_last_message`
consecutive assistant messages. You can override this behaviour by explicitly passing the `continue_final_message`
argument when calling the pipeline.

</Tip>
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/community.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ This page regroups resources around 🤗 Transformers developed by the community
| [Detect objects in an image with DETR](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_minimal_example_(with_DetrFeatureExtractor).ipynb) | How to use a trained *DetrForObjectDetection* model to detect objects in an image and visualize attention | [Niels Rogge](https://github.com/NielsRogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_minimal_example_(with_DetrFeatureExtractor).ipynb) |
| [Fine-tune DETR on a custom object detection dataset](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DETR/Fine_tuning_DetrForObjectDetection_on_custom_dataset_(balloon).ipynb) | How to fine-tune *DetrForObjectDetection* on a custom object detection dataset | [Niels Rogge](https://github.com/NielsRogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DETR/Fine_tuning_DetrForObjectDetection_on_custom_dataset_(balloon).ipynb) |
| [Finetune T5 for Named Entity Recognition](https://github.com/ToluClassics/Notebooks/blob/main/T5_Ner_Finetuning.ipynb) | How to fine-tune *T5* on a Named Entity Recognition Task | [Ogundepo Odunayo](https://github.com/ToluClassics) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1obr78FY_cBmWY5ODViCmzdY6O1KB65Vc?usp=sharing) |
| [Fine-Tuning Open-Source LLM using QLoRA with MLflow and PEFT](https://github.com/mlflow/mlflow/blob/master/docs/source/llms/transformers/tutorials/fine-tuning/transformers-peft.ipynb) | How to use [QLoRA](https://github.com/artidoro/qlora) and [PEFT](https://huggingface.co/docs/peft/en/index) to fine-tune an LLM in a memory-efficient way, while using [MLflow](https://mlflow.org/docs/latest/llms/transformers/index.html) to manage experiment tracking | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlflow/mlflow/blob/master/docs/source/llms/transformers/tutorials/fine-tuning/transformers-peft.ipynb) |
| [Fine-Tuning Open-Source LLM using QLoRA with MLflow and PEFT](https://github.com/mlflow/mlflow/blob/master/docs/source/llms/transformers/tutorials/fine-tuning/transformers-peft.ipynb) | How to use [QLoRA](https://github.com/artidoro/qlora) and [PEFT](https://huggingface.co/docs/peft/en/index) to fine-tune an LLM in a memory-efficient way, while using [MLflow](https://mlflow.org/docs/latest/llms/transformers/index.html) to manage experiment tracking | [Yuki Watanabe](https://github.com/B-Step62) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlflow/mlflow/blob/master/docs/source/llms/transformers/tutorials/fine-tuning/transformers-peft.ipynb) |
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [mBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
Expand Down
49 changes: 37 additions & 12 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,8 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
</Tip>

Args:
sequence_bias (`Dict[Tuple[int], float]`):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence_bias (`List[List[Union[List[int], float]]]`):
List of lists that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
completed (in the token selection step after this processor is applied).
Expand All @@ -1087,12 +1087,12 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)


>>> def get_tokens_as_tuple(word):
... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
>>> def get_tokens(word):
... return tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]


>>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
>>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
>>> sequence_bias = [get_tokens("Trump"), -10.0]
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald J. Donald,
Expand All @@ -1102,16 +1102,17 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
The full name of Donald is Donald Rumsfeld,

>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
>>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
>>> sequence_bias = [get_tokens("Donald Duck"): 10.0]
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald Duck.
```
"""

def __init__(self, sequence_bias: Dict[Tuple[int], float]):
def __init__(self, sequence_bias: List[List[Union[List[int], float]]]):
self.sequence_bias = sequence_bias
self._validate_arguments()
self._convert_list_arguments_into_dict()

# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
# is infered in the first usage, which inhibits initializing here)
Expand Down Expand Up @@ -1178,11 +1179,15 @@ def _prepare_bias_variables(self, scores: torch.FloatTensor):

def _validate_arguments(self):
sequence_bias = self.sequence_bias
if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
raise ValueError(f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}.")
if any(not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()):
if not isinstance(sequence_bias, dict) and not isinstance(sequence_bias, list) or len(sequence_bias) == 0:
raise ValueError(
f"`sequence_bias` has to be a non-empty dictionary, or non-empty list of lists but is {sequence_bias}."
)
if isinstance(sequence_bias, dict) and any(
not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()
):
raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.")
if any(
if isinstance(sequence_bias, dict) and any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids)
or len(sequence_ids) == 0
for sequence_ids in sequence_bias.keys()
Expand All @@ -1191,9 +1196,29 @@ def _validate_arguments(self):
f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
f"{sequence_bias}."
)
if any(not isinstance(bias, float) for bias in sequence_bias.values()):

all_token_bias_pairs_are_valid = lambda token_bias_pair: (
isinstance(token_bias_pair, list)
and all(isinstance(token_id, (int, np.integer)) and token_id > 0 for token_id in token_bias_pair)
or isinstance(token_bias_pair, float)
)
if isinstance(sequence_bias, list) and any(
(not all_token_bias_pairs_are_valid(token_bias_pair) for token_bias_pair in sequence) or len(sequence) == 0
for sequence in sequence_bias
):
raise ValueError(
f"Each element in `sequence_bias` has to be a non-empty list of lists of positive integers and float, but is "
f"{sequence_bias}."
)
if isinstance(sequence_bias, dict) and any(not isinstance(bias, float) for bias in sequence_bias.values()):
raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.")

def _convert_list_arguments_into_dict(self):
VladOS95-cyber marked this conversation as resolved.
Show resolved Hide resolved
"""BC: we used to accept `dict{tuple of tokens: float}`"""
if isinstance(self.sequence_bias, list):
temp_sequence = self.sequence_bias
self.sequence_bias = {tuple(sublist[0]): sublist[1] for sublist in temp_sequence}


class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
"""
Expand Down
Loading