Skip to content

Commit

Permalink
Support num_train_epochs (#1743)
Browse files Browse the repository at this point in the history
* add a test case for num_train_epochs

* fix ci

* quick change

* disable push to hub

* debug windows ci

* try another fix

* skip subprocess tests on windows
  • Loading branch information
vwxyzjn authored Jun 20, 2024
1 parent 3bf9449 commit 34d273f
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 10 deletions.
3 changes: 2 additions & 1 deletion examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,6 @@ def tokenize(element):
)
trainer.train()
trainer.save_model(config.output_dir)
trainer.push_to_hub()
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()
3 changes: 2 additions & 1 deletion examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,6 @@ def tokenize(element):
)
trainer.train()
trainer.save_model(config.output_dir)
trainer.push_to_hub()
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()
3 changes: 2 additions & 1 deletion examples/scripts/rloo/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,6 @@ def tokenize(element):
)
trainer.train()
trainer.save_model(config.output_dir)
trainer.push_to_hub()
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()
3 changes: 2 additions & 1 deletion examples/scripts/rloo/rloo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,6 @@ def tokenize(element):
)
trainer.train()
trainer.save_model(config.output_dir)
trainer.push_to_hub()
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()
32 changes: 30 additions & 2 deletions tests/test_ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import subprocess


def test():
command = """\
python -i examples/scripts/ppo/ppo.py \
python examples/scripts/ppo/ppo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 5 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
--non_eos_penalty \
--stop_token eos \
"""
if platform.system() == "Windows":
# windows CI does not work with subprocesses for some reason
# e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
return
subprocess.run(
command,
shell=True,
check=True,
)


def test_num_train_epochs():
command = """\
python examples/scripts/ppo/ppo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 1 \
--num_train_epochs 0.003 \
--model_name_or_path EleutherAI/pythia-14m \
--non_eos_penalty \
--stop_token eos \
"""
if platform.system() == "Windows":
# windows CI does not work with subprocesses for some reason
# e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
return
subprocess.run(
command,
shell=True,
Expand Down
9 changes: 7 additions & 2 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import subprocess

import torch


def test():
command = """\
python -i examples/scripts/rloo/rloo.py \
python examples/scripts/rloo/rloo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo \
--per_device_train_batch_size 5 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
--non_eos_penalty \
--stop_token eos \
"""
if platform.system() == "Windows":
# windows CI does not work with subprocesses for some reason
# e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
return
subprocess.run(
command,
shell=True,
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
# calculate various batch sizes
#########
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
args.total_episodes = args.num_train_epochs * self.train_dataset_len
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
self.accelerator = accelerator
args.world_size = accelerator.num_processes
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
# calculate various batch sizes
#########
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
args.total_episodes = args.num_train_epochs * self.train_dataset_len
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
self.accelerator = accelerator
args.world_size = accelerator.num_processes
Expand Down

0 comments on commit 34d273f

Please sign in to comment.