|
| 1 | +import unittest |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from code2seq.data.path_context_dataset import PathContextDataset |
| 6 | +from code2seq.data.vocabulary import Vocabulary |
| 7 | + |
| 8 | + |
| 9 | +class TestDatasetTokenization(unittest.TestCase): |
| 10 | + vocab = {Vocabulary.PAD: 0, Vocabulary.UNK: 1, Vocabulary.SOS: 2, Vocabulary.EOS: 3, "my": 4, "super": 5} |
| 11 | + |
| 12 | + def test_tokenize_label(self): |
| 13 | + raw_label = "my|super|label" |
| 14 | + tokenized = PathContextDataset.tokenize_label(raw_label, self.vocab, 5) |
| 15 | + # <SOS> my super <UNK> <EOS> <PAD> |
| 16 | + correct = torch.tensor([2, 4, 5, 1, 3, 0], dtype=torch.long) |
| 17 | + |
| 18 | + torch.testing.assert_equal(tokenized, correct) |
| 19 | + |
| 20 | + def test_tokenize_class(self): |
| 21 | + raw_class = "super" |
| 22 | + tokenized = PathContextDataset.tokenize_class(raw_class, self.vocab) |
| 23 | + correct = torch.tensor([5], dtype=torch.long) |
| 24 | + |
| 25 | + torch.testing.assert_equal(tokenized, correct) |
| 26 | + |
| 27 | + def test_tokenize_token(self): |
| 28 | + raw_token = "my|super|token" |
| 29 | + tokenized = PathContextDataset.tokenize_token(raw_token, self.vocab, 5) |
| 30 | + correct = torch.tensor([4, 5, 1, 0, 0], dtype=torch.long) |
| 31 | + |
| 32 | + torch.testing.assert_equal(tokenized, correct) |
| 33 | + |
| 34 | + |
| 35 | +if __name__ == "__main__": |
| 36 | + unittest.main() |
0 commit comments