diff --git a/trl/trainer/judges.py b/trl/trainer/judges.py index e0638565c9..af56ec3d9b 100644 --- a/trl/trainer/judges.py +++ b/trl/trainer/judges.py @@ -76,7 +76,7 @@ class BaseRankJudge(ABC): """ Base class for LLM ranking judges. - Example: + **Example**: ```python class MyRankJudge(BaseRankJudge): def judge(self, prompts, completions, shuffle_order=True): @@ -96,13 +96,18 @@ def judge(self, prompts: List[str], completions: List[List[str]], shuffle_order: Judge the completion for the given prompts and return the ranks of each completion. Args: - prompts (`List[str]`): List of prompts. - completions (`List[List[str]]`): List of completions list, where each element is a list of completions for the corresponding prompt. - shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias. + prompts (`List[str]`): + List of prompts. + completions (`List[List[str]]`): + List of completions list, where each element is a list of completions for the corresponding prompt. + shuffle_order (`bool`, *optional*, defaults to `True`): + Whether to shuffle the order of the completions to avoid positional bias. Returns: - List of lists of idxs, where each list contains the ranks of the completions for the corresponding prompt. - E.g., [1, 2, 0] means that the second completion (idx=1) is the best, followed by the third, and then the first. + `List[List[int]]`: + List of lists of idxs, where each list contains the ranks of the completions for the corresponding + prompt. E.g., `[1, 2, 0]` means that the second completion (`idx=1`) is the best, followed by the + third, and then the first. """ raise NotImplementedError("Judge subclasses must implement the `judge` method.") @@ -118,18 +123,23 @@ def judge(self, prompts: List[str], completions: List[List[str]], shuffle_order: Judge the completion pairs for the given prompts. Args: - prompts (`List[str]`): List of prompts. - completions (`List[List[str]]`): List of completions pairs, where each element is a pair of completions for the corresponding prompt. - shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias. + prompts (`List[str]`): + List of prompts. + completions (`List[List[str]]`): + List of completions pairs, where each element is a pair of completions for the corresponding prompt. + shuffle_order (`bool`, *optional*, defaults to `True`): + Whether to shuffle the order of the completions to avoid positional bias. Returns: - List of idxs, where each idx is the rank of the best completion for the corresponding prompt. - E.g., 1 means that the second completion (idx=1) is the best. + `List[int]`: + List of idxs, where each idx is the rank of the best completion for the corresponding prompt. + E.g., `1` means that the second completion (`idx=1`) is the best. Note: - If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference has failed. - For instance, this could occur if the underlying language model returned an invalid answer. - In such cases, the caller should handle these invalid indices appropriately, possibly by implementing fallback logic or error handling. + If the judge returns `-1` for any prompt, it indicates that the inner process used to compute the + preference has failed. For instance, this could occur if the underlying language model returned an invalid + answer. In such cases, the caller should handle these invalid indices appropriately, possibly by + implementing fallback logic or error handling. """ raise NotImplementedError("Judge subclasses must implement the `judge` method.") @@ -157,30 +167,34 @@ class PairRMJudge(BasePairwiseJudge): """ LLM judge based on the PairRM model from AllenAI. - This judge uses the PairRM model to rank pairs of completions for given prompts. - It's designed for pairwise comparison of language model outputs. - - The PairRM model is loaded using the llm-blender library and runs on the + This judge uses the PairRM model to rank pairs of completions for given prompts. It's designed for pairwise + comparison of language model outputs. The PairRM model is loaded using the llm-blender library and runs on the default Accelerator device. - Attributes: - blender (llm_blender.Blender): An instance of the Blender class from llm-blender. + **Attributes**: + + blender (`llm_blender.Blender`): + An instance of the Blender class from llm-blender. + + **Example**: + ```python + >>> pairrm_judge = PairRMJudge() + >>> prompts = ["Translate 'hello' to French", "What's the capital of Japan?"] + >>> completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]] + >>> results = pairrm_judge.judge(prompts, completions) + >>> print(results) # [0, 1] (indicating the first completion is preferred for the first prompt and the second) + ``` + + - Example: - >>> pairrm_judge = PairRMJudge() - >>> prompts = ["Translate 'hello' to French", "What's the capital of Japan?"] - >>> completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]] - >>> results = pairrm_judge.judge(prompts, completions) - >>> print(results) # [0, 1] (indicating the first completion is preferred for the first prompt and the second) + This class requires the llm-blender library to be installed. Install it with: `pip install llm-blender`. - Note: - This class requires the llm-blender library to be installed. - Install it with: pip install llm-blender + """ def __init__(self): if not is_llm_blender_available(): - raise ValueError("llm-blender is not installed. Please install it with 'pip install llm-blender'.") + raise ValueError("llm-blender is not installed. Please install it with `pip install llm-blender`.") self.blender = llm_blender.Blender() self.blender.loadranker("llm-blender/PairRM", device=Accelerator().device) @@ -196,25 +210,29 @@ def judge( Judge the completion pairs for the given prompts using the PairRM model. Args: - prompts (List[str]): List of prompts to judge. - completions (List[List[str]]): List of completion pairs for each prompt. - shuffle_order (bool, optional): Whether to shuffle the order of completions - to avoid positional bias. Defaults to True. - return_scores (bool, optional): If True, return probability scores instead of ranks (i.e. a soft-judge). - Defaults to False. - temperature (float, optional): Temperature for scaling logits if return_scores - is True. Defaults to 1.0. + prompts (`List[str]`): + List of prompts to judge. + completions (`List[List[str]]`): + List of completion pairs for each prompt. + shuffle_order (`bool`, *optional*, defaults to `True`): + Whether to shuffle the order of the completions to avoid positional bias. + return_scores (`bool`, *optional*, defaults to `False`): + If `True`, return probability scores of the first completion instead of ranks (i.e. a *soft-judge*). + temperature (`float`, *optional*, defaults to `1.0`): + Temperature for scaling logits if `return_scores` is True. Returns: - List[Union[int, float]]: List of ranks (0 or 1) or scores for each prompt, - indicating which completion is preferred or its score. + `Union[List[int, float]]`: + If `return_scores` is `False`, returns a list of ranks (`0` or `1`) for each prompt, indicating which + completion is preferred. + If `return_scores` is `True`, returns softmax probabilities for the first completion. Raises: - ValueError: If the number of completions per prompt is not exactly 2. + `ValueError`: + If the number of completions per prompt is not exactly 2. Note: - - Ranks are 0-indexed (0 means the first completion is preferred). - - If return_scores is True, returns softmax probabilities for the first completion. + Unlike llm-blender, ranks are 0-indexed (`0` means the first completion is preferred). """ if len(completions[0]) != 2: @@ -254,11 +272,15 @@ class HfPairwiseJudge(BasePairwiseJudge): This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt. Args: - model (`str`, *optional*): The model to use for the judge. Defaults to "meta-llama/Meta-Llama-3-70B-Instruct". - token (`str`, *optional*): The Hugging Face API token to use for the InferenceClient. - system_prompt (`str`, *optional*): The system prompt to be used for the judge. If not provided, a default prompt is used. - Note that the system prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. - Also, the inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token response. + model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`): + Model to use for the judge. + token (`str`, *optional*): + Hugging Face API token to use for the [`huggingface_hub.InferenceClient`]. + system_prompt (`str` or `None`, *optional*, defaults to `None`): + The system prompt to be used for the judge. If not provided, a default prompt is used. Note that the system + prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the + inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token + response. """ def __init__( @@ -306,11 +328,15 @@ class OpenAIPairwiseJudge(BasePairwiseJudge): This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt. Args: - model (`str`, *optional*): The model to use for the judge. Defaults to `"gpt-4-turbo-preview"`. - system_prompt (`str`, *optional*): The system prompt to be used for the judge. If not provided, a default prompt is used. - Note that the system prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. - Also, the inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token response. - max_requests (`int`, *optional*): The maximum number of requests to make to the OpenAI API. Defaults to 1000. If set to `None`, there is no limit. + model (`str`, *optional*, defaults to `"gpt-4-turbo-preview"`): + Model to use for the judge. + system_prompt (`str` or `None`, *optional*, defaults to `None`): + System prompt to be used for the judge. If not provided, a default prompt is used. Note that the system + prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the + inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token + response. + max_requests (`int` or `None`, *optional*, defaults to `1000`): + Maximum number of requests to make to the OpenAI API. If set to `None`, there is no limit. """ def __init__(