Skip to content

Commit

Permalink
Allow modification of zero partitioned parameters (#4192)
Browse files Browse the repository at this point in the history
* Modify zero parameters

* Docs

* py3.6 compatibility

* Update docs

* Update deepspeed/runtime/zero/stage3.py

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>

* Add TODO

* Formatting

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
  • Loading branch information
3 people authored Sep 1, 2023
1 parent f96c1c0 commit a23cda6
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 20 deletions.
49 changes: 36 additions & 13 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,16 +2101,17 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]:

return grad_dict

def _fp32_state_allgather(self, param, fp32_state):
reduce_buffer = torch.zeros(self.partition_count * fp32_state.numel(),
def _fp32_state_allgather(self, param, fp32_state_partition):
reduce_buffer = torch.zeros(self.partition_count * fp32_state_partition.numel(),
dtype=torch.float32,
device=param.device).flatten()
my_rank = dist.get_rank(group=self.dp_process_group)
partitions = [
reduce_buffer.narrow(0,
fp32_state.numel() * i, fp32_state.numel()) for i in range(self.partition_count)
fp32_state_partition.numel() * i, fp32_state_partition.numel())
for i in range(self.partition_count)
]
partitions[my_rank].data.copy_(fp32_state.data, non_blocking=False)
partitions[my_rank].data.copy_(fp32_state_partition.data, non_blocking=False)

dist.all_gather(partitions, partitions[my_rank], group=self.dp_process_group)

Expand All @@ -2125,36 +2126,58 @@ def get_fp32_grad_for_param(self, param) -> Tensor:

if self.offload_optimizer:
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset,
num_elements).to(device=param.device)
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
else:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float()

return self._fp32_state_allgather(param, fp32_grad)

def get_full_hp_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None

def _get_fp32_opt_state_partition(self, param, optim_state_key=None):
if not get_accelerator().is_synchronized_device():
self.reduce_and_partition_stream.synchronize()

group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]

if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_in(group_idx)

fp32_param = self.fp32_partitioned_groups_flat[group_idx]
if optim_state_key is None:
fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements).to(device=param.device)
fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements)
else:
fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow(
0, dest_offset, num_elements).to(device=param.device)
fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow(0, dest_offset, num_elements)

return fp32_opt_state, group_idx

def get_full_hp_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None

fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
hp_param = self._fp32_state_allgather(param, fp32_opt_state)

if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_out(group_idx)

return hp_param

def set_full_hp_param(self, value, param, optim_state_key=None):
if not param.requires_grad:
return

assert value.numel(
) == param.ds_numel, f" Number of elements do not match: {value.numel()} != {param.ds_numel}"

fp32_opt_state_partition, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
my_rank = dist.get_rank(group=self.dp_process_group)
value_partition = value.flatten().narrow(0,
fp32_opt_state_partition.numel() * my_rank,
fp32_opt_state_partition.numel())
fp32_opt_state_partition.data.copy_(value_partition.data)

if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_out(group_idx)

@instrument_w_nvtx
def _partition_all_parameters(self):
self.parameter_offload.partition_all_parameters()
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# TODO: Move tensor fragment and mixed precision to zero utils
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from .tensor_fragment import set_full_hp_param
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
from .mixed_precision_linkage import link_hp_params
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd
2 changes: 2 additions & 0 deletions deepspeed/utils/mixed_precision_linkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import types
from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
from deepspeed.utils import set_full_hp_param


def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
Expand All @@ -27,6 +28,7 @@ def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_gr
lp_param._dp_group = dp_group
lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param)
lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param)

# lp_param overlaps with partition if both are true
# 1) current_offset < partition_end,
Expand Down
56 changes: 51 additions & 5 deletions deepspeed/utils/tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,31 @@ def get_hp_fragment_address(self):
def get_optim_state_keys(self):
return list(self.optim_fragment.keys())

def get_hp_fragment(self, optim_state_key=None):
if optim_state_key is None:
return self.hp_fragment
return self.get_optim_state_fragment(optim_state_key)


def get_full_hp_param(self, optim_state_key=None):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
lp_frag_address = self._hp_mapping.lp_fragment_address
reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
if optim_state_key is None:
hp_fragment = self._hp_mapping.hp_fragment
else:
hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key)

hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
reduce_fragment.data.copy_(hp_fragment.data)
dist.all_reduce(reduce_buffer, group=self._dp_group)
return reduce_buffer.reshape_as(self)


def set_full_hp_param(self, value, optim_state_key=None):
if self._hp_mapping is not None:
lp_frag_address = self._hp_mapping.lp_fragment_address
value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
hp_fragment.data.copy_(value_fragment.data)


def get_full_hp_grad(self):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
Expand Down Expand Up @@ -105,11 +114,28 @@ def safe_get_full_fp32_param(param):
return None


def safe_set_full_fp32_param(param, value):
"""Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_full_hp_param(value, param)

# ZeRO stage 1, 2, and bf16_optimizer params
if hasattr(param, '_hp_mapping'):
param.set_full_hp_param(value)


def safe_get_full_optimizer_state(param, optim_state_key):
"""Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
Expand All @@ -121,6 +147,23 @@ def safe_get_full_optimizer_state(param, optim_state_key):
return None


def safe_set_full_optimizer_state(param, value, optim_state_key):
"""Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_full_hp_param(value, param, optim_state_key)

# ZeRO stage 1, 2, and bf16_optimizer params
if hasattr(param, '_hp_mapping'):
param.set_full_hp_param(value, optim_state_key)


# TODO: Figure out the correct return dtype
def safe_get_full_grad(param):
"""Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
Expand All @@ -142,6 +185,9 @@ def safe_get_full_grad(param):
return None


# TODO: Implement API for setting ZeRO partitioned gradients


def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
param_group_index, partition_start, partition_size, optimizer_state_dict):
lp_end = lp_param.numel() + lp_start
Expand Down
29 changes: 29 additions & 0 deletions docs/code-docs/source/zero3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,35 @@ These routines can be used in a training loop as shown in the following snippet.
optimizer.step()
Modifying Partitioned States
----------------------------

Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following two routines for modifying the fp32 master parameters and the fp32 optimizer states.

.. autofunction:: deepspeed.utils.safe_set_full_fp32_param

.. autofunction:: deepspeed.utils.safe_set_full_optimizer_state


These routines can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet.

.. code-block:: python
[...]
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
# Here is an example to zero all the fp32 parameters and optimizer states.
for n, lp in model.named_parameters():
# Assume zero stage 1 or 2, since stage 3 requires a gather to assemble lp
zero_tensor = torch.zeros_like(lp)
hp = safe_set_full_fp32_param(lp, zero_tensor)
exp_avg = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg")
exp_avg_sq = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq")
[...]
GPU Memory Management
---------------------

Expand Down
110 changes: 108 additions & 2 deletions tests/unit/runtime/zero/test_zero_tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@
import torch

from unit.common import DistributedTest
from unit.simple_model import random_dataloader
from unit.simple_model import random_dataloader, SimpleModel
from unit.util import bf16_required_version_check

import deepspeed
from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.ops.aio import AsyncIOBuilder

WEIGHT_KEY = 'weight'
FIRST_ORDER_KEY = 'exp_avg'
SECOND_ORDER_KEY = 'exp_avg_sq'


def validate_full_tensors(model):
for _, lp in model.named_parameters():
Expand Down Expand Up @@ -73,7 +78,7 @@ def run_fragmented_model(model, config_dict, hidden_dim, dtype):


@pytest.mark.parametrize('frozen_weights', [True, False])
class TestTensorFragment(DistributedTest):
class TestTensorFragmentGet(DistributedTest):
# Need multiple gpus to test possible hanging
world_size = 2
reuse_dist_env = True
Expand Down Expand Up @@ -150,3 +155,104 @@ def test_bf16_fragments(self, frozen_weights):
hidden_dim = 128
model = MyModel(hidden_dim, frozen_weights)
run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16)


def create_random_values(model, key_list, group):
param_values = {}
for n, lp in model.named_parameters():
param_shape = lp.ds_shape if hasattr(lp, 'ds_id') else lp.shape
param_values[n] = {}
for key in key_list:
rand_value = torch.rand(param_shape, dtype=torch.float32, device=model.device)
dist.broadcast(rand_value, src=0, group=group)
param_values[n][key] = rand_value
return param_values


def set_param_values_with_dict(model, value_dict):
for n, lp in model.named_parameters():
for key, value_tensor in value_dict[n].items():
if key == WEIGHT_KEY:
safe_set_full_fp32_param(lp, value_tensor)
else:
safe_set_full_optimizer_state(lp, value_tensor, key)


def validate_param_values_with_dict(model, value_dict):
for n, lp in model.named_parameters():
for key, expected_tensor in value_dict[n].items():
if key == WEIGHT_KEY:
actual_tensor = safe_get_full_fp32_param(lp)
else:
actual_tensor = safe_get_full_optimizer_state(lp, key)
assert torch.equal(expected_tensor, actual_tensor)


@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
class TestTensorFragmentUpdate(DistributedTest):
# Need multiple gpus to test possible hanging
world_size = 2
reuse_dist_env = True

@pytest.mark.parametrize('zero_stage', [1, 2, 3])
@pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme])
def test_zero_fragments(self, tmpdir, zero_stage, offload_device, dtype):

if dtype == torch.bfloat16 and not bf16_required_version_check(accelerator_check=False):
pytest.skip(
" DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
)

if offload_device == OffloadDeviceEnum.nvme:
if zero_stage != 3:
pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}")
if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]:
pytest.skip('Skip tests since async-io is not compatible')

config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"zero_optimization": {
"stage": zero_stage,
}
}

if offload_device == OffloadDeviceEnum.cpu:
config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device}
elif offload_device == OffloadDeviceEnum.nvme:
config_dict["zero_optimization"]["offload_optimizer"] = {
"device": offload_device,
"nvme_path": str(tmpdir)
}

if dtype == torch.float16:
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
elif dtype == torch.bfloat16:
config_dict["bf16"] = {"enabled": True}

hidden_dim = 128
if zero_stage == 3:
config_dict["zero_optimization"]["param_persistence_threshold"] = hidden_dim
with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = SimpleModel(hidden_dim, nlayers=4)
else:
model = SimpleModel(hidden_dim, nlayers=4)

model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
world = dist.get_world_size()
group = dist.new_group(ranks=list(range(world)))

dist.barrier()
optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY]
optim_state_values = create_random_values(model, optim_keys, group)
set_param_values_with_dict(model, optim_state_values)
validate_param_values_with_dict(model, optim_state_values)

# Needed in ZeRO 3. Not doing so can leak memory.
model.destroy()

0 comments on commit a23cda6

Please sign in to comment.