Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
628 changes: 183 additions & 445 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/pretrain_gpt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ DATA_PATH=<Specify path and file prefix>_text_document
CHECKPOINT_PATH=<Specify path>


python pretrain_gpt.py \
deepspeed --num_gpus 1 pretrain_gpt.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
Expand Down
62 changes: 62 additions & 0 deletions examples/pretrain_gpt_single_node.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/bin/bash

# Adapted to use deepspeed on a single node
#
# Multi-node will require either a `hostfile` or switching to `torch.distributed.launch`

# adjust to the number of GPUs to use
N_GPUS=1

CHECKPOINT_PATH=checkpoints/gpt2
VOCAB_FILE=data/gpt2-vocab.json
MERGE_FILE=data/gpt2-merges.txt
DATA_PATH=data/meg-gpt2_text_document

GPT_ARGS=" \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 1024 \
--max-position-embeddings 1024 \
--micro-batch-size 4 \
--global-batch-size 8 \
--lr-decay-iters 320000 \
--lr-decay-style cosine \
--lr 0.00015 \
--min-lr 1.0e-5 \
--lr-decay-style cosine \
--train-iters 5000 \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--data-impl mmap \
--split 949,50,1 \
--distributed-backend nccl \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction .01 \
--fp16 \
"

OUTPUT_ARGS=" \
--log-interval 10 \
--save-interval 500 \
--eval-interval 100 \
--eval-iters 10 \
--checkpoint-activations \
"

DATA_ARGS=" \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
"

ALL_ARGS="$GPT_ARGS $OUTPUT_ARGS $DATA_ARGS"

LAUNCHER="deepspeed --num_gpus $N_GPUS"

CMD="$LAUNCHER pretrain_gpt.py $ALL_ARGS"

echo $CMD

$CMD
45 changes: 45 additions & 0 deletions examples/pretrain_gpt_tiny.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#! /bin/bash

# Runs the "345M" parameter model

RANK=0
WORLD_SIZE=1

DATA_PATH=GPT2/c4_en_partial_gpt2_text_document
CHECKPOINT_PATH=GPT2


deepspeed --num_gpus 1 pretrain_gpt.py \
--num-layers 2 \
--hidden-size 128 \
--num-attention-heads 4 \
--micro-batch-size 4 \
--global-batch-size 8 \
--seq-length 256 \
--max-position-embeddings 256 \
--train-iters 10000 \
--lr-decay-iters 5000 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path t5-small \
--data-impl mmap \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.00015 \
--min-lr 1.0e-5 \
--lr-decay-style cosine \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction .01 \
--checkpoint-activations \
--log-interval 100 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--fp16 \
--tensorboard-dir GPT2

# --vocab-file GPT2/gpt2-vocab.json \
# --merge-file GPT2/gpt2-merges.txt \
5 changes: 4 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,11 @@ def _add_data_args(parser):
default=None,
choices=['BertWordPieceLowerCase',
'BertWordPieceCase',
'GPT2BPETokenizer'],
'GPT2BPETokenizer',
'PretrainedFromHF'],
help='What type of tokenizer to use.')
group.add_argument("--tokenizer-name-or-path", type=str, default=None,
help="Name or path of the huggingface tokenizer.")
group.add_argument('--data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'],
help='Implementation of indexed datasets.')
Expand Down
2 changes: 1 addition & 1 deletion megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args)
if args.vocab_file:
if args.vocab_file or args.tokenizer_name_or_path:
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
Expand Down
44 changes: 42 additions & 2 deletions megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer

from transformers import AutoTokenizer

def build_tokenizer(args):
"""Initialize tokenizer."""
Expand All @@ -29,7 +29,7 @@ def build_tokenizer(args):
flush=True)

# Select and instantiate the tokenizer.
assert args.vocab_file is not None
assert args.vocab_file is not None or args.tokenizer_type == "PretrainedFromHF"
if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True,
Expand All @@ -41,6 +41,13 @@ def build_tokenizer(args):
elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == "PretrainedFromHF":
assert args.tokenizer_name_or_path is not None
print(
" vocab file is un-used. loading tokenizer from pre-trained model",
flush=True,
)
tokenizer = _AutoTokenizer(args.tokenizer_name_or_path)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type))
Expand Down Expand Up @@ -289,3 +296,36 @@ def detokenize(self, token_ids):
@property
def eod(self):
return self.eod_id


class _AutoTokenizer(AbstractTokenizer):
"""AutoTokenizer for Hf Pretrained model loading."""

def __init__(self, tokenizer_name_or_path):
name = tokenizer_name_or_path
super().__init__(name)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
self.encoder = self.tokenizer.get_vocab()
self.decoder = {v: k for k, v in self.encoder.items()}

@property
def vocab_size(self):
return self.tokenizer.vocab_size

@property
def vocab(self):
return self.tokenizer.encoder

@property
def inv_vocab(self):
return self.tokenizer.decoder

def tokenize(self, text):
return self.tokenizer.encode(text)

def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)

@property
def eod(self):
return self.tokenizer.eos_token_id
5 changes: 4 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,10 @@ def build_train_valid_test_data_iterators(
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
# it's possible that train was run, but not eval and it's valid if
# args.consumed_valid_samples == 0
# TODO: eval_interval could have changed between runs, so this might still be wrong
if args.iteration // args.eval_interval > 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
Expand Down
22 changes: 22 additions & 0 deletions tools/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## A few notes on how we created the datasets:

### Creating the Json Lines text file

First you need to create a jsonl file containing your dataset. For this we exported from the HF-datasets format. For example for C4:

```
from datasets import load_dataset
c4 = load_dataset("c4", "en")
c4["train"].to_json("c4_en_train.jsonl")
c4["validation"].to_json("c4_en_valid.jsonl")
```

This creates quite a large file compared to the size of the HF dataset on disk (810GB vs 305 for C4 for example)

### Megatron pre-processing

Then you need to pass that text file to the `preprocess_data.py` script for tokenization and memory-mapping, creating two files, one to store the tokens indices and one to store the document start and ends. The result will be slightly bigger than the text dataset. (360GB vs 305GB for C4 for example). You can choose one of the default Megatron tokenizers (but then you have to pass merges and vocab files) or one from HF-tokenizers. For example, in our GPT-like models reusing a T5 sentencepiece-bpe tokenizer:

`python tools/preprocess_data.py --input ~/c4_en_train.jsonl --output-prefix c4_en_train --dataset-impl mmap --tokenizer-type PretrainedFromHF --tokenizer-name-or-path t5-small --workers 30 --append-eod`

Do note that adding too many workers can be counterproductive for very large dataset: as the bottleneck becomes disk writing, the intermediary process results pool up and can flood the RAM. In our experiments on GCP machines, running with 60 workers on C4 inevitably led the program to fail.
5 changes: 3 additions & 2 deletions tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,16 @@ def get_args():
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer'],
'GPT2BPETokenizer', 'PretrainedFromHF'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')

group.add_argument("--tokenizer-name-or-path", type=str, default=None,
help="Name or path of the huggingface tokenizer.")

group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
Expand Down