Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🛠️ Update tests and fix PPO #2463

Merged
merged 13 commits into from
Dec 12, 2024
Prev Previous commit
Next Next commit
test both policy and critic
  • Loading branch information
kashif committed Dec 12, 2024
commit bcc47faa5452adc732fe141dbcb828d9b98c23c5
108 changes: 76 additions & 32 deletions tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,20 +167,23 @@ def tokenize(element):
self.eval_dataset = prepare_dataset(eval_dataset, self.tokenizer)

def test_basic_training(self):
"""Test basic PPO training configuration from the example script."""
"""Test basic PPO training configuration and verify model updates."""
with tempfile.TemporaryDirectory() as tmp_dir:
# Capture initial critic weights
# Capture initial weights
initial_critic_weights = {}
initial_policy_weights = {}
for name, param in self.value_model.named_parameters():
initial_critic_weights[name] = param.clone().detach()
for name, param in self.model.named_parameters():
initial_policy_weights[name] = param.clone().detach()

# Configure training args similar to example script
training_args = PPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=4, # Reduced from 64 for testing
per_device_train_batch_size=4,
gradient_accumulation_steps=1,
learning_rate=3e-6,
total_episodes=10, # Reduced for testing
total_episodes=10,
save_strategy="no",
report_to="none",
missing_eos_penalty=1.0,
Expand All @@ -198,28 +201,39 @@ def test_basic_training(self):
eval_dataset=self.eval_dataset,
)

# Train and verify no exceptions are raised
# Train
trainer.train()

# Check if critic weights have been updated
weights_updated = False
for name, param in self.value_model.named_parameters():
critic_weights_updated = False
for name, param in trainer.policy_and_value.value_model.named_parameters():
if not torch.allclose(initial_critic_weights[name], param.to("cpu")):
weights_updated = True
critic_weights_updated = True
break

self.assertTrue(weights_updated, "Critic weights were not updated during training")
# Check if policy weights have been updated
policy_weights_updated = False
for name, param in trainer.policy_and_value.policy.named_parameters():
if not torch.allclose(initial_policy_weights[name], param.to("cpu")):
policy_weights_updated = True
break

self.assertTrue(critic_weights_updated, "Critic weights were not updated during training")
self.assertTrue(policy_weights_updated, "Policy weights were not updated during training")

@require_peft
def test_peft_training(self):
"""Test PPO training with PEFT configuration."""
"""Test PPO training with PEFT configuration and verify model updates."""
from peft import LoraConfig

with tempfile.TemporaryDirectory() as tmp_dir:
# Capture initial critic weights
# Capture initial weights
initial_critic_weights = {}
initial_policy_weights = {}
for name, param in self.value_model.named_parameters():
initial_critic_weights[name] = param.clone().detach()
for name, param in self.model.named_parameters():
initial_policy_weights[name] = param.clone().detach()

# Configure training args
training_args = PPOConfig(
Expand Down Expand Up @@ -247,37 +261,48 @@ def test_peft_training(self):
args=training_args,
processing_class=self.tokenizer,
model=self.model,
ref_model=None, # No ref_model needed with PEFT
ref_model=None,
reward_model=self.reward_model,
value_model=self.value_model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
peft_config=peft_config,
)

# Train and verify no exceptions are raised
# Train
trainer.train()

# Check if critic weights have been updated
weights_updated = False
for name, param in self.value_model.named_parameters():
if not torch.allclose(initial_critic_weights[name], param.to("cpu")):
weights_updated = True
critic_weights_updated = False
for name, param in trainer.policy_and_value.value_model.named_parameters():
if name in initial_critic_weights and not torch.allclose(
initial_critic_weights[name], param.to("cpu")
):
critic_weights_updated = True
break

self.assertTrue(weights_updated, "Critic weights were not updated during training")
# Check if policy weights have been updated - for PEFT we check the LoRA weights
policy_weights_updated = False
for name, param in trainer.policy_and_value.policy.named_parameters():
if "lora" in name.lower() and param.requires_grad: # Only check LoRA weights
# New weights should be non-zero if they've been updated
if not torch.allclose(param, torch.zeros_like(param)):
policy_weights_updated = True
break

self.assertTrue(critic_weights_updated, "Critic weights were not updated during training")
self.assertTrue(policy_weights_updated, "Policy LoRA weights were not updated during training")

def test_deepspeed_config(self):
"""Test PPO training with DeepSpeed-like configuration."""
if platform.system() == "Windows":
# Skip on Windows as noted in original tests
return

with tempfile.TemporaryDirectory() as tmp_dir:
# Capture initial critic weights
# Capture initial weights
initial_critic_weights = {}
initial_policy_weights = {}
for name, param in self.value_model.named_parameters():
initial_critic_weights[name] = param.clone().detach()
for name, param in self.model.named_parameters():
initial_policy_weights[name] = param.clone().detach()

# Configure training args similar to deepspeed example
training_args = PPOConfig(
Expand Down Expand Up @@ -310,21 +335,32 @@ def test_deepspeed_config(self):
trainer.train()

# Check if critic weights have been updated
weights_updated = False
for name, param in self.value_model.named_parameters():
critic_weights_updated = False
for name, param in trainer.policy_and_value.value_model.named_parameters():
if not torch.allclose(initial_critic_weights[name], param.to("cpu")):
weights_updated = True
critic_weights_updated = True
break

self.assertTrue(weights_updated, "Critic weights were not updated during training")
# Check if policy weights have been updated
policy_weights_updated = False
for name, param in trainer.policy_and_value.policy.named_parameters():
if not torch.allclose(initial_policy_weights[name], param.to("cpu")):
policy_weights_updated = True
break

self.assertTrue(critic_weights_updated, "Critic weights were not updated during training")
self.assertTrue(policy_weights_updated, "Policy weights were not updated during training")

def test_with_num_train_epochs(self):
"""Test PPO training with num_train_epochs configuration."""
with tempfile.TemporaryDirectory() as tmp_dir:
# Capture initial critic weights
# Capture initial weights
initial_critic_weights = {}
initial_policy_weights = {}
for name, param in self.value_model.named_parameters():
initial_critic_weights[name] = param.clone().detach()
for name, param in self.model.named_parameters():
initial_policy_weights[name] = param.clone().detach()

# Configure training args
training_args = PPOConfig(
Expand Down Expand Up @@ -354,10 +390,18 @@ def test_with_num_train_epochs(self):
trainer.train()

# Check if critic weights have been updated
weights_updated = False
for name, param in self.value_model.named_parameters():
critic_weights_updated = False
for name, param in trainer.policy_and_value.value_model.named_parameters():
if not torch.allclose(initial_critic_weights[name], param.to("cpu")):
weights_updated = True
critic_weights_updated = True
break

# Check if policy weights have been updated
policy_weights_updated = False
for name, param in trainer.policy_and_value.policy.named_parameters():
if not torch.allclose(initial_policy_weights[name], param.to("cpu")):
policy_weights_updated = True
break

self.assertTrue(weights_updated, "Critic weights were not updated during training")
self.assertTrue(critic_weights_updated, "Critic weights were not updated during training")
self.assertTrue(policy_weights_updated, "Policy weights were not updated during training")
Loading