Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean configs documentation #1944

Merged
merged 81 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
c2d9a62
Clean BCO
qgallouedec Aug 18, 2024
e3083f1
Optional[int]
qgallouedec Aug 18, 2024
c7b2fbc
fix sft config
qgallouedec Aug 19, 2024
e7a80bb
Merge branch 'main' into clean-config
qgallouedec Aug 19, 2024
50dbc86
alignprop config
qgallouedec Aug 20, 2024
b718fba
Merge branch 'main' into clean-config
qgallouedec Aug 20, 2024
4a8aba6
upadte tempfile to work with output_dir
qgallouedec Aug 20, 2024
6ae94e9
Merge branch 'clean-config' of https://github.com/huggingface/trl int…
qgallouedec Aug 20, 2024
3ed49fd
Merge branch 'main' into clean-config
qgallouedec Aug 21, 2024
f847f56
clean kto config
qgallouedec Aug 21, 2024
69525f9
intro docstring
qgallouedec Aug 21, 2024
c73f43a
style
qgallouedec Aug 21, 2024
11f6e7e
reward config
qgallouedec Aug 22, 2024
946e2e5
orpo config
qgallouedec Aug 22, 2024
21df122
Merge branch 'main' into clean-config
qgallouedec Aug 26, 2024
a1bff9c
warning in trainer, not in config
qgallouedec Aug 26, 2024
006a454
cpo config
qgallouedec Aug 26, 2024
c9264ee
Merge branch 'main' into clean-config
qgallouedec Aug 27, 2024
01d8814
ppo v2
qgallouedec Aug 27, 2024
5cd9eef
Merge branch 'clean-config' of https://github.com/huggingface/trl int…
qgallouedec Aug 27, 2024
9bef508
model config
qgallouedec Aug 27, 2024
0a49bca
ddpo and per_device_train_batch_size (instead of (train_batch_size)
qgallouedec Aug 27, 2024
1c9bba7
Merge branch 'main' into clean-config
qgallouedec Aug 27, 2024
216856a
rloo
qgallouedec Aug 27, 2024
7270936
Online config
qgallouedec Aug 27, 2024
05bacaf
tmp_dir in test_ddpo
qgallouedec Aug 27, 2024
451b4fc
style
qgallouedec Aug 27, 2024
9e6f0a0
remove to_dict and fix post-init
qgallouedec Aug 28, 2024
2aa4544
batch size in test ddpo
qgallouedec Aug 28, 2024
97738c8
Merge branch 'main' into clean-config
qgallouedec Aug 28, 2024
098ca6a
Merge branch 'main' into clean-config
qgallouedec Aug 28, 2024
02b78ec
dpo
qgallouedec Aug 28, 2024
92ff078
style
qgallouedec Aug 28, 2024
63679fe
Merge branch 'main' into clean-config
qgallouedec Aug 29, 2024
4957a8c
`Args` -> `Parameters`
qgallouedec Aug 29, 2024
bd3693b
parameters
qgallouedec Aug 29, 2024
10468e9
ppo config
qgallouedec Aug 29, 2024
d289982
dont overwrite world size
qgallouedec Aug 29, 2024
d94985a
style
qgallouedec Aug 29, 2024
1bc063a
Merge branch 'main' into clean-config
qgallouedec Aug 29, 2024
00d2faf
outputdir in test ppo
qgallouedec Aug 29, 2024
aa98e42
output dir in ppo config
qgallouedec Aug 29, 2024
66dc235
Merge branch 'clean-config' of https://github.com/huggingface/trl int…
qgallouedec Aug 29, 2024
79234d1
revert non-core change (1/n)
qgallouedec Sep 3, 2024
9b3b3a7
revert non-core changes (2/n)
qgallouedec Sep 3, 2024
6aeba64
revert non-core change (3/n)
qgallouedec Sep 3, 2024
fc4d223
Merge branch 'main' into clean-config
qgallouedec Sep 3, 2024
23fbfc6
uniform max_length
qgallouedec Sep 3, 2024
136cfdc
fix uniform max_length
qgallouedec Sep 3, 2024
640999c
beta uniform
qgallouedec Sep 3, 2024
3d5618c
Merge branch 'clean-config' of https://github.com/huggingface/trl int…
qgallouedec Sep 3, 2024
cfe9b22
style
qgallouedec Sep 3, 2024
358b026
link to `ConstantLengthDataset`
qgallouedec Sep 3, 2024
2190bf1
uniform `dataset_num_proc`
qgallouedec Sep 3, 2024
5434969
uniform `disable_dropout`
qgallouedec Sep 3, 2024
a7e537a
`eval_packing` doc
qgallouedec Sep 3, 2024
1a86078
try latex and α in doc
qgallouedec Sep 3, 2024
7065562
try title first
qgallouedec Sep 3, 2024
2d93d3d
doesn't work
qgallouedec Sep 3, 2024
42acd10
reorganize doc
qgallouedec Sep 3, 2024
92a2206
overview
qgallouedec Sep 3, 2024
81d5147
better latex
qgallouedec Sep 3, 2024
71c110a
is_encoder_decoder uniform
qgallouedec Sep 3, 2024
e60c3b0
proper ticks
qgallouedec Sep 3, 2024
a964090
fix latex
qgallouedec Sep 3, 2024
45d4f99
uniform generate_during_eval
qgallouedec Sep 3, 2024
3bc2d30
uniform truncation_mode
qgallouedec Sep 3, 2024
66a4861
ref_model_mixup_alpha
qgallouedec Sep 3, 2024
e2d8f7f
ref_model_mixup_alpha and ref_model_sync_steps
qgallouedec Sep 3, 2024
79347d9
Uniform `model_init_kwargs` and `ref_model_init_kwargs`
qgallouedec Sep 3, 2024
9ba37a9
rpo_alpha
qgallouedec Sep 3, 2024
52f69b1
Update maximum length argument names in config files
qgallouedec Sep 3, 2024
0fabc42
Update loss_type descriptions in config files
qgallouedec Sep 3, 2024
e1abc3a
Update max_target_length to max_completion_length in CPOConfig and CP…
qgallouedec Sep 3, 2024
d618f0c
Update padding value in config files
qgallouedec Sep 3, 2024
594677c
Update precompute_ref_log_probs flag documentation
qgallouedec Sep 3, 2024
5dee9ab
Fix typos and update comments in dpo_config.py and sft_config.py
qgallouedec Sep 3, 2024
47431f8
Merge branch 'main' into clean-config
qgallouedec Sep 4, 2024
19af1fa
post init warning for `max_target_length`
qgallouedec Sep 4, 2024
34b38b0
Merge branch 'clean-config' of https://github.com/huggingface/trl int…
qgallouedec Sep 4, 2024
07c9cab
Merge branch 'main' into clean-config
qgallouedec Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 35 additions & 33 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,46 @@
title: Understanding Logs
title: Get started
- sections:
- sections:
- local: trainer
title: Overview
- local: alignprop_trainer
title: AlignProp
- local: bco_trainer
title: BCO
- local: cpo_trainer
title: CPO
- local: ddpo_trainer
title: DDPO
- local: dpo_trainer
title: DPO
- local: online_dpo_trainer
title: Online DPO
- local: orpo_trainer
title: ORPO
- local: kto_trainer
title: KTO
- local: ppo_trainer
title: PPO
- local: ppov2_trainer
title: PPOv2
- local: rloo_trainer
title: RLOO
- local: sft_trainer
title: SFT
- local: iterative_sft_trainer
title: Iterative SFT
- local: reward_trainer
title: Reward Model
title: Trainers
- local: models
title: Model Classes
- local: trainer
title: Trainer Classes
- local: reward_trainer
title: Reward Model Training
- local: sft_trainer
title: Supervised Fine-Tuning
- local: ppo_trainer
title: PPO Trainer
- local: ppov2_trainer
title: PPOv2 Trainer
- local: rloo_trainer
title: RLOO Trainer
- local: best_of_n
title: Best of N Sampling
- local: dpo_trainer
title: DPO Trainer
- local: online_dpo_trainer
title: Online DPO Trainer
- local: kto_trainer
title: KTO Trainer
- local: bco_trainer
title: BCO Trainer
- local: cpo_trainer
title: CPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: alignprop_trainer
title: AlignProp Trainer
- local: orpo_trainer
title: ORPO Trainer
- local: iterative_sft_trainer
title: Iterative Supervised Fine-Tuning
- local: callbacks
title: Callback Classes
- local: judges
title: Judge Classes
title: Judges
- local: callbacks
title: Callbacks
- local: text_environments
title: Text Environments
title: API
Expand Down
124 changes: 78 additions & 46 deletions trl/trainer/alignprop_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,117 @@
import sys
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import Any, Dict, Literal, Optional, Tuple

from ..core import flatten_dict
from ..import_utils import is_bitsandbytes_available, is_torchvision_available


@dataclass
class AlignPropConfig:
"""
Configuration class for AlignPropTrainer
r"""
Configuration class for the [`AlignPropTrainer`].

Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.

Parameters:
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
Name of this experiment (defaults to the file name without the extension).
run_name (`str`, *optional*, defaults to `""`):
Name of this run.
log_with (`Optional[Literal["wandb", "tensorboard"]]`, *optional*, defaults to `None`):
Log with either `"wandb"` or `"tensorboard"`. Check
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
log_image_freq (`int`, *optional*, defaults to `1`):
Frequency for logging images.
tracker_kwargs (`Dict[str, Any]`, *optional*, defaults to `{}`):
Keyword arguments for the tracker (e.g., `wandb_project`).
accelerator_kwargs (`Dict[str, Any]`, *optional*, defaults to `{}`):
Keyword arguments for the accelerator.
project_kwargs (`Dict[str, Any]`, *optional*, defaults to `{}`):
Keyword arguments for the accelerator project config (e.g., `logging_dir`).
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
Name of project to use for tracking.
logdir (`str`, *optional*, defaults to `"logs"`):
Top-level logging directory for checkpoint saving.
num_epochs (`int`, *optional*, defaults to `100`):
Number of epochs to train.
save_freq (`int`, *optional*, defaults to `1`):
Number of epochs between saving model checkpoints.
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
Number of checkpoints to keep before overwriting old ones.
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
Mixed precision training.
allow_tf32 (`bool`, *optional*, defaults to `True`):
Allow `tf32` on Ampere GPUs.
resume_from (`str`, *optional*, defaults to `""`):
Path to resume training from a checkpoint.
sample_num_steps (`int`, *optional*, defaults to `50`):
Number of sampler inference steps.
sample_eta (`float`, *optional*, defaults to `1.0`):
Eta parameter for the DDIM sampler.
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
Classifier-free guidance weight.
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
Whether to use the 8bit Adam optimizer from `bitsandbytes`.
train_learning_rate (`float`, *optional*, defaults to `1e-3`):
Learning rate.
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
Beta1 for Adam optimizer.
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
Beta2 for Adam optimizer.
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
Weight decay for Adam optimizer.
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
Epsilon value for Adam optimizer.
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
Number of gradient accumulation steps.
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
Maximum gradient norm for gradient clipping.
negative_prompts (`Optional[str]`, *optional*, defaults to `None`):
Comma-separated list of prompts to use as negative examples.
truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
If `True`, randomized truncation to different diffusion timesteps is used.
truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
truncated_rand_backprop_minmax (`Tuple[int, int]`, *optional*, defaults to `(0, 50)`):
Range of diffusion timesteps for randomized truncated backpropagation.
"""

# common parameters
exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
"""the name of this experiment (by default is the file name without the extension name)"""
run_name: Optional[str] = ""
"""Run name for wandb logging and checkpoint saving."""
run_name: str = ""
seed: int = 0
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
"""Seed value for random generations"""
log_with: Optional[Literal["wandb", "tensorboard"]] = None
"""Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
log_image_freq = 1
"""Logging Frequency for images"""
tracker_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the tracker (e.g. wandb_project)"""
accelerator_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the accelerator"""
project_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
log_image_freq: int = 1
tracker_kwargs: Dict[str, Any] = field(default_factory=dict)
accelerator_kwargs: Dict[str, Any] = field(default_factory=dict)
project_kwargs: Dict[str, Any] = field(default_factory=dict)
tracker_project_name: str = "trl"
"""Name of project to use for tracking"""
logdir: str = "logs"
"""Top-level logging directory for checkpoint saving."""

# hyperparameters
num_epochs: int = 100
"""Number of epochs to train."""
save_freq: int = 1
"""Number of epochs between saving model checkpoints."""
num_checkpoint_limit: int = 5
"""Number of checkpoints to keep before overwriting old ones."""
mixed_precision: str = "fp16"
"""Mixed precision training."""
allow_tf32: bool = True
"""Allow tf32 on Ampere GPUs."""
resume_from: Optional[str] = ""
"""Resume training from a checkpoint."""
resume_from: str = ""
sample_num_steps: int = 50
"""Number of sampler inference steps."""
sample_eta: float = 1.0
"""Eta parameter for the DDIM sampler."""
sample_guidance_scale: float = 5.0
"""Classifier-free guidance weight."""
train_batch_size: int = 1
"""Batch size (per GPU!) to use for training."""
train_use_8bit_adam: bool = False
"""Whether to use the 8bit Adam optimizer from bitsandbytes."""
train_learning_rate: float = 1e-3
"""Learning rate."""
train_adam_beta1: float = 0.9
"""Adam beta1."""
train_adam_beta2: float = 0.999
"""Adam beta2."""
train_adam_weight_decay: float = 1e-4
"""Adam weight decay."""
train_adam_epsilon: float = 1e-8
"""Adam epsilon."""
train_gradient_accumulation_steps: int = 1
"""Number of gradient accumulation steps."""
train_max_grad_norm: float = 1.0
"""Maximum gradient norm for gradient clipping."""
negative_prompts: Optional[str] = ""
"""Comma-separated list of prompts to use as negative examples."""
negative_prompts: Optional[str] = None
truncated_backprop_rand: bool = True
"""Truncated Randomized Backpropation randomizes truncation to different diffusion timesteps"""
truncated_backprop_timestep: int = 49
"""Absolute timestep to which the gradients are being backpropagated. If truncated_backprop_rand is False"""
truncated_rand_backprop_minmax: tuple = (0, 50)
"""Range of diffusion timesteps for randomized truncated backprop."""
truncated_rand_backprop_minmax: Tuple[int, int] = (0, 50)

def to_dict(self):
output_dict = {}
Expand Down
97 changes: 42 additions & 55 deletions trl/trainer/bco_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,87 +12,74 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Any, Dict, Optional

from transformers import TrainingArguments

from ..import_utils import is_sklearn_available


@dataclass
class BCOConfig(TrainingArguments):
r"""
BCOConfig collects all training arguments related to the [`BCOTrainer`] class.
Configuration class for the [`BCOTrainer`].

Using [`HfArgumentParser`] we can turn this class into
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.

Parameters:
max_length (`int`, *optional*, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (`int`, *optional*, defaults to `None`):
The maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`int`, *optional*, defaults to `None`):
The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
beta (`float`, defaults to 0.1):
The beta factor in BCO loss. Higher beta means less divergence from the initial policy.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
padding_value (`int`, defaults to `0`):
The padding value if it is different to the tokenizer's pad_token_id.
truncation_mode (`str`, defaults to `keep_end`):
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
generate_during_eval (`bool`, defaults to `False`):
max_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
max_prompt_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the target. This argument is required if you want to use the default data collator
and your model is an encoder-decoder.
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model.
label_pad_token_id (`int`, *optional*, defaults to `-100`):
Label pad token id. This argument is required if you want to use the default data collator.
padding_value (`Optional[int]`, *optional*, defaults to `None`):
Padding value if it is different to the tokenizer's pad_token_id.
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use
the default data collator.
generate_during_eval (`bool`, *optional*, defaults to `False`):
Whether to sample and log generations during evaluation step.
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
If no model is provided, we need to know if the model_init returns an encoder-decoder.
precompute_ref_log_probs (`bool`, defaults to `False`):
Flag to precompute reference model log probabilities for training and evaluation datasets. This is useful if you want to train
without the reference model and reduce the total GPU memory needed.
model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string.
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the ref model from a string.
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the datasets.
prompt_sample_size: (`int`, defaults to 1024):
is_encoder_decoder (`Optional[bool]`, *optional*, defaults to `None`):
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
you need to specify if the model returned by the callable is an encoder-decoder model.
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
Flag to precompute reference model log probabilities for training and evaluation datasets. This is useful
if you want to train without the reference model and reduce the total GPU memory needed.
model_init_kwargs (`Optional[Dict[str, Any]]`, *optional*, defaults to `None`):
Dict of optional kwargs to pass when instantiating the model from a string.
ref_model_init_kwargs (`Optional[Dict[str, Any]]`, *optional*, defaults to `None`):
Dict of optional kwargs to pass when instantiating the reference model from a string.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
prompt_sample_size (`int`, *optional*, defaults to `1024`):
Number of prompts that are fed to density ratio classifier.
min_density_ratio: (`float`, defaults to 0.5):
The minimum value of the density ratio. The estimated density ratio is clamped to this value.
max_density_ratio: (`float`, defaults to 10.0):
The maximum value of the density ratio. The estimated density ratio is clamped to this value.
min_density_ratio (`float`, *optional*, defaults to `0.5`):
Minimum value of the density ratio. The estimated density ratio is clamped to this value.
max_density_ratio (`float`, *optional*, defaults to `10.0`):
Maximum value of the density ratio. The estimated density ratio is clamped to this value.
"""

max_length: Optional[int] = None
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
max_prompt_length: Optional[int] = None
"""The maximum length of the prompt. This argument is required if you want to use the default data collator."""
max_completion_length: Optional[int] = None
"""The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder."""
beta: float = 0.1
"""The beta factor in BCO loss. Higher beta means less divergence from the initial policy."""

label_pad_token_id: int = -100
padding_value: int = None
padding_value: Optional[int] = None
truncation_mode: str = "keep_end"
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
precompute_ref_log_probs: bool = False
model_init_kwargs: Optional[Dict] = None
ref_model_init_kwargs: Optional[Dict] = None
model_init_kwargs: Optional[Dict[str, Any]] = None
ref_model_init_kwargs: Optional[Dict[str, Any]] = None
dataset_num_proc: Optional[int] = None

# BCO config
prompt_sample_size: int = 1024
min_density_ratio: float = 0.5
max_density_ratio: float = 10.0

def __post_init__(self):
super().__post_init__()

if not is_sklearn_available():
raise ImportError(
"You need to install scikit-learn to use `BCOTrainer` "
"You can install it with `pip install scikit-learn`."
)
7 changes: 6 additions & 1 deletion trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,13 @@ def __init__(
embedding_func: Optional[Callable] = None,
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
):
if not is_sklearn_available():
raise ImportError(
"BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
)

if type(args) is TrainingArguments:
raise ValueError("Please use `BCOConfig` instead TrainingArguments.")
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")

if args.model_init_kwargs is None:
model_init_kwargs = {}
Expand Down
Loading
Loading