Skip to content

Commit

Permalink
[Megatron dataset] Support loading megatron dataset (#6489)
Browse files Browse the repository at this point in the history
* Adapt to Megatron

* fix_print_dataset

* fix BlendableDataset

* fix BlendableDataset

* fix skip_warmup

* fix

* fix

* fix

* fix

* fix

* fix

* cache fix

* make new dataset

* fix loss mask

* fix model_zoo/gpt

* fix model_zoo/gpt

* fix model_zoo/gpt

* fix gpt test

* fix legacy

* fix legacy

* hf_model

* remove legacy

* merge develop gpt

* fix model_zoo/gpt for megatron

* merge develop

* resolve conflict

* fix check_rank_flag for data_cache_path

* fix check_rank_flag for data_cache_path

* remove hcg

* fix model_zoo/gpt eval
  • Loading branch information
KB-Ding authored Aug 25, 2023
1 parent d012c87 commit 1a69081
Show file tree
Hide file tree
Showing 15 changed files with 2,728 additions and 1,060 deletions.
444 changes: 0 additions & 444 deletions llm/gpt-3/dataset.py

This file was deleted.

105 changes: 44 additions & 61 deletions llm/gpt-3/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
),
}

from dataset import GPTDataset, get_train_valid_test_split_
from paddlenlp.data.causal_dataset import build_train_valid_test_datasets, print_rank_0


def add_start_docstrings(*docstr):
Expand Down Expand Up @@ -86,7 +86,6 @@ class DataArguments:
input_dir: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
cache_prefix: str = field(default=None, metadata={"help": "The prefix of the cached dataset."})
split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."})

max_seq_length: int = field(
Expand All @@ -101,6 +100,13 @@ class DataArguments:
metadata={"help": "Use share folder for data dir and output dir on multi machine."},
)

data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."})
skip_warmup: bool = field(
default=True,
metadata={"help": "Whether to skip the warmup process of mmap files."},
)
data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."})


@dataclass
class ModelArguments:
Expand Down Expand Up @@ -143,7 +149,7 @@ def create_pretrained_dataset(
tokenizer,
):

train_valid_test_num_samples = [
train_val_test_num_samples = [
training_args.per_device_train_batch_size
* training_args.dataset_world_size
* training_args.max_steps
Expand All @@ -155,72 +161,50 @@ def create_pretrained_dataset(
training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters,
]

input_prefix = data_file[0]

for suffix in ["_ids.npy", "_idx.npz"]:
if not os.path.isfile(input_prefix + suffix):
raise ValueError("File Not found, %s" % (input_prefix + suffix))

sample_ids = np.load(input_prefix + "_ids.npy", mmap_mode="r", allow_pickle=True)
# All documment ids, extend as 1-D array.

process_data = np.load(input_prefix + "_idx.npz")
# The len(sample_lens) num of docs
# The sum(sample_lens) should equal len(sample_ids)
sample_lens = process_data["lens"]

splits = get_train_valid_test_split_(data_args.split, len(sample_lens))
assert len(sample_lens) >= splits[-1], "The document nums should larger than max of splits, but %s < %s" % (
len(sample_lens),
splits[-1],
print_rank_0(" > datasets target sizes (minimum size):")
print_rank_0(" train: {}".format(train_val_test_num_samples[0]))
print_rank_0(" validation: {}".format(train_val_test_num_samples[1]))
print_rank_0(" test: {}".format(train_val_test_num_samples[2]))

# Build the datasets.
train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets(
data_prefix=data_file,
data_impl=data_args.data_impl,
splits_string=data_args.split,
train_val_test_num_samples=train_val_test_num_samples,
seq_length=data_args.max_seq_length,
seed=training_args.seed,
skip_warmup=data_args.skip_warmup,
data_cache_path=data_args.data_cache,
)

def print_dataset(data, mode="train"):
logger.info(f"Sample data for {mode} mode")
input_ids, loss_mask, attention_mask, position_ids, labels = data
# input_ids, loss_mask, attention_mask, position_ids, labels = data
input_ids = data["text"]

logger.info(tokenizer._decode(input_ids))
# logger.info(tokenizer._decode(labels))
# logger.info(tokenizer.convert_ids_to_tokens(input_ids))

def build_dataset(index, name):
dataset = GPTDataset(
file_prefix=os.path.join(data_args.cache_prefix, os.path.basename(input_prefix)),
build_data_file=training_args.local_process_index == 0,
micro_batch_size=training_args.per_device_train_batch_size
if name == "train"
else training_args.per_device_eval_batch_size,
name="gpt_" + name,
max_seq_len=data_args.max_seq_length,
num_samples=train_valid_test_num_samples[index],
documents=np.arange(splits[index], splits[index + 1]),
sample_ids=sample_ids,
sample_lens=sample_lens,
eos_id=tokenizer.eos_token_id,
seed=training_args.seed,
)
print_dataset(dataset[0], name)
return dataset

from paddlenlp.data import Stack

def _collate_data(data, stack_fn=Stack()):
num_fields = len(data[0])
out = [None] * num_fields
# 0:input_ids, 1:loss_mask, 2:attention_mask, 3:position_ids, 4:labels
for i in (0, 1, 2, 3, 4):
out[i] = stack_fn([x[i] for x in data])
tokens_ = stack_fn(x["text"] for x in data)

labels = tokens_[:, 1:]
tokens = tokens_[:, :-1]

# Attention mask.
attention_mask = paddle.ones(tokens.shape, dtype=paddle.int64)

return {
"input_ids": out[0],
"attention_mask": out[2],
"labels": out[4],
"input_ids": tokens,
"attention_mask": attention_mask,
"labels": labels,
}

# Note, data should be broardcast to all devices.
# for train, valid, test, the distinct data num is data_world_size
train_dataset = build_dataset(0, "train")
valid_dataset = build_dataset(1, "valid")
test_dataset = build_dataset(2, "test")
print_dataset(train_dataset[0])
print_dataset(valid_dataset[0])
print_dataset(test_dataset[0])

return train_dataset, valid_dataset, test_dataset, _collate_data

Expand All @@ -233,9 +217,10 @@ def get_train_data_file(args):
files = [
os.path.join(args.input_dir, f)
for f in os.listdir(args.input_dir)
if (os.path.isfile(os.path.join(args.input_dir, f)) and "_idx.npz" in str(f))
if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f)))
]
files = [x.replace("_idx.npz", "") for x in files]
files = [x.replace(".idx", "") for x in files]

if len(files) > 1:
ret = []
Expand Down Expand Up @@ -333,10 +318,8 @@ def main():
if model_args.tokenizer_name_or_path is None:
model_args.tokenizer_name_or_path = model_args.model_name_or_path

if data_args.cache_prefix is None:
data_args.cache_prefix = data_args.input_dir
else:
os.makedirs(data_args.cache_prefix, exist_ok=True)
if data_args.data_cache is not None:
os.makedirs(data_args.data_cache, exist_ok=True)

set_seed(training_args)
paddle.set_device(training_args.device)
Expand Down
1 change: 0 additions & 1 deletion llm/llama/dataset.py

This file was deleted.

104 changes: 41 additions & 63 deletions llm/llama/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@
),
}

from dataset import GPTDataset, get_train_valid_test_split_
from fused_layers import mock_layers
from modeling_pp import LlamaForCausalLMPipe

from paddlenlp.data.causal_dataset import build_train_valid_test_datasets, print_rank_0


def add_start_docstrings(*docstr):
def docstring_decorator(fn):
Expand Down Expand Up @@ -95,7 +96,6 @@ class DataArguments:
input_dir: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
cache_prefix: str = field(default=None, metadata={"help": "The prefix of the cached dataset."})
split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."})

max_seq_length: int = field(
Expand All @@ -111,6 +111,13 @@ class DataArguments:
)
train_data_size: int = field(default=-1, metadata={"help": "Number of dataset for training"})

data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."})
skip_warmup: bool = field(
default=True,
metadata={"help": "Whether to skip the warmup process of mmap files."},
)
data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."})


@dataclass
class ModelArguments:
Expand Down Expand Up @@ -200,7 +207,7 @@ def create_pretrained_dataset(
tokenizer,
):

train_valid_test_num_samples = [
train_val_test_num_samples = [
training_args.per_device_train_batch_size
* training_args.dataset_world_size
* training_args.max_steps
Expand All @@ -212,74 +219,46 @@ def create_pretrained_dataset(
training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters,
]

input_prefix = data_file[0]

for suffix in ["_ids.npy", "_idx.npz"]:
if not os.path.isfile(input_prefix + suffix):
raise ValueError("File Not found, %s" % (input_prefix + suffix))

sample_ids = np.load(input_prefix + "_ids.npy", mmap_mode="r", allow_pickle=True)
# All documment ids, extend as 1-D array.

process_data = np.load(input_prefix + "_idx.npz")
# The len(sample_lens) num of docs
# The sum(sample_lens) should equal len(sample_ids)
sample_lens = process_data["lens"]

splits = get_train_valid_test_split_(data_args.split, len(sample_lens))
assert len(sample_lens) >= splits[-1], "The document nums should larger than max of splits, but %s < %s" % (
len(sample_lens),
splits[-1],
print_rank_0(" > datasets target sizes (minimum size):")
print_rank_0(" train: {}".format(train_val_test_num_samples[0]))
print_rank_0(" validation: {}".format(train_val_test_num_samples[1]))
print_rank_0(" test: {}".format(train_val_test_num_samples[2]))

# Build the datasets.
train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets(
data_prefix=data_file,
data_impl=data_args.data_impl,
splits_string=data_args.split,
train_val_test_num_samples=train_val_test_num_samples,
seq_length=data_args.max_seq_length,
seed=training_args.seed,
skip_warmup=data_args.skip_warmup,
data_cache_path=data_args.data_cache,
)

def print_dataset(data, mode="train"):
logger.info(f"Sample data for {mode} mode")
input_ids, loss_mask, attention_mask, position_ids, labels = data
# input_ids, loss_mask, attention_mask, position_ids, labels = data
input_ids = data["text"]

logger.info(tokenizer._decode(input_ids))
# logger.info(tokenizer._decode(labels))
# logger.info(tokenizer.convert_ids_to_tokens(input_ids))

def build_dataset(index, name):
dataset = GPTDataset(
file_prefix=os.path.join(data_args.cache_prefix, os.path.basename(input_prefix)),
build_data_file=training_args.local_process_index == 0,
micro_batch_size=training_args.per_device_train_batch_size
if name == "train"
else training_args.per_device_eval_batch_size,
name="gpt_" + name,
max_seq_len=data_args.max_seq_length,
num_samples=train_valid_test_num_samples[index],
documents=np.arange(splits[index], splits[index + 1]),
sample_ids=sample_ids,
sample_lens=sample_lens,
eos_id=tokenizer.eos_token_id,
seed=training_args.seed,
)
print_dataset(dataset[0], name)
return dataset

from paddlenlp.data import Stack

def _collate_data(data, stack_fn=Stack()):
num_fields = len(data[0])
out = [None] * num_fields
# 0:input_ids, 1:loss_mask, 2:attention_mask, 3:position_ids, 4:labels
for i in (0, 1, 2, 3, 4):
out[i] = stack_fn([x[i] for x in data])
tokens_ = stack_fn(x["text"] for x in data)

labels = tokens_[:, 1:]
tokens = tokens_[:, :-1]

return {
"input_ids": out[0],
# "token_type_ids": out[1],
# "attention_mask": out[2],
# "loss_mask": out[3],
"labels": out[4],
"input_ids": tokens,
"labels": labels,
}

# Note, data should be broardcast to all devices.
# for train, valid, test, the distinct data num is data_world_size
train_dataset = build_dataset(0, "train")
valid_dataset = build_dataset(1, "valid")
test_dataset = build_dataset(2, "test")
print_dataset(train_dataset[0], "train")
print_dataset(valid_dataset[0], "valid")
print_dataset(test_dataset[0], "test")

return train_dataset, valid_dataset, test_dataset, _collate_data

Expand All @@ -292,9 +271,10 @@ def get_train_data_file(args):
files = [
os.path.join(args.input_dir, f)
for f in os.listdir(args.input_dir)
if (os.path.isfile(os.path.join(args.input_dir, f)) and "_idx.npz" in str(f))
if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f)))
]
files = [x.replace("_idx.npz", "") for x in files]
files = [x.replace(".idx", "") for x in files] # add

if len(files) > 1:
ret = []
Expand Down Expand Up @@ -396,10 +376,8 @@ def main():
if model_args.tokenizer_name_or_path is None:
model_args.tokenizer_name_or_path = model_args.model_name_or_path

if data_args.cache_prefix is None:
data_args.cache_prefix = data_args.input_dir
else:
os.makedirs(data_args.cache_prefix, exist_ok=True)
if data_args.data_cache is not None:
os.makedirs(data_args.data_cache, exist_ok=True)

set_seed(training_args)
paddle.set_device(training_args.device)
Expand Down
4 changes: 2 additions & 2 deletions llm/llama/run_trainer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

set -x
unset CUDA_VISIBLE_DEVICES

task_name="llama_hybid"
rm -rf output/$task_name/
rm -rf "output/$task_name""_log"
Expand Down Expand Up @@ -56,4 +55,5 @@ python -u -m paddle.distributed.launch \
--recompute 1 \
--do_train \
--do_eval \
--device "gpu"
--device "gpu" \
--data_impl "mmap"
Loading

0 comments on commit 1a69081

Please sign in to comment.