diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index bd6b16d16d..666604e64e 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -41,121 +41,157 @@ class PPOConfig(TrainingArguments): command line. Parameters: + exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`): + Name of this experiment. + log_with (`Optional[Literal["wandb", "tensorboard"]]`, *optional*, defaults to `None`): + Log with either `"wandb"` or `"tensorboard"`. Check + [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details. + task_name (`Optional[str]`, *optional*, defaults to `None`): + Name of task to use - used only for tracking purposes. + model_name (`Optional[str]`, *optional*, defaults to `"gpt2"`): + Name of model to use - used only for tracking purposes. + query_dataset (`Optional[str]`, *optional*, defaults to `"imdb"`): + Name of dataset to query - used only for tracking purposes. + reward_model (`Optional[str]`, *optional*, defaults to `"sentiment-analysis:lvwerra/distilbert-imdb"`): + Reward model to use - used only for tracking purposes. + tracker_kwargs (`JSONDict`, *optional*, defaults to `{}`): + Keyword arguments for the tracker (e.g. `python ppo.py --tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'`. + accelerator_kwargs (`JSONDict`, *optional*, defaults to `{}`): + Keyword arguments for the accelerator. + project_kwargs (`JSONDict`, *optional*, defaults to `{}`): + Keyword arguments for the accelerator project config (e.g. `logging_dir`). + tracker_project_name (`str`, *optional*, defaults to `"trl"`): + Name of project to use for tracking. + push_to_hub_if_best_kwargs (`JSONDict`, *optional*, defaults to `{}`): + Keyword arguments for pushing model to the hub during training (e.g. repo_id). + steps (`int`, *optional*, defaults to `20000`): + Number of training steps. + adap_kl_ctrl (`bool`, *optional*, defaults to `True`): + Use adaptive KL control, otherwise linear. + init_kl_coef (`Optional[float]`, *optional*, defaults to `0.2`): + Initial KL penalty coefficient (used for adaptive and linear control). + kl_penalty (`Literal["kl", "abs", "mse", "full"]`, *optional*, defaults to `"kl"`): + kl penalty options. Possible values are: + + - `"kl"`: model_logp - ref_logp + - `"abs"`: abs(kl) + - `"mse"`: mean squared error mse(kl) + - `"full"`: the actual kl for all tokens in the distribution. + + target (`float`, *optional*, defaults to `6.0`): + Target KL value for adaptive KL control. + horizon (`float`, *optional*, defaults to `10000.0`): + Horizon for adaptive KL control. + gamma (`float`, *optional*, defaults to `1.0`): + Gamma parameter for advantage calculation. + lam (`float`, *optional*, defaults to `0.95`): + Lambda parameter for advantage calculation. + cliprange (`float`, *optional*, defaults to `0.2`): + Range for clipping in PPO policy gradient loss. + cliprange_value (`float`, *optional*, defaults to `0.2`): + Range for clipping values in loss calculation. + vf_coef (`float`, *optional*, defaults to `0.1`): + Scaling factor for value loss. + batch_size (`int`, *optional*, defaults to `128`): + Number of samples per optimisation step. + forward_batch_size (`Optional[int]`, *optional*, defaults to `None`): + DEPRECATED: use `mini_batch_size` instead, which does the same thing. + mini_batch_size (`int`, *optional*, defaults to `128`): + Number of samples optimized in each mini batch. + ppo_epochs (`int`, *optional*, defaults to `4`): + Number of optimisation epochs per batch of samples. + optimize_device_cache (`bool`, *optional*, defaults to `False`): + Optimize device cache for slightly more memory-efficient training. + early_stopping (`bool`, *optional*, defaults to `False`): + Whether to stop the PPO optimization loop early is the KL too high. + target_kl (`float`, *optional*, defaults to `1.0`): + Stop early if we exceed this value by over 50%. + compare_steps (`int`, *optional*, defaults to `1`): + Compare the current step with the previous `compare_steps` steps. + ratio_threshold (`float`, *optional*, defaults to `10.0`): + Skip mini-batches with high PPO ratios that can cause loss spikes. + use_score_scaling (`bool`, *optional*, defaults to `False`): + Use score scaling. + use_score_norm (`bool`, *optional*, defaults to `False`): + Use score normalization. Only applicable if `use_score_scaling` is True. + score_clip (`Optional[float]`, *optional*, defaults to `None`): + Score clipping. + whiten_rewards (`bool`, *optional*, defaults to `False`): + Whiten the rewards before computing advantages. + is_encoder_decoder (`Optional[bool]`, *optional*, defaults to `None`): + Whether the model is an encoder-decoder model. + is_peft_model (`Optional[bool]`, *optional*, defaults to `None`): + Whether the model is a PEFT model. + backward_batch_size (`Optional[int]`, *optional*, defaults to `None`): + Number of samples optimized in an `optimizer.step()` call. + global_backward_batch_size (`Optional[int]`, *optional*, defaults to `None`): + Effective `backward_batch_size` across all processes. + global_batch_size (`Optional[int]`, *optional*, defaults to `None`): + Effective `batch_size` across all processes. + dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): + Number of processes to use for the dataset. """ - - # common parameters exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] - """the name of this experiment (by default is the file name without the extension name)""" log_with: Optional[Literal["wandb", "tensorboard"]] = None - """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" task_name: Optional[str] = None - """Name of task to use - used only for tracking purposes""" - model_name: Optional[str] = "gpt2" - """Name of model to use - used only for tracking purposes""" - query_dataset: Optional[str] = "imdb" - """Name of dataset to query - used only for tracking purposes""" - reward_model: Optional[str] = "sentiment-analysis:lvwerra/distilbert-imdb" - """The reward model to use - used only for tracking purposes""" - remove_unused_columns: bool = True - """Remove unused columns from the dataset if `datasets.Dataset` is used""" + model_name: str = "gpt2" + query_dataset: str = "imdb" + reward_model: str = "sentiment-analysis:lvwerra/distilbert-imdb" tracker_kwargs: JSONDict = field(default_factory=dict) - """Keyword arguments for the tracker (e.g. python ppo.py --tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'""" accelerator_kwargs: JSONDict = field(default_factory=dict) - """Keyword arguments for the accelerator""" project_kwargs: JSONDict = field(default_factory=dict) - """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" tracker_project_name: str = "trl" - """Name of project to use for tracking""" push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict) - """Keyword arguments for pushing model to the hub during training (e.g. repo_id)""" - - # hyperparameters steps: int = 20000 - """Number of training steps""" - learning_rate: float = 1.41e-5 - """Adam learning rate""" adap_kl_ctrl: bool = True - """Use adaptive KL control, otherwise linear""" - init_kl_coef: Optional[float] = 0.2 - """Initial KL penalty coefficient (used for adaptive and linear control)""" + init_kl_coef: float = 0.2 kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl" - """kl penalty options: 'kl': model_logp - ref_logp, 'abs': abs(kl), 'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution""" - target: Optional[float] = 6 - """Target KL value for adaptive KL control""" - horizon: Optional[float] = 10000 - """Horizon for adaptive KL control""" - gamma: float = 1 - """Gamma parameter for advantage calculation""" + target: float = 6.0 + horizon: float = 10000.0 + gamma: float = 1.0 lam: float = 0.95 - """Lambda parameter for advantage calculation""" cliprange: float = 0.2 - """Range for clipping in PPO policy gradient loss""" cliprange_value: float = 0.2 - """Range for clipping values in loss calculation""" vf_coef: float = 0.1 - """Scaling factor for value loss""" batch_size: int = 128 - """Number of samples per optimisation step""" forward_batch_size: Optional[int] = None - """DEPRECATED: use `mini_batch_size` instead, which does the same thing.""" mini_batch_size: int = 128 - """Number of samples optimized in each mini batch""" gradient_accumulation_steps: int = 1 - """The number of gradient accumulation steps""" - world_size: tyro.conf.Suppress[int] = None - """The world size for distributed training""" ppo_epochs: int = 4 - """Number of optimisation epochs per batch of samples""" max_grad_norm: Optional[float] = None - """Maximum gradient norm for gradient clipping""" optimize_cuda_cache: Optional[bool] = None - """DEPRECATED: use `optimize_device_cache` instead, which does the same thing.""" - optimize_device_cache: Optional[bool] = False - """Optimize device cache for slightly more memory-efficient training""" + optimize_device_cache: bool = False early_stopping: bool = False - """Whether to stop the PPO optimization loop early is the KL too high""" - target_kl: float = 1 - """Stop early if we exceed this value by over 50%""" + target_kl: float = 1.0 compare_steps: int = 1 - """Number of steps between comparison of the current reward with the best seen so far""" ratio_threshold: float = 10.0 - """Skip mini-batches with high PPO ratios that can cause loss spikes""" use_score_scaling: bool = False - """Use score scaling""" use_score_norm: bool = False - """Use score normalization. Only applicable if use_score_scaling is True""" score_clip: Optional[float] = None - """Score clipping""" whiten_rewards: bool = False - """Whiten the rewards before compute advantages""" gradient_checkpointing: bool = False - """Enable gradient checkpointing""" - - # computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None - """TO BE FILLED In RUNTIME: Whether the model is an encoder-decoder model""" is_peft_model: Optional[tyro.conf.Suppress[bool]] = None - """TO BE FILLED In RUNTIME: Whether the model is a PEFT model""" backward_batch_size: tyro.conf.Suppress[int] = None - """TO BE FILLED In RUNTIME: Number of samples optimized in an `optimizer.step()` call""" - global_backward_batch_size: tyro.conf.Suppress[int] = None - """TO BE FILLED In RUNTIME: the effective `backward_batch_size` across all processes""" + global_backward_batch_size: Optional[tyro.conf.Suppress[int]] = None global_batch_size: tyro.conf.Suppress[int] = None - """TO BE FILLED In RUNTIME: the effective `batch_size` across all processes""" - dataset_num_proc: Optional[int] = None - if optimize_cuda_cache is not None: - warnings.warn( - "The `optimize_cuda_cache` argument will be deprecated soon, please use `optimize_device_cache` instead." - ) - - if optimize_device_cache is True: - raise ValueError("Both `optimize_device_cache` and `optimize_cuda_cache` were provided") - optimize_device_cache = optimize_cuda_cache def __post_init__(self): + super().__post_init__() + + if self.optimize_cuda_cache is not None: + warnings.warn( + "The `optimize_cuda_cache` argument will be deprecated soon, please use `optimize_device_cache` instead." + ) + + if self.optimize_device_cache is True: + raise ValueError("Both `optimize_device_cache` and `optimize_cuda_cache` were provided") + + self.optimize_device_cache = self.optimize_cuda_cache + if self.forward_batch_size is not None: warnings.warn( "Note that using `forward_batch_size` is deprecated, use `mini_batch_size` instead. By setting it you overwrite `mini_batch_size` which affects both the batch size during forward passes and also the mini batch size for PPO optimization."