Skip to content

Commit c218531

Browse files
committed
Add tests for tokenization
1 parent 6547bf1 commit c218531

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/test_tokenization.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import unittest
2+
3+
import torch
4+
5+
from code2seq.data.path_context_dataset import PathContextDataset
6+
from code2seq.data.vocabulary import Vocabulary
7+
8+
9+
class TestDatasetTokenization(unittest.TestCase):
10+
vocab = {Vocabulary.PAD: 0, Vocabulary.UNK: 1, Vocabulary.SOS: 2, Vocabulary.EOS: 3, "my": 4, "super": 5}
11+
12+
def test_tokenize_label(self):
13+
raw_label = "my|super|label"
14+
tokenized = PathContextDataset.tokenize_label(raw_label, self.vocab, 5)
15+
# <SOS> my super <UNK> <EOS> <PAD>
16+
correct = torch.tensor([2, 4, 5, 1, 3, 0], dtype=torch.long)
17+
18+
torch.testing.assert_equal(tokenized, correct)
19+
20+
def test_tokenize_class(self):
21+
raw_class = "super"
22+
tokenized = PathContextDataset.tokenize_class(raw_class, self.vocab)
23+
correct = torch.tensor([5], dtype=torch.long)
24+
25+
torch.testing.assert_equal(tokenized, correct)
26+
27+
def test_tokenize_token(self):
28+
raw_token = "my|super|token"
29+
tokenized = PathContextDataset.tokenize_token(raw_token, self.vocab, 5)
30+
correct = torch.tensor([4, 5, 1, 0, 0], dtype=torch.long)
31+
32+
torch.testing.assert_equal(tokenized, correct)
33+
34+
35+
if __name__ == "__main__":
36+
unittest.main()

0 commit comments

Comments
 (0)