Skip to content

Commit e33206e

Browse files
margin based training / evaluation
1 parent 75101fd commit e33206e

15 files changed

+258
-209
lines changed

scripts/train.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import os
22
import pickle
33

4-
from src.Name.neural.train import TrainCfg, Trainer, acc, Logger, ModelCfg, macro_binary_stats
5-
from src.Name.neural.batching import filter_data, Sampler, Collator
6-
from src.Name.neural.utils.schedules import make_schedule
4+
from src.Name.nn.training import TrainCfg, Trainer, Logger, ModelCfg
5+
from src.Name.nn.batching import filter_data, Sampler, Collator
6+
from src.Name.nn.utils.schedules import make_schedule
77

88
from torch import device
99
from torch.optim import AdamW
1010
from torch.optim.lr_scheduler import LambdaLR
1111

12-
from random import seed
12+
from random import seed, shuffle
1313

1414
import sys
1515

@@ -31,11 +31,11 @@ def train(config: TrainCfg, data_path: str, cast_to: str):
3131
train_files = [file for file in files if file.file.name in config['train_files']]
3232
dev_files = [file for file in files if file.file.name in config['dev_files']]
3333
print(f'Training on {len(train_files)} files with {sum(len(file.hole_asts) for file in train_files)} holes.')
34-
print(f'Evaluating on {len(dev_files)} files with {sum(len(file.hole_asts) for file in train_files)} holes.')
34+
print(f'Evaluating on {len(dev_files)} files with {sum(len(file.hole_asts) for file in dev_files)} holes.')
3535

3636
train_sampler = Sampler(train_files)
3737
epoch_size = train_sampler.itersize(config['batch_size_s'] * config['backprop_every'], config['batch_size_h'])
38-
collator = Collator(pad_value=-1, device=cast_to, allow_self_loops=False)
38+
collator = Collator(pad_value=-1, device=cast_to, allow_self_loops=config['allow_self_loops'])
3939

4040
model = Trainer(config['model_config']).to(device(cast_to))
4141
optimizer = AdamW(params=model.parameters(), lr=1, weight_decay=1e-02)
@@ -46,7 +46,7 @@ def train(config: TrainCfg, data_path: str, cast_to: str):
4646
total_steps=config['num_epochs'] * epoch_size)
4747
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=schedule, last_epoch=-1)
4848

49-
best_loss = 1e10
49+
best_ap = -1e08
5050

5151
for epoch in range(config['num_epochs']):
5252
print(f'Epoch {epoch}')
@@ -58,31 +58,29 @@ def train(config: TrainCfg, data_path: str, cast_to: str):
5858
optimizer=optimizer,
5959
scheduler=scheduler,
6060
backprop_every=config['backprop_every'])
61-
print(f'Train loss: {sum(train_epoch["loss"])/len(train_epoch["predictions"])}')
62-
print(f'Train stats: {macro_binary_stats(train_epoch["predictions"], train_epoch["truths"])}')
61+
print(f'Train loss: {sum(train_epoch.loss)/len(train_epoch.loss)}')
62+
print(f'Train mAP: {sum(train_epoch.ap)/len(train_epoch.ap)}')
63+
print(f'Train R-Precision: {sum(train_epoch.rp) / len(train_epoch.rp)}')
6364
dev_epoch = model.eval_epoch(map(lambda x: collator([x]), dev_files))
64-
print(f'Dev loss: {sum(dev_epoch["loss"])/len(dev_epoch["predictions"])}')
65-
print(f'Dev stats: {macro_binary_stats(dev_epoch["predictions"], dev_epoch["truths"])}')
66-
print()
67-
68-
# if sum(dev_epoch['loss']) < best_loss:
69-
# print('Saving...')
70-
# model.save(f'./model.pt')
71-
# best_loss = sum(dev_epoch['loss'])
72-
# print('=' * 64 + '\n')
65+
print(f'Dev loss: {sum(dev_epoch.loss)/len(dev_epoch.loss)}')
66+
print(f'Dev mAP: {sum(dev_epoch.ap) / len(dev_epoch.ap)}')
67+
print(f'Dev R-Precision: {sum(dev_epoch.rp) / len(dev_epoch.rp)}')
68+
if sum(dev_epoch.ap) > best_ap:
69+
print('Saving...')
70+
model.save(f'./model.pt')
71+
best_ap = sum(dev_epoch.ap)
72+
print('=' * 64 + '\n')
7373

7474

7575
if __name__ == '__main__':
7676
seed(42)
77-
# todo.
77+
7878
files = [os.path.splitext(file)[0] for file in os.listdir('../data/stdlib/')]
79-
# stdlib = [line for line in open('./data/stdlib.contents').read().split('\n')]
80-
# unimath = [line for line in open('./data/um.contents').read().split('\n')]
81-
# typetopo = [line for line in open('./data/tt.contents').read().split('\n')]
82-
# shuffle(stdlib)
79+
shuffle(files)
80+
train_files, dev_files = files[:(int(0.75 * len(files)))], files[int(0.75 * len(files)):]
8381

8482
model_config: ModelCfg = {
85-
'depth': 8,
83+
'depth': 6,
8684
'num_heads': 8,
8785
'dim': 128,
8886
'atn_dim': None,
@@ -94,16 +92,17 @@ def train(config: TrainCfg, data_path: str, cast_to: str):
9492
'num_epochs': 99,
9593
'warmup_epochs': 3,
9694
'warmdown_epochs': 90,
97-
'batch_size_s': 1,
95+
'batch_size_s': 2,
9896
'batch_size_h': 8,
9997
'max_lr': 5e-4,
10098
'min_lr': 1e-7,
10199
'backprop_every': 1,
102-
'train_files': [f for f in files if f != 'Simple'],
103-
'dev_files': [],
100+
'train_files': train_files,
101+
'dev_files': dev_files,
104102
'test_files': [],
105103
'max_scope_size': 300,
106104
'max_ast_len': 100,
105+
'allow_self_loops': False
107106
}
108107

109108
train(train_cfg, '../data/tokenized.p', 'cuda')

src/Name/data/agda/syntax.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@ class File(_AgdaExpr[Name]):
2222
scope: list[ScopeEntry[Name]]
2323

2424
def __post_init__(self):
25-
if not self.valid_reference_structure():
26-
raise AssertionError('Invalid reference structure')
27-
if not self.unique_entry_names():
28-
raise AssertionError('Duplicate entry names.')
25+
assert self.valid_reference_structure(), 'Invalid reference structure.'
26+
assert self.unique_entry_names(), 'Duplicate entry names.'
2927

3028
def valid_reference_structure(self) -> bool:
3129
names = [entry.name for entry in self.scope]

src/Name/inference.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from .data.agda.reader import File
66
from .data.tokenization import tokenize_file
7-
from .neural.model import Model, ModelCfg
8-
from .neural.batching import Collator
9-
from .neural.train import Logger
7+
from .nn.model import Model, ModelCfg
8+
from .nn.batching import Collator
9+
from .nn.train import Logger
1010

1111
from torch_geometric.utils import to_dense_batch
1212

@@ -25,7 +25,7 @@ def select_premises(self, file: File[str], threshold: float = 0.5) -> list[set[s
2525
with torch.no_grad():
2626
batch = self.collator([tokenized])
2727
scope_reprs, hole_reprs = self.encode(batch)
28-
lemma_predictions = self.predict_lemmas(scope_reprs, hole_reprs, batch.edge_index)
28+
lemma_predictions = self.match(scope_reprs, hole_reprs, batch.edge_index)
2929
sparse = to_dense_batch(lemma_predictions, batch.edge_index[1], fill_value=-1e8)
3030
# pdb.set_trace()
3131
# # todo

src/Name/neural/train.py

-151
This file was deleted.

src/Name/neural/batching.py src/Name/nn/batching.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ def __call__(self, files: list[TokenizedFile]) -> Batch:
9595
src_index, tgt_index, premise_selection = [], [], []
9696
for batch_id, file in enumerate(files):
9797
src_offset = sum(scope_lens[:batch_id])
98-
hole_offset = sum(num_holes[:batch_id])
98+
tgt_offset = sum(num_holes[:batch_id])
9999
for hole_idx, defined_at in enumerate(file.hole_to_scope):
100100
src_index += list(range(src_offset, src_offset + defined_at))
101-
tgt_index += [hole_offset + hole_idx] * defined_at
101+
tgt_index += [tgt_offset + hole_idx] * defined_at
102102
premise_selection += [entry in file.premises[hole_idx] for entry in range(defined_at)]
103103
edge_index = torch.stack((self.tensor(src_index), self.tensor(tgt_index)))
104104
premises = self.tensor(premise_selection)
@@ -144,10 +144,12 @@ def filter_data(files: list[TokenizedFile],
144144
max_ast_len: int) -> Iterator[TokenizedFile]:
145145

146146
for file in files:
147-
if (len(file.hole_asts)
148-
and len(file.scope_asts) <= max_scope_size
147+
if (
148+
len(file.hole_asts)
149+
and 1 <= len(file.scope_asts) <= max_scope_size
149150
and max(len(ast) for ast in file.hole_asts) <= max_ast_len
150-
and max(len(ast) for ast in file.scope_asts) <= max_ast_len):
151+
and max(len(ast) for ast in file.scope_asts) <= max_ast_len
152+
):
151153
yield file
152154

153155

File renamed without changes.
File renamed without changes.

src/Name/neural/model.py src/Name/nn/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def encode(self, batch: Batch) -> tuple[Tensor, Tensor]:
3535
scope_positions=batch.scope_positions,
3636
hole_positions=batch.hole_positions)
3737

38-
def predict_lemmas(self, scope_reprs: Tensor, hole_reprs: Tensor, edge_index: Tensor) -> Tensor:
38+
def match(self, scope_reprs: Tensor, hole_reprs: Tensor, edge_index: Tensor) -> Tensor:
3939
source_index, target_index = edge_index
4040
sources = scope_reprs[source_index]
4141
targets = hole_reprs[target_index]

0 commit comments

Comments
 (0)