Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Curriculum learning #1307

Merged
merged 5 commits into from
Aug 16, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[squash] CL staging squash (#1305)
Co-authored-by: Conglong Li <conglong.li@gmail.com>
  • Loading branch information
jeffra and conglongli committed Aug 12, 2021
commit 47db735f8bcf10d8c16561784457593e81b0478a
21 changes: 21 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@ class DeepSpeedConfigError(Exception):
pass


def get_curriculum_enabled(param_dict):
if CURRICULUM_LEARNING in param_dict.keys():
return get_scalar_param(param_dict[CURRICULUM_LEARNING],
CURRICULUM_ENABLED,
CURRICULUM_ENABLED_DEFAULT)
else:
return False


def get_curriculum_params(param_dict):
if CURRICULUM_LEARNING in param_dict.keys():
curriculum_params = copy.copy(param_dict[CURRICULUM_LEARNING])
curriculum_params.pop(CURRICULUM_ENABLED)
return curriculum_params
else:
return False


def get_pld_enabled(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys():
return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP],
Expand Down Expand Up @@ -812,6 +830,9 @@ def _initialize_params(self, param_dict):
self.pld_enabled = get_pld_enabled(param_dict)
self.pld_params = get_pld_params(param_dict)

self.curriculum_enabled = get_curriculum_enabled(param_dict)
self.curriculum_params = get_curriculum_params(param_dict)

checkpoint_params = get_checkpoint_params(param_dict)
validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params)
self.checkpoint_tag_validation_enabled = validation_mode != ValidationMode.IGNORE
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,14 @@
PLD_GAMMA = "gamma"
PLD_GAMMA_DEFAULT = 0.001

#########################################
# Curriculum Learning
#########################################
CURRICULUM_LEARNING = "curriculum_learning"

CURRICULUM_ENABLED = "enabled"
CURRICULUM_ENABLED_DEFAULT = False


#########################################
# Validation modes
Expand Down
Empty file.
109 changes: 109 additions & 0 deletions deepspeed/runtime/data_pipeline/curriculum_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
import math


class CurriculumScheduler(object):
def __init__(self, config):
super().__init__()
self.state = {}
self.state['min_difficulty'] = config['min_difficulty']
self.state['max_difficulty'] = config['max_difficulty']
self.state['current_difficulty'] = config['min_difficulty']
self.state['schedule_type'] = config['schedule_type']
if config['schedule_type'] == 'fixed_discrete':
"""
The schedule_config is a list of difficulty and a list of max
step belonging to each difficulty. Example json config:
"schedule_config": {
"difficulty": [1,2,3],
"max_step": [5,10]
}
The "max_step" has one less element than "difficulty", because
the last difficulty will be used for all following steps.
The self.state['schedule'] is a dictionary of
difficulty : [max step for this difficulty, next difficulty].
"""
self.state['schedule'] = {}
for i in range(len(config['schedule_config']['max_step'])):
self.state['schedule'][config['schedule_config']['difficulty'][i]] = \
[config['schedule_config']['max_step'][i],
config['schedule_config']['difficulty'][i+1]]
elif config['schedule_type'] == 'fixed_root':
"""
The schedule_config includes:
total_step: how many steps the curriculum learning takes to go
from min difficulty to max difficulty.
difficulty_step: the difficulty level determined every time must
be a multiple of this difficulty_step. This is used to determine
the step of difficulty increase, and to ensure the use of NVIDIA
Tensor Core acceleration (requires multiple of 8 (FP16) or
16 (INT8)).
root_degree: the degree of the root function. Degree of 2 means
square root and degree of 3 means cube root. Degree of 1 is
equivalent to linear.
"schedule_config": {
"total_step": 30000,
"difficulty_step": 8,
"root_degree": 2
}
"""
self.state['schedule'] = config['schedule_config']
elif config['schedule_type'] == 'fixed_linear':
"""
The schedule_config is the same as 'fixed_root' but without the
root_degree.
"schedule_config": {
"total_step": 30000,
"difficulty_step": 8
}
"""
self.state['schedule'] = config['schedule_config']
else:
raise RuntimeError('Unsupported curriculum schedule type')

def get_current_difficulty(self):
return self.state['current_difficulty']

def set_current_difficulty(self, difficulty):
self.state['current_difficulty'] = difficulty

def get_state(self):
return self.state

def set_state(self, state):
self.state = state

def __fixed_discrete_update_difficulty(self, global_steps):
s_state = self.state['schedule'][self.state['current_difficulty']]
if global_steps > s_state[0]:
self.state['current_difficulty'] = s_state[1]
return self.state['current_difficulty']

def __fixed_root_update_difficulty(self, global_steps, root_degree=None):
s_state = self.state['schedule']
if root_degree is None:
root_degree = s_state['root_degree']
next_difficulty = (float(global_steps) / s_state['total_step'])**(1.0 /
root_degree)
next_difficulty = math.floor(
next_difficulty *
(self.state['max_difficulty'] - self.state['min_difficulty']) +
self.state['min_difficulty'])
next_difficulty -= (next_difficulty % s_state['difficulty_step'])
self.state['current_difficulty'] = min(next_difficulty,
self.state['max_difficulty'])
return self.state['current_difficulty']

def update_difficulty(self, global_steps):
if self.state['current_difficulty'] >= self.state['max_difficulty']:
return self.state['current_difficulty']
if self.state['schedule_type'] == 'fixed_discrete':
return self.__fixed_discrete_update_difficulty(global_steps)
elif self.state['schedule_type'] == 'fixed_linear':
return self.__fixed_root_update_difficulty(global_steps, 1)
elif self.state['schedule_type'] == 'fixed_root':
return self.__fixed_root_update_difficulty(global_steps)
else:
raise RuntimeError('Unsupported curriculum schedule type')
20 changes: 20 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from deepspeed.utils.debug import debug_extract_module_and_param_names
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from deepspeed.runtime.eigenvalue import Eigenvalue
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler

from .pipe.module import PipelineModule
from .utils import ensure_directory_exists
Expand Down Expand Up @@ -214,6 +215,9 @@ def __init__(self,
if self.pld_enabled():
self.progressive_layer_drop = self._configure_progressive_layer_drop()

if self.curriculum_enabled():
self.curriculum_scheduler = self._configure_curriculum_scheduler()

if self.global_rank == 0:
self._config.print('DeepSpeedEngine configuration')
if self.dump_state():
Expand Down Expand Up @@ -281,6 +285,12 @@ def eigenvalue_layer_name(self):

def eigenvalue_layer_num(self):
return self._config.eigenvalue_layer_num

def curriculum_enabled(self):
return self._config.curriculum_enabled

def curriculum_params(self):
return self._config.curriculum_params

def tensorboard_enabled(self):
return self._config.tensorboard_enabled
Expand Down Expand Up @@ -998,6 +1008,10 @@ def _configure_progressive_layer_drop(self):

return pld

def _configure_curriculum_scheduler(self):
scheduler = CurriculumScheduler(self.curriculum_params())
return scheduler

@staticmethod
def is_map_style_dataset(obj):
return hasattr(obj, "__getitem__") and hasattr(obj, "__len__")
Expand Down Expand Up @@ -1101,6 +1115,12 @@ def forward(self, *inputs, **kwargs):
if self.module.training and self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())

if self.module.training and self.curriculum_enabled():
self.curriculum_scheduler.update_difficulty(self.global_steps + 1)
if self.curriculum_params()["curriculum_type"] == "seqlen":
kwargs.update(
{"seqlen": self.curriculum_scheduler.get_current_difficulty()})

if self.zero_optimization_partition_weights():
# Enable automated discovery of external parameters by indicating that
# we are in a forward pass.
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def forward(self, x, y):
return self.cross_entropy_loss(hidden_dim, y)


class Curriculum_SimpleModel(SimpleModel):
def __init__(self, hidden_dim, empty_grad=False):
super(Curriculum_SimpleModel, self).__init__(hidden_dim, empty_grad)

def forward(self, x, y, **kwargs):
seqlen = kwargs.get('seqlen', None)
loss = super(Curriculum_SimpleModel, self).forward(x, y)
return loss, seqlen


class UnusedParametersModel(SimpleModel):
def __init__(self, hidden_dim, empty_grad=False):
super().__init__(hidden_dim, empty_grad)
Expand Down
133 changes: 133 additions & 0 deletions tests/unit/test_curriculum_learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
import torch.distributed as dist
import deepspeed
import argparse
import pytest
import json
import os
import numpy as np
import time
from common import distributed_test
from simple_model import Curriculum_SimpleModel, random_dataloader, args_from_dict


def test_curriculum_scheduler_fixed_discrete(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"weight_decay": 0.01
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 16
},
"curriculum_learning": {
"enabled": True,
"curriculum_type": "seqlen",
"min_difficulty": 1,
"max_difficulty": 5,
"schedule_type": "fixed_discrete",
"schedule_config": {
"difficulty": [1,
2,
3,
4,
5],
"max_step": [2,
4,
6,
8]
}
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
ground_truths = {1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 3, 7: 4, 8: 4}
model = Curriculum_SimpleModel(hidden_dim)

@distributed_test(world_size=[1, 2])
def _test_curriculum_scheduler_fixed_discrete(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=20,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss, seqlen = model(batch[0], batch[1])
model.backward(loss)
model.step()
true_seqlen = 5
if n + 1 in ground_truths:
true_seqlen = ground_truths[n + 1]
print('at step {} the seqlen is {}'.format(n + 1, seqlen))
assert seqlen == true_seqlen, f"Incorrect curriculum schedule"

_test_curriculum_scheduler_fixed_discrete(args=args,
model=model,
hidden_dim=hidden_dim)


def test_curriculum_scheduler_fixed_linear(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"weight_decay": 0.01
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 16
},
"curriculum_learning": {
"enabled": True,
"curriculum_type": "seqlen",
"min_difficulty": 2,
"max_difficulty": 10,
"schedule_type": "fixed_linear",
"schedule_config": {
"total_step": 8,
"difficulty_step": 2
}
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
ground_truths = {1: 2, 2: 4, 3: 4, 4: 6, 5: 6, 6: 8, 7: 8, 8: 10, 9: 10, 10: 10}
model = Curriculum_SimpleModel(hidden_dim)

@distributed_test(world_size=[1, 2])
def _test_curriculum_scheduler_fixed_linear(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=20,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss, seqlen = model(batch[0], batch[1])
model.backward(loss)
model.step()
if n + 1 in ground_truths:
true_seqlen = ground_truths[n + 1]
print('at step {} the seqlen is {}'.format(n + 1, seqlen))
assert seqlen == true_seqlen, f"Incorrect curriculum schedule"

_test_curriculum_scheduler_fixed_linear(args=args,
model=model,
hidden_dim=hidden_dim)