From a3cce4a7ff3cdaab71d3a15f70b5df786f53d841 Mon Sep 17 00:00:00 2001 From: vahe1994 Date: Fri, 12 Jan 2024 20:03:58 +0400 Subject: [PATCH] Added assert on eval_mode Co-authored-by: blackadder --- src/datautils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/datautils.py b/src/datautils.py index 57fc058e..081b2b8f 100644 --- a/src/datautils.py +++ b/src/datautils.py @@ -17,8 +17,9 @@ def set_seed(seed: Optional[int]): torch.random.manual_seed(seed) -def get_red_pajama(nsamples, seqlen, tokenizer): +def get_red_pajama(nsamples, seqlen, tokenizer, eval_mode=False): print("Loading red_pajama from togethercomputer/RedPajama-Data-1T-Sample") + assert not eval_mode, "Only train set is supported in RedPajama" traindata = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", split="train") tokenizer.bos_token_id = 1 tokenizer.eos_token_id = 2 @@ -238,7 +239,7 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_ if name.lower() == "wikitext2": data = get_wikitext2(nsamples, seqlen, tokenizer, eval_mode=eval_mode) elif name.lower() == "pajama": - data = get_red_pajama(nsamples, seqlen, tokenizer) + data = get_red_pajama(nsamples, seqlen, tokenizer, eval_mode=eval_mode) elif name.lower() == "ptb": data = get_ptb(nsamples, seqlen, tokenizer, eval_mode=eval_mode) elif name.lower() == "ptb_new":