Skip to content

Commit

Permalink
Added assert on eval_mode
Browse files Browse the repository at this point in the history
Co-authored-by: blackadder <scope.denis@mail.ru>
  • Loading branch information
Vahe1994 and Godofnothing committed Jan 12, 2024
1 parent b60ba23 commit a3cce4a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def set_seed(seed: Optional[int]):
torch.random.manual_seed(seed)


def get_red_pajama(nsamples, seqlen, tokenizer):
def get_red_pajama(nsamples, seqlen, tokenizer, eval_mode=False):
print("Loading red_pajama from togethercomputer/RedPajama-Data-1T-Sample")
assert not eval_mode, "Only train set is supported in RedPajama"
traindata = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", split="train")
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
Expand Down Expand Up @@ -238,7 +239,7 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, eval_mode=False, model_
if name.lower() == "wikitext2":
data = get_wikitext2(nsamples, seqlen, tokenizer, eval_mode=eval_mode)
elif name.lower() == "pajama":
data = get_red_pajama(nsamples, seqlen, tokenizer)
data = get_red_pajama(nsamples, seqlen, tokenizer, eval_mode=eval_mode)
elif name.lower() == "ptb":
data = get_ptb(nsamples, seqlen, tokenizer, eval_mode=eval_mode)
elif name.lower() == "ptb_new":
Expand Down

0 comments on commit a3cce4a

Please sign in to comment.