Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API for updating ZeRO gradients #6590

Merged
merged 8 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2298,6 +2298,24 @@ def get_fp32_grad_for_param(self, param) -> Tensor:

return self._fp32_state_allgather(param, fp32_grad)

def set_fp32_grad_for_param(self, value, param):
if not param.requires_grad:
return

if not get_accelerator().resolves_data_dependency():
self.reduce_and_partition_stream.synchronize()

if self.offload_optimizer:
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
else:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id]
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

my_rank = dist.get_rank(group=self.dp_process_group)
value_partition = value.flatten().narrow(0, fp32_grad.numel() * my_rank, fp32_grad.numel())

fp32_grad.data.copy_(value_partition.data)

def _get_fp32_opt_state_partition(self, param, optim_state_key=None):
if not get_accelerator().resolves_data_dependency():
self.reduce_and_partition_stream.synchronize()
Expand Down Expand Up @@ -2346,12 +2364,6 @@ def set_full_hp_param(self, value, param, optim_state_key=None):

### Local API START ###

def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None
fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
return fp32_opt_state

def get_local_fp32_grad_for_param(self, param) -> Tensor:
if not param.requires_grad:
return None
Expand All @@ -2366,6 +2378,30 @@ def get_local_fp32_grad_for_param(self, param) -> Tensor:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float()
return fp32_grad

def set_local_grad_for_param(self, value, param):
if not param.requires_grad:
return

assert value.numel() == param.ds_tensor.numel(
), f" Number of elements do not match: {value.numel()} != {param.ds_tensor.ds_numel}"

if not get_accelerator().resolves_data_dependency():
self.reduce_and_partition_stream.synchronize()

if self.offload_optimizer:
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
else:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id]

fp32_grad.data.copy_(value.flatten().data)

def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None
fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
return fp32_opt_state

def set_local_hp_param(self, value, param, optim_state_key=None):
if not param.requires_grad:
return
Expand All @@ -2380,7 +2416,7 @@ def set_local_hp_param(self, value, param, optim_state_key=None):

if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_out(group_idx)
logger.info(f"[set_local_hp_param][update the params' value successfully]")
# logger.info(f"[set_local_hp_param][update the params' value successfully]")

### Local API END ###

Expand Down
6 changes: 3 additions & 3 deletions deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# TODO: Move tensor fragment and mixed precision to zero utils
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from .tensor_fragment import set_full_hp_param
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
from .tensor_fragment import set_full_hp_param, set_full_hp_grad
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state, safe_set_full_grad
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.runtime.dataloader import RepeatingLoader
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/utils/mixed_precision_linkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import types
from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
from deepspeed.utils import set_full_hp_param
from deepspeed.utils import set_full_hp_param, set_full_hp_grad


def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
Expand Down Expand Up @@ -35,6 +35,7 @@ def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_gr
lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param)
lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param)
lp_param.set_full_hp_grad = types.MethodType(set_full_hp_grad, lp_param)

# lp_param overlaps with partition if both are true
# 1) current_offset < partition_end,
Expand Down
67 changes: 56 additions & 11 deletions deepspeed/utils/tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ def get_hp_fragment(self, optim_state_key=None):
return self.hp_fragment
return self.get_optim_state_fragment(optim_state_key)

def get_lp_grad_fragment(self, index_in_param_group):
if self.use_offload:
gradient_dict = self.offload_gradient_dict
else:
gradient_dict = self.gradient_dict

if self.param_group_index not in gradient_dict or gradient_dict[self.param_group_index] is None:
raise ValueError("Gradients are only available immediately after backward and before engine step")

return gradient_dict[self.param_group_index][index_in_param_group]


def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys):
for key in opt_keys:
Expand Down Expand Up @@ -95,17 +106,7 @@ def set_full_hp_param(self, value, optim_state_key=None):
def get_full_hp_grad(self):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
hp_mapping = self._hp_mapping

if hp_mapping.use_offload:
gradient_dict = hp_mapping.offload_gradient_dict
else:
gradient_dict = hp_mapping.gradient_dict

if hp_mapping.param_group_index not in gradient_dict or gradient_dict[hp_mapping.param_group_index] is None:
raise ValueError("Gradients are only available immediately after backward and before engine step")

lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][self._index_in_param_group]
lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group)
hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()

lp_frag_address = self._hp_mapping.lp_fragment_address
Expand All @@ -120,6 +121,14 @@ def get_full_hp_grad(self):
return reduce_buffer.reshape_as(self)


def set_full_hp_grad(self, value):
if self._hp_mapping is not None:
lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group)
lp_frag_address = self._hp_mapping.lp_fragment_address
value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
lp_grad_fragment.data.copy_(value_fragment.data.reshape_as(lp_grad_fragment.data))


def safe_get_full_fp32_param(param):
"""Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.

Expand Down Expand Up @@ -207,6 +216,26 @@ def safe_get_full_grad(param):
return None


def safe_set_full_grad(param, value):
"""Update the partitioned gradient of a low-precision (e.g., fp16) parameter.

Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
"""
if param.grad is not None:
param.grad.copy_(value)
return

# ZeRO stage 3 param
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_fp32_grad_for_param(value, param)

# ZeRO stage 1, 2, and bf16_optimizer params
if hasattr(param, '_hp_mapping'):
param.set_full_hp_grad(value)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved


### Local API START ###
def safe_get_local_grad(param):
"""Get the fp32 gradient of a partitioned parameter.
Expand All @@ -223,6 +252,22 @@ def safe_get_local_grad(param):
return None


def safe_set_local_grad(param, value):
"""Update the gradient of a partitioned parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
"""
if param.grad is not None:
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
return param.grad.copy_(value)

# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
return param._z3_optimizer.set_local_grad_for_param(value, param)

return None


def safe_get_local_fp32_param(param):
"""Get the fp32 partitioned parameter.
Args:
Expand Down
44 changes: 37 additions & 7 deletions docs/code-docs/source/zero3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,13 @@ These routines can be used in a training loop as shown in the following snippet.
from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
for n, lp in model.named_parameters():
# 1. Access the full states
# 1) gradient lookup
# 1.1) gradient lookup
# For zero1 and zero2, gradient lookup must be called after `backward` and before `step`
# For zero3, gradient lookup must be called after `backward`
hp_grad = safe_get_full_grad(lp)


# 2) fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step`
# 1.2) fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step`
hp = safe_get_full_fp32_param(lp)
exp_avg = safe_get_full_optimizer_state(lp, "exp_avg")
exp_avg_sq = safe_get_full_optimizer_state(lp, "exp_avg_sq")
Expand All @@ -396,34 +396,39 @@ These routines can be used in a training loop as shown in the following snippet.
Modifying Partitioned States
----------------------------

Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following routines for modifying the fp32 master parameters and the fp32 optimizer states.
Sometimes, a user may want to modify parameters, gradients, or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following routines for modifying the fp32 master parameters and the fp32 optimizer states.

.. autofunction:: deepspeed.utils.safe_set_full_fp32_param

.. autofunction:: deepspeed.utils.safe_set_full_optimizer_state

.. autofunction:: deepspeed.utils.safe_set_full_grad

.. autofunction:: deepspeed.utils.safe_set_local_fp32_param

.. autofunction:: deepspeed.utils.safe_set_local_grad

.. autofunction:: deepspeed.utils.safe_set_local_optimizer_state

These routines can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet.
The routines for modifying parameters and optimizer states can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet.

.. code-block:: python

[...]
from deepspeed.runtime.zero.utils import is_zero_param
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_optimizer_state
# Here is an example to zero all the fp32 parameters and optimizer states.
for n, lp in model.named_parameters():
# 1. For zero stage 1 or 2, set the full fp32 and their full optim states
zero_tensor = torch.zeros_like(lp)
# 1. For zero stage 1, 2, or 3 set the full fp32 and their full optim states
zero_tensor = torch.zeros(lp.ds_shape) if is_zero_param(lp) else torch.zeros(lp.shape)

safe_set_full_fp32_param(lp, zero_tensor)
safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg")
safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq")

# 2. For zero stage 3, each process sets its local fp32 parameters and their local optimizer states individually
zero_tensor_local = torch.zeros_like(lp.ds_tensor.shape)
zero_tensor_local = torch.zeros(lp.ds_tensor.shape)

safe_set_local_fp32_param(lp, zero_tensor_local)
safe_set_local_optimizer_state(lp, zero_tensor_local, "exp_avg")
Expand All @@ -432,6 +437,31 @@ These routines can be used at any point after initialization of the DeepSpeed en
[...]


The routines for modifying gradients can be used after ``backward`` but before ``step`` as shown in the following snippet.

.. code-block:: python

backward(loss)
[...]
from deepspeed.runtime.zero.utils import is_zero_param
from deepspeed.utils import safe_set_full_grad, safe_set_local_grad
# Here is an example of how to zero all the gradients.
for n, lp in model.named_parameters():
# 1. For zero stage 1, 2, or 3 set the full gradient.
zero_tensor = torch.zeros(lp.ds_shape) if is_zero_param(lp) else torch.zeros(lp.shape)

safe_set_full_grad(lp, zero_tensor)

# 2. For zero stage 3, each process sets its local gradient partition.
zero_tensor_local = torch.zeros_like(lp.ds_tensor.shape)

safe_set_local_grad(lp, zero_tensor_local)

[...]
optimizer.step()



GPU Memory Management
---------------------

Expand Down
Loading
Loading