Skip to content

Commit

Permalink
Curriculum learning (#1307)
Browse files Browse the repository at this point in the history
Co-authored-by: Conglong Li <conglong.li@gmail.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
conglongli and jeffra authored Aug 16, 2021
1 parent 504893a commit b2b34ae
Show file tree
Hide file tree
Showing 14 changed files with 530 additions and 2 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)


# News
* [2021/08/16] [Curriculum learning: a regularization method for stable and 2.6x faster GPT-2 pre-training with 8x/4x larger batch size/learning rate](https://www.deepspeed.ai/tutorials/curriculum-learning/)
* [2021/05/24] [DeepSpeed: Accelerating large-scale model inference and training via system optimizations and compression](https://www.microsoft.com/en-us/research/blog/deepspeed-accelerating-large-scale-model-inference-and-training-via-system-optimizations-and-compression/)
* [2021/04/20] [1-bit LAMB: up to 4.6x less communication and 2.8x faster training, together with LAMB's convergence speed at large batch sizes](https://www.deepspeed.ai/tutorials/onebit-lamb/)
* [2021/04/19] [ZeRO-Infinity unlocks unprecedented model scale for deep learning training](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/)
Expand Down Expand Up @@ -148,6 +149,10 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage.
* Learning Rate Range Test
* 1Cycle Learning Rate Schedule
* [Simplified Data Loader](https://www.deepspeed.ai/features/#simplified-data-loader)
* [Curriculum Learning](https://www.deepspeed.ai/tutorials/curriculum-learning/)
* A curriculum learning-based data pipeline that presents easier or simpler examples earlier during training
* Stable and 2.6x faster GPT-2 pre-training with 8x/4x larger batch size/learning rate while maintaining token-wise convergence speed
* Complementary to many other DeepSpeed features
* [Performance Analysis and Debugging](https://www.deepspeed.ai/features/#performance-analysis-and-debugging)


Expand Down Expand Up @@ -198,9 +203,10 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
2. Jeff Rasley, Samyam Rajbhandari, Olatunji Ruwase, and Yuxiong He. (2020) DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters. [In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (KDD '20, Tutorial)](https://dl.acm.org/doi/10.1145/3394486.3406703).
3. Minjia Zhang, Yuxiong He. (2020) Accelerating Training of Transformer-Based Language Models with Progressive Layer Dropping. [arXiv:2010.13369](https://arxiv.org/abs/2010.13369) and [NeurIPS 2020](https://proceedings.neurips.cc/paper/2020/hash/a1140a3d0df1c81e24ae954d935e8926-Abstract.html).
4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840).
5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888).
5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888) and [ICML 2021](http://proceedings.mlr.press/v139/tang21a.html).
6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857).
7. Conglong Li, Ammar Ahmad Awan, Hanlin Tang, Samyam Rajbhandari, Yuxiong He. (2021) 1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed. [arXiv:2104.06069](https://arxiv.org/abs/2104.06069).
8. Conglong Li, Minjia Zhang, Yuxiong He. (2021) Curriculum Learning: A Regularization Method for Efficient and Stable Billion-Scale GPT Model Pre-Training. [arXiv:2108.06084](https://arxiv.org/abs/2108.06084).

# Videos
1. DeepSpeed KDD 2020 Tutorial
Expand Down
21 changes: 21 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@ class DeepSpeedConfigError(Exception):
pass


def get_curriculum_enabled(param_dict):
if CURRICULUM_LEARNING in param_dict.keys():
return get_scalar_param(param_dict[CURRICULUM_LEARNING],
CURRICULUM_ENABLED,
CURRICULUM_ENABLED_DEFAULT)
else:
return False


def get_curriculum_params(param_dict):
if CURRICULUM_LEARNING in param_dict.keys():
curriculum_params = copy.copy(param_dict[CURRICULUM_LEARNING])
curriculum_params.pop(CURRICULUM_ENABLED)
return curriculum_params
else:
return False


def get_pld_enabled(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys():
return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP],
Expand Down Expand Up @@ -812,6 +830,9 @@ def _initialize_params(self, param_dict):
self.pld_enabled = get_pld_enabled(param_dict)
self.pld_params = get_pld_params(param_dict)

self.curriculum_enabled = get_curriculum_enabled(param_dict)
self.curriculum_params = get_curriculum_params(param_dict)

checkpoint_params = get_checkpoint_params(param_dict)
validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params)
self.checkpoint_tag_validation_enabled = validation_mode != ValidationMode.IGNORE
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,14 @@
PLD_GAMMA = "gamma"
PLD_GAMMA_DEFAULT = 0.001

#########################################
# Curriculum Learning
#########################################
CURRICULUM_LEARNING = "curriculum_learning"

CURRICULUM_ENABLED = "enabled"
CURRICULUM_ENABLED_DEFAULT = False


#########################################
# Validation modes
Expand Down
Empty file.
133 changes: 133 additions & 0 deletions deepspeed/runtime/data_pipeline/curriculum_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
import math
from deepspeed.utils import logger


class CurriculumScheduler(object):
def __init__(self, config):
super().__init__()
self.state = {}
assert "curriculum_type" in config, "Curriculum learning requires the config 'curriculum_type'"
assert "min_difficulty" in config, "Curriculum learning requires the config 'min_difficulty'"
assert "max_difficulty" in config, "Curriculum learning requires the config 'max_difficulty'"
assert "schedule_type" in config, "Curriculum learning requires the config 'schedule_type'"
self.state['min_difficulty'] = config['min_difficulty']
self.state['max_difficulty'] = config['max_difficulty']
self.state['current_difficulty'] = config['min_difficulty']
self.state['schedule_type'] = config['schedule_type']
if config['schedule_type'] == 'fixed_discrete':
"""
The schedule_config is a list of difficulty and a list of max
step belonging to each difficulty. Example json config:
"schedule_config": {
"difficulty": [1,2,3],
"max_step": [5,10]
}
The "max_step" has one less element than "difficulty", because
the last difficulty will be used for all following steps.
The self.state['schedule'] is a dictionary of
difficulty : [max step for this difficulty, next difficulty].
"""
assert "difficulty" in config['schedule_config'], "Curriculum learning with fixed_discrete schedule requires the schedule_config 'difficulty'"
assert "max_step" in config['schedule_config'], "Curriculum learning with fixed_discrete schedule requires the schedule_config 'max_step'"
assert len(config['schedule_config']['max_step']) > 0
assert len(config['schedule_config']['difficulty']) > 0
assert len(config['schedule_config']['difficulty']) == len(
config['schedule_config']['max_step']) + 1
self.state['schedule'] = {}
for i in range(len(config['schedule_config']['max_step'])):
self.state['schedule'][config['schedule_config']['difficulty'][i]] = \
[config['schedule_config']['max_step'][i],
config['schedule_config']['difficulty'][i+1]]
elif config['schedule_type'] == 'fixed_root':
"""
The schedule_config includes:
total_curriculum_step: how many steps the curriculum learning takes to go

This comment has been minimized.

Copy link
@stas00

stas00 Oct 9, 2021

Collaborator

Is it too late to rename it to total_curriculum_steps? otherwise it's not very intuitive as it implies a specific step.

This comment has been minimized.

Copy link
@conglongli

conglongli Oct 9, 2021

Author Contributor

Right now it is too late because we already have internal 1P users using this API, so really don't want to change it and generate problems. Later as we design a more mature CL v2 framework, we will keep this in mind.

This comment has been minimized.

Copy link
@stas00

stas00 Oct 9, 2021

Collaborator

Understood! Thank you for considering that adjustment for the future, @conglongli

from min difficulty to max difficulty.
difficulty_step: the difficulty level determined every time must
be a multiple of this difficulty_step. This is used to determine
the step of difficulty increase, and to ensure the use of NVIDIA
Tensor Core acceleration (requires multiple of 8 (FP16) or
16 (INT8)).
root_degree: the degree of the root function. Degree of 2 means
square root and degree of 3 means cube root. Degree of 1 is
equivalent to linear.
"schedule_config": {
"total_curriculum_step": 30000,
"difficulty_step": 8,
"root_degree": 2
}
"""
assert "total_curriculum_step" in config['schedule_config'], "Curriculum learning with fixed_root schedule requires the schedule_config 'total_curriculum_step'"
assert "difficulty_step" in config['schedule_config'], "Curriculum learning with fixed_root schedule requires the schedule_config 'difficulty_step'"
assert "root_degree" in config['schedule_config'], "Curriculum learning with fixed_root schedule requires the schedule_config 'root_degree'"
if config['schedule_config']['difficulty_step'] % 8 != 0:
logger.warning(
f'The difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your hardware.'
)
self.state['schedule'] = config['schedule_config']
elif config['schedule_type'] == 'fixed_linear':
"""
The schedule_config is the same as 'fixed_root' but without the
root_degree.
"schedule_config": {
"total_curriculum_step": 30000,
"difficulty_step": 8
}
"""
assert "total_curriculum_step" in config['schedule_config'], "Curriculum learning with fixed_linear schedule requires the schedule_config 'total_curriculum_step'"
assert "difficulty_step" in config['schedule_config'], "Curriculum learning with fixed_linear schedule requires the schedule_config 'difficulty_step'"
if config['schedule_config']['difficulty_step'] % 8 != 0:
logger.warning(
f'The difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your hardware.'
)
self.state['schedule'] = config['schedule_config']
else:
raise RuntimeError('Unsupported curriculum schedule type')

def get_current_difficulty(self):
return self.state['current_difficulty']

def set_current_difficulty(self, difficulty):
self.state['current_difficulty'] = difficulty

def get_state(self):
return self.state

def set_state(self, state):
self.state = state

def __fixed_discrete_update_difficulty(self, global_steps):
s_state = self.state['schedule'][self.state['current_difficulty']]
if global_steps > s_state[0]:
self.state['current_difficulty'] = s_state[1]
return self.state['current_difficulty']

def __fixed_root_update_difficulty(self, global_steps, root_degree=None):
s_state = self.state['schedule']
if root_degree is None:
root_degree = s_state['root_degree']
next_difficulty = (float(global_steps) /
s_state['total_curriculum_step'])**(1.0 / root_degree)
next_difficulty = math.floor(
next_difficulty *
(self.state['max_difficulty'] - self.state['min_difficulty']) +
self.state['min_difficulty'])
next_difficulty -= (next_difficulty % s_state['difficulty_step'])
self.state['current_difficulty'] = min(next_difficulty,
self.state['max_difficulty'])
return self.state['current_difficulty']

def update_difficulty(self, global_steps):
if self.state['current_difficulty'] >= self.state['max_difficulty']:
return self.state['current_difficulty']
if self.state['schedule_type'] == 'fixed_discrete':
return self.__fixed_discrete_update_difficulty(global_steps)
elif self.state['schedule_type'] == 'fixed_linear':
return self.__fixed_root_update_difficulty(global_steps, 1)
elif self.state['schedule_type'] == 'fixed_root':
return self.__fixed_root_update_difficulty(global_steps)
else:
raise RuntimeError('Unsupported curriculum schedule type')
22 changes: 22 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from deepspeed.utils.debug import debug_extract_module_and_param_names
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from deepspeed.runtime.eigenvalue import Eigenvalue
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler

from .pipe.module import PipelineModule
from .utils import ensure_directory_exists
Expand Down Expand Up @@ -214,6 +215,9 @@ def __init__(self,
if self.pld_enabled():
self.progressive_layer_drop = self._configure_progressive_layer_drop()

if self.curriculum_enabled():
self.curriculum_scheduler = self._configure_curriculum_scheduler()

if self.global_rank == 0:
self._config.print('DeepSpeedEngine configuration')
if self.dump_state():
Expand Down Expand Up @@ -282,6 +286,12 @@ def eigenvalue_layer_name(self):
def eigenvalue_layer_num(self):
return self._config.eigenvalue_layer_num

def curriculum_enabled(self):
return self._config.curriculum_enabled

def curriculum_params(self):
return self._config.curriculum_params

def tensorboard_enabled(self):
return self._config.tensorboard_enabled

Expand Down Expand Up @@ -998,6 +1008,10 @@ def _configure_progressive_layer_drop(self):

return pld

def _configure_curriculum_scheduler(self):
scheduler = CurriculumScheduler(self.curriculum_params())
return scheduler

@staticmethod
def is_map_style_dataset(obj):
return hasattr(obj, "__getitem__") and hasattr(obj, "__len__")
Expand Down Expand Up @@ -1101,6 +1115,14 @@ def forward(self, *inputs, **kwargs):
if self.module.training and self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())

if self.module.training and self.curriculum_enabled():
self.curriculum_scheduler.update_difficulty(self.global_steps + 1)
if self.curriculum_params()["curriculum_type"] == "seqlen":
kwargs.update({
"curriculum_seqlen":
self.curriculum_scheduler.get_current_difficulty()
})

if self.zero_optimization_partition_weights():
# Enable automated discovery of external parameters by indicating that
# we are in a forward pass.
Expand Down
1 change: 1 addition & 0 deletions docs/_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ collections:
- bert-finetuning.md
- bert-pretraining.md
- cifar-10.md
- curriculum-learning.md
- flops-profiler.md
- gan.md
- lrrt.md
Expand Down
2 changes: 2 additions & 0 deletions docs/_data/navigation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ lnav:
url: /tutorials/bert-pretraining/
- title: "CIFAR-10"
url: /tutorials/cifar-10/
- title: "Curriculum Learning"
url: /tutorials/curriculum-learning/
- title: "Flops Profiler"
url: /tutorials/flops-profiler/
- title: "GAN"
Expand Down
76 changes: 76 additions & 0 deletions docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -716,3 +716,79 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s
"num_sliding_window_blocks": 3
}
```

### Curriculum Learning
```json
"curriculum_learning": {
"enabled": true,
"curriculum_type": "seqlen",
"min_difficulty": 8,
"max_difficulty": 1024,
"schedule_type": "fixed_linear",
"schedule_config": {
"total_curriculum_step": 40000,
"difficulty_step": 8
}
}
```
<i>**enabled**</i>: [boolean]

| Description | Default |
| ----------------------------------------- | ------- |
| Set to true to enable curriculum learning | `false` |

<i>**curriculum_type**</i>: [string]

| Description | Default |
| ----------------------------------------------------------------- | ------- |
| Type of curriculum difficulty metric. Currently support `seqlen`. | N/A |


<i>**min_difficulty**</i>: [integer]

| Description | Default |
| ----------------------------- | ------- |
| The starting difficulty level | N/A |

<i>**max_difficulty**</i>: [integer]

| Description | Default |
| --------------------------- | ------- |
| The ending difficulty level | N/A |

<i>**schedule_type**</i>: [string]

| Description | Default |
| -------------------------------------------------------------------------------------------------- | ------- |
| Type of curriculum schedule. Currently support `fixed_linear`, `fixed_root`, and `fixed_discrete`. | N/A |


<i>**total_curriculum_step**</i>: [integer]

| Description | Default |
| --------------------------------------------------------------- | ------- |
| Total number of steps for the curriculum learning. One of the `schedule_config` when the `fixed_linear` and `fixed_root` schedule_type are used. | N/A |

<i>**difficulty_step**</i>: [integer]

| Description | Default |
| --------------------------------------------------------------- | ------- |
| At any time, the curriculum learning difficulty must be multiple of this `difficulty_step`. Set this to multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. One of the `schedule_config` when the `fixed_linear` and `fixed_root` schedule_type are used. | N/A |

<i>**root_degree**</i>: [integer]

| Description | Default |
| --------------------------------------------------------------- | ------- |
| Root degree of the curriculum schedule function. One of the `schedule_config` when the `fixed_root` schedule_type is used. | N/A |

<i>**difficulty**</i>: [list of integer]

| Description | Default |
| --------------------------------------------------------------- | ------- |
| List of difficulty levels to be used during schedule. One of the `schedule_config` when the `fixed_discrete` schedule_type is used. | N/A |

<i>**max_step**</i>: [list of integer]

| Description | Default |
| --------------------------------------------------------------- | ------- |
| List of which step to change difficulty level. One of the `schedule_config` when the `fixed_discrete` schedule_type is used. | N/A |
3 changes: 3 additions & 0 deletions docs/_pages/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ DeepSpeed abstracts away data parallelism and model parallelism from the user wh
comes to data loading. Users simply provide a PyTorch dataset, and DeepSpeed data loader
can automatically handle batch creation appropriately.

## Curriculum Learning
Please refer to the [Curriculum Learning](/tutorials/curriculum-learning/) tutorial.

## Performance Analysis and Debugging

DeepSpeed provides a set of tools for performance analysis and debugging.
Expand Down
Loading

0 comments on commit b2b34ae

Please sign in to comment.