Skip to content

Commit 8b5cab7

Browse files
authored
Create test_model.py
1 parent 39486c4 commit 8b5cab7

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/test_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
from NLPModel import create_model, ModelConfig
3+
4+
def test_model_initialization():
5+
vocab_size = 1000
6+
model, trainer = create_model(vocab_size)
7+
assert model is not None
8+
assert trainer is not None
9+
10+
def test_forward_pass():
11+
vocab_size = 1000
12+
model, _ = create_model(vocab_size)
13+
batch_size = 2
14+
seq_len = 10
15+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
16+
output = model(input_ids)
17+
assert output.shape == (batch_size, seq_len, vocab_size)

0 commit comments

Comments
 (0)