Skip to content

Commit 6547bf1

Browse files
committed
Refactor typed code2seq model
1 parent c8a01cc commit 6547bf1

File tree

5 files changed

+141
-144
lines changed

5 files changed

+141
-144
lines changed

code2seq/data/typed_path_context_data_module.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1+
from os.path import exists, join
12
from typing import List, Optional
23

4+
from commode_utils.vocabulary import build_from_scratch
35
from omegaconf import DictConfig
46

5-
from code2seq.data import (
6-
PathContextDataModule,
7-
TypedPathContextDataset,
8-
BatchedLabeledTypedPathContext,
9-
LabeledTypedPathContext,
10-
)
7+
from code2seq.data.path_context import LabeledTypedPathContext, BatchedLabeledTypedPathContext
8+
from code2seq.data.path_context_data_module import PathContextDataModule
9+
from code2seq.data.typed_path_context_dataset import TypedPathContextDataset
1110
from code2seq.data.vocabulary import TypedVocabulary
1211

1312

@@ -24,6 +23,15 @@ def collate_wrapper(batch: List[Optional[LabeledTypedPathContext]]) -> BatchedLa
2423
def _create_dataset(self, holdout_file: str, random_context: bool) -> TypedPathContextDataset:
2524
return TypedPathContextDataset(holdout_file, self._config, self._vocabulary, random_context)
2625

26+
def setup(self, stage: Optional[str] = None):
27+
if not exists(join(self._data_dir, TypedVocabulary.vocab_filename)):
28+
print("Can't find vocabulary, collect it from train holdout")
29+
build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), TypedVocabulary)
30+
vocabulary_path = join(self._data_dir, TypedVocabulary.vocab_filename)
31+
self._vocabulary = TypedVocabulary(
32+
vocabulary_path, self._config.max_labels, self._config.max_tokens, self._config.max_types
33+
)
34+
2735
@property
2836
def vocabulary(self) -> TypedVocabulary:
2937
if self._vocabulary is None:

code2seq/model/code2seq.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, List, Dict
1+
from typing import Tuple, List, Dict, Optional
22

33
import torch
44
from commode_utils.losses import SequenceCrossEntropyLoss
@@ -43,9 +43,9 @@ def __init__(
4343
}
4444
self.__metrics = MetricCollection(metrics)
4545

46-
self.__encoder = self._get_encoder(model_config)
46+
self._encoder = self._get_encoder(model_config)
4747
decoder_step = LSTMDecoderStep(model_config, len(vocabulary.label_to_id), self.__pad_idx)
48-
self.__decoder = Decoder(
48+
self._decoder = Decoder(
4949
decoder_step, len(vocabulary.label_to_id), vocabulary.label_to_id[vocabulary.SOS], teacher_forcing
5050
)
5151

@@ -78,23 +78,28 @@ def forward( # type: ignore
7878
output_length: int,
7979
target_sequence: torch.Tensor = None,
8080
) -> torch.Tensor:
81-
encoded_paths = self.__encoder(from_token, path_nodes, to_token)
82-
output_logits = self.__decoder(encoded_paths, contexts_per_label, output_length, target_sequence)
81+
encoded_paths = self._encoder(from_token, path_nodes, to_token)
82+
output_logits = self._decoder(encoded_paths, contexts_per_label, output_length, target_sequence)
8383
return output_logits
8484

8585
# ========== Model step ==========
8686

87-
def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict:
88-
target_sequence = batch.labels if step == "train" else None
89-
# [seq length; batch size; vocab size]
90-
logits = self(
87+
def logits_from_batch(
88+
self, batch: BatchedLabeledPathContext, target_sequence: Optional[torch.Tensor]
89+
) -> torch.Tensor:
90+
return self(
9191
batch.from_token,
9292
batch.path_nodes,
9393
batch.to_token,
9494
batch.contexts_per_label,
9595
batch.labels.shape[0],
9696
target_sequence,
9797
)
98+
99+
def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict:
100+
target_sequence = batch.labels if step == "train" else None
101+
# [seq length; batch size; vocab size]
102+
logits = self.logits_from_batch(batch, target_sequence)
98103
loss = self.__loss(logits[1:], batch.labels[1:])
99104

100105
with torch.no_grad():

code2seq/model/typed_code2seq.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from typing import Optional
2+
3+
import torch
14
from omegaconf import DictConfig
25

6+
from code2seq.data.path_context import BatchedLabeledTypedPathContext
37
from code2seq.data.vocabulary import TypedVocabulary
48
from code2seq.model import Code2Seq
59
from code2seq.model.modules import TypedPathEncoder, PathEncoder
@@ -26,3 +30,32 @@ def _get_encoder(self, config: DictConfig) -> PathEncoder:
2630
len(self._vocabulary.type_to_id),
2731
self._vocabulary.type_to_id[TypedVocabulary.PAD],
2832
)
33+
34+
def forward( # type: ignore
35+
self,
36+
from_type: torch.Tensor,
37+
from_token: torch.Tensor,
38+
path_nodes: torch.Tensor,
39+
to_token: torch.Tensor,
40+
to_type: torch.Tensor,
41+
contexts_per_label: torch.Tensor,
42+
output_length: int,
43+
target_sequence: torch.Tensor = None,
44+
) -> torch.Tensor:
45+
encoded_paths = self._encoder(from_type, from_token, path_nodes, to_token, to_type)
46+
output_logits = self._decoder(encoded_paths, contexts_per_label, output_length, target_sequence)
47+
return output_logits
48+
49+
def logits_from_batch(
50+
self, batch: BatchedLabeledTypedPathContext, target_sequence: Optional[torch.Tensor]
51+
) -> torch.Tensor:
52+
return self(
53+
batch.from_type,
54+
batch.from_token,
55+
batch.path_nodes,
56+
batch.to_token,
57+
batch.to_type,
58+
batch.contexts_per_label,
59+
batch.labels.shape[0],
60+
target_sequence,
61+
)

config/typed-code2seq-java-small.yaml

Lines changed: 42 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,58 @@
1-
hydra:
2-
run:
3-
dir: .
4-
output_subdir: null
5-
job_logging: null
6-
hydra_logging: null
1+
data_folder: ../data/code2seq/java-small
72

8-
name: typed-code2seq
3+
checkpoint: null
94

105
seed: 7
11-
num_workers: 2
126
log_offline: false
13-
14-
resume_from_checkpoint: null
15-
16-
# data keys
17-
data_folder: /permanent-data
18-
vocabulary_name: vocabulary.pkl
19-
train_holdout: train
20-
val_holdout: val
21-
test_holdout: test
22-
23-
save_every_epoch: 1
24-
val_every_epoch: 1
25-
log_every_epoch: 10
7+
# Training in notebooks (e.g. Google Colab) may crash with too small value
268
progress_bar_refresh_rate: 1
9+
print_config: true
10+
11+
data:
12+
url: https://s3.eu-west-1.amazonaws.com/datasets.ml.labs.aws.intellij.net/java-paths-methods/java-small.tar.gz
13+
num_workers: 4
14+
15+
# Each token appears at least 5 times
16+
max_labels: 5332
17+
max_label_parts: 7
18+
# Each token appears at least 5 times
19+
max_tokens: 18240
20+
max_token_parts: 5
21+
max_typed: null
22+
max_type_parts: 5
23+
path_length: 9
2724

28-
hyper_parameters:
29-
n_epochs: 3000
30-
patience: 10
31-
batch_size: 512
32-
test_batch_size: 512
33-
clip_norm: 5
3425
max_context: 200
3526
random_context: true
36-
shuffle_data: true
37-
38-
optimizer: "Momentum"
39-
nesterov: true
40-
learning_rate: 0.01
41-
weight_decay: 0
42-
decay_gamma: 0.95
4327

44-
dataset:
45-
name: java-small-psi
46-
target:
47-
max_parts: 7
48-
is_wrapped: true
49-
is_splitted: true
50-
vocabulary_size: 11316
51-
token:
52-
max_parts: 5
53-
is_wrapped: false
54-
is_splitted: true
55-
vocabulary_size: 73904
56-
path:
57-
max_parts: 9
58-
is_wrapped: false
59-
is_splitted: true
60-
vocabulary_size: null
61-
type:
62-
max_parts: 5
63-
is_wrapped: false
64-
is_splitted: true
65-
vocabulary_size: null
28+
batch_size: 512
29+
test_batch_size: 768
6630

67-
encoder:
31+
model:
32+
# Encoder
6833
embedding_size: 128
69-
rnn_size: 128
34+
encoder_dropout: 0.25
35+
encoder_rnn_size: 128
7036
use_bi_rnn: true
71-
embedding_dropout: 0.25
7237
rnn_num_layers: 1
73-
rnn_dropout: 0.5
7438

75-
decoder:
39+
# Decoder
7640
decoder_size: 320
77-
embedding_size: 128
78-
num_decoder_layers: 1
41+
decoder_num_layers: 1
7942
rnn_dropout: 0.5
80-
teacher_forcing: 1
81-
beam_width: 0
43+
44+
optimizer:
45+
optimizer: "Momentum"
46+
nesterov: true
47+
lr: 0.01
48+
weight_decay: 0
49+
decay_gamma: 0.95
50+
51+
train:
52+
n_epochs: 10
53+
patience: 10
54+
clip_norm: 5
55+
teacher_forcing: 1.0
56+
val_every_epoch: 1
57+
save_every_epoch: 1
58+
log_every_n_steps: 10

config/typed-code2seq-java-test.yaml

Lines changed: 38 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,55 @@
1-
hydra:
2-
run:
3-
dir: .
4-
output_subdir: null
5-
job_logging: null
6-
hydra_logging: null
1+
data_folder: ../data/code2seq/java-test-typed
72

8-
name: typed-code2seq
3+
checkpoint: null
94

105
seed: 7
11-
num_workers: 0
126
log_offline: true
7+
# Training in notebooks (e.g. Google Colab) may crash with too small value
8+
progress_bar_refresh_rate: 1
9+
print_config: true
1310

14-
resume_from_checkpoint: null
15-
16-
# data keys
17-
data_folder: /permanent-data
18-
vocabulary_name: vocabulary.pkl
19-
train_holdout: train
20-
val_holdout: val
21-
test_holdout: test
11+
data:
12+
num_workers: 0
2213

23-
save_every_epoch: 1
24-
val_every_epoch: 1
25-
log_every_epoch: 10
26-
progress_bar_refresh_rate: 1
14+
max_labels: null
15+
max_label_parts: 7
16+
max_tokens: null
17+
max_token_parts: 5
18+
max_types: null
19+
max_type_parts: 5
20+
path_length: 9
2721

28-
hyper_parameters:
29-
n_epochs: 5
30-
patience: 10
31-
batch_size: 5
32-
test_batch_size: 5
33-
clip_norm: 5
3422
max_context: 200
3523
random_context: true
36-
shuffle_data: true
37-
38-
optimizer: "Momentum"
39-
nesterov: true
40-
learning_rate: 0.01
41-
weight_decay: 0
42-
decay_gamma: 0.95
4324

44-
dataset:
45-
name: java-test-psi
46-
target:
47-
max_parts: 7
48-
is_wrapped: true
49-
is_splitted: true
50-
vocabulary_size: 11316
51-
token:
52-
max_parts: 5
53-
is_wrapped: false
54-
is_splitted: true
55-
vocabulary_size: 73904
56-
path:
57-
max_parts: 9
58-
is_wrapped: false
59-
is_splitted: true
60-
vocabulary_size: null
61-
type:
62-
max_parts: 5
63-
is_wrapped: false
64-
is_splitted: true
65-
vocabulary_size: null
25+
batch_size: 5
26+
test_batch_size: 10
6627

67-
encoder:
28+
model:
29+
# Encoder
6830
embedding_size: 10
69-
rnn_size: 10
31+
encoder_dropout: 0.25
32+
encoder_rnn_size: 10
7033
use_bi_rnn: true
71-
embedding_dropout: 0.25
7234
rnn_num_layers: 1
73-
rnn_dropout: 0.5
7435

75-
decoder:
36+
# Decoder
7637
decoder_size: 20
77-
embedding_size: 10
78-
num_decoder_layers: 1
38+
decoder_num_layers: 1
7939
rnn_dropout: 0.5
80-
teacher_forcing: 1
81-
beam_width: 0
40+
41+
optimizer:
42+
optimizer: "Momentum"
43+
nesterov: true
44+
lr: 0.01
45+
weight_decay: 0
46+
decay_gamma: 0.95
47+
48+
train:
49+
n_epochs: 5
50+
patience: 10
51+
clip_norm: 10
52+
teacher_forcing: 1.0
53+
val_every_epoch: 1
54+
save_every_epoch: 1
55+
log_every_n_steps: 10

0 commit comments

Comments
 (0)