Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Curriculum learning #1307

Merged
merged 5 commits into from
Aug 16, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
add warning, fix config name
  • Loading branch information
conglongli committed Aug 16, 2021
commit 1d3acab161af787eb0a65beb472ace0181664a03
23 changes: 16 additions & 7 deletions deepspeed/runtime/data_pipeline/curriculum_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Copyright 2021 The Microsoft DeepSpeed Team
'''
import math
from deepspeed.utils import logger


class CurriculumScheduler(object):
Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(self, config):
elif config['schedule_type'] == 'fixed_root':
"""
The schedule_config includes:
total_step: how many steps the curriculum learning takes to go
total_curriculum_step: how many steps the curriculum learning takes to go
from min difficulty to max difficulty.
difficulty_step: the difficulty level determined every time must
be a multiple of this difficulty_step. This is used to determine
Expand All @@ -54,26 +55,34 @@ def __init__(self, config):
square root and degree of 3 means cube root. Degree of 1 is
equivalent to linear.
"schedule_config": {
"total_step": 30000,
"total_curriculum_step": 30000,
"difficulty_step": 8,
"root_degree": 2
}
"""
assert "total_step" in config['schedule_config'], "Curriculum learning with fixed_root schedule requires the schedule_config 'total_step'"
assert "total_curriculum_step" in config['schedule_config'], "Curriculum learning with fixed_root schedule requires the schedule_config 'total_curriculum_step'"
assert "difficulty_step" in config['schedule_config'], "Curriculum learning with fixed_root schedule requires the schedule_config 'difficulty_step'"
assert "root_degree" in config['schedule_config'], "Curriculum learning with fixed_root schedule requires the schedule_config 'root_degree'"
if config['schedule_config']['difficulty_step'] % 8 != 0:
logger.warning(
f'The difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your hardware.'
)
self.state['schedule'] = config['schedule_config']
elif config['schedule_type'] == 'fixed_linear':
"""
The schedule_config is the same as 'fixed_root' but without the
root_degree.
"schedule_config": {
"total_step": 30000,
"total_curriculum_step": 30000,
"difficulty_step": 8
}
"""
assert "total_step" in config['schedule_config'], "Curriculum learning with fixed_linear schedule requires the schedule_config 'total_step'"
assert "total_curriculum_step" in config['schedule_config'], "Curriculum learning with fixed_linear schedule requires the schedule_config 'total_curriculum_step'"
assert "difficulty_step" in config['schedule_config'], "Curriculum learning with fixed_linear schedule requires the schedule_config 'difficulty_step'"
if config['schedule_config']['difficulty_step'] % 8 != 0:
logger.warning(
f'The difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your hardware.'
)
self.state['schedule'] = config['schedule_config']
else:
raise RuntimeError('Unsupported curriculum schedule type')
Expand All @@ -100,8 +109,8 @@ def __fixed_root_update_difficulty(self, global_steps, root_degree=None):
s_state = self.state['schedule']
if root_degree is None:
root_degree = s_state['root_degree']
next_difficulty = (float(global_steps) / s_state['total_step'])**(1.0 /
root_degree)
next_difficulty = (float(global_steps) /
s_state['total_curriculum_step'])**(1.0 / root_degree)
next_difficulty = math.floor(
next_difficulty *
(self.state['max_difficulty'] - self.state['min_difficulty']) +
Expand Down
4 changes: 2 additions & 2 deletions docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s
"max_difficulty": 1024,
"schedule_type": "fixed_linear",
"schedule_config": {
"total_step": 40000,
"total_curriculum_step": 40000,
"difficulty_step": 8
}
}
Expand Down Expand Up @@ -763,7 +763,7 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s
| Type of curriculum schedule. Currently support `fixed_linear`, `fixed_root`, and `fixed_discrete`. | N/A |


<i>**total_step**</i>: [integer]
<i>**total_curriculum_step**</i>: [integer]

| Description | Default |
| --------------------------------------------------------------- | ------- |
Expand Down
12 changes: 6 additions & 6 deletions docs/_tutorials/curriculum-learning.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Curriculum learning can be used by setting the DeepSpeed configuration as the fo
"max_difficulty": 1024,
"schedule_type": "fixed_linear",
"schedule_config": {
"total_step": 15000,
"total_curriculum_step": 15000,
"difficulty_step": 8
}
}
Expand All @@ -60,12 +60,12 @@ For `fixed_linear` schedule there are two configurations:
```json
"schedule_type": "fixed_linear",
"schedule_config": {
"total_step": 15000,
"total_curriculum_step": 15000,
"difficulty_step": 8
}
```

The `total_step` is the total number of steps for the curriculum learning. For `fixed_linear` schedule the difficulty level will linearly increase from `min_difficulty` to `max_difficulty` during the `total_step` duration. This configuration needs to be tuned for each training task. We observe that too small and too large `total_step` are both suboptimal: with too small `total_step` curriculum learning might not be able to provide enough training stability benefit so the training might still diverge; with too large `total_step` the model may overfit too much during curriculum learning on the easier/simpler training data thus hurt the overall convergence. We recommend to first set `total_step` as 20% to 40% of the total training steps (note that if you increase the batch size for the curriculum learning-based training, you also need to reduce the total training steps correspondingly), then increase the `total_step` if the training is not stable, or reduce the `total_step` to test if convergence improves.
The `total_curriculum_step` is the total number of steps for the curriculum learning. For `fixed_linear` schedule the difficulty level will linearly increase from `min_difficulty` to `max_difficulty` during the `total_curriculum_step` duration. This configuration needs to be tuned for each training task. We observe that too small and too large `total_curriculum_step` are both suboptimal: with too small `total_curriculum_step` curriculum learning might not be able to provide enough training stability benefit so the training might still diverge; with too large `total_curriculum_step` the model may overfit too much during curriculum learning on the easier/simpler training data thus hurt the overall convergence. We recommend to first set `total_curriculum_step` as 20% to 40% of the total training steps (note that if you increase the batch size for the curriculum learning-based training, you also need to reduce the total training steps correspondingly), then increase the `total_curriculum_step` if the training is not stable, or reduce the `total_curriculum_step` to test if convergence improves.

The `difficulty_step` configuration ensures that at anytime the difficulty level must be multiple of `difficulty_step`. We usually set it as 8 (for FP16 data) or 16 (for INT8 data) to enable [NVIDIA GPU's Tensor Core acceleration](https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/). If this is unrelated to your training experiment, you can set it as 1.

Expand All @@ -75,13 +75,13 @@ For `fixed_root` schedule there are three configurations:
```json
"schedule_type": "fixed_root",
"schedule_config": {
"total_step": 15000,
"total_curriculum_step": 15000,
"difficulty_step": 8,
"root_degree": 2
}
```

The `total_step` and `difficulty_step` have the same meaning as in the `fixed_linear` schedule case. The `root_degree` determines the root degree of the root function of the schedule. The difficulty level at certain step is determined as ((current step/`total_step`)**(1/`root_degree`)) * (`max_difficulty` - `min_difficulty`) + `min_difficulty`. Thus `fixed_linear` is basically a special case of `fixed_root` with `root_degree` as 1. In our (limited) study, we find the `fixed_root` schedule does not provide any clear advantage over `fixed_linear` schedule, while requiring one additional parameter.
The `total_curriculum_step` and `difficulty_step` have the same meaning as in the `fixed_linear` schedule case. The `root_degree` determines the root degree of the root function of the schedule. The difficulty level at certain step is determined as ((current step/`total_curriculum_step`)**(1/`root_degree`)) * (`max_difficulty` - `min_difficulty`) + `min_difficulty`. Thus `fixed_linear` is basically a special case of `fixed_root` with `root_degree` as 1. In our (limited) study, we find the `fixed_root` schedule does not provide any clear advantage over `fixed_linear` schedule, while requiring one additional parameter.

### 1.3 fixed_discrete schedule
For `fixed_discrete` schedule there are two configurations:
Expand All @@ -104,4 +104,4 @@ Besides the additional DeepSpeed configurations, there are some other necessary

Second, since there will be less tokens per step during curriculum learning, for curriculum-learning based training it requires more steps in order to reach the same number of training tokens as baseline. Thus in Megatron-LM we add a `--train-tokens` argument to terminate the training based on number of tokens. Then we usually set a long enough `--train-iters` (e.g., two times of baseline's total training step), and set the `--train-tokens` the same for baseline and curriculum-learning based training.

Third, again due to the less tokens per step during curriculum learning, we find that for curriculum-learning based training it is beneficial to increase the learning rate decay steps (otherwise the curriculum learning case will have faster token-wise learning rate decay than baseline). For `fixed_linear` schedule because we start from very short sequence length, the total number of tokens during the curriculum learning is roughly halved. Thus we usually just add half of `fixed_linear` schedule's `total_step` to the Megatron-LM's `--lr-decay-iters`.
Third, again due to the less tokens per step during curriculum learning, we find that for curriculum-learning based training it is beneficial to increase the learning rate decay steps (otherwise the curriculum learning case will have faster token-wise learning rate decay than baseline). For `fixed_linear` schedule because we start from very short sequence length, the total number of tokens during the curriculum learning is roughly halved. Thus we usually just add half of `fixed_linear` schedule's `total_curriculum_step` to the Megatron-LM's `--lr-decay-iters`.
2 changes: 1 addition & 1 deletion tests/unit/test_curriculum_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_curriculum_scheduler_fixed_linear(tmpdir):
"max_difficulty": 10,
"schedule_type": "fixed_linear",
"schedule_config": {
"total_step": 8,
"total_curriculum_step": 8,
"difficulty_step": 2
}
}
Expand Down