Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DistDataloader] Update implementation, add nested.py #8380

Merged
merged 7 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 70 additions & 125 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
from paddle.distributed import fleet

from paddlenlp.utils.log import logger

_MAX_DATA_DIM = 64
from paddlenlp.utils.nested import (
nested_broadcast_tensor,
nested_copy_place,
nested_empty_tensor,
nested_reduce_tensor,
)


class DummyDataset(paddle.io.Dataset):
Expand Down Expand Up @@ -53,6 +56,7 @@
timeout=0,
worker_init_fn=None,
persistent_workers=False,
eval=False,
):

if dataset is None:
Expand All @@ -62,12 +66,15 @@
super().__init__(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=num_workers)

self._hcg = fleet.get_hybrid_communicate_group()
self.eval = eval

Check warning on line 69 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L69

Added line #L69 was not covered by tests

# Init pp data comm group.
if self._hcg.get_pipe_parallel_world_size() > 1:
self._pp_data_group = self._init_dataloader_comm_group()
self._pp_group = self._hcg.get_pipe_parallel_group()

Check warning on line 74 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L74

Added line #L74 was not covered by tests
else:
self._pp_data_group = None
self._pp_group = None

Check warning on line 77 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L77

Added line #L77 was not covered by tests

self.mp_group = self._hcg.get_model_parallel_group()
self.mp_rank = self._hcg.get_model_parallel_rank()
Expand All @@ -78,10 +85,6 @@
sharding_rank = self._hcg.get_sharding_parallel_rank()
self._need_data = (self.mp_rank == 0) and (self.pp_rank == 0)

# When needed other data types, we can modify dtype_list.
self.dtype_list = [paddle.int64, paddle.float32, paddle.int32]
self._data_keys_list, self._data_keys_size = None, None

if self._need_data:
self._dataloader = paddle.io.DataLoader(
dataset,
Expand Down Expand Up @@ -127,7 +130,6 @@
parallel_groups = topo.get_comm_list("pipe")

for group in parallel_groups:
# only first rank and last rank
ranks = [group[0], group[-1]]
comm_group = paddle.distributed.new_group(ranks=ranks)
if paddle.distributed.get_rank() in ranks:
Expand All @@ -137,127 +139,70 @@
def __iter__(self):
return self

def __next__(self):
data_keys_size = [0 for i in range(len(self.dtype_list))]
if self._need_data:
data = next(self._dataloader_iter)
data_keys = list(data.keys())

for key in data_keys:
if data[key].dtype not in self.dtype_list:
raise ValueError(
f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: {data[key].dtype}"
def _broadcast_data(self, data):
process_rank = paddle.distributed.get_rank()
if self.mp_group.nranks > 1:
if process_rank == self.mp_src_rank:
fake_data = [nested_reduce_tensor(data)]

Check warning on line 146 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L143-L146

Added lines #L143 - L146 were not covered by tests
else:
if data is not None:
logger.warning(

Check warning on line 149 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L148-L149

Added lines #L148 - L149 were not covered by tests
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
)

data_list, data_keys_list = [], []
for i, dtype in enumerate(self.dtype_list):
data_list.append([data[key] for key in data_keys if data[key].dtype == dtype])
data_keys_list.append([key for key in data_keys if data[key].dtype == dtype])
data_keys_size = [len(keys) for keys in data_keys_list]

# Broadcast data keys size.
if self._data_keys_size is None:
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(data_keys_size, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
paddle.distributed.broadcast_object_list(
data_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
self._data_keys_size = data_keys_size

if not self._need_data:
data_keys_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]

# Broadcast data keys name.
if self._data_keys_list is None:
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(data_keys_list, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
paddle.distributed.broadcast_object_list(
data_keys_list, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
self._data_keys_list = data_keys_list

# Broadcast data.
if not self._need_data:
data_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]

if self.mp_group.nranks > 1 and self.pp_rank == 0:
for i, dtype in enumerate(self.dtype_list):
if self._data_keys_size[i] > 0:
data_list[i] = broadcast_data_list(
data_list[i], dtype, self.mp_rank, self.mp_group, self.mp_src_rank
fake_data = [None]
if self._pp_group is not None:
if process_rank == self._pp_group.ranks[0]:
fake_data = [nested_reduce_tensor(data)]

Check warning on line 155 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L152-L155

Added lines #L152 - L155 were not covered by tests
else:
if data is not None:
logger.warning(

Check warning on line 158 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L157-L158

Added lines #L157 - L158 were not covered by tests
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
)

if self._pp_data_group is not None:
# Note(daisimng): In last stage of pp, we don't need input_ids.
# It will be removed in future.
for i, dtype in enumerate(self.dtype_list):
if self._data_keys_size[i] > 0:
data_list[i] = broadcast_data_list(
data_list[i],
dtype,
self.pp_rank,
self._pp_data_group,
self._pp_data_group.ranks[0],
)

out_data = {}
for keys, datas in zip(self._data_keys_list, data_list):
out_data.update([(k, d) for k, d in zip(keys, datas)])

return out_data


def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_rank=0):
"""
Broadcast data from src_rank to all ranks in comm_group.
"""
# Move to GPU and broadcast.
size_cpu = []
if comm_rank == 0:
for data in data_list:
size_cpu.append(len(data.shape))
size_cpu += data.shape
size_cpu = size_cpu + [0] * (_MAX_DATA_DIM - len(size_cpu))
size_cuda = paddle.to_tensor(size_cpu)
paddle.distributed.broadcast(size_cuda, src_rank, group=comm_group).wait()

size_cpu = size_cuda.tolist()
i = 0
numel = 0
sizes = []
while size_cpu[i] > 0:
rank = size_cpu[i]
this_size = size_cpu[i + 1 : i + 1 + rank]
numel += int(np.prod(this_size))
sizes.append(this_size)
i += rank + 1

if comm_rank == 0:
assert data.dtype == datatype, "input has data type {} which " "is different than {}".format(
data.dtype, datatype
)
if paddle.is_compiled_with_cuda():
data_b = paddle.concat([d.cuda().reshape([-1]) for d in data_list], 0)
fake_data = [None]
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(

Check warning on line 163 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L161-L163

Added lines #L161 - L163 were not covered by tests
fake_data,
src=self.mp_src_rank,
group=self.mp_group,
)
if self._pp_group is not None:
paddle.distributed.broadcast_object_list(

Check warning on line 169 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L168-L169

Added lines #L168 - L169 were not covered by tests
fake_data,
src=self._pp_group.ranks[0],
group=self._pp_group,
)
else:
data_b = paddle.concat([d.reshape([-1]) for d in data_list], 0)
fake_data = [None]

Check warning on line 175 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L175

Added line #L175 was not covered by tests

assert numel == sum([d.numel().item() for d in data_list]), (numel, [d.numel().item() for d in data_list])
else:
if paddle.is_compiled_with_cuda():
data_b = paddle.empty([numel], dtype=datatype).cuda()
else:
data_b = paddle.empty([numel], dtype=datatype)
fake_data = fake_data[0]
if fake_data is None:
raise StopIteration

Check warning on line 179 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L177-L179

Added lines #L177 - L179 were not covered by tests

# Broadcast
paddle.distributed.broadcast(data_b, src_rank, group=comm_group).wait()
dst_pp_group = self._pp_group if self.eval else self._pp_data_group
if self.mp_group.nranks > 1:
if process_rank != self.mp_src_rank:
data = nested_empty_tensor(fake_data)
if dst_pp_group is not None:
if process_rank != dst_pp_group.ranks[0]:
data = nested_empty_tensor(fake_data)

Check warning on line 187 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L181-L187

Added lines #L181 - L187 were not covered by tests

ret = []
offset = 0
for size in sizes:
numel = int(np.prod(size))
ret.append(data_b[offset : offset + numel].reshape(size))
offset += numel
if self.mp_group.nranks > 1 and self.pp_rank == 0:
data = nested_broadcast_tensor(data, src=self.mp_src_rank, group=self.mp_group)
if dst_pp_group is not None:
data = nested_broadcast_tensor(data, src=dst_pp_group.ranks[0], group=dst_pp_group)

Check warning on line 192 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L189-L192

Added lines #L189 - L192 were not covered by tests
# for pp1 - pp_{n-1}, Paddle need to recevie empty dict for pipeline parallel.
if data is None:
data = {}

Check warning on line 195 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L194-L195

Added lines #L194 - L195 were not covered by tests

return ret
return data

Check warning on line 197 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L197

Added line #L197 was not covered by tests

def __next__(self):
data = None
if self._need_data:
try:
data = next(self._dataloader_iter)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy可以放到这里

Suggested change
data = next(self._dataloader_iter)
data = next(self._dataloader_iter)
data = nested_copy_place(data, place=paddle.framework._current_expected_place())

data = nested_copy_place(data, place=paddle.framework._current_expected_place())
except:
pass
data = self._broadcast_data(data)
return data

Check warning on line 208 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L200-L208

Added lines #L200 - L208 were not covered by tests
24 changes: 1 addition & 23 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
SAFE_WEIGHTS_NAME,
)
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import nested_copy, nested_copy_place

if is_safetensors_available():
# from safetensors import safe_open
Expand Down Expand Up @@ -1880,29 +1881,6 @@ def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys):
return new_actions


def nested_copy(inputs):
if isinstance(inputs, dict):
outputs = {}
for key in list(inputs.keys()):
outputs[key] = nested_copy(inputs[key])
return outputs
return inputs


def nested_copy_place(inputs, place=None, blocking=False):
if isinstance(inputs, dict):
outputs = {}
for key in list(inputs.keys()):
outputs[key] = nested_copy_place(inputs[key], place, blocking)
return outputs
if isinstance(inputs, paddle.Tensor):
if inputs.place._equals(place):
return inputs
else:
return inputs._copy_to(place, blocking)
return inputs


def flatten_list(nested_list):
flattened_list = []
for item in nested_list:
Expand Down
Loading
Loading