Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added pajama data generation #1

Merged
merged 7 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed data/refined_web_n=128.pth
Binary file not shown.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def update_outs_parallel(
parser.add_argument(
"dataset",
type=str,
help="Dataset name [c4, pajama, refinedweb] or path to data where to extract calibration data from.",
help="Dataset name [c4, pajama] or path to data where to extract calibration data from.",
)
parser.add_argument(
"--new_eval",
Expand Down
39 changes: 29 additions & 10 deletions src/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from datasets import load_dataset
from packaging import version
from tqdm import trange
from transformers import AutoTokenizer, LlamaTokenizer


Expand All @@ -16,6 +17,26 @@ def set_seed(seed: Optional[int]):
torch.random.manual_seed(seed)


def get_red_pajama(nsamples, seqlen, tokenizer):
print("Loading red_pajama from togethercomputer/RedPajama-Data-1T-Sample")
traindata = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", split="train")
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
trainloader = []
for _ in trange(nsamples, desc="Making red_pajama calibration set", leave=False):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
if trainenc.input_ids.shape[1] > seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
assert inp.shape[1] == seqlen
trainloader.append(inp)
return trainloader


def get_wikitext2(nsamples, seqlen, tokenizer, eval_mode=False):
if not eval_mode:
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
Expand Down Expand Up @@ -165,9 +186,8 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_
Loads and prepares data for a Transformers model.
Args:
name (str): The name of the dataset to load.
This can be one of 'wikitext2', 'c4', 'ptb' for datasets loaded from Huggingface datasets,
'pajama' or 'refinedweb' for pre-tokenized datasets in folder `data` or 'none' for cases
where a dataset is not needed, like RTN. It can also accept data path to custom file.
This can be one of 'wikitext2', 'c4', 'ptb','pajama' for datasets loaded from Huggingface datasets,
or 'none' for cases where a dataset is not needed, like RTN. It can also accept data path to custom file.
nsamples (int, optional): The number of samples to load from the dataset. Defaults to 128.
seed (int, optional): The random seed value for data shuffling and splitting. Defaults to 0.
seqlen (int, optional): The maximum sequence length for input tokenization. Defaults to 2048.
Expand All @@ -186,11 +206,8 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_
set_seed(seed)

# for pre-tokenized datasets
if name.lower() == "pajama":
data = torch.load(f"./data/red_pajama_n=1024_{seqlen}_context_length.pth")[:nsamples]
elif name.lower() == "refinedweb":
data = torch.load("./data/refined_web_n=128.pth")[:nsamples]
elif name.lower() == "none":

if name.lower() == "none":
print("Not loading any dataset. (OK if you use no compression or methods like RTN.)")
return None
elif os.path.isfile(name):
Expand All @@ -199,7 +216,7 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_
except FileNotFoundError:
raise FileNotFoundError(
f"Failed to load custom data from {name}.",
"Check data path or use one of [c4, wikitext2, ptb, pajama, refinedweb, none]",
"Check data path or use one of [c4, wikitext2, ptb, pajama, none]",
)
else:
# for datasets requiring tokenization
Expand All @@ -220,6 +237,8 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_

if name.lower() == "wikitext2":
data = get_wikitext2(nsamples, seqlen, tokenizer, eval_mode=eval_mode)
elif name.lower() == "pajama":
data = get_red_pajama(nsamples, seqlen, tokenizer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please assert eval_mode here

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thank you

elif name.lower() == "ptb":
data = get_ptb(nsamples, seqlen, tokenizer, eval_mode=eval_mode)
elif name.lower() == "ptb_new":
Expand All @@ -231,7 +250,7 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_
else:
raise ValueError(
f"Failed to load data from {name}.",
"Check dataset name or path or use one of [c4, wikitext2, ptb, pajama, refinedweb, none]",
"Check dataset name or path or use one of [c4, wikitext2, ptb, pajama, none]",
)

if hasattr(data, "input_ids"):
Expand Down