Skip to content

Commit 7be34a1

Browse files
backup
1 parent 57f309e commit 7be34a1

File tree

9 files changed

+194
-122
lines changed

9 files changed

+194
-122
lines changed

scripts/preprocess.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55

66
samples = []
7-
for i, file in enumerate(parse_dir('../stdlib/json')):
7+
for i, file in enumerate(parse_dir('../stdlib', version='simplified')):
88
anonymous = enum_references(file)
99
scope, holes = tokenize_file(anonymous)
1010
if len(holes) != 0:
1111
samples.append((scope, holes))
1212

13-
with open('../data/tokenized.p', 'wb') as f:
13+
with open('../data/tokenized_sim.p', 'wb') as f:
1414
pickle.dump(samples, f)

scripts/train.py

+25-25
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import pickle
22

3-
import torch
3+
import sys
4+
sys.path.extend(['../'])
45

6+
import torch
57
from src.Name.neural.batching import make_collator, Sampler
68
from src.Name.neural.training import TrainWrapper
79
from src.Name.neural.utils import make_schedule, binary_stats, macro_binary_stats
810
from torch import device as _device
9-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
1011
from torch.optim import AdamW
1112
from torch.optim.lr_scheduler import LambdaLR
1213
from math import ceil
@@ -15,11 +16,11 @@
1516
tokenized = pickle.load(f)
1617

1718

19+
1820
dim = 128
1921
num_epochs = 100
20-
encoder_layers = 3
21-
num_iters = 4
22-
batch_size = 4
22+
num_layers = 8
23+
batch_size = 2
2324
backprop_every = 1
2425
num_holes = 4
2526
max_scope_size = 150
@@ -38,13 +39,10 @@
3839

3940
epoch_size = train_sampler.itersize(batch_size * backprop_every, num_holes)
4041

41-
model = TrainWrapper(num_layers=encoder_layers, num_iters=num_iters, dim=dim,
42-
max_scope_size=max_scope_size, max_db_index=max_db_index).to(device)
42+
model = TrainWrapper(num_layers=num_layers, dim=dim, max_db_index=max_db_index).to(device)
4343

44-
lemma_loss_fn = BCEWithLogitsLoss(reduction='sum', pos_weight=torch.tensor(50., device=device))
45-
lm_loss_fn = CrossEntropyLoss(reduction='sum')
4644

47-
opt = AdamW(model.parameters(), lr=1)
45+
opt = AdamW(model.parameters(), lr=1, weight_decay=1e-02)
4846
scheduler = LambdaLR(opt,
4947
make_schedule(warmup_steps=3 * epoch_size,
5048
total_steps=100 * epoch_size,
@@ -61,22 +59,24 @@
6159
train_epoch = train_sampler.iter(batch_size, num_holes)
6260
model.train()
6361

64-
for batch_id, batch in enumerate(train_epoch):
65-
lemma_preds, gold_labels, (lm_hits, lm_total), lemma_loss, lm_loss = model.compute_losses(collator(batch, 0.1))
66-
loss = lemma_loss + lm_loss
67-
loss.backward()
62+
with torch.autograd.set_detect_anomaly(True):
63+
for batch_id, batch in enumerate(train_epoch):
64+
collated = collator(batch, 0.1, 0.5)
65+
lemma_preds, gold_labels, (lm_hits, lm_total), lemma_loss, lm_loss = model.compute_losses(collated)
66+
loss = lemma_loss + lm_loss
67+
loss.backward()
6868

69-
if (batch_id + 1) % backprop_every == 0:
70-
opt.step()
71-
scheduler.step()
72-
opt.zero_grad(set_to_none=True)
69+
if (batch_id + 1) % backprop_every == 0:
70+
opt.step()
71+
scheduler.step()
72+
opt.zero_grad(set_to_none=True)
7373

74-
epoch_lemma_loss += lemma_loss.item()
75-
epoch_lm_loss += lm_loss.item()
76-
epoch_lemma_preds += lemma_preds
77-
epoch_lemma_correct += gold_labels
78-
epoch_lm_hits += lm_hits
79-
epoch_lm_total += lm_total
74+
epoch_lemma_loss += lemma_loss.item()
75+
epoch_lm_loss += lm_loss.item()
76+
epoch_lemma_preds += lemma_preds
77+
epoch_lemma_correct += gold_labels
78+
epoch_lm_hits += lm_hits
79+
epoch_lm_total += lm_total
8080

8181
print('=' * 64)
8282
print(f'Epoch {epoch_id}')
@@ -100,7 +100,7 @@
100100

101101
with torch.no_grad():
102102
for file in dev_sampler.filtered:
103-
lemma_preds, gold_labels, _, lemma_loss, _ = model.compute_losses(collator([file], 0.0))
103+
lemma_preds, gold_labels, _, lemma_loss, _ = model.compute_losses(collator([file], 0.0, 1))
104104
epoch_dev_loss += lemma_loss.item()
105105

106106
epoch_lemma_preds += lemma_preds

src/Name/data/reader.py

+28-20
Original file line numberDiff line numberDiff line change
@@ -154,58 +154,66 @@ def substitute(self, names: dict[Name, Other]) -> LevelType[Other]:
154154
return self
155155

156156

157-
def parse_dir(directory: str, must_contain: str | None = None) -> Iterator[File[str]]:
157+
def parse_dir(directory: str, must_contain: str | None = None, version: str = 'original') -> Iterator[File[str]]:
158158
for file in listdir(directory):
159-
if must_contain is None or must_contain in file:
160-
yield parse_file(path.join(directory, file))
159+
if (must_contain is None or must_contain in file) and file.endswith('.json'):
160+
print(f'Parsing {file}')
161+
yield parse_file(path.join(directory, file), version)
161162

162163

163-
def parse_file(filepath: str) -> File[str]:
164+
def parse_file(filepath: str, version: str) -> File[str]:
164165
with open(filepath, 'r') as f:
165-
return parse_data(load(f))
166+
return parse_data(load(f), version)
166167

167168

168-
def parse_data(data_json: dict) -> File[str]:
169+
def parse_data(data_json: dict, version: str) -> File[str]:
169170
return File(name=data_json['scope']['name'],
170-
scope=[parse_declaration(d) for d in data_json['scope']['item']],
171-
holes=[parse_holes(s) for s in data_json['samples']])
171+
scope=[parse_declaration(d, version) for d in data_json['scope']['item']],
172+
holes=[parse_holes(s, version) for s in data_json['samples']])
172173

173174

174-
def parse_holes(hole_json: dict) -> Hole[str]:
175+
def parse_holes(hole_json: dict, version: str) -> Hole[str]:
175176
context_json = hole_json['ctx']['thing']
176177
goal_type_json = hole_json['goal']
177178
goal_term_json = hole_json['term']
178179
goal_names_used = hole_json['namesUsed']
179-
context = [Declaration(name=c['name'], type=parse_type(c['item'])) for c in context_json]
180+
context = [Declaration(name=c['name'], type=parse_type(c['item'], version)) for c in context_json]
180181

181182
return Hole(
182183
goal_type=reduce(lambda result, argument: PiType(argument, result),
183184
reversed(context),
184-
parse_type(goal_type_json['thing'])), # type: ignore
185-
goal_term=parse_type(goal_term_json['thing']),
185+
parse_type(goal_type_json['thing'], version)), # type: ignore
186+
goal_term=parse_type(goal_term_json['thing']['original'], version),
186187
names_used=[Reference(name) for name in goal_names_used])
187188

188189

189-
def parse_declaration(dec_json: dict) -> Declaration[str]:
190-
return Declaration(name=dec_json['name'], type=parse_type(dec_json['item']['thing']))
190+
def parse_declaration(dec_json: dict, version: str) -> Declaration[str]:
191+
return Declaration(name=dec_json['name'], type=parse_type(dec_json['item']['thing'], version))
191192

192193

193-
def parse_type(type_json: dict) -> AgdaType[str]:
194+
def parse_type(type_json: dict, which: str) -> AgdaType[str]:
195+
def go(_type_json: dict) -> AgdaType[str]: return parse_type(_type_json, which)
196+
197+
if which in type_json.keys():
198+
if (tmp := type_json[which]) is not None:
199+
type_json = tmp
200+
else:
201+
type_json = type_json['original']
202+
194203
match type_json['tag']:
195204
case 'Pi':
196205
left, right = type_json['contents']
197206
name, type_json = left['name'], left['item']
198-
return PiType(argument=(Declaration(name=name, type=parse_type(type_json))
199-
if name != '_' else parse_type(type_json)),
200-
result=parse_type(right))
207+
return PiType(argument=(Declaration(name=name, type=go(type_json)) if name != '_' else go(type_json)),
208+
result=go(right))
201209
case 'App':
202210
head, args = type_json['contents']
203211
head_type = parse_head(head)
204-
arg_types = [parse_type(arg) for arg in args]
212+
arg_types = [go(arg) for arg in args]
205213
return reduce(AppType, arg_types, head_type) # type: ignore
206214
case 'Lam':
207215
contents = type_json['contents']
208-
return LamType(abstraction=contents['name'], body=parse_type(contents['item']))
216+
return LamType(abstraction=contents['name'], body=go(contents['item']))
209217
case 'Sort':
210218
return SortType(type_json['contents'].replace(' ', '_'))
211219
case 'Lit':

src/Name/data/tokenization.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pdb
2+
13
from .reader import File
24
from .internal import AgdaTree, DontCare, DeBruijn, Reference, OpNames, agda_to_tree
35
from .tree import enumerate_nodes, flatten

src/Name/neural/batching.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import pdb
2-
31
import torch
42
from torch import Tensor, device
53
from ..data.tokenization import TokenizedSample, TokenizedFile, TokenizedTree
@@ -8,12 +6,37 @@
86
from typing import Iterator, Callable
97
from itertools import groupby
108
from torch.nn.functional import pad as _pad
9+
from random import random
10+
from itertools import takewhile
11+
12+
NineTensors = tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]
13+
14+
15+
def filter_unreferenced(file: TokenizedFile, negative_sampling: float) -> TokenizedFile:
16+
scope, goals = file
1117

12-
EightTensors = tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]
18+
def refers_to(tree: TokenizedTree, excluding: set[int]) -> set[int]:
19+
direct = {tv for tt, tv, _, _ in tree if tt == 3 and tv not in excluding}
20+
excluding |= direct
21+
return {indirect
22+
for reference in direct
23+
for indirect in refers_to(scope[reference], excluding)} | direct
24+
25+
def rename(tree: TokenizedTree, using: dict[int, int]) -> TokenizedTree:
26+
return [(tt, using[tv] if tt == 3 else tv, np, using[tp]) for tt, tv, np, tp in tree]
27+
28+
all_references = set.union(*[refers_to(tree, set()) for tree in [*scope, *[goal_type for goal_type, _ in goals]]])
29+
all_references |= {ref for _, names_used in goals for ref in names_used}
30+
removed = [idx for idx in range(len(scope)) if idx not in all_references or random() > negative_sampling]
31+
renames = {kept: kept - sum(map(lambda _: 1, takewhile(lambda r: r < kept, removed))) for kept in range(len(scope))}
32+
renames[-1] = -1
33+
return ([rename(tree, renames) for idx, tree in enumerate(scope) if idx not in removed],
34+
[(rename(goal_type, renames), [renames[ref] for ref in names_used]) for goal_type, names_used in goals])
1335

1436

1537
def make_collator(cast_to: device = device('cpu'),
16-
pad_value: int = -1,) -> Callable[[list[TokenizedSample], float], EightTensors]:
38+
pad_value: int = -1,
39+
goal_id: int = -1) -> Callable[[list[TokenizedSample], float, float], NineTensors]:
1740
def _longt(xs) -> Tensor:
1841
return torch.tensor(xs, device=cast_to, dtype=torch.long)
1942

@@ -26,7 +49,9 @@ def pad_tree(tree: TokenizedTree, to: int) -> Tensor:
2649
def pad_seq(file: list[Tensor]) -> Tensor:
2750
return pad_sequence(file, padding_value=pad_value)
2851

29-
def collator(samples: list[TokenizedSample], lm_chance: float) -> EightTensors:
52+
def collator(samples: list[TokenizedSample], lm_chance: float, negative_sampling: float) -> NineTensors:
53+
# samples = [filter_unreferenced(sample, negative_sampling) for sample in samples]
54+
3055
num_scopes = len(samples)
3156
scope_sizes, goal_sizes = zip(*[(len(scope), len(holes)) for scope, holes in samples])
3257
most_trees = max(x+y for x, y in zip(scope_sizes, goal_sizes))
@@ -58,15 +83,15 @@ def collator(samples: list[TokenizedSample], lm_chance: float) -> EightTensors:
5883
dense_batch.masked_scatter_(lm_mask.unsqueeze(-1), masked_refs)
5984
batch_pointers = torch.arange(0, num_scopes, device=cast_to).view(-1, 1, 1) * torch.ones_like(token_padding_mask)
6085
batch_pointers = batch_pointers[lm_mask]
61-
62-
# is_goal = (dense_batch[:, :, :, -1] == goal_id).all(dim=-1) & tree_padding_mask
63-
# scope_attention_mask = (~is_goal & tree_padding_mask).unsqueeze(-2).expand(-1, most_trees, -1)
64-
# diag_mask = torch.eye(most_trees, dtype=torch.bool, device=cast_to).unsqueeze(0).expand(num_scopes, -1, -1)
65-
# tree_attention_mask = scope_attention_mask | (diag_mask & tree_padding_mask.unsqueeze(-1))
86+
87+
is_goal = (dense_batch[:, :, :, -1] == goal_id).all(dim=-1) & tree_padding_mask
88+
scope_attention_mask = (~is_goal & tree_padding_mask).unsqueeze(-2).expand(-1, most_trees, -1)
89+
diag_mask = torch.eye(most_trees, dtype=torch.bool, device=cast_to).unsqueeze(0).expand(num_scopes, -1, -1)
90+
tree_attention_mask = scope_attention_mask | (diag_mask & tree_padding_mask.unsqueeze(-1))
6691
return (dense_batch.permute(-1, 0, 1, 2),
6792
token_attention_mask,
6893
tree_padding_mask,
69-
# tree_attention_mask,
94+
tree_attention_mask,
7095
edge_index,
7196
gold_labels,
7297
lm_mask,

src/Name/neural/embedding.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, dim: int):
2121
def embed_positions(self, positions: list[int]) -> Tensor:
2222
# todo: this can be made much more efficient by reusing subsequence maps
2323
word_seq = [torch.tensor(self.node_pos_to_path(pos), device=self.primitives.device, dtype=torch.long)
24-
if pos > 0 else torch.tensor([])
24+
if pos > 0 else torch.empty(0, device=self.primitives.device, dtype=torch.long)
2525
for pos in positions]
2626
word_ten = pad_sequence(word_seq, padding_value=2)
2727
maps = self.identity.repeat(len(positions), 1)
@@ -68,21 +68,20 @@ def __init__(self,
6868
num_ops: int,
6969
num_leaves: int,
7070
dim: int,
71-
max_scope_size: int = 250,
7271
max_db_index: int = 50):
7372
super(TokenEmbedder, self).__init__()
7473
self.num_leaves = num_leaves
7574
self.num_ops = num_ops
76-
self.max_scope_size = max_scope_size
7775
self.max_db_size = max_db_index
76+
self.dim = dim
7877
# ops, leaves, [sos], [ref], [oos], [mask]
79-
self.fixed_embeddings = Embedding(num_embeddings=num_ops+num_leaves+4, embedding_dim=dim // 2)
78+
self.fixed_embeddings = Embedding(num_embeddings=num_ops+num_leaves+3, embedding_dim=dim // 2)
8079
self.path_encoder = BinaryPathEncoder.orthogonal(dim // 2)
8180
self.db_encoder = SequentialPositionEncoder(dim // 2, freq=max_db_index)
8281

8382
def forward(self, dense_batch: Tensor) -> tuple[Tensor, Tensor, Tensor]:
8483
token_types, token_values, node_positions, tree_positions = dense_batch
85-
num_scopes, num_entries, _ = token_types.shape
84+
num_scopes, num_entries, num_tokens = token_types.shape
8685

8786
sos_mask = token_types == 0
8887
op_mask = token_types == 1
@@ -98,8 +97,8 @@ def forward(self, dense_batch: Tensor) -> tuple[Tensor, Tensor, Tensor]:
9897
content_embeddings[sos_mask] = self.fixed_embeddings.weight[0]
9998
content_embeddings[op_mask] = self.fixed_embeddings.forward(token_values[op_mask] + 1)
10099
content_embeddings[leaf_mask] = self.fixed_embeddings.forward(token_values[leaf_mask] + self.num_ops + 1)
101-
content_embeddings[ref_mask] = self.fixed_embeddings.weight[-3]
102-
content_embeddings[oos_mask] = self.fixed_embeddings.weight[-2]
100+
content_embeddings[ref_mask] = self.fixed_embeddings.weight[-2]
101+
content_embeddings[oos_mask] = self.fixed_embeddings.weight[-1]
103102
content_embeddings[lm_mask] = self.fixed_embeddings.weight[-1]
104103
content_embeddings[db_mask] = self.db_encoder.forward(token_values[db_mask])
105104

0 commit comments

Comments
 (0)