Skip to content

[Dygraph]opt sharding stage3 #39334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from .sharding_utils import Type, ShardingClipGrad, device_guard
from ..pp_utils.utils import _all_gather
from ...utils.internal_storage import GradStorage

# CUDA alignment 256 bytes
alignment = {"gpu": 256, }
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(self,
group=None,
sync_buffers=False,
device="gpu",
segment_size=2**15,
pertrain_sync_models=True,
accumulate_grads=False,
offload=False,
Expand All @@ -83,6 +85,8 @@ def __init__(self,
self._accumulate_grads = accumulate_grads
self._offload = offload
self._sync_comm = sync_comm
# segmentation size
self._segment_size = segment_size if not offload else 0

global DEV
DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device(
Expand All @@ -107,7 +111,10 @@ def __init__(self,
self._param2buffer_size = dict() # {param.name: size}
self._param2buffer = dict(
) # {param.name: [(start0, end0),(start1, end1), ...]}
self._trainable_params = dict() # {layer.name: [trainable_params]}
self._trainable_params = dict() # {id(layer): [trainable_params]}
self._unslice_params = set() # param's numel <= segment_size
self._unslice_params2align = dict() # {param.name: param's align}
self._grad_storages = dict() # {param.dtype: GradStorage}

assert not isinstance(
optimizer, list), "Multiple optimizers are not supported now."
Expand All @@ -131,10 +138,13 @@ def __init__(self,

self._segment_rank_params(self._layer)

# Add unslice params to master_weight in fp16
self._handle_unslice_params()

# In the first step, record the execution order of the layer
self._order_tracer = OrderedDict()
self._order_tracer["order"] = 0
self._order_tracer["layer"] = []
self._order_tracer["layer"] = list()

# Register task flow
self._task_flow = TaskFlow()
Expand Down Expand Up @@ -168,8 +178,10 @@ def _sync_params_and_buffers(self):
def _clear_gradients(self):
assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True)
# 1.Handle param's slice
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))
for param in trainable_params:
assert hasattr(
param, "fw_storage"
Expand All @@ -178,27 +190,35 @@ def _clear_gradients(self):
param.fw_storage.clear_gradient(False)
param.fw_storage._gradient_set_empty(False)
param.bw_storage._clear()
# 2.Handle unslice param
for grad_storage in self._grad_storages.values():
grad_storage.buffer.zero_()

# Update param memery slice
def _update_params_slice(self):
update_list = self._update_params()

if not isinstance(self._optim._param_groups[0], dict):
slice_params = [param.fw_storage for param in update_list]
self._optim._parameter_list = slice_params
self._optim._param_groups = slice_params
self._optim._parameter_list = slice_params + list(
self._unslice_params)
self._optim._param_groups = slice_params + list(
self._unslice_params)
else:
params_name_list = list(map(lambda p: p.name, update_list))
fw_storage_name_list = list(
map(lambda p: p.fw_storage.name, update_list))
for param_group in self._optim._param_groups:
slice_p = []
p_group = []
for p in param_group['params']:
if p.name in params_name_list:
assert hasattr(
p, "fw_storage"
), "Find {} don't have fw_storage attribute.".format(
p.name)
slice_p.append(p.fw_storage)
param_group['params'] = slice_p
p_group.append(p.fw_storage)
elif p.name in fw_storage_name_list:
p_group.append(update_list[fw_storage_name_list.index(
p.name)].fw_storage)
elif p in self._unslice_params:
p_group.append(p)
param_group['params'] = p_group

def forward(self, *inputs, **kwargs):
"""
Expand All @@ -213,6 +233,32 @@ def forward(self, *inputs, **kwargs):

return fw

def _handle_unslice_params(self):
buffer_size = dict()
buffer_size[Type.fp32.value] = 0
buffer_size[Type.fp16.value] = 0
for param in self._unslice_params:
# Updata optimizer master weights
if param.dtype == Type.fp16.value and not self._offload:
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)
param2dtype[param.name] = param.dtype
p_align = self._param2align(param)
self._unslice_params2align[param.name] = p_align
buffer_size[param.dtype] += param._numel() + p_align

# Create unslice_params'grad
for param in sorted(list(self._unslice_params), key=lambda p: p.name):
if param.dtype not in self._grad_storages.keys():
self._grad_storages[param.dtype] = GradStorage(
buffer_size[param.dtype],
dtype=param.dtype,
device=self._default_device,
destination=self._rank,
parm2align=self._unslice_params2align)
self._grad_storages[param.dtype].add_grad(
param, self._unslice_params2align[param.name])

def _segment_rank_params(self, layer, name="last_layer"):
"""
Flatten parameters according to layer.
Expand All @@ -233,24 +279,22 @@ def _flatten_layer_params(self, layer, current_layer_params):
def _add_manage_info(trainable_param):
return _PartitionParam(trainable_param)

trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
current_params = list()
for p in current_layer_params:
if p.trainable and p._numel() > self._segment_size:
current_params.append(_add_manage_info(p))
elif p.trainable:
self._unslice_params.add(_UnsliceParam(p))

assert id(layer) not in self._trainable_params.keys()
self._trainable_params[id(layer)] = list(
map(_add_manage_info, trainable_params))
self._trainable_params[id(layer)] = current_params

for param in self._trainable_params[id(layer)]:
if param.name in self._param2buffer.keys():
continue
self._param2buffer[param.name] = []
# 1.Params alignment
offset = 0
# CUDA alignment 256 bytes
size = param._numel() * align[param.dtype]
remaining = size % alignment[self._default_device]
ali = 0 if remaining == 0 else alignment[
self._default_device] - remaining
align_ = ali // align[param.dtype]
align_ = self._param2align(param)

offset = align_ + param._numel()
buffer_size = offset if offset % self._group.nranks == 0 else offset + self._group.nranks - (
Expand Down Expand Up @@ -379,7 +423,9 @@ def _update_params(self):
assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))
# 1.Handle param's slice
for param in trainable_params:
assert hasattr(
param,
Expand All @@ -396,6 +442,19 @@ def _update_params(self):
assert param.fw_storage.grad is None
param.fw_storage._copy_gradient_from(param.bw_storage)
update_list.append(param)

# 2.Handle unslice param
for grad_storage in self._grad_storages.values():
grad_storage.buffer.scale_(scale=self._world_size_scaling)
dist.all_reduce(
tensor=grad_storage.buffer,
group=self._group,
use_calc_stream=True)
dist.wait(
tensor=grad_storage.buffer,
group=self._group,
use_calc_stream=True)

return update_list

def get_all_parameters(self, convert2cpu=False):
Expand All @@ -405,7 +464,8 @@ def get_all_parameters(self, convert2cpu=False):
assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))
t_flow = _allgather_buffer(
trainable_params,
self._group,
Expand All @@ -415,7 +475,7 @@ def get_all_parameters(self, convert2cpu=False):
offload=self._offload,
convert2cpu=convert2cpu)
if convert2cpu:
for param in current_layer_params:
for param in trainable_params:
t_flow.full_param[param.name]._share_buffer_to(param)

self._optim._parameter_list = self._ori_parameter_list
Expand All @@ -424,7 +484,8 @@ def get_all_parameters(self, convert2cpu=False):
def _register_backward_hooks(self):
current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))

for param in trainable_params:
allreduce_function = self._get_allreduce_fn(param)
Expand All @@ -435,42 +496,36 @@ def _get_allreduce_fn(self, param):
def reduce(*_):
if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name]
with paddle.amp.auto_cast(enable=False):
if not self._accumulate_grads:
full_grad.scale_(scale=self._world_size_scaling)
# Only support sync allreduce current rank's layer now
dist.all_reduce(
tensor=full_grad,
group=self._group,
use_calc_stream=True)
dist.wait(
tensor=full_grad,
group=self._group,
use_calc_stream=True)
if not self._accumulate_grads:
full_grad.scale_(scale=self._world_size_scaling)
# Only support sync allreduce current rank's layer now
dist.all_reduce(
tensor=full_grad, group=self._group, use_calc_stream=True)
dist.wait(
tensor=full_grad, group=self._group, use_calc_stream=True)

start, end = self._param2buffer[param.name][self._rank]
if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value(
).get_tensor()._is_initialized():
param.bw_storage = core.VarBase(
full_grad._slice(start, end)).detach().clone()
if self._offload:
param.bw_storage = _device2cpu(param.bw_storage,
True)
start, end = self._param2buffer[param.name][self._rank]
if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value(
).get_tensor()._is_initialized():
param.bw_storage = core.VarBase(
full_grad._slice(start, end)).detach().clone()
if self._offload:
param.bw_storage = _device2cpu(param.bw_storage, True)
else:
if self._offload:
cpu_grad = _device2cpu(
core.VarBase(full_grad._slice(start, end))
.detach().clone(), True)
param.bw_storage = paddle.add(param.bw_storage,
cpu_grad)
else:
if self._offload:
cpu_grad = _device2cpu(
core.VarBase(full_grad._slice(start, end))
.detach().clone(), True)
param.bw_storage = paddle.add(param.bw_storage,
cpu_grad)
else:
# param.bw_storage.add_(
# core.VarBase(full_grad._slice(start, end))
# .detach().clone())
param.bw_storage = paddle.add(
param.bw_storage,
core.VarBase(full_grad._slice(
start, end)).detach().clone())
# param.bw_storage.add_(
# core.VarBase(full_grad._slice(start, end))
# .detach().clone())
param.bw_storage = paddle.add(
param.bw_storage,
core.VarBase(full_grad._slice(start, end)).detach(
).clone())
param.clear_gradient(False)
param._gradient_set_empty(False)
tmp_var = self._task_flow.full_grad.pop(param.name)
Expand All @@ -493,6 +548,15 @@ def reduce(*_):

return reduce

def _param2align(self, param):
# CUDA alignment 256 bytes
size = param._numel() * align[param.dtype]
remaining = size % alignment[self._default_device]
ali = 0 if remaining == 0 else alignment[
self._default_device] - remaining
align_ = ali // align[param.dtype]
return align_

def _redefine_opt_step(self):
params_slice_func = self._update_params_slice
opt_step = self._optim.step
Expand Down Expand Up @@ -679,14 +743,13 @@ def _wait_layer(trainable_params,
group,
use_calc_stream,
offload=False):
paddle.device.cuda.synchronize()
for param in trainable_params:
if param.status == "all":
param.use_count += 1
continue
if param.name in task_flow.full_param.keys():
full_param = task_flow.full_param[param.name]
with paddle.amp.auto_cast(enable=False):
paddle.device.cuda.synchronize()
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param)
param.fw_storage._clear()
Expand Down Expand Up @@ -725,7 +788,7 @@ def _allgather_buffer(trainable_params,
full_param = _all_gather(
param.fw_storage, group, use_calc_stream=use_calc_stream)

# Allgather current layer in the 1st step
# Allgather current layer in the 1st step synchronously
if sync_wait:
with paddle.amp.auto_cast(enable=False):
dist.wait(
Expand Down Expand Up @@ -774,6 +837,12 @@ def _PartitionParam(param):
return param


def _UnsliceParam(param):
if not hasattr(param, "unslice"):
setattr(param, "unslice", True)
return param


def _VarBaseWrapper(param):
varbase = param.fw_storage
tmp_param = ParamBase(
Expand Down
Loading