Skip to content

Commit

Permalink
🧽 Fix judge documentation (#2318)
Browse files Browse the repository at this point in the history
* Update judge examples and documentation

* without ':'

* Clean doc

* Fix typo in example code

* Add space after Attributes

* Update attribute name in judges.py

* Add installation instructions for llm-blender library

* Update PairRMJudge attributes documentation

* Fix return type in PairRMJudge
  • Loading branch information
qgallouedec authored Nov 4, 2024
1 parent 54b106d commit 337005d
Showing 1 changed file with 79 additions and 53 deletions.
132 changes: 79 additions & 53 deletions trl/trainer/judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class BaseRankJudge(ABC):
"""
Base class for LLM ranking judges.
Example:
**Example**:
```python
class MyRankJudge(BaseRankJudge):
def judge(self, prompts, completions, shuffle_order=True):
Expand All @@ -96,13 +96,18 @@ def judge(self, prompts: List[str], completions: List[List[str]], shuffle_order:
Judge the completion for the given prompts and return the ranks of each completion.
Args:
prompts (`List[str]`): List of prompts.
completions (`List[List[str]]`): List of completions list, where each element is a list of completions for the corresponding prompt.
shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias.
prompts (`List[str]`):
List of prompts.
completions (`List[List[str]]`):
List of completions list, where each element is a list of completions for the corresponding prompt.
shuffle_order (`bool`, *optional*, defaults to `True`):
Whether to shuffle the order of the completions to avoid positional bias.
Returns:
List of lists of idxs, where each list contains the ranks of the completions for the corresponding prompt.
E.g., [1, 2, 0] means that the second completion (idx=1) is the best, followed by the third, and then the first.
`List[List[int]]`:
List of lists of idxs, where each list contains the ranks of the completions for the corresponding
prompt. E.g., `[1, 2, 0]` means that the second completion (`idx=1`) is the best, followed by the
third, and then the first.
"""
raise NotImplementedError("Judge subclasses must implement the `judge` method.")

Expand All @@ -118,18 +123,23 @@ def judge(self, prompts: List[str], completions: List[List[str]], shuffle_order:
Judge the completion pairs for the given prompts.
Args:
prompts (`List[str]`): List of prompts.
completions (`List[List[str]]`): List of completions pairs, where each element is a pair of completions for the corresponding prompt.
shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias.
prompts (`List[str]`):
List of prompts.
completions (`List[List[str]]`):
List of completions pairs, where each element is a pair of completions for the corresponding prompt.
shuffle_order (`bool`, *optional*, defaults to `True`):
Whether to shuffle the order of the completions to avoid positional bias.
Returns:
List of idxs, where each idx is the rank of the best completion for the corresponding prompt.
E.g., 1 means that the second completion (idx=1) is the best.
`List[int]`:
List of idxs, where each idx is the rank of the best completion for the corresponding prompt.
E.g., `1` means that the second completion (`idx=1`) is the best.
Note:
If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference has failed.
For instance, this could occur if the underlying language model returned an invalid answer.
In such cases, the caller should handle these invalid indices appropriately, possibly by implementing fallback logic or error handling.
If the judge returns `-1` for any prompt, it indicates that the inner process used to compute the
preference has failed. For instance, this could occur if the underlying language model returned an invalid
answer. In such cases, the caller should handle these invalid indices appropriately, possibly by
implementing fallback logic or error handling.
"""
raise NotImplementedError("Judge subclasses must implement the `judge` method.")

Expand Down Expand Up @@ -157,30 +167,34 @@ class PairRMJudge(BasePairwiseJudge):
"""
LLM judge based on the PairRM model from AllenAI.
This judge uses the PairRM model to rank pairs of completions for given prompts.
It's designed for pairwise comparison of language model outputs.
The PairRM model is loaded using the llm-blender library and runs on the
This judge uses the PairRM model to rank pairs of completions for given prompts. It's designed for pairwise
comparison of language model outputs. The PairRM model is loaded using the llm-blender library and runs on the
default Accelerator device.
Attributes:
blender (llm_blender.Blender): An instance of the Blender class from llm-blender.
**Attributes**:
blender (`llm_blender.Blender`):
An instance of the Blender class from llm-blender.
**Example**:
```python
>>> pairrm_judge = PairRMJudge()
>>> prompts = ["Translate 'hello' to French", "What's the capital of Japan?"]
>>> completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]]
>>> results = pairrm_judge.judge(prompts, completions)
>>> print(results) # [0, 1] (indicating the first completion is preferred for the first prompt and the second)
```
<Tip>
Example:
>>> pairrm_judge = PairRMJudge()
>>> prompts = ["Translate 'hello' to French", "What's the capital of Japan?"]
>>> completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]]
>>> results = pairrm_judge.judge(prompts, completions)
>>> print(results) # [0, 1] (indicating the first completion is preferred for the first prompt and the second)
This class requires the llm-blender library to be installed. Install it with: `pip install llm-blender`.
Note:
This class requires the llm-blender library to be installed.
Install it with: pip install llm-blender
</Tip>
"""

def __init__(self):
if not is_llm_blender_available():
raise ValueError("llm-blender is not installed. Please install it with 'pip install llm-blender'.")
raise ValueError("llm-blender is not installed. Please install it with `pip install llm-blender`.")
self.blender = llm_blender.Blender()
self.blender.loadranker("llm-blender/PairRM", device=Accelerator().device)

Expand All @@ -196,25 +210,29 @@ def judge(
Judge the completion pairs for the given prompts using the PairRM model.
Args:
prompts (List[str]): List of prompts to judge.
completions (List[List[str]]): List of completion pairs for each prompt.
shuffle_order (bool, optional): Whether to shuffle the order of completions
to avoid positional bias. Defaults to True.
return_scores (bool, optional): If True, return probability scores instead of ranks (i.e. a soft-judge).
Defaults to False.
temperature (float, optional): Temperature for scaling logits if return_scores
is True. Defaults to 1.0.
prompts (`List[str]`):
List of prompts to judge.
completions (`List[List[str]]`):
List of completion pairs for each prompt.
shuffle_order (`bool`, *optional*, defaults to `True`):
Whether to shuffle the order of the completions to avoid positional bias.
return_scores (`bool`, *optional*, defaults to `False`):
If `True`, return probability scores of the first completion instead of ranks (i.e. a *soft-judge*).
temperature (`float`, *optional*, defaults to `1.0`):
Temperature for scaling logits if `return_scores` is True.
Returns:
List[Union[int, float]]: List of ranks (0 or 1) or scores for each prompt,
indicating which completion is preferred or its score.
`Union[List[int, float]]`:
If `return_scores` is `False`, returns a list of ranks (`0` or `1`) for each prompt, indicating which
completion is preferred.
If `return_scores` is `True`, returns softmax probabilities for the first completion.
Raises:
ValueError: If the number of completions per prompt is not exactly 2.
`ValueError`:
If the number of completions per prompt is not exactly 2.
Note:
- Ranks are 0-indexed (0 means the first completion is preferred).
- If return_scores is True, returns softmax probabilities for the first completion.
Unlike llm-blender, ranks are 0-indexed (`0` means the first completion is preferred).
"""

if len(completions[0]) != 2:
Expand Down Expand Up @@ -254,11 +272,15 @@ class HfPairwiseJudge(BasePairwiseJudge):
This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt.
Args:
model (`str`, *optional*): The model to use for the judge. Defaults to "meta-llama/Meta-Llama-3-70B-Instruct".
token (`str`, *optional*): The Hugging Face API token to use for the InferenceClient.
system_prompt (`str`, *optional*): The system prompt to be used for the judge. If not provided, a default prompt is used.
Note that the system prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`.
Also, the inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token response.
model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`):
Model to use for the judge.
token (`str`, *optional*):
Hugging Face API token to use for the [`huggingface_hub.InferenceClient`].
system_prompt (`str` or `None`, *optional*, defaults to `None`):
The system prompt to be used for the judge. If not provided, a default prompt is used. Note that the system
prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the
inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token
response.
"""

def __init__(
Expand Down Expand Up @@ -306,11 +328,15 @@ class OpenAIPairwiseJudge(BasePairwiseJudge):
This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt.
Args:
model (`str`, *optional*): The model to use for the judge. Defaults to `"gpt-4-turbo-preview"`.
system_prompt (`str`, *optional*): The system prompt to be used for the judge. If not provided, a default prompt is used.
Note that the system prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`.
Also, the inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token response.
max_requests (`int`, *optional*): The maximum number of requests to make to the OpenAI API. Defaults to 1000. If set to `None`, there is no limit.
model (`str`, *optional*, defaults to `"gpt-4-turbo-preview"`):
Model to use for the judge.
system_prompt (`str` or `None`, *optional*, defaults to `None`):
System prompt to be used for the judge. If not provided, a default prompt is used. Note that the system
prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the
inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token
response.
max_requests (`int` or `None`, *optional*, defaults to `1000`):
Maximum number of requests to make to the OpenAI API. If set to `None`, there is no limit.
"""

def __init__(
Expand Down

0 comments on commit 337005d

Please sign in to comment.