Skip to content

Commit a8697c6

Browse files
committed
Refactor code2class model
1 parent 850ff0b commit a8697c6

File tree

9 files changed

+108
-94
lines changed

9 files changed

+108
-94
lines changed

code2seq/code2class_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def train_code2class(config: DictConfig):
2626
print_config(config, fields=["model", "data", "train", "optimizer"])
2727

2828
# Load data module
29-
data_module = PathContextDataModule(config.data_folder, config.data)
29+
data_module = PathContextDataModule(config.data_folder, config.data, is_class=True)
3030
data_module.prepare_data()
3131
data_module.setup()
3232

code2seq/data/path_context_data_module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ class PathContextDataModule(LightningDataModule):
2020

2121
_vocabulary: Optional[Vocabulary] = None
2222

23-
def __init__(self, data_dir: str, config: DictConfig):
23+
def __init__(self, data_dir: str, config: DictConfig, is_class: bool = False):
2424
super().__init__()
2525
self._config = config
2626
self._data_dir = data_dir
2727
self._name = basename(data_dir)
28+
self._is_class = is_class
2829

2930
@property
3031
def vocabulary(self) -> Vocabulary:
@@ -45,7 +46,7 @@ def setup(self, stage: Optional[str] = None):
4546
print("Can't find vocabulary, collect it from train holdout")
4647
build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), Vocabulary)
4748
vocabulary_path = join(self._data_dir, Vocabulary.vocab_filename)
48-
self._vocabulary = Vocabulary(vocabulary_path, self._config.max_labels, self._config.max_tokens)
49+
self._vocabulary = Vocabulary(vocabulary_path, self._config.max_labels, self._config.max_tokens, self._is_class)
4950

5051
@staticmethod
5152
def collate_wrapper(batch: List[Optional[LabeledPathContext]]) -> BatchedLabeledPathContext:

code2seq/data/path_context_dataset.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ def __init__(self, data_file: str, config: DictConfig, vocabulary: Vocabulary, r
2323
self._vocab = vocabulary
2424
self._random_context = random_context
2525

26-
self._label_unk = vocabulary.label_to_id[vocabulary.UNK]
27-
2826
self._line_offsets = get_lines_offsets(data_file)
2927
self._n_samples = len(self._line_offsets)
3028

@@ -49,7 +47,10 @@ def __getitem__(self, index) -> Optional[LabeledPathContext]:
4947
raw_path_contexts = raw_path_contexts[:n_contexts]
5048

5149
# Tokenize label
52-
label = self._tokenize_label(raw_label)
50+
if self._config.max_label_parts == 1:
51+
label = self.tokenize_class(raw_label, self._vocab.label_to_id)
52+
else:
53+
label = self.tokenize_label(raw_label, self._vocab.label_to_id, self._config.max_label_parts)
5354

5455
# Tokenize paths
5556
try:
@@ -61,30 +62,40 @@ def __getitem__(self, index) -> Optional[LabeledPathContext]:
6162

6263
return LabeledPathContext(label, paths)
6364

64-
def _tokenize_label(self, raw_label: str) -> torch.Tensor:
65-
label = torch.full((self._config.max_label_parts + 1,), self._vocab.label_to_id[self._vocab.PAD])
66-
label[0] = self._vocab.label_to_id[self._vocab.SOS]
67-
sublabels = raw_label.split(self._separator)[: self._config.max_label_parts]
68-
label[1 : len(sublabels) + 1] = torch.tensor(
69-
[self._vocab.label_to_id.get(sl, self._label_unk) for sl in sublabels]
70-
)
71-
if len(sublabels) < self._config.max_label_parts:
72-
label[len(sublabels) + 1] = self._vocab.label_to_id[self._vocab.EOS]
65+
@staticmethod
66+
def tokenize_class(raw_class: str, vocab: Dict[str, int]) -> torch.Tensor:
67+
return torch.tensor([vocab[raw_class]], dtype=torch.long)
68+
69+
@staticmethod
70+
def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> torch.Tensor:
71+
sublabels = raw_label.split(PathContextDataset._separator)
72+
max_parts = max_parts or len(sublabels)
73+
label_unk = vocab[Vocabulary.UNK]
74+
75+
label = torch.full((max_parts + 1,), vocab[Vocabulary.PAD], dtype=torch.long)
76+
label[0] = vocab[Vocabulary.SOS]
77+
sub_tokens_ids = [vocab.get(st, label_unk) for st in sublabels[:max_parts]]
78+
label[1 : len(sub_tokens_ids) + 1] = torch.tensor(sub_tokens_ids)
79+
80+
if len(sublabels) < max_parts:
81+
label[len(sublabels) + 1] = vocab[Vocabulary.EOS]
82+
7383
return label
7484

75-
def _tokenize_token(self, token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> torch.Tensor:
76-
sub_tokens = token.split(self._separator)
85+
@staticmethod
86+
def tokenize_token(token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> torch.Tensor:
87+
sub_tokens = token.split(PathContextDataset._separator)
7788
max_parts = max_parts or len(sub_tokens)
78-
token_unk = vocab[self._vocab.UNK]
89+
token_unk = vocab[Vocabulary.UNK]
7990

80-
result = torch.full((max_parts,), vocab[self._vocab.PAD], dtype=torch.long)
91+
result = torch.full((max_parts,), vocab[Vocabulary.PAD], dtype=torch.long)
8192
sub_tokens_ids = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]]
8293
result[: len(sub_tokens_ids)] = torch.tensor(sub_tokens_ids)
8394
return result
8495

8596
def _get_path(self, raw_path: List[str]) -> Path:
8697
return Path(
87-
from_token=self._tokenize_token(raw_path[0], self._vocab.token_to_id, self._config.max_token_parts),
88-
path_node=self._tokenize_token(raw_path[1], self._vocab.node_to_id, self._config.path_length),
89-
to_token=self._tokenize_token(raw_path[2], self._vocab.token_to_id, self._config.max_token_parts),
98+
from_token=self.tokenize_token(raw_path[0], self._vocab.token_to_id, self._config.max_token_parts),
99+
path_node=self.tokenize_token(raw_path[1], self._vocab.node_to_id, self._config.path_length),
100+
to_token=self.tokenize_token(raw_path[2], self._vocab.token_to_id, self._config.max_token_parts),
90101
)

code2seq/data/typed_path_context_dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ def __init__(self, data_file: str, config: DictConfig, vocabulary: TypedVocabula
1414

1515
def _get_path(self, raw_path: List[str]) -> TypedPath:
1616
return TypedPath(
17-
from_type=self._tokenize_token(raw_path[0], self._vocab.type_to_id, self._config.max_type_parts),
18-
from_token=self._tokenize_token(raw_path[1], self._vocab.token_to_id, self._config.max_token_parts),
19-
path_node=self._tokenize_token(raw_path[2], self._vocab.node_to_id, self._config.path_length),
20-
to_token=self._tokenize_token(raw_path[3], self._vocab.token_to_id, self._config.max_token_parts),
21-
to_type=self._tokenize_token(raw_path[4], self._vocab.type_to_id, self._config.max_type_parts),
17+
from_type=self.tokenize_token(raw_path[0], self._vocab.type_to_id, self._config.max_type_parts),
18+
from_token=self.tokenize_token(raw_path[1], self._vocab.token_to_id, self._config.max_token_parts),
19+
path_node=self.tokenize_token(raw_path[2], self._vocab.node_to_id, self._config.path_length),
20+
to_token=self.tokenize_token(raw_path[3], self._vocab.token_to_id, self._config.max_token_parts),
21+
to_type=self.tokenize_token(raw_path[4], self._vocab.type_to_id, self._config.max_type_parts),
2222
)

code2seq/data/vocabulary.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,19 @@
88

99

1010
class Vocabulary(BaseVocabulary):
11+
def __init__(
12+
self,
13+
vocabulary_file: str,
14+
max_labels: Optional[int] = None,
15+
max_tokens: Optional[int] = None,
16+
is_class: bool = False,
17+
):
18+
super().__init__(vocabulary_file, max_labels, max_tokens)
19+
if is_class:
20+
self._label_to_id = {
21+
token[0]: i for i, token in enumerate(self._counters[self.LABEL].most_common(max_labels))
22+
}
23+
1124
@staticmethod
1225
def _process_raw_sample(raw_sample: str, counters: Dict[str, CounterType[str]], context_seq: List[str]):
1326
label, *path_contexts = raw_sample.split(" ")

code2seq/model/code2class.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ def __init__(self, model_config: DictConfig, optimizer_config: DictConfig, vocab
2828
vocabulary.node_to_id[Vocabulary.PAD],
2929
)
3030

31-
self._classifier = Classifier(model_config, self._num_classes)
31+
self._classifier = Classifier(model_config, len(vocabulary.label_to_id))
3232

3333
metrics: Dict[str, Metric] = {
3434
f"{holdout}_acc": Accuracy(num_classes=len(vocabulary.label_to_id)) for holdout in ["train", "val", "test"]
3535
}
3636
self.__metrics = MetricCollection(metrics)
3737

3838
def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]:
39-
return configure_optimizers_alon(self._config.hyper_parameters, self.parameters())
39+
return configure_optimizers_alon(self._optim_config, self.parameters())
4040

4141
def forward( # type: ignore
4242
self,
@@ -45,7 +45,7 @@ def forward( # type: ignore
4545
to_token: torch.Tensor,
4646
contexts_per_label: torch.Tensor,
4747
) -> torch.Tensor:
48-
encoded_paths = self.__encoder(from_token, path_nodes, to_token)
48+
encoded_paths = self._encoder(from_token, path_nodes, to_token)
4949
output_logits = self._classifier(encoded_paths, contexts_per_label)
5050
return output_logits
5151

@@ -54,11 +54,12 @@ def forward( # type: ignore
5454
def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict:
5555
# [batch size; num_classes]
5656
logits = self(batch.from_token, batch.path_nodes, batch.to_token, batch.contexts_per_label)
57-
loss = torch.nn.functional.cross_entropy(logits, batch.labels.squeeze(0))
57+
labels = batch.labels.squeeze(0)
58+
loss = torch.nn.functional.cross_entropy(logits, labels)
5859

5960
with torch.no_grad():
6061
predictions = logits.argmax(-1)
61-
accuracy = self.__metrics[f"{step}_acc"](predictions, batch.labels)
62+
accuracy = self.__metrics[f"{step}_acc"](predictions, labels)
6263

6364
return {f"{step}/loss": loss, f"{step}/accuracy": accuracy}
6465

@@ -78,7 +79,7 @@ def test_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict:
7879

7980
def _shared_epoch_end(self, outputs: List[Dict], step: str):
8081
with torch.no_grad():
81-
mean_loss = torch.stack([out["loss"] for out in outputs]).mean()
82+
mean_loss = torch.stack([out[f"{step}/loss"] for out in outputs]).mean()
8283
accuracy = self.__metrics[f"{step}_acc"].compute()
8384
log = {f"{step}/loss": mean_loss, f"{step}/accuracy": accuracy}
8485
self.log_dict(log, on_step=False, on_epoch=True)

code2seq/model/modules/path_encoder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,14 @@ def __init__(
3232

3333
concat_size = self._calculate_concat_size(config.embedding_size, config.encoder_rnn_size, self.num_directions)
3434
self.embedding_dropout = nn.Dropout(config.encoder_dropout)
35-
self.linear = nn.Linear(concat_size, config.decoder_size, bias=False)
36-
self.norm = nn.LayerNorm(config.decoder_size)
35+
if "decoder_size" in config:
36+
out_size = config["decoder_size"]
37+
elif "classifier_size" in config:
38+
out_size = config["classifier_size"]
39+
else:
40+
raise ValueError("Specify out size of encoder")
41+
self.linear = nn.Linear(concat_size, out_size, bias=False)
42+
self.norm = nn.LayerNorm(out_size)
3743

3844
@staticmethod
3945
def _calculate_concat_size(embedding_size: int, rnn_size: int, num_directions: int) -> int:

config/code2class-poj104.yaml

Lines changed: 39 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,54 @@
1-
hydra:
2-
run:
3-
dir: .
4-
output_subdir: null
5-
job_logging: null
6-
hydra_logging: null
1+
data_folder: ../data/poj-104/poj-104-code2seq
72

8-
name: code2class
3+
checkpoint: null
94

105
seed: 7
11-
num_workers: 2
12-
log_offline: false
6+
log_offline: true
7+
# Training in notebooks (e.g. Google Colab) may crash with too small value
8+
progress_bar_refresh_rate: 1
9+
print_config: true
1310

14-
# data keys
15-
data_folder: /permanent-data
16-
vocabulary_name: vocabulary.pkl
17-
train_holdout: train
18-
val_holdout: val
19-
test_holdout: test
11+
data:
12+
url: https://s3.eu-west-1.amazonaws.com/datasets.ml.labs.aws.intellij.net/poj-104/poj-104-code2seq.tar.gz
13+
num_workers: 0
2014

21-
save_every_epoch: 1
22-
val_every_epoch: 1
23-
log_every_epoch: 10
24-
progress_bar_refresh_rate: 1
15+
max_labels: null
16+
max_label_parts: 1
17+
max_tokens: 190000
18+
max_token_parts: 5
19+
path_length: 9
2520

26-
hyper_parameters:
27-
n_epochs: 3000
28-
patience: 10
29-
batch_size: 512
30-
test_batch_size: 512
31-
clip_norm: 5
3221
max_context: 200
3322
random_context: true
34-
shuffle_data: true
35-
36-
optimizer: "Momentum"
37-
nesterov: true
38-
learning_rate: 0.01
39-
weight_decay: 0
40-
decay_gamma: 0.95
4123

42-
dataset:
43-
name: poj_104
44-
target:
45-
max_parts: 1
46-
is_wrapped: false
47-
is_splitted: false
48-
vocabulary_size: 27000
49-
token:
50-
max_parts: 5
51-
is_wrapped: false
52-
is_splitted: true
53-
vocabulary_size: 190000
54-
path:
55-
max_parts: 9
56-
is_wrapped: false
57-
is_splitted: true
58-
vocabulary_size: null
24+
batch_size: 512
25+
test_batch_size: 768
5926

60-
encoder:
27+
model:
28+
# Encoder
6129
embedding_size: 128
62-
rnn_size: 128
30+
encoder_dropout: 0.25
31+
encoder_rnn_size: 128
6332
use_bi_rnn: true
64-
embedding_dropout: 0.25
6533
rnn_num_layers: 1
66-
rnn_dropout: 0.5
6734

68-
classifier:
69-
n_hidden_layers: 2
70-
hidden_size: 128
71-
classifier_input_size: 256
35+
# Classifier
36+
classifier_layers: 2
37+
classifier_size: 128
7238
activation: relu
39+
40+
optimizer:
41+
optimizer: "Momentum"
42+
nesterov: true
43+
lr: 0.01
44+
weight_decay: 0
45+
decay_gamma: 0.95
46+
47+
train:
48+
n_epochs: 10
49+
patience: 10
50+
clip_norm: 5
51+
teacher_forcing: 1.0
52+
val_every_epoch: 1
53+
save_every_epoch: 1
54+
log_every_n_steps: 10

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ torchmetrics==0.5.0
77

88
tqdm==4.62.1
99
wandb==0.12.0
10-
omegaconf==2.1.0
11-
commode-utils==0.3.7
10+
omegaconf==2.1.1
11+
commode-utils==0.3.8

0 commit comments

Comments
 (0)