Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API for updating ZeRO gradients #6590

Merged
merged 8 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,8 +2145,6 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
else:
self.zero_grad()

report_progress = self.global_rank == 0 if self.global_rank else True

# Check overflow here since in DS fp16 optimizer, the overflow is updated in above step() function.
overflow = False
if hasattr(self.optimizer, "overflow"):
Expand All @@ -2166,8 +2164,10 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
# pipe_engine.train_batch()
self.lr_scheduler.step(self.train_batch_size())

if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1)
if self.steps_per_print() is not None:
report_progress = self.global_rank == 0 if self.global_rank else True
if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1)

self.losses = None
self.global_steps += 1
Expand Down
50 changes: 43 additions & 7 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2298,6 +2298,24 @@ def get_fp32_grad_for_param(self, param) -> Tensor:

return self._fp32_state_allgather(param, fp32_grad)

def set_fp32_grad_for_param(self, value, param):
if not param.requires_grad:
return

if not get_accelerator().resolves_data_dependency():
self.reduce_and_partition_stream.synchronize()

if self.offload_optimizer:
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
else:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id]
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

my_rank = dist.get_rank(group=self.dp_process_group)
value_partition = value.flatten().narrow(0, fp32_grad.numel() * my_rank, fp32_grad.numel())

fp32_grad.data.copy_(value_partition.data)

def _get_fp32_opt_state_partition(self, param, optim_state_key=None):
if not get_accelerator().resolves_data_dependency():
self.reduce_and_partition_stream.synchronize()
Expand Down Expand Up @@ -2346,12 +2364,6 @@ def set_full_hp_param(self, value, param, optim_state_key=None):

### Local API START ###

def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None
fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
return fp32_opt_state

def get_local_fp32_grad_for_param(self, param) -> Tensor:
if not param.requires_grad:
return None
Expand All @@ -2366,6 +2378,30 @@ def get_local_fp32_grad_for_param(self, param) -> Tensor:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float()
return fp32_grad

def set_local_grad_for_param(self, value, param):
if not param.requires_grad:
return

assert value.numel() == param.ds_tensor.numel(
), f" Number of elements do not match: {value.numel()} != {param.ds_tensor.ds_numel}"

if not get_accelerator().resolves_data_dependency():
self.reduce_and_partition_stream.synchronize()

if self.offload_optimizer:
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
else:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id]

fp32_grad.data.copy_(value.flatten().data)

def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None
fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
return fp32_opt_state

def set_local_hp_param(self, value, param, optim_state_key=None):
if not param.requires_grad:
return
Expand All @@ -2380,7 +2416,7 @@ def set_local_hp_param(self, value, param, optim_state_key=None):

if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_out(group_idx)
logger.info(f"[set_local_hp_param][update the params' value successfully]")
# logger.info(f"[set_local_hp_param][update the params' value successfully]")

### Local API END ###

Expand Down
6 changes: 3 additions & 3 deletions deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# TODO: Move tensor fragment and mixed precision to zero utils
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from .tensor_fragment import set_full_hp_param
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
from .tensor_fragment import set_full_hp_param, set_full_hp_grad
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state, safe_set_full_grad
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.runtime.dataloader import RepeatingLoader
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/utils/mixed_precision_linkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import types
from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
from deepspeed.utils import set_full_hp_param
from deepspeed.utils import set_full_hp_param, set_full_hp_grad


def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
Expand Down Expand Up @@ -35,6 +35,7 @@ def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_gr
lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param)
lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param)
lp_param.set_full_hp_grad = types.MethodType(set_full_hp_grad, lp_param)

# lp_param overlaps with partition if both are true
# 1) current_offset < partition_end,
Expand Down
127 changes: 80 additions & 47 deletions deepspeed/utils/tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ def get_hp_fragment(self, optim_state_key=None):
return self.hp_fragment
return self.get_optim_state_fragment(optim_state_key)

def get_lp_grad_fragment(self, index_in_param_group):
if self.use_offload:
gradient_dict = self.offload_gradient_dict
else:
gradient_dict = self.gradient_dict

if self.param_group_index not in gradient_dict or gradient_dict[self.param_group_index] is None:
raise ValueError("Gradients are only available immediately after backward and before engine step")

return gradient_dict[self.param_group_index][index_in_param_group]


def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys):
for key in opt_keys:
Expand Down Expand Up @@ -95,17 +106,7 @@ def set_full_hp_param(self, value, optim_state_key=None):
def get_full_hp_grad(self):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
hp_mapping = self._hp_mapping

if hp_mapping.use_offload:
gradient_dict = hp_mapping.offload_gradient_dict
else:
gradient_dict = hp_mapping.gradient_dict

if hp_mapping.param_group_index not in gradient_dict or gradient_dict[hp_mapping.param_group_index] is None:
raise ValueError("Gradients are only available immediately after backward and before engine step")

lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][self._index_in_param_group]
lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group)
hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()

lp_frag_address = self._hp_mapping.lp_fragment_address
Expand All @@ -120,6 +121,14 @@ def get_full_hp_grad(self):
return reduce_buffer.reshape_as(self)


def set_full_hp_grad(self, value):
if self._hp_mapping is not None:
lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group)
lp_frag_address = self._hp_mapping.lp_fragment_address
value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
lp_grad_fragment.data.copy_(value_fragment.data.reshape_as(lp_grad_fragment.data))


def safe_get_full_fp32_param(param):
"""Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.

Expand Down Expand Up @@ -188,7 +197,10 @@ def safe_set_full_optimizer_state(param, value, optim_state_key):

# TODO: Figure out the correct return dtype
def safe_get_full_grad(param):
"""Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
"""
Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
The return data type is that used for gradient accumulation. This is usually the param data type,
but could also be different (e.g., bf16 param training with fp32 gradient accumulation).

Args:
param (``torch.nn.Parameter``): A model parameter
Expand All @@ -207,74 +219,95 @@ def safe_get_full_grad(param):
return None


def safe_set_full_grad(param, value):
"""
Update the partitioned gradient of a low-precision (e.g., fp16) parameter.
To avoid precision issues, the update value should have the data type of
gradient accumulation.

Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): The un-partitioned new gradient value.
"""
if param.grad is not None:
param.grad.copy_(value)
elif hasattr(param, 'ds_id'):
# ZeRO stage 3 param
param._z3_optimizer.set_fp32_grad_for_param(value, param)
elif hasattr(param, '_hp_mapping'):
# ZeRO stage 1, 2, and bf16_optimizer params
param.set_full_hp_grad(value)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved


### Local API START ###
def safe_get_local_grad(param):
"""Get the fp32 gradient of a partitioned parameter.
"""
Get the local gradient partition of a ZeRO-3 partitioned parameter.
The return data type is that used for gradient accumulation. This is usually the param data type,
but could also be different (e.g., bf16 param training with fp32 gradient accumulation).
Args:
param (``torch.nn.Parameter``): A model parameter
"""
if param.grad is not None:
return param.grad
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
return param._z3_optimizer.get_local_fp32_grad_for_param(param)

# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
return param._z3_optimizer.get_local_fp32_grad_for_param(param)

return None
def safe_set_local_grad(param, value):
"""
Update the local gradient partition of a ZeRO-3 partitioned parameter.
To avoid precision issues, the update value should have the data type of
gradient accumulation.

Args:
param (``torch.nn.Parameter``): A model parameter.
value (``torch.Tensor``): New value of local gradient partition.
"""
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
param._z3_optimizer.set_local_grad_for_param(value, param)


def safe_get_local_fp32_param(param):
"""Get the fp32 partitioned parameter.
"""Get the local partition of a ZeRO-3 partitioned parameter in fp32 precision.
Args:
param (``torch.nn.Parameter``): A model parameter
param (``torch.nn.Parameter``): A model parameter.
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
return param._z3_optimizer.get_local_fp32_param(param)

return None
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
return param._z3_optimizer.get_local_fp32_param(param)


def safe_get_local_optimizer_state(param, optim_state_key):
"""Get the fp32 optimizer state of a partitioned parameter.
"""Get the local optimizer state partition of ZeRO-3 partitioned parameter in fp32 precision.
Args:
param (``torch.nn.Parameter``): A model parameter
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
return param._z3_optimizer.get_local_fp32_param(param, optim_state_key)

return None
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
return param._z3_optimizer.get_local_fp32_param(param, optim_state_key)


def safe_set_local_optimizer_state(param, value, optim_state_key):
"""Update the fp32 optimizer state of a partitioned parameter.
"""Update the local optimizer state partition of a ZeRO-3 partitioned parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
param (``torch.nn.Parameter``): A model parameter.
value (``torch.Tensor``): New value of local optimizer state partition.
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer).
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_local_hp_param(value, param, optim_state_key)
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
param._z3_optimizer.set_local_hp_param(value, param, optim_state_key)


def safe_set_local_fp32_param(param, value):
"""Update the partitioned fp32 parameter.
"""Update the local partition of ZeRO-3 partitioned parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
param (``torch.nn.Parameter``): A model parameter.
value (``torch.Tensor``): New value of local parameter partition.
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_local_hp_param(value, param)
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
param._z3_optimizer.set_local_hp_param(value, param)


### Local API END ###

# TODO: Implement API for setting ZeRO partitioned gradients


def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
param_group_index, partition_start, partition_size):
Expand Down
9 changes: 7 additions & 2 deletions deepspeed/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def get_mean(self, names, normalizer=1.0, reset=True):

class ThroughputTimer:

def __init__(self, config, batch_size, start_step=2, steps_per_output=50, monitor_memory=False, logging_fn=None):
def __init__(self, config, batch_size, start_step=2, steps_per_output=None, monitor_memory=False, logging_fn=None):
from deepspeed.utils import logger
self.config = config
self.start_time = 0
Expand Down Expand Up @@ -238,6 +238,11 @@ def start(self):
get_accelerator().synchronize()
self.start_time = time.time()

def _is_report_boundary(self):
if self.steps_per_output is None:
return False
return self.global_step_count % self.steps_per_output == 0

def stop(self, global_step=False, report_speed=True):
if not self.config.enabled or not self.started:
return
Expand All @@ -255,7 +260,7 @@ def stop(self, global_step=False, report_speed=True):
self.step_elapsed_time += duration

if global_step:
if report_speed and self.global_step_count % self.steps_per_output == 0:
if report_speed and self._is_report_boundary():
self.logging(
"epoch={}/micro_step={}/global_step={}, RunningAvgSamplesPerSec={}, CurrSamplesPerSec={}, "
"MemAllocated={}GB, MaxMemAllocated={}GB".format(
Expand Down
Loading
Loading