Skip to content

Commit

Permalink
[DistDataloader] Update implementation, add nested.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed May 7, 2024
1 parent edc04f3 commit a1378a3
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 190 deletions.
181 changes: 59 additions & 122 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
from paddle.distributed import fleet

from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import (
nested_broadcast_tensor,
nested_copy_place,
nested_empty_tensor,
nested_reduce_tensor,
)

_MAX_DATA_DIM = 64

Expand Down Expand Up @@ -78,10 +83,6 @@ def __init__(
sharding_rank = self._hcg.get_sharding_parallel_rank()
self._need_data = (self.mp_rank == 0) and (self.pp_rank == 0)

# When needed other data types, we can modify dtype_list.
self.dtype_list = [paddle.int64, paddle.float32, paddle.int32]
self._data_keys_list, self._data_keys_size = None, None

if self._need_data:
self._dataloader = paddle.io.DataLoader(
dataset,
Expand Down Expand Up @@ -137,127 +138,63 @@ def _init_dataloader_comm_group(self):
def __iter__(self):
return self

def __next__(self):
data_keys_size = [0 for i in range(len(self.dtype_list))]
if self._need_data:
data = next(self._dataloader_iter)
data_keys = list(data.keys())

for key in data_keys:
if data[key].dtype not in self.dtype_list:
raise ValueError(
f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: {data[key].dtype}"
)

data_list, data_keys_list = [], []
for i, dtype in enumerate(self.dtype_list):
data_list.append([data[key] for key in data_keys if data[key].dtype == dtype])
data_keys_list.append([key for key in data_keys if data[key].dtype == dtype])
data_keys_size = [len(keys) for keys in data_keys_list]

# Broadcast data keys size.
if self._data_keys_size is None:
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(data_keys_size, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
paddle.distributed.broadcast_object_list(
data_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
self._data_keys_size = data_keys_size

if not self._need_data:
data_keys_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]

# Broadcast data keys name.
if self._data_keys_list is None:
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(data_keys_list, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
paddle.distributed.broadcast_object_list(
data_keys_list, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
self._data_keys_list = data_keys_list

# Broadcast data.
if not self._need_data:
data_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]

if self.mp_group.nranks > 1 and self.pp_rank == 0:
for i, dtype in enumerate(self.dtype_list):
if self._data_keys_size[i] > 0:
data_list[i] = broadcast_data_list(
data_list[i], dtype, self.mp_rank, self.mp_group, self.mp_src_rank
def _broadcast_data(self, data):
process_rank = paddle.distributed.get_rank()
if self.mp_group.nranks > 1:
if process_rank == self.mp_src_rank:
fake_data = [nested_reduce_tensor(data)]
else:
if data is not None:
logger.warning(
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
)

fake_data = [None]
if self._pp_data_group is not None:
# Note(daisimng): In last stage of pp, we don't need input_ids.
# It will be removed in future.
for i, dtype in enumerate(self.dtype_list):
if self._data_keys_size[i] > 0:
data_list[i] = broadcast_data_list(
data_list[i],
dtype,
self.pp_rank,
self._pp_data_group,
self._pp_data_group.ranks[0],
if process_rank == self._pp_data_group.ranks[0]:
fake_data = [nested_reduce_tensor(data)]
else:
if data is not None:
logger.warning(
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
)
fake_data = [None]
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(
fake_data,
src=self.mp_src_rank,
group=self.mp_group,
)
if self._pp_data_group is not None:
paddle.disibributed.broadcast_object_list(
fake_data,
src=self._pp_data_group.ranks[0],
group=self._pp_data_group,
)
fake_data = fake_data[0]

out_data = {}
for keys, datas in zip(self._data_keys_list, data_list):
out_data.update([(k, d) for k, d in zip(keys, datas)])

return out_data


def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_rank=0):
"""
Broadcast data from src_rank to all ranks in comm_group.
"""
# Move to GPU and broadcast.
size_cpu = []
if comm_rank == 0:
for data in data_list:
size_cpu.append(len(data.shape))
size_cpu += data.shape
size_cpu = size_cpu + [0] * (_MAX_DATA_DIM - len(size_cpu))
size_cuda = paddle.to_tensor(size_cpu)
paddle.distributed.broadcast(size_cuda, src_rank, group=comm_group).wait()

size_cpu = size_cuda.tolist()
i = 0
numel = 0
sizes = []
while size_cpu[i] > 0:
rank = size_cpu[i]
this_size = size_cpu[i + 1 : i + 1 + rank]
numel += int(np.prod(this_size))
sizes.append(this_size)
i += rank + 1

if comm_rank == 0:
assert data.dtype == datatype, "input has data type {} which " "is different than {}".format(
data.dtype, datatype
)
if paddle.is_compiled_with_cuda():
data_b = paddle.concat([d.cuda().reshape([-1]) for d in data_list], 0)
else:
data_b = paddle.concat([d.reshape([-1]) for d in data_list], 0)

assert numel == sum([d.numel().item() for d in data_list]), (numel, [d.numel().item() for d in data_list])
else:
if paddle.is_compiled_with_cuda():
data_b = paddle.empty([numel], dtype=datatype).cuda()
else:
data_b = paddle.empty([numel], dtype=datatype)
if self.mp_group.nranks > 1:
if process_rank != self.mp_src_rank:
data = nested_empty_tensor(fake_data)
if self._pp_data_group is not None:
if process_rank != self._pp_data_group.ranks[0]:
data = nested_empty_tensor(fake_data)
data = nested_copy_place(data, place=paddle.framework._current_expected_place())
if self.mp_group.nranks > 1 and self.pp_rank == 0:
data = nested_broadcast_tensor(data, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
data = nested_broadcast_tensor(data, src=self._pp_data_group.ranks[0], group=self._pp_data_group)

# Broadcast
paddle.distributed.broadcast(data_b, src_rank, group=comm_group).wait()
if data is None:
raise StopIteration

ret = []
offset = 0
for size in sizes:
numel = int(np.prod(size))
ret.append(data_b[offset : offset + numel].reshape(size))
offset += numel
return data

return ret
def __next__(self):
data = None
if self._need_data:
try:
data = next(self._dataloader_iter)
except:
pass
data = self._broadcast_data(data)
return data
21 changes: 1 addition & 20 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
SAFE_WEIGHTS_NAME,
)
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import nested_copy, nested_copy_place

if is_safetensors_available():
from safetensors import safe_open
Expand Down Expand Up @@ -1876,26 +1877,6 @@ def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys):
return new_actions


def nested_copy(inputs):
if isinstance(inputs, dict):
outputs = {}
for key in list(inputs.keys()):
outputs[key] = nested_copy(inputs[key])
return outputs
return inputs


def nested_copy_place(inputs, place=None, blocking=False):
if isinstance(inputs, dict):
outputs = {}
for key in list(inputs.keys()):
outputs[key] = nested_copy_place(inputs[key], place, blocking)
return outputs
if isinstance(inputs, paddle.Tensor):
inputs = inputs if inputs.place == place else inputs._copy_to(place, blocking)
return inputs


def flatten_list(nested_list):
flattened_list = []
for item in nested_list:
Expand Down
53 changes: 5 additions & 48 deletions paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# This file is modified from
# https://github.com/huggingface/transformers/blob/main/src/transformers

import collections
import copy
import os
from typing import Any, Optional

Expand All @@ -27,6 +25,11 @@
from paddle.distributed import fleet

from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import (
nested_broadcast_tensor,
nested_empty_tensor,
nested_reduce_tensor,
)

__all__ = [
"distributed_concat",
Expand Down Expand Up @@ -180,52 +183,6 @@ def distributed_file(filename):
return filename


TensorHolder = collections.namedtuple("TensorHolder", ["shape", "dtype", "name"])


def nested_reduce_tensor(tensor):
if isinstance(tensor, dict):
# copy tensor since it will be inplace modified dict
tensor = copy.copy(tensor)
for key in list(tensor.keys()):
tensor[key] = nested_reduce_tensor(tensor[key])
if isinstance(tensor, (tuple, list)):
return type(tensor)(nested_reduce_tensor(t) for t in tensor)

if isinstance(tensor, paddle.Tensor):
return TensorHolder(tensor.shape, tensor.dtype, tensor.name)

return tensor


def nested_empty_tensor(tensor):
if isinstance(tensor, dict):
for key in list(tensor.keys()):
tensor[key] = nested_empty_tensor(tensor[key])
if isinstance(tensor, list):
return type(tensor)(nested_empty_tensor(t) for t in tensor)

# TensorHolder is tuple
if isinstance(tensor, TensorHolder):
t = paddle.empty(tensor.shape, dtype=tensor.dtype, name=tensor.name)
t.name = tensor.name
return t

return tensor


def nested_broadcast_tensor(tensor, src=0, group=None):
if isinstance(tensor, dict):
for key in list(tensor.keys()):
tensor[key] = nested_broadcast_tensor(tensor[key], src=src, group=group)
if isinstance(tensor, list):
return type(tensor)(nested_broadcast_tensor(t, src=src, group=group) for t in tensor)

if isinstance(tensor, paddle.Tensor):
paddle.distributed.broadcast(tensor, src=src, group=group, sync_op=True)
return tensor


def broadcast_dp_optimizer(state_dict):
if paddle.distributed.get_world_size() <= 1:
return state_dict
Expand Down
Loading

0 comments on commit a1378a3

Please sign in to comment.