Skip to content

Commit

Permalink
support sft mapdataset (#8840)
Browse files Browse the repository at this point in the history
* support sft mapdataset

* fix __len__ and __getitem__

* fix judge impl method

* fix mmap

* change variable name
  • Loading branch information
greycooker authored Aug 5, 2024
1 parent a6a7870 commit c4d1abf
Showing 1 changed file with 272 additions and 0 deletions.
272 changes: 272 additions & 0 deletions paddlenlp/data/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import shutil
import struct
import time
from dataclasses import fields
from functools import lru_cache
from itertools import accumulate

Expand Down Expand Up @@ -68,6 +69,19 @@ def make_dataset(path, impl, skip_warmup=False):
return None


def make_sft_dataset(path, dataclass, skip_warmup=False, impl="mmap"):
if impl != "mmap":
raise ValueError("SFT Indexed Dataset only support mmap memory-mapped method temporarily")

print_rank_0(" > building dataset index ...")
start_time = time.time()
sft_indexed_dataset = SFTMMapIndexedDataset(path, dataclass, skip_warmup)
print_rank_0(" > finished creating SFT indexed dataset in {:4f} " "seconds".format(time.time() - start_time))
print_rank_0(" number of samples: {}".format(len(sft_indexed_dataset.doc_idx) - 1))

return sft_indexed_dataset


def dataset_exists(path, impl):
if impl == "mmap":
return MMapIndexedDataset.exists(path)
Expand Down Expand Up @@ -120,6 +134,18 @@ def index_file_path(prefix_path):
return prefix_path + ".idx"


def sft_index_file_path(prefix_path):
return os.path.join(prefix_path, "index.idx")


def sft_data_file_path(prefix_path, dataclass):
file_path_list = []
for field in fields(dataclass):
file_path = os.path.join(prefix_path, f"{field.name}.bin")
file_path_list.append(file_path)
return file_path_list


def data_file_path(prefix_path):
return prefix_path + ".bin"

Expand Down Expand Up @@ -548,13 +574,259 @@ def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))


class SFTMMapIndexedDataset(paddle.io.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"

@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, "wb")
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack("<Q", 1))
self._file.write(struct.pack("<B", code(dtype)))

return self

@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers

def write(self, sizes, doc_idx):

pointers = self._get_pointers(sizes)
self._file.write(struct.pack("<Q", len(sizes)))
self._file.write(struct.pack("<Q", len(doc_idx)))

sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes

pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers

doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order="C"))

def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()

return _Writer()

def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version

(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize

self._len = struct.unpack("<Q", stream.read(8))[0]
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()

if not skip_warmup:
print_rank_0(" warming up index mmap file...")
_warmup_mmap_file(path)

self._buffer_mmap = np.memmap(path, mode="r", order="C")
self._buffer = memoryview(self._buffer_mmap)
print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(self._buffer, dtype=np.int32, count=self._len, offset=offset)
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(
self._buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes
)
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(
self._buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)

def __del__(self):
self._buffer_mmap._mmap.close()
del self._buffer_mmap

@property
def dtype(self):
return self._dtype

@property
def sizes(self):
return self._sizes

@property
def doc_idx(self):
return self._doc_idx

@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]

def __len__(self):
return self._doc_count - 1

def __init__(self, path, dataclass, skip_warmup=False):
super().__init__()
self._dataclass = dataclass
self._path = None
self._index = None
self._bin_buffer = None

self._do_init(path, skip_warmup)

def __getstate__(self):
return self._path

def __setstate__(self, state):
self._do_init(state, skip_warmup=True)

def _do_init(self, path, skip_warmup):
self._path = path
if not self.exists(path, self._dataclass):
raise ValueError("Missing file, %s" % (path))

self._index = self.Index(sft_index_file_path(self._path), skip_warmup)
if not skip_warmup:
print_rank_0(" warming up data mmap file...")
for data_file in sft_data_file_path(self._path, self._dataclass):
_warmup_mmap_file(data_file)
print_rank_0(" creating numpy buffer of mmap...")

self._bin_buffer_mmap_dict = {}
self._bin_buffer_dict = {}
for data_file in sft_data_file_path(self._path, self._dataclass):
self._bin_buffer_mmap_dict[data_file] = np.memmap(data_file, mode="r", order="C")
self._bin_buffer_dict[data_file] = memoryview(self._bin_buffer_mmap_dict[data_file])
print_rank_0(" creating memory view of numpy buffer...")

def __del__(self):
for key, value in self._bin_buffer_mmap_dict.items():
value._mmap.close()
for key, value in self._bin_buffer_dict.items():
del value
del self._index

def __len__(self):
return len(self._index)

def __getitem__(self, idx):
def get_index(idx):
doc_idx = self._index.doc_idx
start_sentence, end_sentence = doc_idx[idx], doc_idx[idx + 1]
start_pointers, _ = self._index[start_sentence]
length_list = self._index._sizes[start_sentence:end_sentence]

dataclass_fields = fields(self._dataclass)
dataclass_list = []
sequence_offset = start_pointers
scalar_offset = doc_idx[idx] * np.dtype(self._index.dtype).itemsize

for length in length_list:
field_data = {field.name: [] for field in dataclass_fields}
for field in dataclass_fields:
bin_buffer = self._bin_buffer_dict[os.path.join(self._path, f"{field.name}.bin")]
if field.type != int:
data = np.frombuffer(bin_buffer, dtype=self._index.dtype, count=length, offset=sequence_offset)
field_data[field.name] = data.tolist()
else:
data = np.frombuffer(bin_buffer, dtype=self._index.dtype, count=1, offset=scalar_offset)
field_data[field.name] = int(data[0])

dataclass_list.append(self._dataclass(**field_data))

sequence_offset += length * np.dtype(self._index.dtype).itemsize
scalar_offset += np.dtype(self._index.dtype).itemsize
return dataclass_list

if isinstance(idx, (int, np.integer)):
return get_index(idx)
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
return [get_index(idx) for idx in range(start, stop)]

@property
def sizes(self):
return self._index.sizes

@property
def doc_idx(self):
return self._index.doc_idx

def get_doc_idx(self):
return self._index._doc_idx

def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_

@property
def supports_prefetch(self):
return False

@staticmethod
def exists(path, dataclass):
file_path_list = sft_data_file_path(path, dataclass)
file_path_list.append(sft_index_file_path(path))
for file_path in file_path_list:
if not os.path.exists(file_path):
return False
return True


def make_builder(out_file, impl, save_dtype, loss_mask_file=None):
if impl == "mmap":
return MMapIndexedDatasetBuilder(out_file, dtype=save_dtype, loss_mask_file=loss_mask_file)
else:
return IndexedDatasetBuilder(out_file, dtype=save_dtype)


class SFTMMapIndexedDatasetBuilder(object):
def __init__(self, output_file_dict, dtype):
self._data_file_dict = {}
for key, filename in output_file_dict.items():
self._data_file_dict[key] = open(filename, "wb")
self.output_file_dict = output_file_dict
self._dtype = dtype
self._sizes = []
self._doc_idx = [0]

def add_item(self, sequence):
add_sequence_len = False
for key in self._data_file_dict.keys():
tensor = np.array(getattr(sequence, key), dtype=self._dtype)
if tensor.size > 1 and not add_sequence_len:
self._sizes.append(tensor.size)
add_sequence_len = True
self._data_file_dict[key].write(tensor.tobytes(order="C"))

def end_document(self):
self._doc_idx.append(len(self._sizes))

def finalize(self, index_file):
for key, filename in self._data_file_dict.items():
filename.close()
with SFTMMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes, self._doc_idx)


class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype, loss_mask_file=None):
self._data_file = open(out_file, "wb")
Expand Down

0 comments on commit c4d1abf

Please sign in to comment.