Skip to content

Commit

Permalink
Merge pull request #1 from Vahe1994/pajama_data_generation
Browse files Browse the repository at this point in the history
Added pajama data generation
  • Loading branch information
Vahe1994 authored Jan 12, 2024
2 parents a4283b0 + a3cce4a commit 0671bc7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
Binary file removed data/refined_web_n=128.pth
Binary file not shown.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def update_outs_parallel(
parser.add_argument(
"dataset",
type=str,
help="Dataset name [c4, pajama, refinedweb] or path to data where to extract calibration data from.",
help="Dataset name [c4, pajama] or path to data where to extract calibration data from.",
)
parser.add_argument(
"--new_eval",
Expand Down
40 changes: 30 additions & 10 deletions src/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from datasets import load_dataset
from packaging import version
from tqdm import trange
from transformers import AutoTokenizer, LlamaTokenizer


Expand All @@ -16,6 +17,27 @@ def set_seed(seed: Optional[int]):
torch.random.manual_seed(seed)


def get_red_pajama(nsamples, seqlen, tokenizer, eval_mode=False):
print("Loading red_pajama from togethercomputer/RedPajama-Data-1T-Sample")
assert not eval_mode, "Only train set is supported in RedPajama"
traindata = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", split="train")
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
trainloader = []
for _ in trange(nsamples, desc="Making red_pajama calibration set", leave=False):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
if trainenc.input_ids.shape[1] > seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
assert inp.shape[1] == seqlen
trainloader.append(inp)
return trainloader


def get_wikitext2(nsamples, seqlen, tokenizer, eval_mode=False):
if not eval_mode:
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
Expand Down Expand Up @@ -165,9 +187,8 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_
Loads and prepares data for a Transformers model.
Args:
name (str): The name of the dataset to load.
This can be one of 'wikitext2', 'c4', 'ptb' for datasets loaded from Huggingface datasets,
'pajama' or 'refinedweb' for pre-tokenized datasets in folder `data` or 'none' for cases
where a dataset is not needed, like RTN. It can also accept data path to custom file.
This can be one of 'wikitext2', 'c4', 'ptb','pajama' for datasets loaded from Huggingface datasets,
or 'none' for cases where a dataset is not needed, like RTN. It can also accept data path to custom file.
nsamples (int, optional): The number of samples to load from the dataset. Defaults to 128.
seed (int, optional): The random seed value for data shuffling and splitting. Defaults to 0.
seqlen (int, optional): The maximum sequence length for input tokenization. Defaults to 2048.
Expand All @@ -186,11 +207,8 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_
set_seed(seed)

# for pre-tokenized datasets
if name.lower() == "pajama":
data = torch.load(f"./data/red_pajama_n=1024_{seqlen}_context_length.pth")[:nsamples]
elif name.lower() == "refinedweb":
data = torch.load("./data/refined_web_n=128.pth")[:nsamples]
elif name.lower() == "none":

if name.lower() == "none":
print("Not loading any dataset. (OK if you use no compression or methods like RTN.)")
return None
elif os.path.isfile(name):
Expand All @@ -199,7 +217,7 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_
except FileNotFoundError:
raise FileNotFoundError(
f"Failed to load custom data from {name}.",
"Check data path or use one of [c4, wikitext2, ptb, pajama, refinedweb, none]",
"Check data path or use one of [c4, wikitext2, ptb, pajama, none]",
)
else:
# for datasets requiring tokenization
Expand All @@ -220,6 +238,8 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_

if name.lower() == "wikitext2":
data = get_wikitext2(nsamples, seqlen, tokenizer, eval_mode=eval_mode)
elif name.lower() == "pajama":
data = get_red_pajama(nsamples, seqlen, tokenizer, eval_mode=eval_mode)
elif name.lower() == "ptb":
data = get_ptb(nsamples, seqlen, tokenizer, eval_mode=eval_mode)
elif name.lower() == "ptb_new":
Expand All @@ -231,7 +251,7 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_
else:
raise ValueError(
f"Failed to load data from {name}.",
"Check dataset name or path or use one of [c4, wikitext2, ptb, pajama, refinedweb, none]",
"Check dataset name or path or use one of [c4, wikitext2, ptb, pajama, none]",
)

if hasattr(data, "input_ids"):
Expand Down

0 comments on commit 0671bc7

Please sign in to comment.