From b2b34ae342d6f851226e995f2e1021d12e761093 Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Mon, 16 Aug 2021 11:57:55 -0700 Subject: [PATCH] Curriculum learning (#1307) Co-authored-by: Conglong Li Co-authored-by: Jeff Rasley --- README.md | 8 +- deepspeed/runtime/config.py | 21 +++ deepspeed/runtime/constants.py | 8 ++ deepspeed/runtime/data_pipeline/__init__.py | 0 .../data_pipeline/curriculum_scheduler.py | 133 ++++++++++++++++++ deepspeed/runtime/engine.py | 22 +++ docs/_config.yml | 1 + docs/_data/navigation.yml | 2 + docs/_pages/config-json.md | 76 ++++++++++ docs/_pages/features.md | 3 + docs/_tutorials/curriculum-learning.md | 107 ++++++++++++++ docs/index.md | 8 +- tests/unit/simple_model.py | 10 ++ tests/unit/test_curriculum_learning.py | 133 ++++++++++++++++++ 14 files changed, 530 insertions(+), 2 deletions(-) create mode 100644 deepspeed/runtime/data_pipeline/__init__.py create mode 100644 deepspeed/runtime/data_pipeline/curriculum_scheduler.py create mode 100644 docs/_tutorials/curriculum-learning.md create mode 100644 tests/unit/test_curriculum_learning.py diff --git a/README.md b/README.md index 4b752f021787..9e46cb022e4f 100755 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale) # News +* [2021/08/16] [Curriculum learning: a regularization method for stable and 2.6x faster GPT-2 pre-training with 8x/4x larger batch size/learning rate](https://www.deepspeed.ai/tutorials/curriculum-learning/) * [2021/05/24] [DeepSpeed: Accelerating large-scale model inference and training via system optimizations and compression](https://www.microsoft.com/en-us/research/blog/deepspeed-accelerating-large-scale-model-inference-and-training-via-system-optimizations-and-compression/) * [2021/04/20] [1-bit LAMB: up to 4.6x less communication and 2.8x faster training, together with LAMB's convergence speed at large batch sizes](https://www.deepspeed.ai/tutorials/onebit-lamb/) * [2021/04/19] [ZeRO-Infinity unlocks unprecedented model scale for deep learning training](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/) @@ -148,6 +149,10 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage. * Learning Rate Range Test * 1Cycle Learning Rate Schedule * [Simplified Data Loader](https://www.deepspeed.ai/features/#simplified-data-loader) +* [Curriculum Learning](https://www.deepspeed.ai/tutorials/curriculum-learning/) + * A curriculum learning-based data pipeline that presents easier or simpler examples earlier during training + * Stable and 2.6x faster GPT-2 pre-training with 8x/4x larger batch size/learning rate while maintaining token-wise convergence speed + * Complementary to many other DeepSpeed features * [Performance Analysis and Debugging](https://www.deepspeed.ai/features/#performance-analysis-and-debugging) @@ -198,9 +203,10 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information 2. Jeff Rasley, Samyam Rajbhandari, Olatunji Ruwase, and Yuxiong He. (2020) DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters. [In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (KDD '20, Tutorial)](https://dl.acm.org/doi/10.1145/3394486.3406703). 3. Minjia Zhang, Yuxiong He. (2020) Accelerating Training of Transformer-Based Language Models with Progressive Layer Dropping. [arXiv:2010.13369](https://arxiv.org/abs/2010.13369) and [NeurIPS 2020](https://proceedings.neurips.cc/paper/2020/hash/a1140a3d0df1c81e24ae954d935e8926-Abstract.html). 4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840). -5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888). +5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888) and [ICML 2021](http://proceedings.mlr.press/v139/tang21a.html). 6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857). 7. Conglong Li, Ammar Ahmad Awan, Hanlin Tang, Samyam Rajbhandari, Yuxiong He. (2021) 1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed. [arXiv:2104.06069](https://arxiv.org/abs/2104.06069). +8. Conglong Li, Minjia Zhang, Yuxiong He. (2021) Curriculum Learning: A Regularization Method for Efficient and Stable Billion-Scale GPT Model Pre-Training. [arXiv:2108.06084](https://arxiv.org/abs/2108.06084). # Videos 1. DeepSpeed KDD 2020 Tutorial diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 9b9b0f082917..a1f6b12535d6 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -55,6 +55,24 @@ class DeepSpeedConfigError(Exception): pass +def get_curriculum_enabled(param_dict): + if CURRICULUM_LEARNING in param_dict.keys(): + return get_scalar_param(param_dict[CURRICULUM_LEARNING], + CURRICULUM_ENABLED, + CURRICULUM_ENABLED_DEFAULT) + else: + return False + + +def get_curriculum_params(param_dict): + if CURRICULUM_LEARNING in param_dict.keys(): + curriculum_params = copy.copy(param_dict[CURRICULUM_LEARNING]) + curriculum_params.pop(CURRICULUM_ENABLED) + return curriculum_params + else: + return False + + def get_pld_enabled(param_dict): if PROGRESSIVE_LAYER_DROP in param_dict.keys(): return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP], @@ -812,6 +830,9 @@ def _initialize_params(self, param_dict): self.pld_enabled = get_pld_enabled(param_dict) self.pld_params = get_pld_params(param_dict) + self.curriculum_enabled = get_curriculum_enabled(param_dict) + self.curriculum_params = get_curriculum_params(param_dict) + checkpoint_params = get_checkpoint_params(param_dict) validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params) self.checkpoint_tag_validation_enabled = validation_mode != ValidationMode.IGNORE diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index cee2e3712438..420650c0135e 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -344,6 +344,14 @@ PLD_GAMMA = "gamma" PLD_GAMMA_DEFAULT = 0.001 +######################################### +# Curriculum Learning +######################################### +CURRICULUM_LEARNING = "curriculum_learning" + +CURRICULUM_ENABLED = "enabled" +CURRICULUM_ENABLED_DEFAULT = False + ######################################### # Validation modes diff --git a/deepspeed/runtime/data_pipeline/__init__.py b/deepspeed/runtime/data_pipeline/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/deepspeed/runtime/data_pipeline/curriculum_scheduler.py b/deepspeed/runtime/data_pipeline/curriculum_scheduler.py new file mode 100644 index 000000000000..5f676ab7905a --- /dev/null +++ b/deepspeed/runtime/data_pipeline/curriculum_scheduler.py @@ -0,0 +1,133 @@ +''' +Copyright 2021 The Microsoft DeepSpeed Team +''' +import math +from deepspeed.utils import logger + + +class CurriculumScheduler(object): + def __init__(self, config): + super().__init__() + self.state = {} + assert "curriculum_type" in config, "Curriculum learning requires the config 'curriculum_type'" + assert "min_difficulty" in config, "Curriculum learning requires the config 'min_difficulty'" + assert "max_difficulty" in config, "Curriculum learning requires the config 'max_difficulty'" + assert "schedule_type" in config, "Curriculum learning requires the config 'schedule_type'" + self.state['min_difficulty'] = config['min_difficulty'] + self.state['max_difficulty'] = config['max_difficulty'] + self.state['current_difficulty'] = config['min_difficulty'] + self.state['schedule_type'] = config['schedule_type'] + if config['schedule_type'] == 'fixed_discrete': + """ + The schedule_config is a list of difficulty and a list of max + step belonging to each difficulty. Example json config: + "schedule_config": { + "difficulty": [1,2,3], + "max_step": [5,10] + } + The "max_step" has one less element than "difficulty", because + the last difficulty will be used for all following steps. + The self.state['schedule'] is a dictionary of + difficulty : [max step for this difficulty, next difficulty]. + """ + assert "difficulty" in config['schedule_config'], "Curriculum learning with fixed_discrete schedule requires the schedule_config 'difficulty'" + assert "max_step" in config['schedule_config'], "Curriculum learning with fixed_discrete schedule requires the schedule_config 'max_step'" + assert len(config['schedule_config']['max_step']) > 0 + assert len(config['schedule_config']['difficulty']) > 0 + assert len(config['schedule_config']['difficulty']) == len( + config['schedule_config']['max_step']) + 1 + self.state['schedule'] = {} + for i in range(len(config['schedule_config']['max_step'])): + self.state['schedule'][config['schedule_config']['difficulty'][i]] = \ + [config['schedule_config']['max_step'][i], + config['schedule_config']['difficulty'][i+1]] + elif config['schedule_type'] == 'fixed_root': + """ + The schedule_config includes: + total_curriculum_step: how many steps the curriculum learning takes to go + from min difficulty to max difficulty. + difficulty_step: the difficulty level determined every time must + be a multiple of this difficulty_step. This is used to determine + the step of difficulty increase, and to ensure the use of NVIDIA + Tensor Core acceleration (requires multiple of 8 (FP16) or + 16 (INT8)). + root_degree: the degree of the root function. Degree of 2 means + square root and degree of 3 means cube root. Degree of 1 is + equivalent to linear. + "schedule_config": { + "total_curriculum_step": 30000, + "difficulty_step": 8, + "root_degree": 2 + } + """ + assert "total_curriculum_step" in config['schedule_config'], "Curriculum learning with fixed_root schedule requires the schedule_config 'total_curriculum_step'" + assert "difficulty_step" in config['schedule_config'], "Curriculum learning with fixed_root schedule requires the schedule_config 'difficulty_step'" + assert "root_degree" in config['schedule_config'], "Curriculum learning with fixed_root schedule requires the schedule_config 'root_degree'" + if config['schedule_config']['difficulty_step'] % 8 != 0: + logger.warning( + f'The difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your hardware.' + ) + self.state['schedule'] = config['schedule_config'] + elif config['schedule_type'] == 'fixed_linear': + """ + The schedule_config is the same as 'fixed_root' but without the + root_degree. + "schedule_config": { + "total_curriculum_step": 30000, + "difficulty_step": 8 + } + """ + assert "total_curriculum_step" in config['schedule_config'], "Curriculum learning with fixed_linear schedule requires the schedule_config 'total_curriculum_step'" + assert "difficulty_step" in config['schedule_config'], "Curriculum learning with fixed_linear schedule requires the schedule_config 'difficulty_step'" + if config['schedule_config']['difficulty_step'] % 8 != 0: + logger.warning( + f'The difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your hardware.' + ) + self.state['schedule'] = config['schedule_config'] + else: + raise RuntimeError('Unsupported curriculum schedule type') + + def get_current_difficulty(self): + return self.state['current_difficulty'] + + def set_current_difficulty(self, difficulty): + self.state['current_difficulty'] = difficulty + + def get_state(self): + return self.state + + def set_state(self, state): + self.state = state + + def __fixed_discrete_update_difficulty(self, global_steps): + s_state = self.state['schedule'][self.state['current_difficulty']] + if global_steps > s_state[0]: + self.state['current_difficulty'] = s_state[1] + return self.state['current_difficulty'] + + def __fixed_root_update_difficulty(self, global_steps, root_degree=None): + s_state = self.state['schedule'] + if root_degree is None: + root_degree = s_state['root_degree'] + next_difficulty = (float(global_steps) / + s_state['total_curriculum_step'])**(1.0 / root_degree) + next_difficulty = math.floor( + next_difficulty * + (self.state['max_difficulty'] - self.state['min_difficulty']) + + self.state['min_difficulty']) + next_difficulty -= (next_difficulty % s_state['difficulty_step']) + self.state['current_difficulty'] = min(next_difficulty, + self.state['max_difficulty']) + return self.state['current_difficulty'] + + def update_difficulty(self, global_steps): + if self.state['current_difficulty'] >= self.state['max_difficulty']: + return self.state['current_difficulty'] + if self.state['schedule_type'] == 'fixed_discrete': + return self.__fixed_discrete_update_difficulty(global_steps) + elif self.state['schedule_type'] == 'fixed_linear': + return self.__fixed_root_update_difficulty(global_steps, 1) + elif self.state['schedule_type'] == 'fixed_root': + return self.__fixed_root_update_difficulty(global_steps) + else: + raise RuntimeError('Unsupported curriculum schedule type') diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 2c48d0ed9bfb..dab48aaca7da 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -41,6 +41,7 @@ from deepspeed.utils.debug import debug_extract_module_and_param_names from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop from deepspeed.runtime.eigenvalue import Eigenvalue +from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler from .pipe.module import PipelineModule from .utils import ensure_directory_exists @@ -214,6 +215,9 @@ def __init__(self, if self.pld_enabled(): self.progressive_layer_drop = self._configure_progressive_layer_drop() + if self.curriculum_enabled(): + self.curriculum_scheduler = self._configure_curriculum_scheduler() + if self.global_rank == 0: self._config.print('DeepSpeedEngine configuration') if self.dump_state(): @@ -282,6 +286,12 @@ def eigenvalue_layer_name(self): def eigenvalue_layer_num(self): return self._config.eigenvalue_layer_num + def curriculum_enabled(self): + return self._config.curriculum_enabled + + def curriculum_params(self): + return self._config.curriculum_params + def tensorboard_enabled(self): return self._config.tensorboard_enabled @@ -998,6 +1008,10 @@ def _configure_progressive_layer_drop(self): return pld + def _configure_curriculum_scheduler(self): + scheduler = CurriculumScheduler(self.curriculum_params()) + return scheduler + @staticmethod def is_map_style_dataset(obj): return hasattr(obj, "__getitem__") and hasattr(obj, "__len__") @@ -1101,6 +1115,14 @@ def forward(self, *inputs, **kwargs): if self.module.training and self.progressive_layer_drop: kwargs.update(self.progressive_layer_drop.get_state()) + if self.module.training and self.curriculum_enabled(): + self.curriculum_scheduler.update_difficulty(self.global_steps + 1) + if self.curriculum_params()["curriculum_type"] == "seqlen": + kwargs.update({ + "curriculum_seqlen": + self.curriculum_scheduler.get_current_difficulty() + }) + if self.zero_optimization_partition_weights(): # Enable automated discovery of external parameters by indicating that # we are in a forward pass. diff --git a/docs/_config.yml b/docs/_config.yml index a39298be04f9..b57c56629885 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -36,6 +36,7 @@ collections: - bert-finetuning.md - bert-pretraining.md - cifar-10.md + - curriculum-learning.md - flops-profiler.md - gan.md - lrrt.md diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml index 78b6d1d80826..21b41c8a6fd3 100755 --- a/docs/_data/navigation.yml +++ b/docs/_data/navigation.yml @@ -70,6 +70,8 @@ lnav: url: /tutorials/bert-pretraining/ - title: "CIFAR-10" url: /tutorials/cifar-10/ + - title: "Curriculum Learning" + url: /tutorials/curriculum-learning/ - title: "Flops Profiler" url: /tutorials/flops-profiler/ - title: "GAN" diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index b8512b4621af..fda6ba72b2df 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -716,3 +716,79 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s "num_sliding_window_blocks": 3 } ``` + +### Curriculum Learning +```json + "curriculum_learning": { + "enabled": true, + "curriculum_type": "seqlen", + "min_difficulty": 8, + "max_difficulty": 1024, + "schedule_type": "fixed_linear", + "schedule_config": { + "total_curriculum_step": 40000, + "difficulty_step": 8 + } + } +``` +**enabled**: [boolean] + +| Description | Default | +| ----------------------------------------- | ------- | +| Set to true to enable curriculum learning | `false` | + +**curriculum_type**: [string] + +| Description | Default | +| ----------------------------------------------------------------- | ------- | +| Type of curriculum difficulty metric. Currently support `seqlen`. | N/A | + + +**min_difficulty**: [integer] + +| Description | Default | +| ----------------------------- | ------- | +| The starting difficulty level | N/A | + +**max_difficulty**: [integer] + +| Description | Default | +| --------------------------- | ------- | +| The ending difficulty level | N/A | + +**schedule_type**: [string] + +| Description | Default | +| -------------------------------------------------------------------------------------------------- | ------- | +| Type of curriculum schedule. Currently support `fixed_linear`, `fixed_root`, and `fixed_discrete`. | N/A | + + +**total_curriculum_step**: [integer] + +| Description | Default | +| --------------------------------------------------------------- | ------- | +| Total number of steps for the curriculum learning. One of the `schedule_config` when the `fixed_linear` and `fixed_root` schedule_type are used. | N/A | + +**difficulty_step**: [integer] + +| Description | Default | +| --------------------------------------------------------------- | ------- | +| At any time, the curriculum learning difficulty must be multiple of this `difficulty_step`. Set this to multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. One of the `schedule_config` when the `fixed_linear` and `fixed_root` schedule_type are used. | N/A | + +**root_degree**: [integer] + +| Description | Default | +| --------------------------------------------------------------- | ------- | +| Root degree of the curriculum schedule function. One of the `schedule_config` when the `fixed_root` schedule_type is used. | N/A | + +**difficulty**: [list of integer] + +| Description | Default | +| --------------------------------------------------------------- | ------- | +| List of difficulty levels to be used during schedule. One of the `schedule_config` when the `fixed_discrete` schedule_type is used. | N/A | + +**max_step**: [list of integer] + +| Description | Default | +| --------------------------------------------------------------- | ------- | +| List of which step to change difficulty level. One of the `schedule_config` when the `fixed_discrete` schedule_type is used. | N/A | diff --git a/docs/_pages/features.md b/docs/_pages/features.md index 9b0b89d0a64b..0b5e6e861bdd 100755 --- a/docs/_pages/features.md +++ b/docs/_pages/features.md @@ -241,6 +241,9 @@ DeepSpeed abstracts away data parallelism and model parallelism from the user wh comes to data loading. Users simply provide a PyTorch dataset, and DeepSpeed data loader can automatically handle batch creation appropriately. +## Curriculum Learning +Please refer to the [Curriculum Learning](/tutorials/curriculum-learning/) tutorial. + ## Performance Analysis and Debugging DeepSpeed provides a set of tools for performance analysis and debugging. diff --git a/docs/_tutorials/curriculum-learning.md b/docs/_tutorials/curriculum-learning.md new file mode 100644 index 000000000000..cc9352f8ae23 --- /dev/null +++ b/docs/_tutorials/curriculum-learning.md @@ -0,0 +1,107 @@ +--- +title: "Curriculum Learning: A Regularization Method for Efficient and Stable Billion-Scale GPT Model Pre-Training" +--- + +In this tutorial, we introduce DeepSpeed's curriculum learning-based data pipeline, which presents easier or simpler examples earlier during training. By enabling stable training with 8x/4x larger batch size/learning rate (whereas the baseline approach struggles with training divergence), we observe that curriculum learning (based on sequence length) provides stable and 2.6x faster GPT-2 pre-training (tested on 117M and 1.5B parameters), together with better token-wise convergence speed and zero-shot WikiText-103/LAMBADA evaluation results. In addition, since curriculum learning only affect the data pipeline, its benefit is complementary to many DeepSpeed features and other system optimization techniques. For example, curriculum learning is compatible with DeepSpeed's [ZeRO Redundancy Optimizer](/tutorials/zero/) and [ZeRO-Offload](/tutorials/zero-offload/), and Megatron-LM's Model Parallelism. + +To illustrate the benefits and usage of curriculum learning, we use the Megatron-LM GPT-2 pre-training task as example. For more details on this task, please refer to the [tutorial](/tutorials/megatron/). In addition, we also have a [paper](https://arxiv.org/abs/2108.06084) which provides the technical details including implementation and evaluations. + +## 1. Configurations and tuning strategy +Curriculum learning can be used by setting the DeepSpeed configuration as the following example json config file: + +```json +{ + "train_batch_size": 4096, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "max_grad_norm": 1.0, + "betas": [0.9, 0.95] + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "curriculum_learning": { + "enabled": true, + "curriculum_type": "seqlen", + "min_difficulty": 8, + "max_difficulty": 1024, + "schedule_type": "fixed_linear", + "schedule_config": { + "total_curriculum_step": 15000, + "difficulty_step": 8 + } + } +} +``` +To support curriculum learning, we add the following new parameters: + +`curriculum_type` is the type of curriculum difficulty metric. Currently we support the `seqlen` metric which presents shorter sequences earlier in training. We implement this type of curriculum learning by passing an additional `curriculum_seqlen` argument to the model's forward function, and performing training data sequence truncation before the actual forward pass. We will describe how to implement this in the Megatron-LM GPT-2 pre-training example below. + +`min_difficulty` is the starting difficulty level. For `seqlen` metric it means we start with sequence length as `min_difficulty`. We observe that lower `min_difficulty` usually provides better convergence speedup but with two caveats: First, sometimes (especially for large models) starting with too small difficulty level may lead to severe overfitting (e.g., training loss divergence or validation loss keeps jumping up and down) thus hurt the convergence. In such case it is recommended to try increasing the `min_difficulty`. Second, for `seqlen` metric it is recommended to set `min_difficulty` as multiple of 8 (for FP16 data) or 16 (for INT8 data) in order to enable [NVIDIA GPU's Tensor Core acceleration](https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/). + +`max_difficulty` is the ending difficulty level. For `seqlen` metric it should be set as the full sequence length (e.g., 1024 for Megatron-LM GPT-2 pre-training). + +`schedule_type` is the scheduling policy for curriculum learning (i.e., which difficulty level to use at certain step). Currently we support three schedules: `fixed_linear`, `fixed_root`, and `fixed_discrete`. We recommend to first try the `fixed_linear` schedule, which is easier to tune and provides great training stability/efficiency gain in our tests. Each schedule has its own configurations: + + +### 1.1 fixed_linear schedule +For `fixed_linear` schedule there are two configurations: + +```json +"schedule_type": "fixed_linear", +"schedule_config": { + "total_curriculum_step": 15000, + "difficulty_step": 8 +} +``` + +The `total_curriculum_step` is the total number of steps for the curriculum learning. For `fixed_linear` schedule the difficulty level will linearly increase from `min_difficulty` to `max_difficulty` during the `total_curriculum_step` duration. This configuration needs to be tuned for each training task. We observe that too small and too large `total_curriculum_step` are both suboptimal: with too small `total_curriculum_step` curriculum learning might not be able to provide enough training stability benefit so the training might still diverge; with too large `total_curriculum_step` the model may overfit too much during curriculum learning on the easier/simpler training data thus hurt the overall convergence. We recommend to first set `total_curriculum_step` as 20% to 40% of the total training steps (note that if you increase the batch size for the curriculum learning-based training, you also need to reduce the total training steps correspondingly), then increase the `total_curriculum_step` if the training is not stable, or reduce the `total_curriculum_step` to test if convergence improves. + +The `difficulty_step` configuration ensures that at anytime the difficulty level must be multiple of `difficulty_step`. We usually set it as 8 (for FP16 data) or 16 (for INT8 data) to enable [NVIDIA GPU's Tensor Core acceleration](https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/). If this is unrelated to your training experiment, you can set it as 1. + +### 1.2 fixed_root schedule +For `fixed_root` schedule there are three configurations: + +```json +"schedule_type": "fixed_root", +"schedule_config": { + "total_curriculum_step": 15000, + "difficulty_step": 8, + "root_degree": 2 +} +``` + +The `total_curriculum_step` and `difficulty_step` have the same meaning as in the `fixed_linear` schedule case. The `root_degree` determines the root degree of the root function of the schedule. The difficulty level at certain step is determined as ((current step/`total_curriculum_step`)**(1/`root_degree`)) * (`max_difficulty` - `min_difficulty`) + `min_difficulty`. Thus `fixed_linear` is basically a special case of `fixed_root` with `root_degree` as 1. In our (limited) study, we find the `fixed_root` schedule does not provide any clear advantage over `fixed_linear` schedule, while requiring one additional parameter. + +### 1.3 fixed_discrete schedule +For `fixed_discrete` schedule there are two configurations: + +```json +"schedule_type": "fixed_discrete", +"schedule_config": { + "difficulty": [1,2,3], + "max_step": [5,10] +} +``` + +The `difficulty` is a list of difficulty levels to be used during schedule. The `max_step` is a list of step timestamp to determine when to switch to next difficulty level. For example, the json config above means that at step 1-5 difficulty 1 is used, at step 6-10 difficulty 2 is used, from step 11 difficulty 3 is used. This `fixed_discrete` schedule provides the most flexible curriculum learning scheduling. However, we find that one risk of this kind of schedule is that if the model stays at certain difficulty level for too long, training divergence may happen when switching to next difficulty due to severe overfitting. + +## 2. Curriculum learning for Megatron-LM GPT-2 pre-training + +We provide example scripts under [DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning). The `ds_train.sh` is the training script to run and it also includes the actual configurations we used for the experiments in our [paper](https://arxiv.org/abs/2108.06084). + +Besides the additional DeepSpeed configurations, there are some other necessary changes on the user side to enable curriculum learning. First, it is necessary to add a `curriculum_seqlen` argument in the model's forward pass and use it to perform training data sequence length truncation. For Megatron-LM GPT-2 pre-training, we implement this in `forward()` in [DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py) and in `forward_step()` in [DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py). + +Second, since there will be less tokens per step during curriculum learning, for curriculum-learning based training it requires more steps in order to reach the same number of training tokens as baseline. Thus in Megatron-LM we add a `--train-tokens` argument to terminate the training based on number of tokens. Then we usually set a long enough `--train-iters` (e.g., two times of baseline's total training step), and set the `--train-tokens` the same for baseline and curriculum-learning based training. + +Third, again due to the less tokens per step during curriculum learning, we find that for curriculum-learning based training it is beneficial to increase the learning rate decay steps (otherwise the curriculum learning case will have faster token-wise learning rate decay than baseline). For `fixed_linear` schedule because we start from very short sequence length, the total number of tokens during the curriculum learning is roughly halved. Thus we usually just add half of `fixed_linear` schedule's `total_curriculum_step` to the Megatron-LM's `--lr-decay-iters`. diff --git a/docs/index.md b/docs/index.md index dd607aa1466f..78bc0dbbbf52 100755 --- a/docs/index.md +++ b/docs/index.md @@ -30,6 +30,7 @@ initiative to enable next-generation AI capabilities at scale, where you can fin information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale). # What's New? +* [2021/08/16] [Curriculum learning: a regularization method for stable and 2.6x faster GPT-2 pre-training with 8x/4x larger batch size/learning rate](https://www.deepspeed.ai/tutorials/curriculum-learning/) * [2021/05/24] [DeepSpeed: Accelerating large-scale model inference and training via system optimizations and compression](https://www.microsoft.com/en-us/research/blog/deepspeed-accelerating-large-scale-model-inference-and-training-via-system-optimizations-and-compression/) * [2021/04/20] [1-bit LAMB: up to 4.6x less communication and 2.8x faster training, together with LAMB's convergence speed at large batch sizes](https://www.deepspeed.ai/tutorials/onebit-lamb/) * [2021/04/19] [ZeRO-Infinity unlocks unprecedented model scale for deep learning training](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/) @@ -201,6 +202,10 @@ Below we provide a brief feature list, see our detailed [feature overview](https * Learning Rate Range Test * 1Cycle Learning Rate Schedule * [Simplified Data Loader](https://www.deepspeed.ai/features/#simplified-data-loader) +* [Curriculum Learning](https://www.deepspeed.ai/tutorials/curriculum-learning/) + * A curriculum learning-based data pipeline that presents easier or simpler examples earlier during training + * Stable and 2.6x faster GPT-2 pre-training with 8x/4x larger batch size/learning rate while maintaining token-wise convergence speed + * Complementary to many other DeepSpeed features * [Progressive Layer Dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html) * Efficient and robust compressed training * Up to 2.5x convergence speedup for pre-training @@ -235,9 +240,10 @@ comments. 2. Jeff Rasley, Samyam Rajbhandari, Olatunji Ruwase, and Yuxiong He. (2020) DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters. [In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (KDD '20, Tutorial)](https://dl.acm.org/doi/10.1145/3394486.3406703). 3. Minjia Zhang, Yuxiong He. (2020) Accelerating Training of Transformer-Based Language Models with Progressive Layer Dropping. [arXiv:2010.13369](https://arxiv.org/abs/2010.13369) and [NeurIPS 2020](https://proceedings.neurips.cc/paper/2020/hash/a1140a3d0df1c81e24ae954d935e8926-Abstract.html). 4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840). -5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888). +5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888) and [ICML 2021](http://proceedings.mlr.press/v139/tang21a.html). 6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857). 7. Conglong Li, Ammar Ahmad Awan, Hanlin Tang, Samyam Rajbhandari, Yuxiong He. (2021) 1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed. [arXiv:2104.06069](https://arxiv.org/abs/2104.06069). +8. Conglong Li, Minjia Zhang, Yuxiong He. (2021) Curriculum Learning: A Regularization Method for Efficient and Stable Billion-Scale GPT Model Pre-Training. [arXiv:2108.06084](https://arxiv.org/abs/2108.06084). # Videos 1. DeepSpeed KDD 2020 Tutorial diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index 15c40976b6a1..aa430fb35ce7 100755 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -21,6 +21,16 @@ def forward(self, x, y): return self.cross_entropy_loss(hidden_dim, y) +class Curriculum_SimpleModel(SimpleModel): + def __init__(self, hidden_dim, empty_grad=False): + super(Curriculum_SimpleModel, self).__init__(hidden_dim, empty_grad) + + def forward(self, x, y, **kwargs): + seqlen = kwargs.get('curriculum_seqlen', None) + loss = super(Curriculum_SimpleModel, self).forward(x, y) + return loss, seqlen + + class UnusedParametersModel(SimpleModel): def __init__(self, hidden_dim, empty_grad=False): super().__init__(hidden_dim, empty_grad) diff --git a/tests/unit/test_curriculum_learning.py b/tests/unit/test_curriculum_learning.py new file mode 100644 index 000000000000..cf0562ab61e1 --- /dev/null +++ b/tests/unit/test_curriculum_learning.py @@ -0,0 +1,133 @@ +import torch +import torch.distributed as dist +import deepspeed +import argparse +import pytest +import json +import os +import numpy as np +import time +from common import distributed_test +from simple_model import Curriculum_SimpleModel, random_dataloader, args_from_dict + + +def test_curriculum_scheduler_fixed_discrete(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + }, + "curriculum_learning": { + "enabled": True, + "curriculum_type": "seqlen", + "min_difficulty": 1, + "max_difficulty": 5, + "schedule_type": "fixed_discrete", + "schedule_config": { + "difficulty": [1, + 2, + 3, + 4, + 5], + "max_step": [2, + 4, + 6, + 8] + } + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + ground_truths = {1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 3, 7: 4, 8: 4} + model = Curriculum_SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_curriculum_scheduler_fixed_discrete(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss, seqlen = model(batch[0], batch[1]) + model.backward(loss) + model.step() + true_seqlen = 5 + if n + 1 in ground_truths: + true_seqlen = ground_truths[n + 1] + print('at step {} the seqlen is {}'.format(n + 1, seqlen)) + assert seqlen == true_seqlen, f"Incorrect curriculum schedule" + + _test_curriculum_scheduler_fixed_discrete(args=args, + model=model, + hidden_dim=hidden_dim) + + +def test_curriculum_scheduler_fixed_linear(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + }, + "curriculum_learning": { + "enabled": True, + "curriculum_type": "seqlen", + "min_difficulty": 2, + "max_difficulty": 10, + "schedule_type": "fixed_linear", + "schedule_config": { + "total_curriculum_step": 8, + "difficulty_step": 2 + } + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + ground_truths = {1: 2, 2: 4, 3: 4, 4: 6, 5: 6, 6: 8, 7: 8, 8: 10, 9: 10, 10: 10} + model = Curriculum_SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_curriculum_scheduler_fixed_linear(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss, seqlen = model(batch[0], batch[1]) + model.backward(loss) + model.step() + if n + 1 in ground_truths: + true_seqlen = ground_truths[n + 1] + print('at step {} the seqlen is {}'.format(n + 1, seqlen)) + assert seqlen == true_seqlen, f"Incorrect curriculum schedule" + + _test_curriculum_scheduler_fixed_linear(args=args, + model=model, + hidden_dim=hidden_dim)