Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow modification of zero partitioned parameters #4192

Merged
merged 10 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Modify zero parameters
  • Loading branch information
tjruwase committed Aug 22, 2023
commit 3aa58e3dcbdc6540a3fe7f58c4978379ffdc7025
48 changes: 35 additions & 13 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,16 +2053,17 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]:

return grad_dict

def _fp32_state_allgather(self, param, fp32_state):
reduce_buffer = torch.zeros(self.partition_count * fp32_state.numel(),
def _fp32_state_allgather(self, param, fp32_state_partition):
reduce_buffer = torch.zeros(self.partition_count * fp32_state_partition.numel(),
dtype=torch.float32,
device=param.device).flatten()
my_rank = dist.get_rank(group=self.dp_process_group)
partitions = [
reduce_buffer.narrow(0,
fp32_state.numel() * i, fp32_state.numel()) for i in range(self.partition_count)
fp32_state_partition.numel() * i, fp32_state_partition.numel())
for i in range(self.partition_count)
]
partitions[my_rank].data.copy_(fp32_state.data, non_blocking=False)
partitions[my_rank].data.copy_(fp32_state_partition.data, non_blocking=False)

dist.all_gather(partitions, partitions[my_rank], group=self.dp_process_group)

Expand All @@ -2077,36 +2078,57 @@ def get_fp32_grad_for_param(self, param) -> Tensor:

if self.offload_optimizer:
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset,
num_elements).to(device=param.device)
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
else:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float()

return self._fp32_state_allgather(param, fp32_grad)

def get_full_hp_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None

def _get_fp32_opt_state_partition(self, param, optim_state_key=None):
if not get_accelerator().is_synchronized_device():
self.reduce_and_partition_stream.synchronize()

group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]

if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_in(group_idx)

fp32_param = self.fp32_partitioned_groups_flat[group_idx]
if optim_state_key is None:
fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements).to(device=param.device)
fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements)
else:
fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow(
0, dest_offset, num_elements).to(device=param.device)
fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow(0, dest_offset, num_elements)

return fp32_opt_state, group_idx

def get_full_hp_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None

fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
hp_param = self._fp32_state_allgather(param, fp32_opt_state)

if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_out(group_idx)

return hp_param

def set_full_hp_param(self, value, param, optim_state_key=None):
if not param.requires_grad:
return

assert value.numel() == param.ds_numel, f'{value.numel()=} != {param.ds_numel=}'

fp32_opt_state_partition, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
my_rank = dist.get_rank(group=self.dp_process_group)
value_partition = value.flatten().narrow(0,
fp32_opt_state_partition.numel() * my_rank,
fp32_opt_state_partition.numel())
fp32_opt_state_partition.data.copy_(value_partition.data)

if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_out(group_idx)

@instrument_w_nvtx
def _partition_all_parameters(self):
self.parameter_offload.partition_all_parameters()
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# TODO: Move tensor fragment and mixed precision to zero utils
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from .tensor_fragment import set_full_hp_param
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
from .mixed_precision_linkage import link_hp_params
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd
2 changes: 2 additions & 0 deletions deepspeed/utils/mixed_precision_linkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import types
from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
from deepspeed.utils import set_full_hp_param


def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
Expand All @@ -27,6 +28,7 @@ def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_gr
lp_param._dp_group = dp_group
lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param)
lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param)

# lp_param overlaps with partition if both are true
# 1) current_offset < partition_end,
Expand Down
49 changes: 44 additions & 5 deletions deepspeed/utils/tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,31 @@ def get_hp_fragment_address(self):
def get_optim_state_keys(self):
return list(self.optim_fragment.keys())

def get_hp_fragment(self, optim_state_key=None):
if optim_state_key is None:
return self.hp_fragment
return self.get_optim_state_fragment(optim_state_key)


def get_full_hp_param(self, optim_state_key=None):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
lp_frag_address = self._hp_mapping.lp_fragment_address
reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
if optim_state_key is None:
hp_fragment = self._hp_mapping.hp_fragment
else:
hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key)

hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
reduce_fragment.data.copy_(hp_fragment.data)
dist.all_reduce(reduce_buffer, group=self._dp_group)
return reduce_buffer.reshape_as(self)


def set_full_hp_param(self, value, optim_state_key=None):
if self._hp_mapping is not None:
lp_frag_address = self._hp_mapping.lp_fragment_address
value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
hp_fragment.data.copy_(value_fragment.data)


def get_full_hp_grad(self):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
Expand Down Expand Up @@ -105,6 +114,21 @@ def safe_get_full_fp32_param(param):
return None


def safe_set_full_fp32_param(param, value):
""" Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter.

Args:
param (``torch.nn.Parameter``): A model parameter
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_full_hp_param(value, param)

# ZeRO stage 1, 2, and bf16_optimizer params
if hasattr(param, '_hp_mapping'):
param.set_full_hp_param(value)


def safe_get_full_optimizer_state(param, optim_state_key):
"""Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.

Expand All @@ -121,6 +145,21 @@ def safe_get_full_optimizer_state(param, optim_state_key):
return None


def safe_set_full_optimizer_state(param, value, optim_state_key):
""" Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter.

Args:
param (``torch.nn.Parameter``): A model parameter
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_full_hp_param(value, param, optim_state_key)

# ZeRO stage 1, 2, and bf16_optimizer params
if hasattr(param, '_hp_mapping'):
param.set_full_hp_param(value, optim_state_key)


# TODO: Figure out the correct return dtype
def safe_get_full_grad(param):
"""Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
Expand Down
110 changes: 108 additions & 2 deletions tests/unit/runtime/zero/test_zero_tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@
import torch

from unit.common import DistributedTest
from unit.simple_model import random_dataloader
from unit.simple_model import random_dataloader, SimpleModel
from unit.util import bf16_required_version_check

import deepspeed
from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.ops.aio import AsyncIOBuilder

WEIGHT_KEY = 'weight'
FIRST_ORDER_KEY = 'exp_avg'
SECOND_ORDER_KEY = 'exp_avg_sq'


def validate_full_tensors(model):
for _, lp in model.named_parameters():
Expand Down Expand Up @@ -73,7 +78,7 @@ def run_fragmented_model(model, config_dict, hidden_dim, dtype):


@pytest.mark.parametrize('frozen_weights', [True, False])
class TestTensorFragment(DistributedTest):
class TestTensorFragmentGet(DistributedTest):
# Need multiple gpus to test possible hanging
world_size = 2
reuse_dist_env = True
Expand Down Expand Up @@ -150,3 +155,104 @@ def test_bf16_fragments(self, frozen_weights):
hidden_dim = 128
model = MyModel(hidden_dim, frozen_weights)
run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16)


def create_random_values(model, key_list, group):
param_values = {}
for n, lp in model.named_parameters():
param_shape = lp.ds_shape if hasattr(lp, 'ds_id') else lp.shape
param_values[n] = {}
for key in key_list:
rand_value = torch.rand(param_shape, dtype=torch.float32, device=model.device)
dist.broadcast(rand_value, src=0, group=group)
param_values[n][key] = rand_value
return param_values


def set_param_values_with_dict(model, value_dict):
for n, lp in model.named_parameters():
for key, value_tensor in value_dict[n].items():
if key == WEIGHT_KEY:
safe_set_full_fp32_param(lp, value_tensor)
else:
safe_set_full_optimizer_state(lp, value_tensor, key)


def validate_param_values_with_dict(model, value_dict):
for n, lp in model.named_parameters():
for key, expected_tensor in value_dict[n].items():
if key == WEIGHT_KEY:
actual_tensor = safe_get_full_fp32_param(lp)
else:
actual_tensor = safe_get_full_optimizer_state(lp, key)
assert torch.equal(expected_tensor, actual_tensor)


@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
class TestTensorFragmentUpdate(DistributedTest):
# Need multiple gpus to test possible hanging
world_size = 2
reuse_dist_env = True

@pytest.mark.parametrize('zero_stage', [1, 2, 3])
@pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme])
def test_zero_fragments(self, tmpdir, zero_stage, offload_device, dtype):

if dtype == torch.bfloat16 and not bf16_required_version_check(accelerator_check=False):
pytest.skip(
" DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
)

if offload_device == OffloadDeviceEnum.nvme:
if zero_stage != 3:
pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}")
if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]:
pytest.skip('Skip tests since async-io is not compatible')

config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"zero_optimization": {
"stage": zero_stage,
}
}

if offload_device == OffloadDeviceEnum.cpu:
config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device}
elif offload_device == OffloadDeviceEnum.nvme:
config_dict["zero_optimization"]["offload_optimizer"] = {
"device": offload_device,
"nvme_path": str(tmpdir)
}

if dtype == torch.float16:
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
elif dtype == torch.bfloat16:
config_dict["bf16"] = {"enabled": True}

hidden_dim = 128
if zero_stage == 3:
config_dict["zero_optimization"]["param_persistence_threshold"] = hidden_dim
with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = SimpleModel(hidden_dim, nlayers=4)
else:
model = SimpleModel(hidden_dim, nlayers=4)

model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
world = dist.get_world_size()
group = dist.new_group(ranks=list(range(world)))

dist.barrier()
optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY]
optim_state_values = create_random_values(model, optim_keys, group)
set_param_values_with_dict(model, optim_state_values)
validate_param_values_with_dict(model, optim_state_values)

# Needed in ZeRO 3. Not doing so can leak memory.
model.destroy()