Skip to content

Commit 84a8d01

Browse files
committed
use lists for dataset get item
1 parent e678f8b commit 84a8d01

File tree

4 files changed

+43
-35
lines changed

4 files changed

+43
-35
lines changed

code2seq/data/path_context.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,47 @@
11
from dataclasses import dataclass
2-
from typing import Iterable, Tuple, Optional, Sequence
2+
from typing import Iterable, Tuple, Optional, Sequence, List, cast
33

44
import torch
55

66

77
@dataclass
88
class Path:
9-
from_token: torch.Tensor # [max token parts]
10-
path_node: torch.Tensor # [path length]
11-
to_token: torch.Tensor # [max token parts]
9+
from_token: List[int] # [max token parts]
10+
path_node: List[int] # [path length]
11+
to_token: List[int] # [max token parts]
1212

1313

1414
@dataclass
1515
class LabeledPathContext:
16-
label: torch.Tensor # [max label parts]
16+
label: List[int] # [max label parts]
1717
path_contexts: Sequence[Path]
1818

1919

20+
def transpose(list_of_lists: List[List[int]]) -> List[List[int]]:
21+
return [cast(List[int], it) for it in zip(*list_of_lists)]
22+
23+
2024
class BatchedLabeledPathContext:
2125
def __init__(self, all_samples: Sequence[Optional[LabeledPathContext]]):
2226
samples = [s for s in all_samples if s is not None]
2327

2428
# [max label parts; batch size]
25-
self.labels = torch.cat([s.label.unsqueeze(1) for s in samples], dim=1)
29+
self.labels = torch.tensor(transpose([s.label for s in samples]), dtype=torch.long)
2630
# [batch size]
2731
self.contexts_per_label = torch.tensor([len(s.path_contexts) for s in samples])
2832

2933
# [max token parts; n contexts]
30-
self.from_token = torch.cat([path.from_token.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1)
34+
self.from_token = torch.tensor(
35+
transpose([path.from_token for s in samples for path in s.path_contexts]), dtype=torch.long
36+
)
3137
# [path length; n contexts]
32-
self.path_nodes = torch.cat([path.path_node.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1)
38+
self.path_nodes = torch.tensor(
39+
transpose([path.path_node for s in samples for path in s.path_contexts]), dtype=torch.long
40+
)
3341
# [max token parts; n contexts]
34-
self.to_token = torch.cat([path.to_token.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1)
42+
self.to_token = torch.tensor(
43+
transpose([path.to_token for s in samples for path in s.path_contexts]), dtype=torch.long
44+
)
3545

3646
def __len__(self) -> int:
3747
return len(self.contexts_per_label)
@@ -53,8 +63,8 @@ def move_to_device(self, device: torch.device):
5363

5464
@dataclass
5565
class TypedPath(Path):
56-
from_type: torch.Tensor # [max type parts]
57-
to_type: torch.Tensor # [max type parts]
66+
from_type: List[int] # [max type parts]
67+
to_type: List[int] # [max type parts]
5868

5969

6070
@dataclass
@@ -67,6 +77,10 @@ def __init__(self, all_samples: Sequence[Optional[LabeledTypedPathContext]]):
6777
super().__init__(all_samples)
6878
samples = [s for s in all_samples if s is not None]
6979
# [max type parts; n contexts]
70-
self.from_type = torch.cat([path.from_type.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1)
80+
self.from_type = torch.tensor(
81+
transpose([path.from_type for s in samples for path in s.path_contexts]), dtype=torch.long
82+
)
7183
# [max type parts; n contexts]
72-
self.to_type = torch.cat([path.to_type.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1)
84+
self.to_type = torch.tensor(
85+
transpose([path.to_type for s in samples for path in s.path_contexts]), dtype=torch.long
86+
)

code2seq/data/path_context_dataset.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from random import shuffle
33
from typing import Dict, List, Optional
44

5-
import torch
65
from commode_utils.filesystem import get_lines_offsets, get_line_by_offset
76
from omegaconf import DictConfig
87
from torch.utils.data import Dataset
@@ -63,34 +62,29 @@ def __getitem__(self, index) -> Optional[LabeledPathContext]:
6362
return LabeledPathContext(label, paths)
6463

6564
@staticmethod
66-
def tokenize_class(raw_class: str, vocab: Dict[str, int]) -> torch.Tensor:
67-
return torch.tensor([vocab[raw_class]], dtype=torch.long)
65+
def tokenize_class(raw_class: str, vocab: Dict[str, int]) -> List[int]:
66+
return [vocab[raw_class]]
6867

6968
@staticmethod
70-
def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> torch.Tensor:
69+
def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]:
7170
sublabels = raw_label.split(PathContextDataset._separator)
7271
max_parts = max_parts or len(sublabels)
7372
label_unk = vocab[Vocabulary.UNK]
7473

75-
label = torch.full((max_parts + 1,), vocab[Vocabulary.PAD], dtype=torch.long)
76-
label[0] = vocab[Vocabulary.SOS]
77-
sub_tokens_ids = [vocab.get(st, label_unk) for st in sublabels[:max_parts]]
78-
label[1 : len(sub_tokens_ids) + 1] = torch.tensor(sub_tokens_ids)
79-
74+
label = [vocab[Vocabulary.SOS]] + [vocab.get(st, label_unk) for st in sublabels[:max_parts]]
8075
if len(sublabels) < max_parts:
81-
label[len(sublabels) + 1] = vocab[Vocabulary.EOS]
82-
76+
label.append(vocab[Vocabulary.EOS])
77+
label += [vocab[Vocabulary.PAD]] * (max_parts + 1 - len(label))
8378
return label
8479

8580
@staticmethod
86-
def tokenize_token(token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> torch.Tensor:
81+
def tokenize_token(token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]:
8782
sub_tokens = token.split(PathContextDataset._separator)
8883
max_parts = max_parts or len(sub_tokens)
8984
token_unk = vocab[Vocabulary.UNK]
9085

91-
result = torch.full((max_parts,), vocab[Vocabulary.PAD], dtype=torch.long)
92-
sub_tokens_ids = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]]
93-
result[: len(sub_tokens_ids)] = torch.tensor(sub_tokens_ids)
86+
result = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]]
87+
result += [vocab[Vocabulary.PAD]] * (max_parts - len(result))
9488
return result
9589

9690
def _get_path(self, raw_path: List[str]) -> Path:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from setuptools import setup, find_packages
22

3-
VERSION = "1.0.0"
3+
VERSION = "1.0.1"
44

55
with open("README.md") as readme_file:
66
readme = readme_file.read()

tests/test_tokenization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,23 @@ def test_tokenize_label(self):
1313
raw_label = "my|super|label"
1414
tokenized = PathContextDataset.tokenize_label(raw_label, self.vocab, 5)
1515
# <SOS> my super <UNK> <EOS> <PAD>
16-
correct = torch.tensor([2, 4, 5, 1, 3, 0], dtype=torch.long)
16+
correct = [2, 4, 5, 1, 3, 0]
1717

18-
torch.testing.assert_equal(tokenized, correct)
18+
self.assertListEqual(tokenized, correct)
1919

2020
def test_tokenize_class(self):
2121
raw_class = "super"
2222
tokenized = PathContextDataset.tokenize_class(raw_class, self.vocab)
23-
correct = torch.tensor([5], dtype=torch.long)
23+
correct = [5]
2424

25-
torch.testing.assert_equal(tokenized, correct)
25+
self.assertListEqual(tokenized, correct)
2626

2727
def test_tokenize_token(self):
2828
raw_token = "my|super|token"
2929
tokenized = PathContextDataset.tokenize_token(raw_token, self.vocab, 5)
30-
correct = torch.tensor([4, 5, 1, 0, 0], dtype=torch.long)
30+
correct = [4, 5, 1, 0, 0]
3131

32-
torch.testing.assert_equal(tokenized, correct)
32+
self.assertListEqual(tokenized, correct)
3333

3434

3535
if __name__ == "__main__":

0 commit comments

Comments
 (0)