-
Notifications
You must be signed in to change notification settings - Fork 16
Implement pass rate-based curriculum learning with weighted sampling #153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: verl-latest-cispo
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -43,7 +43,8 @@ | |||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| from verl.utils.profiler import marked_timer | ||||||||||||||||||||||||||||||
| from verl.utils.rollout_skip import RolloutSkip | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| from verl.utils.pass_rate_tracker import PassRateTracker | ||||||||||||||||||||||||||||||
| from verl.utils.pass_rate_weighted_sampler import PassRateWeightedSampler | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class RayDAPOTrainer(RayPPOTrainer): | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
|
|
@@ -68,12 +69,19 @@ def fit(self): | |||||||||||||||||||||||||||||
| config=OmegaConf.to_container(self.config, resolve=True), | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| self.global_steps = 0 | ||||||||||||||||||||||||||||||
| self.global_steps = 0 | ||||||||||||||||||||||||||||||
| self.gen_steps = 0 | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # load checkpoint before doing anything | ||||||||||||||||||||||||||||||
| self._load_checkpoint() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Extract pass rate tracker from sampler if using curriculum learning | ||||||||||||||||||||||||||||||
| # The PassRateWeightedSampler owns the tracker internally but we need to manually update it during training | ||||||||||||||||||||||||||||||
| # Currently, we only support PassRateWeightedSampler for curriculum learning | ||||||||||||||||||||||||||||||
| self.pass_rate_tracker = None | ||||||||||||||||||||||||||||||
| self.data_sampler = self.train_dataloader.sampler # train_dataloader is created in `RayPPOTrainer._create_dataloader()` and always has a sampler | ||||||||||||||||||||||||||||||
| if isinstance(self.data_sampler, PassRateWeightedSampler): | ||||||||||||||||||||||||||||||
| self.pass_rate_tracker = self.data_sampler.pass_rate_tracker | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # perform validation before training | ||||||||||||||||||||||||||||||
| # currently, we only support validation using the reward_function. | ||||||||||||||||||||||||||||||
| if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): | ||||||||||||||||||||||||||||||
|
|
@@ -135,7 +143,6 @@ def fit(self): | |||||||||||||||||||||||||||||
| non_tensor_batch_keys=["raw_prompt_ids"], | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| is_last_step = self.global_steps >= self.total_training_steps | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| with marked_timer("step", timing_raw): | ||||||||||||||||||||||||||||||
|
|
@@ -189,7 +196,6 @@ def fit(self): | |||||||||||||||||||||||||||||
| reward_extra_infos_dict = {} | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| new_batch.batch["token_level_scores"] = reward_tensor | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if reward_extra_infos_dict: | ||||||||||||||||||||||||||||||
| new_batch.non_tensor_batch.update( | ||||||||||||||||||||||||||||||
| {k: np.array(v) for k, v in reward_extra_infos_dict.items()} | ||||||||||||||||||||||||||||||
|
|
@@ -206,6 +212,46 @@ def fit(self): | |||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # === Curriculum Learning: Update pass rate tracker for weighted resampling === | ||||||||||||||||||||||||||||||
| # When using PassRateWeightedSampler, track per-sample success rates to enable dynamic curriculum learning. | ||||||||||||||||||||||||||||||
| # The sampler uses these pass rates to adjust sampling probabilities in the next epoch. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Note: make updating the pass rate tracker as a utility function later | ||||||||||||||||||||||||||||||
| # 1. if sampler is an instance of PassRateWeightedSampler, self.pass_rate_tracker is not None | ||||||||||||||||||||||||||||||
| # 2. `dataset_index` field is added to the RL datatset to identify samples | ||||||||||||||||||||||||||||||
| if "dataset_index" in new_batch.non_tensor_batch and self.pass_rate_tracker is not None: | ||||||||||||||||||||||||||||||
| dataset_indices = new_batch.non_tensor_batch["dataset_index"] | ||||||||||||||||||||||||||||||
| # Sum token-level rewards to get sequence-level reward | ||||||||||||||||||||||||||||||
| seq_rewards = new_batch.batch["token_level_rewards"].sum(dim=-1).cpu().numpy() | ||||||||||||||||||||||||||||||
| # Success is 1 if sequence reward > 0, else 0 | ||||||||||||||||||||||||||||||
| successes = (seq_rewards > 0).astype(float) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Deduplicate: batch was repeated n times (interleaved), so we need to aggregate | ||||||||||||||||||||||||||||||
| unique_indices, inverse_indices = np.unique(dataset_indices, return_inverse=True) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Aggregate successes: take mean across rollouts for each sample | ||||||||||||||||||||||||||||||
| aggregated_successes = np.zeros(len(unique_indices), dtype=float) | ||||||||||||||||||||||||||||||
| for i, _ in enumerate(unique_indices): | ||||||||||||||||||||||||||||||
| mask = inverse_indices == i # boolean array to indicate positions of unique index i | ||||||||||||||||||||||||||||||
| aggregated_successes[i] = np.mean(successes[mask]) # take average success across rollouts for sample i | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| pass_rates = self.pass_rate_tracker.get_pass_rates() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Log curriculum metrics BEFORE updating tracker | ||||||||||||||||||||||||||||||
| # Track improvement of hardest samples (across all samples, not just attempted) | ||||||||||||||||||||||||||||||
| metrics['curriculum/hardest_10pct_pass_rate'] = float(np.percentile(pass_rates, 10)) | ||||||||||||||||||||||||||||||
| metrics['curriculum/hardest_25pct_pass_rate'] = float(np.percentile(pass_rates, 25)) | ||||||||||||||||||||||||||||||
| metrics['curriculum/hardest_50pct_pass_rate'] = float(np.percentile(pass_rates, 50)) | ||||||||||||||||||||||||||||||
| metrics['curriculum/hardest_75pct_pass_rate'] = float(np.percentile(pass_rates, 75)) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Batch-level statistics | ||||||||||||||||||||||||||||||
| metrics['curriculum/min_batch_pass_rate'] = float(np.min(aggregated_successes)) | ||||||||||||||||||||||||||||||
| metrics['curriculum/mean_batch_pass_rate'] = float(np.mean(aggregated_successes)) | ||||||||||||||||||||||||||||||
| metrics['curriculum/effective_batch_size'] = np.sum(aggregated_successes > 0)/len(unique_indices) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Update tracker with current batch results | ||||||||||||||||||||||||||||||
| self.pass_rate_tracker.update(sample_indices=unique_indices.astype(int), current_pass_rate=aggregated_successes) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| self.pass_rate_tracker.update(sample_indices=unique_indices.astype(int), current_pass_rate=aggregated_successes) | |
| # In distributed setups, ensure only rank 0 updates the tracker to avoid | |
| # each rank maintaining inconsistent local pass-rate statistics. | |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
| if torch.distributed.get_rank() == 0: | |
| self.pass_rate_tracker.update( | |
| sample_indices=unique_indices.astype(int), | |
| current_pass_rate=aggregated_successes, | |
| ) | |
| else: | |
| self.pass_rate_tracker.update( | |
| sample_indices=unique_indices.astype(int), | |
| current_pass_rate=aggregated_successes, | |
| ) |
jb3618columbia marked this conversation as resolved.
Show resolved
Hide resolved
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition if self.data_sampler is not None is insufficient. It should check if the sampler is an instance of PassRateWeightedSampler specifically, since get_wandb_3d_plot_data is only available on that class. Other samplers will cause an AttributeError. Change to: if isinstance(self.data_sampler, PassRateWeightedSampler).
| if self.data_sampler is not None: | |
| if isinstance(self.data_sampler, PassRateWeightedSampler): |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inline imports of wandb and pandas (using import) inside the metrics logging code is an anti-pattern. These should be imported at the module level or handled more cleanly. This dynamic import pattern can cause issues with IDE autocomplete, type checking, and makes dependencies less clear. If wandb/pandas are optional dependencies, consider using a try-except block at the module level and setting a flag.
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,25 +10,42 @@ WANDB_PROJECT="Reasoning360" # Your wandb project name | |||||||||
|
|
||||||||||
| # --- External Services --- | ||||||||||
| export STEM_LLM_JUDGE_URL="<STEM_LLM_JUDGE_URL>" # Optional: Fill in the llm-as-judge hosted URL for 'STEM' domain evaluation | ||||||||||
| export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-853:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty | ||||||||||
|
||||||||||
| export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-853:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty | |
| export MATH_LLM_JUDGE_URL="<MATH_LLM_JUDGE_URL>" # Optional: Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MATH_LLM_JUDGE_URL is configured to use plain http://, so all OmniMATH scoring requests (including prompts, model outputs, and scores) will be transmitted in cleartext over the cluster network. An attacker or malicious tenant with network access could sniff or tamper with this traffic, corrupting evaluation results or exfiltrating potentially sensitive data. Use an HTTPS endpoint for the math judge service and ensure TLS certificate validation is enabled so these requests are encrypted and integrity-protected.
| export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-853:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty | |
| export MATH_LLM_JUDGE_URL="https://azure-uk-hpc-H200-instance-853:8000" # Fill in the OmniMATH llm-as-judge HTTPS URL, only used to score OmniMATH dataset if not empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'PassRateTracker' is not used.