diff --git a/src/datautils.py b/src/datautils.py index 57fc058e..081b2b8f 100644 --- a/src/datautils.py +++ b/src/datautils.py @@ -17,8 +17,9 @@ def set_seed(seed: Optional[int]): torch.random.manual_seed(seed) -def get_red_pajama(nsamples, seqlen, tokenizer): +def get_red_pajama(nsamples, seqlen, tokenizer, eval_mode=False): print("Loading red_pajama from togethercomputer/RedPajama-Data-1T-Sample") + assert not eval_mode, "Only train set is supported in RedPajama" traindata = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", split="train") tokenizer.bos_token_id = 1 tokenizer.eos_token_id = 2 @@ -238,7 +239,7 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_ if name.lower() == "wikitext2": data = get_wikitext2(nsamples, seqlen, tokenizer, eval_mode=eval_mode) elif name.lower() == "pajama": - data = get_red_pajama(nsamples, seqlen, tokenizer) + data = get_red_pajama(nsamples, seqlen, tokenizer, eval_mode=eval_mode) elif name.lower() == "ptb": data = get_ptb(nsamples, seqlen, tokenizer, eval_mode=eval_mode) elif name.lower() == "ptb_new":