Skip to content

Commit

Permalink
Big science related changes (#1407)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Shaden Smith <shaden.smith@microsoft.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
Co-authored-by: eltonzheng <eltonz@microsoft.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
  • Loading branch information
7 people authored Sep 29, 2021
1 parent 39c7744 commit e2fdd25
Show file tree
Hide file tree
Showing 19 changed files with 438 additions and 209 deletions.
2 changes: 1 addition & 1 deletion DeepSpeedExamples
20 changes: 14 additions & 6 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers

# DeepSpeed Checkpointing Enabled or Disabled
Expand Down Expand Up @@ -213,9 +213,12 @@ def model_parallel_cuda_manual_seed(seed):
model parallel regions.
"""
global mpu

tp_rank = bwc_tensor_model_parallel_rank(mpu)

# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
model_parallel_seed = offset + mpu.get_model_parallel_rank()
model_parallel_seed = offset + tp_rank
# Data parallel gets the original sedd.
data_parallel_seed = seed

Expand All @@ -225,7 +228,7 @@ def model_parallel_cuda_manual_seed(seed):
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(),
mpu.get_model_parallel_rank(),
tp_rank,
mpu.get_data_parallel_rank(),
model_parallel_seed,
data_parallel_seed),
Expand Down Expand Up @@ -515,9 +518,14 @@ def save_args_for_backward(*all_args):
global data_offsets, size_offsets
if mp_rank is None:
if mpu is not None:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
mp_rank = mpu.get_tensor_model_parallel_rank()
mp_size = mpu.get_tensor_model_parallel_world_size()
mp_group = mpu.get_tensor_model_parallel_group()
else:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
Expand Down
89 changes: 63 additions & 26 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
import deepspeed.utils.groups as groups
from deepspeed.runtime.utils import get_grad_norm
from deepspeed.utils import logger, log_dist, init_distributed
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.utils.debug import debug_extract_module_and_param_names
Expand Down Expand Up @@ -140,6 +141,8 @@ def __init__(self,
self.dist_backend = "nccl"
self.has_moe_layers = False
self.num_experts = None
self._step_applied = False
self._global_grad_norm = None

# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)
Expand Down Expand Up @@ -259,6 +262,40 @@ def get_batch_info(self):
"""
return self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps

def set_train_batch_size(self, train_batch_size):
"""Adjust the global batch size by increasing or decreasing the number of
micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
(i.e., ``train_micro_batch_size_per_gpu``) is not changed.
Args:
train_batch_size (int): The new global batch size for training.
Raises:
ValueError: if ``train_batch_size`` is not divisible by the
configured micro-batch size and data parallelism.
"""
if train_batch_size % (self.train_micro_batch_size_per_gpu() *
self.dp_world_size) != 0:
#print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}')
raise ValueError(
f'Train batch size must be divisible by micro-batch data parallelism')
new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() *
self.dp_world_size)
# overwrite config
self._config.train_batch_size = train_batch_size
self._config.gradient_accumulation_steps = new_gas

def get_global_grad_norm(self) -> float:
"""Return the 2-norm of all gradients. If there is model parallelism,
the norm will be global.
The computed norm will be cached and reused until the next step() pass.
.. note::
In the presence of model parallelism, this is a collective call
and acts as a barrier among ``mpu.get_model_parallel_group()``.
Returns:
float: norm
"""
return self._global_grad_norm

def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled

Expand Down Expand Up @@ -1146,6 +1183,18 @@ def is_iterable_style_dataset(obj):
def dataloader_drop_last(self):
return self._config.dataloader_drop_last

def was_step_applied(self) -> bool:
"""Returns True if the latest ``step()`` produced in parameter updates.
Note that a ``False`` return is not an error condition. Steps are frequently
no-ops, such as between gradient accumulation boundaries or when overflows
occur.
Returns:
bool: Whether the latest ``step()`` modified model parameters.
"""
return self._step_applied

def deepspeed_io(self,
dataset,
batch_size=None,
Expand Down Expand Up @@ -1432,6 +1481,9 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
mpu=self.mpu)
self.optimizer.step()

if hasattr(self.optimizer, '_global_grad_norm'):
self._global_grad_norm = self.optimizer._global_grad_norm

# Quantize the updated parameter if there is no overflow
if self.quantizer:
self.quantizer.quantize(
Expand All @@ -1454,12 +1506,19 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
overflow = False
if hasattr(self.optimizer, 'overflow'):
overflow = self.optimizer.overflow
self._step_applied = not overflow

if overflow:
self.skipped_steps += 1
else:
if self.lr_scheduler is not None:
self.lr_scheduler.step(**(lr_kwargs or {}))
try:
self.lr_scheduler.step(**(lr_kwargs or {}))
except TypeError:
# XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines.
# We don't currently have a way to specify lr_kwargs from
# pipe_engine.train_batch()
self.lr_scheduler.step(increment=self.train_batch_size())

if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1)
Expand All @@ -1479,6 +1538,8 @@ def step(self, lr_kwargs=None):
"init in order to use step"
report_progress = self.global_rank == 0 if self.global_rank else True

self._step_applied = False # assume False, will flip to True

# Update the model when we reach gradient accumulation boundaries
if self.is_gradient_accumulation_boundary():
self.gas_boundary_ctr += 1
Expand Down Expand Up @@ -2413,7 +2474,7 @@ def _copy_recovery_script(self, save_path):
script = "zero_to_fp32.py"
src = os.path.join(base_dir, "utils", script)
dst = os.path.join(save_path, script)
logger.info(f"creating recovery script {dst}")
#logger.info(f"creating recovery script {dst}")
copyfile(src, dst)
# make executable
os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC)
Expand Down Expand Up @@ -2530,27 +2591,3 @@ def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
os.makedirs(save_dir, exist_ok=True)
logger.info(f"Saving model weights to {path}")
torch.save(state_dict, path)

def set_train_batch_size(self, train_batch_size):
"""Adjust the global batch size by increasing or decreasing the size of
each micro-batch (i.e., ``train_micro_batch_size_per_gpu``). The number of
micro-batches (i.e., gradient accumulation steps) is not changed.
Args:
train_batch_size (int): The new global batch size for training.
Raises:
ValueError: if ``train_batch_size`` is not divisible by the
configured gradient_accumulation_steps and data parallelism.
"""

if train_batch_size % (self.gradient_accumulation_steps() *
self.dp_world_size) != 0:
raise ValueError(
f'Train batch size must be divisible by gradient_accumulation_steps * data parallelism'
)

new_micro_bsz = train_batch_size // (self.gradient_accumulation_steps() *
self.dp_world_size)

# overwrite config
self._config.train_batch_size = train_batch_size
self._config.train_micro_batch_size_per_gpu = new_micro_bsz
20 changes: 11 additions & 9 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import math
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import groups, logger, log_dist
import torch.distributed as dist
Expand Down Expand Up @@ -47,6 +47,8 @@ def __init__(self,
self.fp16_groups_flat = []
self.fp32_groups_flat = []

self._global_grad_norm = 0.

# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
Expand Down Expand Up @@ -163,8 +165,11 @@ def step_fused_adam(self, closure=None):
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
return self.overflow

self._global_grad_norm = get_global_norm(norm_list=norm_groups)

combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
norm_groups,
self._global_grad_norm,
apply_scale=False)
# norm is in fact norm*cur_scale
self.optimizer.step(grads=[[g] for g in grads_groups_flat],
Expand Down Expand Up @@ -268,8 +273,10 @@ def step(self, closure=None):

self.stop_timers([COMPUTE_NORM])

self._global_grad_norm = get_global_norm(norm_list=[all_groups_norm])

self.start_timers([UNSCALE_AND_CLIP])
self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm])
self.unscale_and_clip_grads(grads_groups_flat, self._global_grad_norm)
self.stop_timers([UNSCALE_AND_CLIP])

self.start_timers([BASIC_STEP])
Expand All @@ -294,12 +301,7 @@ def step(self, closure=None):

return self.overflow

def unscale_and_clip_grads(self, grad_groups_flat, norm_groups, apply_scale=True):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)

def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True):
# compute combined scale factor for this group
combined_scale = self.cur_scale
if self.clip_grad > 0.:
Expand Down
17 changes: 8 additions & 9 deletions deepspeed/runtime/fp16/unfused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import math

from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger

Expand All @@ -33,6 +33,7 @@ def __init__(self,
fused_lamb_legacy=False):

self.fused_lamb_legacy = fused_lamb_legacy
self._global_grad_norm = 0.

if torch.distributed.get_rank() == 0:
logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ')
Expand Down Expand Up @@ -163,7 +164,9 @@ def step_fused_lamb(self, closure=None):
self.cur_scale))
return self.overflow

combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False)
self._global_grad_norm = get_global_norm(norm_list=norm_groups)
combined_scale = self.unscale_and_clip_grads(self._global_grad_norm,
apply_scale=False)
self.optimizer.step(grads=grads_groups,
output_params=self.fp16_groups,
scale=combined_scale)
Expand Down Expand Up @@ -216,7 +219,8 @@ def step(self, closure=None):
else:
fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)

self.unscale_and_clip_grads(norm_groups)
self._global_grad_norm = get_global_norm(norm_list=norm_groups)
self.unscale_and_clip_grads(self._global_grad_norm)

self.optimizer.step()

Expand All @@ -231,12 +235,7 @@ def step(self, closure=None):

return self.overflow

def unscale_and_clip_grads(self, norm_groups, apply_scale=True):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)

def unscale_and_clip_grads(self, total_norm, apply_scale=True):
# compute combined scale factor for this group
combined_scale = self.cur_scale
if self.clip_grad > 0.:
Expand Down
Loading

0 comments on commit e2fdd25

Please sign in to comment.