Skip to content

Commit

Permalink
Add API for updating ZeRO gradients (#6590)
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Oct 14, 2024
1 parent cf41e8c commit 65ab644
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 79 deletions.
50 changes: 43 additions & 7 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2299,6 +2299,24 @@ def get_fp32_grad_for_param(self, param) -> Tensor:

return self._fp32_state_allgather(param, fp32_grad)

def set_fp32_grad_for_param(self, value, param):
if not param.requires_grad:
return

if not get_accelerator().resolves_data_dependency():
self.reduce_and_partition_stream.synchronize()

if self.offload_optimizer:
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
else:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id]

my_rank = dist.get_rank(group=self.dp_process_group)
value_partition = value.flatten().narrow(0, fp32_grad.numel() * my_rank, fp32_grad.numel())

fp32_grad.data.copy_(value_partition.data)

def _get_fp32_opt_state_partition(self, param, optim_state_key=None):
if not get_accelerator().resolves_data_dependency():
self.reduce_and_partition_stream.synchronize()
Expand Down Expand Up @@ -2347,12 +2365,6 @@ def set_full_hp_param(self, value, param, optim_state_key=None):

### Local API START ###

def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None
fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
return fp32_opt_state

def get_local_fp32_grad_for_param(self, param) -> Tensor:
if not param.requires_grad:
return None
Expand All @@ -2367,6 +2379,30 @@ def get_local_fp32_grad_for_param(self, param) -> Tensor:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float()
return fp32_grad

def set_local_grad_for_param(self, value, param):
if not param.requires_grad:
return

assert value.numel() == param.ds_tensor.numel(
), f" Number of elements do not match: {value.numel()} != {param.ds_tensor.ds_numel}"

if not get_accelerator().resolves_data_dependency():
self.reduce_and_partition_stream.synchronize()

if self.offload_optimizer:
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
else:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id]

fp32_grad.data.copy_(value.flatten().data)

def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None
fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
return fp32_opt_state

def set_local_hp_param(self, value, param, optim_state_key=None):
if not param.requires_grad:
return
Expand All @@ -2381,7 +2417,7 @@ def set_local_hp_param(self, value, param, optim_state_key=None):

if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_out(group_idx)
logger.info(f"[set_local_hp_param][update the params' value successfully]")
# logger.info(f"[set_local_hp_param][update the params' value successfully]")

### Local API END ###

Expand Down
6 changes: 3 additions & 3 deletions deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# TODO: Move tensor fragment and mixed precision to zero utils
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from .tensor_fragment import set_full_hp_param
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
from .tensor_fragment import set_full_hp_param, set_full_hp_grad
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state, safe_set_full_grad
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.runtime.dataloader import RepeatingLoader
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/utils/mixed_precision_linkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import types
from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
from deepspeed.utils import set_full_hp_param
from deepspeed.utils import set_full_hp_param, set_full_hp_grad


def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
Expand Down Expand Up @@ -35,6 +35,7 @@ def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_gr
lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param)
lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param)
lp_param.set_full_hp_grad = types.MethodType(set_full_hp_grad, lp_param)

# lp_param overlaps with partition if both are true
# 1) current_offset < partition_end,
Expand Down
127 changes: 80 additions & 47 deletions deepspeed/utils/tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ def get_hp_fragment(self, optim_state_key=None):
return self.hp_fragment
return self.get_optim_state_fragment(optim_state_key)

def get_lp_grad_fragment(self, index_in_param_group):
if self.use_offload:
gradient_dict = self.offload_gradient_dict
else:
gradient_dict = self.gradient_dict

if self.param_group_index not in gradient_dict or gradient_dict[self.param_group_index] is None:
raise ValueError("Gradients are only available immediately after backward and before engine step")

return gradient_dict[self.param_group_index][index_in_param_group]


def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys):
for key in opt_keys:
Expand Down Expand Up @@ -95,17 +106,7 @@ def set_full_hp_param(self, value, optim_state_key=None):
def get_full_hp_grad(self):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
hp_mapping = self._hp_mapping

if hp_mapping.use_offload:
gradient_dict = hp_mapping.offload_gradient_dict
else:
gradient_dict = hp_mapping.gradient_dict

if hp_mapping.param_group_index not in gradient_dict or gradient_dict[hp_mapping.param_group_index] is None:
raise ValueError("Gradients are only available immediately after backward and before engine step")

lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][self._index_in_param_group]
lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group)
hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()

lp_frag_address = self._hp_mapping.lp_fragment_address
Expand All @@ -120,6 +121,14 @@ def get_full_hp_grad(self):
return reduce_buffer.reshape_as(self)


def set_full_hp_grad(self, value):
if self._hp_mapping is not None:
lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group)
lp_frag_address = self._hp_mapping.lp_fragment_address
value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
lp_grad_fragment.data.copy_(value_fragment.data.reshape_as(lp_grad_fragment.data))


def safe_get_full_fp32_param(param):
"""Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.
Expand Down Expand Up @@ -188,7 +197,10 @@ def safe_set_full_optimizer_state(param, value, optim_state_key):

# TODO: Figure out the correct return dtype
def safe_get_full_grad(param):
"""Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
"""
Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
The return data type is that used for gradient accumulation. This is usually the param data type,
but could also be different (e.g., bf16 param training with fp32 gradient accumulation).
Args:
param (``torch.nn.Parameter``): A model parameter
Expand All @@ -207,74 +219,95 @@ def safe_get_full_grad(param):
return None


def safe_set_full_grad(param, value):
"""
Update the partitioned gradient of a low-precision (e.g., fp16) parameter.
To avoid precision issues, the update value should have the data type of
gradient accumulation.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): The un-partitioned new gradient value.
"""
if param.grad is not None:
param.grad.copy_(value)
elif hasattr(param, 'ds_id'):
# ZeRO stage 3 param
param._z3_optimizer.set_fp32_grad_for_param(value, param)
elif hasattr(param, '_hp_mapping'):
# ZeRO stage 1, 2, and bf16_optimizer params
param.set_full_hp_grad(value)


### Local API START ###
def safe_get_local_grad(param):
"""Get the fp32 gradient of a partitioned parameter.
"""
Get the local gradient partition of a ZeRO-3 partitioned parameter.
The return data type is that used for gradient accumulation. This is usually the param data type,
but could also be different (e.g., bf16 param training with fp32 gradient accumulation).
Args:
param (``torch.nn.Parameter``): A model parameter
"""
if param.grad is not None:
return param.grad
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
return param._z3_optimizer.get_local_fp32_grad_for_param(param)

# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
return param._z3_optimizer.get_local_fp32_grad_for_param(param)

return None
def safe_set_local_grad(param, value):
"""
Update the local gradient partition of a ZeRO-3 partitioned parameter.
To avoid precision issues, the update value should have the data type of
gradient accumulation.
Args:
param (``torch.nn.Parameter``): A model parameter.
value (``torch.Tensor``): New value of local gradient partition.
"""
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
param._z3_optimizer.set_local_grad_for_param(value, param)


def safe_get_local_fp32_param(param):
"""Get the fp32 partitioned parameter.
"""Get the local partition of a ZeRO-3 partitioned parameter in fp32 precision.
Args:
param (``torch.nn.Parameter``): A model parameter
param (``torch.nn.Parameter``): A model parameter.
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
return param._z3_optimizer.get_local_fp32_param(param)

return None
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
return param._z3_optimizer.get_local_fp32_param(param)


def safe_get_local_optimizer_state(param, optim_state_key):
"""Get the fp32 optimizer state of a partitioned parameter.
"""Get the local optimizer state partition of ZeRO-3 partitioned parameter in fp32 precision.
Args:
param (``torch.nn.Parameter``): A model parameter
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
return param._z3_optimizer.get_local_fp32_param(param, optim_state_key)

return None
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
return param._z3_optimizer.get_local_fp32_param(param, optim_state_key)


def safe_set_local_optimizer_state(param, value, optim_state_key):
"""Update the fp32 optimizer state of a partitioned parameter.
"""Update the local optimizer state partition of a ZeRO-3 partitioned parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
param (``torch.nn.Parameter``): A model parameter.
value (``torch.Tensor``): New value of local optimizer state partition.
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer).
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_local_hp_param(value, param, optim_state_key)
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
param._z3_optimizer.set_local_hp_param(value, param, optim_state_key)


def safe_set_local_fp32_param(param, value):
"""Update the partitioned fp32 parameter.
"""Update the local partition of ZeRO-3 partitioned parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
param (``torch.nn.Parameter``): A model parameter.
value (``torch.Tensor``): New value of local parameter partition.
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_local_hp_param(value, param)
assert hasattr(param, 'ds_id'), f'This API is only defined for ZeRO-3 partitioned parameters'
param._z3_optimizer.set_local_hp_param(value, param)


### Local API END ###

# TODO: Implement API for setting ZeRO partitioned gradients


def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
param_group_index, partition_start, partition_size):
Expand Down
44 changes: 37 additions & 7 deletions docs/code-docs/source/zero3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,13 @@ These routines can be used in a training loop as shown in the following snippet.
from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
for n, lp in model.named_parameters():
# 1. Access the full states
# 1) gradient lookup
# 1.1) gradient lookup
# For zero1 and zero2, gradient lookup must be called after `backward` and before `step`
# For zero3, gradient lookup must be called after `backward`
hp_grad = safe_get_full_grad(lp)
# 2) fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step`
# 1.2) fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step`
hp = safe_get_full_fp32_param(lp)
exp_avg = safe_get_full_optimizer_state(lp, "exp_avg")
exp_avg_sq = safe_get_full_optimizer_state(lp, "exp_avg_sq")
Expand All @@ -396,34 +396,39 @@ These routines can be used in a training loop as shown in the following snippet.
Modifying Partitioned States
----------------------------

Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following routines for modifying the fp32 master parameters and the fp32 optimizer states.
Sometimes, a user may want to modify parameters, gradients, or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following routines for modifying the fp32 master parameters and the fp32 optimizer states.

.. autofunction:: deepspeed.utils.safe_set_full_fp32_param

.. autofunction:: deepspeed.utils.safe_set_full_optimizer_state

.. autofunction:: deepspeed.utils.safe_set_full_grad

.. autofunction:: deepspeed.utils.safe_set_local_fp32_param

.. autofunction:: deepspeed.utils.safe_set_local_grad

.. autofunction:: deepspeed.utils.safe_set_local_optimizer_state

These routines can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet.
The routines for modifying parameters and optimizer states can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet.

.. code-block:: python
[...]
from deepspeed.runtime.zero.utils import is_zero_param
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_optimizer_state
# Here is an example to zero all the fp32 parameters and optimizer states.
for n, lp in model.named_parameters():
# 1. For zero stage 1 or 2, set the full fp32 and their full optim states
zero_tensor = torch.zeros_like(lp)
# 1. For zero stage 1, 2, or 3 set the full fp32 and their full optim states
zero_tensor = torch.zeros(lp.ds_shape) if is_zero_param(lp) else torch.zeros(lp.shape)
safe_set_full_fp32_param(lp, zero_tensor)
safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg")
safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq")
# 2. For zero stage 3, each process sets its local fp32 parameters and their local optimizer states individually
zero_tensor_local = torch.zeros_like(lp.ds_tensor.shape)
zero_tensor_local = torch.zeros(lp.ds_tensor.shape)
safe_set_local_fp32_param(lp, zero_tensor_local)
safe_set_local_optimizer_state(lp, zero_tensor_local, "exp_avg")
Expand All @@ -432,6 +437,31 @@ These routines can be used at any point after initialization of the DeepSpeed en
[...]
The routines for modifying gradients can be used after ``backward`` but before ``step`` as shown in the following snippet.

.. code-block:: python
backward(loss)
[...]
from deepspeed.runtime.zero.utils import is_zero_param
from deepspeed.utils import safe_set_full_grad, safe_set_local_grad
# Here is an example of how to zero all the gradients.
for n, lp in model.named_parameters():
# 1. For zero stage 1, 2, or 3 set the full gradient.
zero_tensor = torch.zeros(lp.ds_shape) if is_zero_param(lp) else torch.zeros(lp.shape)
safe_set_full_grad(lp, zero_tensor)
# 2. For zero stage 3, each process sets its local gradient partition.
zero_tensor_local = torch.zeros_like(lp.ds_tensor.shape)
safe_set_local_grad(lp, zero_tensor_local)
[...]
optimizer.step()
GPU Memory Management
---------------------

Expand Down
Loading

0 comments on commit 65ab644

Please sign in to comment.