Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 107 additions & 72 deletions examples/gsm8k_geo3k/train_colocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,16 @@

ensure_video_input_available()

from lightrft.datasets import PromptDatasetVL, SFTDatasetVL
from lightrft.datasets import DatasetConfig, DatasetLoader
from lightrft.models.actor_language import ActorLanguage
from lightrft.models.actor_vl import ActorVL
from lightrft.strategy import get_strategy
from lightrft.trainer.spmd_ppo_trainer import SPMDPPOTrainerVL
from lightrft.utils import blending_datasets, get_tokenizer_processor_vl
from lightrft.utils import get_tokenizer_processor_vl

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from reward_models_utils import RECIPE, load_reward_models, reward_fn
from reward_models_utils import RECIPE
from lightrft.reward import RewardManager


def train(args):
Expand Down Expand Up @@ -155,16 +156,6 @@ def train(args):
else:
critic = None

# Load reward models (multiple types: value, safety, knowledge, etc.)
strategy.report_memory(f"before loaded reward models in main entry")
reward_models, reward_tokenizers, label_map = load_reward_models(
raw_reward_pretrain=args.reward_pretrain,
strategy=strategy,
use_engine=args.rm_use_engine,
)
strategy.print(f"label_map: {label_map}")
strategy.report_memory(f"after loaded reward models in main entry")

strategy.print(actor)
strategy.print(critic)

Expand Down Expand Up @@ -203,70 +194,129 @@ def train(args):
)
assert processor is not None, "processor is None"

# ==================== Data Loading Optimization ====================
# The following sections now rely on the robust `blending_datasets` function.
# We add more logging for clarity.
# Initialize reward manager (using rule-based rewards for gsm8k/geo3k)
strategy.report_memory(f"before loaded reward models in main entry")

# For gsm8k/geo3k, we use rule-based rewards, so no neural models are needed
# Create a wrapper function for compatibility with trainer
def reward_fn(
model_reward_list,
labels,
queries,
refs,
label_map,
):
"""
Wrapper function for RewardManager to match trainer's expected interface.

For rule-based rewards, model_reward_list will be empty.
The reward manager will compute rewards based on labels and queries.
"""
# Determine rule type from labels (geo3k or gsm8k)
# Use the first label to determine rule type, or default to geo3k_combined
if labels:
first_label = labels[0]
if "gsm8k" in first_label.lower():
rule_type = "gsm8k_combined"
elif "geo3k" in first_label.lower():
rule_type = "geo3k_combined"
else:
rule_type = "geo3k_combined" # Default
else:
rule_type = "geo3k_combined"

# Create a temporary reward manager with the correct rule type
# Note: We could optimize this by caching managers per rule type
# For rule-based rewards, tokenizer and strategy are not required
temp_reward_manager = RewardManager(
reward_type="rule",
rule_type=rule_type,
)

# Compute rewards using the reward manager
rewards, metrics = temp_reward_manager.compute(
queries=queries,
references=refs,
labels=labels,
)

return rewards, metrics

label_map = {} # Empty for rule-based rewards
reward_models = [] # Empty for rule-based rewards
reward_tokenizers = [] # Empty for rule-based rewards

strategy.print(f"Initialized rule-based reward manager for gsm8k/geo3k")
strategy.report_memory(f"after loaded reward models in main entry")

# ==================== Data Loading with New API ====================
# Use DatasetConfig and DatasetLoader for unified dataset loading

# Initialize dataset loader
dataset_loader = DatasetLoader(
tokenizer=tokenizer,
processor=processor,
strategy=strategy,
)

# Prepare prompts dataset
strategy.print(f"Loading prompts dataset from: {args.prompt_data} with split: {args.prompt_split}")
prompts_data = blending_datasets(
args.prompt_data,
args.prompt_data_probs,
strategy,
args.seed,
return_eval=False,
train_split=args.prompt_split,
train_config = DatasetConfig.for_train(
data_path=args.prompt_data,
data_probs=args.prompt_data_probs,
split=args.prompt_split,
max_samples=args.max_samples,
seed=args.seed,
)
prompts_dataset = dataset_loader.load_train_dataset(
config=train_config,
prompt_max_len=args.prompt_max_len,
input_template=args.input_template,
)

prompts_data = prompts_data.select(range(min(args.max_samples, len(prompts_data))))
prompts_dataset = PromptDatasetVL(prompts_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template)
strategy.print(f"Loaded {len(prompts_dataset)} samples for prompts.")

# Prepare evaluation dataset
eval_dataloader = None
if args.eval_data or args.eval_split:
eval_data_path = args.eval_data if args.eval_data else args.prompt_data
if eval_data_path:
strategy.print(f"Loading evaluation dataset from {eval_data_path}, split='{args.eval_split}'")
eval_data = blending_datasets(
eval_data_path, "1.0", strategy, args.seed, return_eval=False,
# Note: `train_split` parameter is used to specify the desired split name for evaluation data.
train_split=args.eval_split,
eval_config = DatasetConfig.for_eval(
data_path=eval_data_path,
data_probs="1.0",
split=args.eval_split,
max_samples=args.max_eval_samples,
seed=args.seed,
)
if len(eval_data) == 0:
strategy.print(f"Warning: Evaluation dataset at {eval_data_path} with split '{args.eval_split}' is empty. Skipping evaluation.")
else:
eval_data = eval_data.select(range(min(args.max_eval_samples, len(eval_data))))
eval_dataset = PromptDatasetVL(eval_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template)
eval_dataset = dataset_loader.load_eval_dataset(
config=eval_config,
prompt_max_len=args.prompt_max_len,
input_template=args.input_template,
)
if eval_dataset is not None:
eval_dataloader = strategy.setup_dataloader(
eval_dataset, args.rollout_batch_size // strategy.world_size, False, False, collate_fn=eval_dataset.collate_fn
)
strategy.print(f"Evaluation dataset loaded: {len(eval_dataset)} samples")
else:
strategy.print("Warning: eval_split specified but no data path available for evaluation.")

# Prepare pretrain dataset
pretrain_dataloader = None
if args.pretrain_data:
strategy.print(f"Loading pretrain dataset from: {args.pretrain_data} with split: {args.pretrain_split}")
pretrain_data = blending_datasets(
args.pretrain_data, args.pretrain_data_probs, strategy, args.seed,
return_eval=False, train_split=args.pretrain_split,
pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len
# Calculate total samples needed for pretraining
total_pretrain_samples = args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt

pretrain_config = DatasetConfig.for_pretrain(
data_path=args.pretrain_data,
data_probs=args.pretrain_data_probs,
split=args.pretrain_split,
max_samples=total_pretrain_samples,
seed=args.seed,
)
if len(pretrain_data) == 0:
strategy.print(f"Warning: Pretrain dataset at {args.pretrain_data} is empty. PTX loss will not be applied.")
pretrain_dataloader = None
else:
pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len
# Calculate total samples needed for pretraining
total_pretrain_samples = args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt
pretrain_data_subset = pretrain_data.select(range(min(len(pretrain_data), total_pretrain_samples)))

pretrain_dataset = SFTDatasetVL(
pretrain_data_subset, tokenizer, pretrain_max_len, strategy, pretrain_mode=True,
)
strategy.print(f"Loaded {len(pretrain_dataset)} samples for pretraining.")
pretrain_dataset = dataset_loader.load_pretrain_dataset(
config=pretrain_config,
pretrain_max_len=pretrain_max_len,
)

if pretrain_dataset is not None:
pretrain_dataloader = itertools.cycle(
iter(
strategy.setup_dataloader(
Expand All @@ -282,21 +332,6 @@ def train(args):
prompts_dataset, args.rollout_batch_size // strategy.world_size, True, True, collate_fn=prompts_dataset.collate_fn
)

if args.pretrain_data:
pretrain_dataloader = itertools.cycle(
iter(
strategy.setup_dataloader(
pretrain_dataset,
args.micro_train_batch_size,
True,
True,
pretrain_dataset.collate_fn,
)
)
)
else:
pretrain_dataloader = None

# for scheduler
num_update_steps_per_episodes = (
len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs
Expand Down
47 changes: 40 additions & 7 deletions lightrft/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,32 @@
"""
Dataset Module for LightRFT

This module provides unified interfaces for loading datasets for training,
evaluation, and pretraining in RLHF workflows.

Main Features:
- Unified dataset configuration via DatasetConfig
- Consistent loading interface via DatasetLoader
- Support for train, eval, and pretrain datasets
- Automatic handling of blending_datasets parameters

Classes:
DatasetConfig: Configuration class for dataset loading
DatasetLoader: Unified loader for all dataset types
"""

# Import new unified interfaces first
from .config import DatasetConfig
from .loader import DatasetLoader

# Import existing dataset classes
from .process_reward_dataset import ProcessRewardDataset
from .prompts_dataset import PromptDataset
from .prompts_dataset_vl import PromptDatasetVL
from .sft_dataset import SFTDataset
from .sft_dataset_vl import SFTDatasetVL

# Import other dataset classes
from .grm_dataset import GRMDataset
from .srm_dataset import RankDatasetVL, RankDatasetAL
from .omnirewardbench import *
Expand All @@ -8,18 +37,22 @@
from .videodpo import *
from .videogen_rewardbench import *
from .genai_bench import *
from .rft_dataset import RFTDatasetVL
from .utils import (
extract_answer,
zero_pad_sequences,
find_subsequence,
load_multimodal_content,
BaseDataHandler,
)
from .process_reward_dataset import ProcessRewardDataset
from .prompts_dataset import PromptDataset
from .prompts_dataset_vl import PromptDatasetVL
from .sft_dataset import SFTDataset
from .sft_dataset_vl import SFTDatasetVL
from .rft_dataset import RFTDatasetVL

__all__ = ["ProcessRewardDataset", "PromptDataset", "PromptDatasetVL", "SFTDataset", "SFTDatasetVL", "RFTDatasetVL"]
__all__ = [
"DatasetConfig",
"DatasetLoader",
"ProcessRewardDataset",
"PromptDataset",
"PromptDatasetVL",
"SFTDataset",
"SFTDatasetVL",
"RFTDatasetVL",
]
Loading