Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dpo Train Issue: max step from 1000 to 996349 #2355

Open
8 of 9 tasks
seTalent opened this issue Nov 14, 2024 · 1 comment
Open
8 of 9 tasks

Dpo Train Issue: max step from 1000 to 996349 #2355

seTalent opened this issue Nov 14, 2024 · 1 comment
Assignees
Labels
🐛 bug Something isn't working 🏋 DPO Related to DPO

Comments

@seTalent
Copy link

System Info

  • Platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.31
  • Python version: 3.10.15
  • PyTorch version: 2.0.1
  • CUDA device(s): NVIDIA A100-PCIE-40GB, NVIDIA A100-PCIE-40GB, NVIDIA A100-PCIE-40GB, NVIDIA A100-PCIE-40GB, NVIDIA A100-PCIE-40GB, NVIDIA A100-PCIE-40GB, NVIDIA A100-PCIE-40GB, NVIDIA A100-PCIE-40GB
  • Transformers version: 4.46.2
  • Accelerate version: 1.1.1
  • Accelerate config:
    • compute_environment: LOCAL_MACHINE
    • distributed_type: NO
    • mixed_precision: no
    • use_cpu: False
    • debug: False
    • num_processes: 1
    • machine_rank: 0
    • num_machines: 1
    • gpu_ids: 0
    • rdzv_backend: static
    • same_network: True
    • main_training_function: main
    • enable_cpu_affinity: False
    • downcast_bf16: no
    • tpu_use_cluster: False
    • tpu_use_sudo: False
    • tpu_env: []
  • Datasets version: 3.1.0
  • HF Hub version: 0.26.2
  • TRL version: 0.12.0
  • bitsandbytes version: 0.44.1
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: not installed
  • PEFT version: 0.13.2

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

sh:

accelerate launch dpo_llama2.py \
    --model_name_or_path="sft/final_checkpoint" \
    --output_dir="dpo" \
    --beta=0.1 \
    --base_model_path '/data/share_weight/llama2-7b-hf' \
    --dataset_path './dataset/stack-exchange-paired' \
    --num_proc 24 \
    --bf16=False 
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 0. imports
import os
os.environ["CUDA_VISIBLE_DEVICES"]='7'
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
from accelerate import Accelerator
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed

from trl import DPOConfig, DPOTrainer


# Define and parse arguments.
@dataclass
class ScriptArguments:
    """
    The arguments for the DPO training script.
    """

    # data parameters
    beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

    # training parameters
    model_name_or_path: Optional[str] = field(
        default="../sft/results/final_checkpoint",
        metadata={"help": "the location of the SFT model name or path"},
    )
    learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
    lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
    warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
    weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
    optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})

    per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
    per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
    gradient_accumulation_steps: Optional[int] = field(
        default=4, metadata={"help": "the number of gradient accumulation steps"}
    )
    gradient_checkpointing: Optional[bool] = field(
        default=True, metadata={"help": "whether to use gradient checkpointing"}
    )

    gradient_checkpointing_use_reentrant: Optional[bool] = field(
        default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
    )

    lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
    lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
    lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})

    max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
    max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
    max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
    logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
    save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
    eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})

    output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
    log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
    load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
    model_dtype: Optional[str] = field(
        default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
    )

    # instrumentation
    report_to: Optional[str] = field(
        default="wandb",
        metadata={
            "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
            '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
            'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
        },
    )
    # debug argument for distributed training
    ignore_bias_buffers: Optional[bool] = field(
        default=False,
        metadata={
            "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
            "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
        },
    )
    seed: Optional[int] = field(
        default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
    )

    #单独加载tokenizer
    base_model_path: Optional[str] = field(
        default='meta-llama/Llama-2-7b-hf',
        metadata={"help": "Select the path of LLM's Tokenizer"}
    )

    dataset_path: Optional[str] = field(
        default='',
        metadata={"help": "the path or name of dataset"}
    )

    num_proc: Optional[int] = field(
    default=4,
    metadata={"help": "the path or name of dataset"}
    )   
    bf16: Optional[bool] = field(
        default=False,
        metadata={"help": "bf16"}
    )

    


def get_stack_exchange_paired(
    data_dir: str = "data/rl",
    cache_dir: Optional[str] = None,
    num_proc=24,
    dataset_path = None
) -> Dataset:
    """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts are structured as follows:
      "Question: " + <prompt> + "\n\nAnswer: "
    """
    dataset = load_dataset(
        dataset_path,
        split="train",
        cache_dir=cache_dir,
        data_dir=data_dir,
        verification_mode="no_checks",
    )
    original_columns = dataset.column_names

    def return_prompt_and_responses(samples) -> Dict[str, str]:
        return {
            "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
            "chosen": samples["response_j"],
            "rejected": samples["response_k"],
        }

    return dataset.map(
        return_prompt_and_responses,
        batched=True,
        num_proc=num_proc,
        remove_columns=original_columns,
    )


if __name__ == "__main__":
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]

    set_seed(script_args.seed)

    # 1. load a pretrained model
    torch_dtype = torch.float
    if script_args.model_dtype == "float16":
        torch_dtype = torch.float16
    elif script_args.model_dtype == "bfloat16":
        torch_dtype = torch.bfloat16

    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch_dtype,
        load_in_4bit=script_args.load_in_4bit,
        device_map={"": Accelerator().local_process_index},
    )
    model.config.use_cache = False

    if script_args.ignore_bias_buffers:
        # torch distributed hack
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_path)
    tokenizer.pad_token = tokenizer.eos_token

    # 2. Load the Stack-exchange paired dataset
    train_dataset = get_stack_exchange_paired(data_dir="data/rl",dataset_path=script_args.dataset_path)
    train_dataset = train_dataset.filter(
        lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
        and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
        num_proc=script_args.num_proc,
    )

    # 3. Load evaluation dataset
    eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation",dataset_path=script_args.dataset_path)
    eval_dataset = eval_dataset.filter(
        lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
        and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
        num_proc=script_args.num_proc,
    )

    # 4. initialize training arguments:
    training_args = DPOConfig(
        per_device_train_batch_size=script_args.per_device_train_batch_size,
        per_device_eval_batch_size=script_args.per_device_eval_batch_size,
        max_steps=script_args.max_steps,
        logging_steps=script_args.logging_steps,
        save_steps=script_args.save_steps,
        gradient_accumulation_steps=script_args.gradient_accumulation_steps,
        gradient_checkpointing=script_args.gradient_checkpointing,
        learning_rate=script_args.learning_rate,
        eval_strategy="steps",
        eval_steps=script_args.eval_steps,
        output_dir=script_args.output_dir,
        report_to=script_args.report_to,
        lr_scheduler_type=script_args.lr_scheduler_type,
        warmup_steps=script_args.warmup_steps,
        optim=script_args.optimizer_type,
        bf16=script_args.bf16,
        remove_unused_columns=False,
        run_name="dpo_llama2",
        gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
        seed=script_args.seed,
    )

    peft_config = LoraConfig(
        r=script_args.lora_r,
        lora_alpha=script_args.lora_alpha,
        lora_dropout=script_args.lora_dropout,
        target_modules=[
            "q_proj",
            "v_proj",
            "k_proj",
            "out_proj",
            "fc_in",
            "fc_out",
            "wte",
        ],
        bias="none",
        task_type="CAUSAL_LM",
    )

    # 5. initialize the DPO trainer
    dpo_trainer = DPOTrainer(
        model,
        ref_model=None,
        args=training_args,
        beta=script_args.beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=tokenizer,
        peft_config=peft_config,
        max_prompt_length=script_args.max_prompt_length,
        max_length=script_args.max_length,
    )

    # 6. train
    dpo_trainer.train()
    dpo_trainer.save_model(script_args.output_dir)

    # 7. save
    output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
    dpo_trainer.model.save_pretrained(output_dir)

outputs:

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.93s/it]
Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 31787.07it/s]
Loading dataset shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:00<00:00, 949.31it/s]
Filter (num_proc=24): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7435908/7435908 [00:03<00:00, 2410507.23 examples/s]
Loading dataset shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 1158.01it/s]
Filter (num_proc=24): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4483004/4483004 [00:02<00:00, 2115333.57 examples/s]
/data/zky_1/anaconda3/envs/py310/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': max_prompt_length, max_length. Will not be supported from version '0.13.0'.

Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
  warnings.warn(message, FutureWarning)
/data/zky_1/anaconda3/envs/py310/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py:469: UserWarning: You passed `max_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.
  warnings.warn(
/data/zky_1/anaconda3/envs/py310/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py:475: UserWarning: You passed `max_prompt_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.
  warnings.warn(
Extracting prompt from train dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1652614/1652614 [02:07<00:00, 12986.35 examples/s]
Applying chat template to train dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1652614/1652614 [01:21<00:00, 20216.14 examples/s]
Extracting prompt from eval dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 996349/996349 [01:18<00:00, 12757.64 examples/s]
Applying chat template to eval dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 996349/996349 [00:44<00:00, 22368.85 examples/s]
Tokenizing train dataset:  59%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                          | 970702/1652614 [12:45<07:20, 1546.34 examples/s]Tokenizing train dataset:  59%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                          | 970931/1652614 [12:45<07:24, 1533.71 examples/s]Tokenizing train dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1652614/1652614 [21:15<00:00, 1295.86 examples/s]
Tokenizing eval dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 996349/996349 [11:28<00:00, 1446.50 examples/s]
max_steps is given, it will override any value given in num_train_epochs
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: WARNING Invalid choice
wandb: Enter your choice: 3
wandb: You chose "Don't visualize my results"
wandb: Tracking run with wandb version 0.18.6
wandb: W&B syncing is set to `offline` in this directory.  
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
  0%|                                                                                                                                                                                                                                        | 0/1000 [00:00<?, ?it/s]Could not estimate the number of tokens of the input, floating-point operations will not be computed
  0%|▍                                                                                                                                                                                                                             | 2/1000 [01:11<9:49:10, 35.42s/it]



{'loss': 0.6929, 'grad_norm': 4.2050275802612305, 'learning_rate': 5e-05, 'rewards/chosen': -0.004248094744980335, 'rewards/rejected': -0.004661589860916138, 'rewards/accuracies': 0.4375, 'rewards/margins': 0.00041349526145495474, 'logps/chosen': -174.2248992919922, 'logps/rejected': -171.5912628173828, 'logits/chosen': -0.42958006262779236, 'logits/rejected': -0.47797107696533203, 'epoch': 0.0}
{'loss': 0.6924, 'grad_norm': 4.090967178344727, 'learning_rate': 0.0001, 'rewards/chosen': -0.01845647022128105, 'rewards/rejected': -0.020029572769999504, 'rewards/accuracies': 0.550000011920929, 'rewards/margins': 0.0015731047606095672, 'logps/chosen': -178.96826171875, 'logps/rejected': -170.4486846923828, 'logits/chosen': -0.3984534740447998, 'logits/rejected': -0.42798668146133423, 'epoch': 0.0}
{'loss': 0.6877, 'grad_norm': 4.2013468742370605, 'learning_rate': 0.00015, 'rewards/chosen': -0.03504800423979759, 'rewards/rejected': -0.046783603727817535, 'rewards/accuracies': 0.606249988079071, 'rewards/margins': 0.011735597625374794, 'logps/chosen': -171.3468780517578, 'logps/rejected': -168.1322021484375, 'logits/chosen': -0.27878862619400024, 'logits/rejected': -0.3488096296787262, 'epoch': 0.0}
{'loss': 0.6716, 'grad_norm': 6.470763683319092, 'learning_rate': 0.0002, 'rewards/chosen': 0.024564553052186966, 'rewards/rejected': -0.0266877468675375, 'rewards/accuracies': 0.65625, 'rewards/margins': 0.051252298057079315, 'logps/chosen': -186.5611572265625, 'logps/rejected': -165.66610717773438, 'logits/chosen': -0.19978216290473938, 'logits/rejected': -0.34658709168434143, 'epoch': 0.0}
{'loss': 0.6739, 'grad_norm': 15.604384422302246, 'learning_rate': 0.00025, 'rewards/chosen': 0.13673622906208038, 'rewards/rejected': 0.05065234377980232, 'rewards/accuracies': 0.6187499761581421, 'rewards/margins': 0.08608388900756836, 'logps/chosen': -181.21852111816406, 'logps/rejected': -160.2905731201172, 'logits/chosen': -0.26764318346977234, 'logits/rejected': -0.27629831433296204, 'epoch': 0.0}
{'loss': 0.6384, 'grad_norm': 7.398013114929199, 'learning_rate': 0.0003, 'rewards/chosen': 0.2138195037841797, 'rewards/rejected': 0.01574927195906639, 'rewards/accuracies': 0.59375, 'rewards/margins': 0.19807025790214539, 'logps/chosen': -168.34925842285156, 'logps/rejected': -173.3873291015625, 'logits/chosen': -0.19198311865329742, 'logits/rejected': -0.24480994045734406, 'epoch': 0.0}
{'loss': 0.6929, 'grad_norm': 6.878015518188477, 'learning_rate': 0.00035, 'rewards/chosen': -0.008969360031187534, 'rewards/rejected': -0.1367751955986023, 'rewards/accuracies': 0.5687500238418579, 'rewards/margins': 0.12780582904815674, 'logps/chosen': -184.54605102539062, 'logps/rejected': -173.0166015625, 'logits/chosen': -0.18250033259391785, 'logits/rejected': -0.19687585532665253, 'epoch': 0.0}
{'loss': 0.6736, 'grad_norm': 5.93695592880249, 'learning_rate': 0.0004, 'rewards/chosen': -0.054489076137542725, 'rewards/rejected': -0.17297396063804626, 'rewards/accuracies': 0.574999988079071, 'rewards/margins': 0.11848487704992294, 'logps/chosen': -170.2427978515625, 'logps/rejected': -175.39981079101562, 'logits/chosen': -0.09635505080223083, 'logits/rejected': -0.21266642212867737, 'epoch': 0.0}
{'loss': 0.6341, 'grad_norm': 7.539218902587891, 'learning_rate': 0.00045000000000000004, 'rewards/chosen': 0.1032119020819664, 'rewards/rejected': -0.0954817458987236, 'rewards/accuracies': 0.612500011920929, 'rewards/margins': 0.1986936628818512, 'logps/chosen': -171.16197204589844, 'logps/rejected': -169.14718627929688, 'logits/chosen': -0.11813618987798691, 'logits/rejected': -0.19252558052539825, 'epoch': 0.0}
{'loss': 0.6486, 'grad_norm': 10.038784980773926, 'learning_rate': 0.0005, 'rewards/chosen': -0.2863594889640808, 'rewards/rejected': -0.5529825687408447, 'rewards/accuracies': 0.6312500238418579, 'rewards/margins': 0.2666230797767639, 'logps/chosen': -165.84481811523438, 'logps/rejected': -173.65853881835938, 'logits/chosen': -0.1978793740272522, 'logits/rejected': -0.2690560817718506, 'epoch': 0.0}
 10%|██████████████████████                                                                                                                                                                                                      | 100/1000 [59:19<9:03:31, 36.23s/it]
  0%|▏                                                                                                                                                                                                                       | 957/996349 [16:43<342:15:23,  1.24s/it]
  0%|▏                                                                                                                                                                                                                       | 958/996349 [16:44<359:04:39,  1.30s/it]
  0%|▏                                                                                                                                                                                                                       | 966/996349 [16:54<308:11:01,  1.11s/it]
  0%|▏                                                                                                                                                                                                                       | 967/996349 [16:55<333:28:56,  1.21s/it]
  0%|▏                                                                                                                                                                                                                       | 968/996349 [16:57<351:20:56,  1.27s/it]
  0%|▎                                                                                                                                                                                                                      | 1674/996349 [29:35<333:57:52,  1.21s/it]

  0%|▎                                                                                                                                                                                                                      | 1675/996349 [29:36<322:02:04,  1.17s/it]


  0%|▎                                                                                                                                                                                                                      | 1689/996349 [29:52<322:56:12,  1.17s/it]
  0%|▎                                                                                                                                                                                                                      | 1690/996349 [29:53<306:01:19,  1.11s/it]



  0%|▎                                                                                                                                                                                                                      | 1691/996349 [29:54<313:11:07,  1.13s/it]






  0%|▍                                                                                                                                                                                                                      | 2282/996349 [40:14<159:55:05,  1.73it/s]

  0%|▍                                                                                                                                                                                                                      | 2283/996349 [40:14<180:03:33,  1.53it/s]


  5%|█████████▌                                                                                                                                                                                                         | 45124/996349 [13:07:05<221:19:22,  1.19it/s]

  5%|█████████▌                                                                                                                                                                                                         | 45
  5%|█████████▌                                                                                                                                                                                                         | 45294/996349 [13:10:19<274:42:30,  1.04s/it]


  5%|█████████▌                                                                                                                                                                                                         | 45295/996349 [13:10:20<256:59:58,  1.03it/s]

  5%|█████████▋                                                                                                                                                                                                         | 45478/996349 [13:13:36<199:30:50,  1.32it/s]
  5%|█████████▋                                                                                                                                                                                                         | 45479/996349 [13:13:36<204:14:05,  1.29it/s]




  5%|█████████▋                                                                                                                                                                                                         | 45480/996349 [13:13:38<238:12:50,  1.11it/s]


  5%|█████████▋                                                                                                                                                                                                         | 45718/996349 [13:17:37<259:52:56,  1.02it/s]
                                                                                                                                                                                                                                                                     nvidia0████████▋                                                                                                                                                                                                         | 45720/996349 [13:17:40<288:13:56,  1.09s/i-i
  5%|█████████▋                                                                                                                                                                                                         | 45722/996349 [13:17:42<295:09:37,  1.12s/it]
  5%|█████████▋                                                                                                                                                                                                         | 45723/996349 [13:17:43<270:43:34,  1.03s/it]

  5%|█████████▊                                                                                                                                                                                                         | 46304/996349 [13:27:29<168:08:47,  1.57it/s]

Expected behavior

When I use dpo.sh to optimize the llama2-se, the max_steps increase to the 996, 349 from 1000 when the step is 100/1000.What happended?

I train it on one Nvidia A100 40G GPU.

Looking forward to your reply,thank you.

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@qgallouedec qgallouedec added 🐛 bug Something isn't working 🏋 DPO Related to DPO labels Nov 20, 2024
@qgallouedec qgallouedec self-assigned this Nov 20, 2024
@kashif
Copy link
Collaborator

kashif commented Nov 21, 2024

just to debug, can you kindly try to see if you get the same issue when you do not pass a validation dataset?

Also, can you check what happens when you explicitly pass num_train_epochs=1 as an option to the DPOConfig Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 DPO Related to DPO
Projects
None yet
Development

No branches or pull requests

3 participants