Skip to content

Commit

Permalink
Improve overflow handling (#2944)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
tjruwase and jeffra authored Mar 7, 2023
1 parent 87eaf8f commit 80d8fcb
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 39 deletions.
16 changes: 16 additions & 0 deletions deepspeed/runtime/fp16/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
#Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9

import torch

INITIAL_LOSS_SCALE = 'init_scale'
SCALE_WINDOW = 'scale_window'
DELAYED_SHIFT = 'delayed_shift'
Expand All @@ -35,6 +37,7 @@ class LossScalerBase:
"""
def __init__(self, cur_scale):
self.cur_scale = cur_scale
self.dynamic = False

@property
def loss_scale(self):
Expand Down Expand Up @@ -117,6 +120,7 @@ def __init__(self,
self.cur_hysteresis = delayed_shift
self.consecutive_hysteresis = consecutive_hysteresis
self.raise_error_at_min_scale = raise_error_at_min_scale
self.dynamic = True

# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
Expand Down Expand Up @@ -170,6 +174,18 @@ def update_scale(self, overflow):
self.cur_iter += 1


# Although loss scaling is only defined for fp16, yet for backwards compatibility
# we still create a scaler for other dtypes (fp32, bf16) which does not perform any scaling.
def CreateLossScaler(dtype, static_loss_scale, dynamic_scaling, dynamic_loss_args):
if dtype == torch.half and dynamic_scaling:
if dynamic_loss_args is None:
return DynamicLossScaler()
return DynamicLossScaler(**dynamic_loss_args)

loss_scale_value = static_loss_scale if dtype == torch.half else 1.0
return LossScaler(scale=loss_scale_value)


##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
Expand Down
28 changes: 10 additions & 18 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from deepspeed.runtime import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter
from deepspeed.runtime.zero.partition_parameters import *
Expand Down Expand Up @@ -332,18 +332,11 @@ def __init__(self,
#exit(0)

# we may have a way of fusing dynamic scale. Do not support for now
if self.dtype == torch.float or not dynamic_loss_scale:
loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale

self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=loss_scale_value)
else:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)

self.dynamic_loss_scale = True
self.loss_scaler = CreateLossScaler(dtype=self.dtype,
static_loss_scale=static_loss_scale,
dynamic_scaling=dynamic_loss_scale,
dynamic_loss_args=dynamic_loss_args)
self.dynamic_loss_scale = self.loss_scaler.dynamic

self.debug_fp16_grads = [{} for _ in self.fp16_groups]

Expand Down Expand Up @@ -1844,11 +1837,10 @@ def _overflow_clean_up(self, prev_scale):
see_memory_usage('After overflow after clearing gradients', force=False)

if dist.get_rank() == 0:
logger.info(
"[deepspeed] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(dist.get_rank(),
prev_scale,
self.loss_scale))
overflow_msg = f"[deepspeed] OVERFLOW! Rank {dist.get_rank()} Skipping step."
if self.dtype == torch.half:
overflow_msg += f" Attempted loss scale: {prev_scale}, reducing to {self.loss_scale}"
logger.info(overflow_msg)

@instrument_w_nvtx
def _overflow_check_and_loss_scale_update(self):
Expand Down
31 changes: 10 additions & 21 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import OrderedDict

from deepspeed.runtime import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank,
get_global_norm,
empty_cache,
Expand Down Expand Up @@ -506,21 +506,11 @@ def __init__(self,
self.external_loss_scale = None

# we may have a way of fusing dynamic scale. Do not support for now
if self.dtype == torch.float or self.dtype == torch.bfloat16 or not dynamic_loss_scale:
loss_scale_value = 1.0 if (
(self.dtype == torch.float) or
(self.dtype == torch.bfloat16)) else static_loss_scale

self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=loss_scale_value)
cur_iter = 0
else:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)

self.dynamic_loss_scale = True
self.loss_scaler = CreateLossScaler(dtype=self.dtype,
static_loss_scale=static_loss_scale,
dynamic_scaling=dynamic_loss_scale,
dynamic_loss_args=dynamic_loss_args)
self.dynamic_loss_scale = self.loss_scaler.dynamic

see_memory_usage("Before initializing optimizer states", force=True)
self.initialize_optimizer_states()
Expand Down Expand Up @@ -1788,11 +1778,10 @@ def step(self, closure=None):
self._update_scale(self.overflow)
if self.overflow:
if dist.get_rank() == 0:
logger.info(
"[deepspeed] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(dist.get_rank(),
prev_scale,
self.loss_scale))
overflow_msg = f"[deepspeed] OVERFLOW! Rank {dist.get_rank()} Skipping step."
if self.dtype == torch.half:
overflow_msg += f" Attempted loss scale: {prev_scale}, reducing to {self.loss_scale}"
logger.info(overflow_msg)

see_memory_usage('After overflow before clearing gradients')
self.zero_grad(set_to_none=True)
Expand Down

0 comments on commit 80d8fcb

Please sign in to comment.