Skip to content

Commit d9157fc

Browse files
author
cactopi
authored
Create bert-dep.py (#77)
Python version of notebooks/BertDependencies.ipynb
1 parent feecb88 commit d9157fc

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed

examples/bert-dep.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/usr/bin/env python3
2+
3+
import torchtext
4+
import torch
5+
6+
from torch import nn
7+
from torch_struct import DependencyCRF
8+
from torch_struct.data import SubTokenizedField, ConllXDataset, TokenBucket
9+
from torchtext.data import RawField, BucketIterator
10+
11+
from pytorch_transformers import BertModel, BertTokenizer, AdamW, WarmupLinearSchedule
12+
13+
config = {'bert': 'bert-base-cased', 'H': 768, 'dropout': 0.2}
14+
15+
# parse conll dependency data
16+
model_class, tokenizer_class, pretrained_weights = BertModel, BertTokenizer, config['bert']
17+
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
18+
19+
def batch_num(nums):
20+
lengths = torch.tensor([len(n) for n in nums]).long()
21+
n = lengths.max()
22+
out = torch.zeros(len(nums), n).long()
23+
for b, n in enumerate(nums):
24+
out[b, :len(n)] = torch.tensor(n)
25+
return out, lengths
26+
27+
HEAD = RawField(preprocessing=lambda x: [int(i) for i in x],
28+
postprocessing=batch_num)
29+
HEAD.is_target = True
30+
WORD = SubTokenizedField(tokenizer)
31+
32+
def len_filt(x): return 5 < len(x.word[0]) < 40
33+
34+
train = ConllXDataset('wsj.train.conllx', (('word', WORD), ('head', HEAD)),
35+
filter_pred=len_filt)
36+
train_iter = TokenBucket(train, 750)
37+
val = ConllXDataset('wsj.dev.conllx', (('word', WORD), ('head', HEAD)),
38+
filter_pred=len_filt)
39+
val_iter = BucketIterator(val, batch_size=20, device='cuda:0')
40+
41+
# make bert model to compute potentials
42+
H = config['H']
43+
class Model(nn.Module):
44+
def __init__(self, hidden):
45+
super().__init__()
46+
self.base_model = model_class.from_pretrained(pretrained_weights)
47+
self.linear = nn.Linear(H, H)
48+
self.bilinear = nn.Linear(H, H)
49+
self.root = nn.Parameter(torch.rand(H))
50+
self.dropout = nn.Dropout(config['dropout'])
51+
52+
def forward(self, words, mapper):
53+
out = self.dropout(self.base_model(words)[0])
54+
out = torch.matmul(mapper.float().cuda().transpose(1, 2), out)
55+
final1 = torch.matmul(out, self.linear.weight)
56+
final2 = torch.einsum('bnh,hg,bmg->bnm', out, self.bilinear.weight, final1)
57+
root_score = torch.matmul(out, self.root)
58+
final2 = final2[:, 1:-1, 1:-1]
59+
N = final2.shape[1]
60+
final2[:, torch.arange(N), torch.arange(N)] += root_score[:, 1:-1]
61+
return final2
62+
63+
model = Model(H)
64+
model.cuda()
65+
66+
# validation and train loops
67+
def validate(val_iter):
68+
incorrect_edges = 0
69+
total_edges = 0
70+
model.eval()
71+
for i, ex in enumerate(val_iter):
72+
words, mapper, _ = ex.word
73+
label, lengths = ex.head
74+
batch, _ = label.shape
75+
76+
final = model(words.cuda(), mapper)
77+
for b in range(batch):
78+
final[b, lengths[b]-1:, :] = 0
79+
final[b, :, lengths[b]-1:] = 0
80+
dist = DependencyCRF(final, lengths=lengths)
81+
gold = dist.struct.to_parts(label, lengths=lengths).type_as(dist.argmax)
82+
incorrect_edges += (dist.argmax[:, :].cpu() - gold[:, :].cpu()).abs().sum() / 2.0
83+
total_edges += gold.sum()
84+
85+
print(total_edges, incorrect_edges)
86+
model.train()
87+
88+
def train(train_iter, val_iter, model):
89+
opt = AdamW(model.parameters(), lr=1e-4, eps=1e-8)
90+
scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)
91+
model.train()
92+
losses = []
93+
for i, ex in enumerate(train_iter):
94+
opt.zero_grad()
95+
words, mapper, _ = ex.word
96+
label, lengths = ex.head
97+
batch, _ = label.shape
98+
99+
# Model
100+
final = model(words.cuda(), mapper)
101+
for b in range(batch):
102+
final[b, lengths[b]-1:, :] = 0
103+
final[b, :, lengths[b]-1:] = 0
104+
105+
if not lengths.max() <= final.shape[1] + 1:
106+
print("fail")
107+
continue
108+
dist = DependencyCRF(final, lengths=lengths)
109+
110+
labels = dist.struct.to_parts(label, lengths=lengths).type_as(final)
111+
log_prob = dist.log_prob(labels)
112+
113+
loss = log_prob.sum()
114+
(-loss).backward()
115+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
116+
117+
opt.step()
118+
scheduler.step()
119+
losses.append(loss.detach())
120+
if i % 50 == 1:
121+
print(-torch.tensor(losses).mean(), words.shape)
122+
losses = []
123+
if i % 600 == 500:
124+
validate(val_iter)
125+
126+
train(train_iter, val_iter, model)

0 commit comments

Comments
 (0)