Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more argument settings for scheduler #6435

Merged
merged 1 commit into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,9 @@ def create_scheduler(self, num_training_steps: int):
learning_rate=self.args.learning_rate,
num_warmup_steps=warmup,
num_training_steps=num_training_steps,
num_cycles=self.args.num_cycles,
lr_end=self.args.lr_end,
power=self.args.power,
)

return self.lr_scheduler
Expand Down
29 changes: 29 additions & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@
learning_rate: float,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: Optional[float] = 0.5,
lr_end: Optional[float] = 1e-7,
power: Optional[float] = 1.0,
):
"""
Unified API to get any scheduler from its name.
Expand All @@ -408,6 +411,15 @@
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (``float``, *optional*):
The number of waves in the cosine scheduler (the defaults is to just decrease from the max value to 0
following a half-cosine). This is not required by all schedulers (hence the argument being optional)
lr_end (``float``, *optional*):
The end LR in the polynomial scheduler. This is not required by all schedulers (hence the argument
being optional).
power (``float``, *optional*):
The power factor in the polynomial scheduler. This is not required by all schedulers (hence the argument
being optional).
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
Expand All @@ -425,6 +437,23 @@
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")

if name == SchedulerType.COSINE:
return schedule_func(

Check warning on line 441 in paddlenlp/trainer/trainer_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer_utils.py#L441

Added line #L441 was not covered by tests
learning_rate,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
)

if name == SchedulerType.POLYNOMIAL:
return schedule_func(

Check warning on line 449 in paddlenlp/trainer/trainer_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer_utils.py#L449

Added line #L449 was not covered by tests
learning_rate,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
lr_end=lr_end,
power=power,
)

return schedule_func(learning_rate, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)


Expand Down
10 changes: 10 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ class TrainingArguments:
Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
warmup_steps (`int`, *optional*, defaults to 0):
Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine scheduler.
lr_end (`float`, *optional*, defaults to 1e-7):
The end LR used in the polynomial scheduler.
power (`float`, *optional*, defaults to 1.0):
The power factor used in the polynomial scheduler.

log_on_each_node (`bool`, *optional*, defaults to `True`):
In multinode distributed training, whether to log using `log_level` once per node, or only on the main
node.
Expand Down Expand Up @@ -363,6 +370,9 @@ class TrainingArguments:
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
)
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
num_cycles: float = field(default=0.5, metadata={"help": "The number of waves in the cosine scheduler."})
lr_end: float = field(default=1e-7, metadata={"help": "The end LR in the polynomial scheduler."})
power: float = field(default=1.0, metadata={"help": "The power factor in the polynomial scheduler."})

log_on_each_node: bool = field(
default=True,
Expand Down