Skip to content

Commit

Permalink
Last update
Browse files Browse the repository at this point in the history
  • Loading branch information
PKULiuHui committed Dec 15, 2018
1 parent 6c319f1 commit b6a2252
Show file tree
Hide file tree
Showing 17 changed files with 1,471 additions and 217 deletions.
2 changes: 1 addition & 1 deletion baselines/baseline_gcn/RNN_GCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, args, embed=None):
batch_first=True,
)

# 每一层都会有不同的权重矩阵,所以组织成列表
# 每一层都会有不同的权重矩阵
self.graph_w_0 = nn.Parameter(torch.FloatTensor(self.H, self.H).uniform_(-0.1, 0.1))
self.graph_w_1 = nn.Parameter(torch.FloatTensor(self.H, self.H).uniform_(-0.1, 0.1))
self.graph_w_2 = nn.Parameter(torch.FloatTensor(self.H, self.H).uniform_(-0.1, 0.1))
Expand Down
59 changes: 38 additions & 21 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# coding:utf-8
import torch
from torch.autograd import Variable
from torch.nn.utils import clip_grad_norm_
from myrouge.rouge import get_rouge_score
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import math
Expand All @@ -13,7 +13,7 @@

parser = argparse.ArgumentParser(description='LiveBlogSum')
# model paras
parser.add_argument('-model', type=str, default='Model3')
parser.add_argument('-model', type=str, default='Model5')
parser.add_argument('-embed_frozen', type=bool, default=False)
parser.add_argument('-embed_dim', type=int, default=100)
parser.add_argument('-embed_num', type=int, default=100)
Expand All @@ -26,8 +26,11 @@
# train paras
parser.add_argument('-save_dir', type=str, default='checkpoints2/')
parser.add_argument('-lr', type=float, default=1e-3)
parser.add_argument('-lr_decay', type=float, default=0.5)
parser.add_argument('-max_norm', type=float, default=5.0)
parser.add_argument('-epochs', type=int, default=6)
parser.add_argument('-srl_ratio', type=float, default=0.15)
parser.add_argument('-teacher_forcing', type=float, default=0.0)
parser.add_argument('-epochs', type=int, default=8)
parser.add_argument('-seed', type=int, default=1)
parser.add_argument('-sent_trunc', type=int, default=25)
parser.add_argument('-valid_every', type=int, default=500)
Expand Down Expand Up @@ -102,26 +105,39 @@ def mmr(sents, scores, ref_len):
return summary.strip()


def adjust_learning_rate(optimizer, epoch):
lr = args.lr * (args.lr_decay ** epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr


# 在验证集或测试集上测loss, rouge值
def evaluate(net, my_loss, vocab, data_iter, train_next): # train_next指明接下来是否要继续训练
net.eval()
my_loss.eval()
loss, r1, r2, rl, rsu = .0, .0, .0, .0, .0
blog_num = float(len(data_iter))
f = open('tmp.txt', 'w')
for i, blog in enumerate(tqdm(data_iter)):
sents, sent_targets, doc_lens, doc_targets, events, event_targets, event_tfs, event_lens, event_sent_lens, sents_content, summary = vocab.make_tensors(blog, args)
# sents, sent_targets, doc_targets, events, event_targets, event_tfs = Variable(sents), Variable(sent_targets.float()), Variable(doc_targets.float()), Variable(events), Variable(event_targets.float()), Variable(event_tfs.float())
sents, sent_targets, doc_lens, doc_targets, events, event_targets, event_tfs, event_prs, event_lens, event_sent_lens, sents_content, summary = vocab.make_tensors(
blog, args)
if use_cuda:
sents = sents.cuda()
sent_targets = sent_targets.cuda()
doc_targets = doc_targets.cuda()
events = events.cuda()
event_targets = event_targets.cuda()
event_tfs = event_tfs.cuda()
# sent_probs, doc_probs = net(sents, doc_lens)
sent_probs, doc_probs, event_probs = net(sents, doc_lens, events, event_lens, event_sent_lens, event_tfs)
# loss += my_loss(sent_probs, doc_probs, sent_targets, doc_targets).data.item()
loss += my_loss(sent_probs, doc_probs, event_probs, sent_targets, doc_targets, event_targets).data.item()
# sent_probs = net(sents, doc_lens)
sent_probs, event_probs = net(sents, doc_lens, events, event_lens, event_sent_lens, event_tfs, event_targets, sent_targets, False)

loss += F.mse_loss(event_probs, event_targets).data.item()
for a, b in zip(event_probs, event_targets):
f.write(str(a.data.item()) + '\t' + str(b.data.item()) + '\n')
f.write('\n')

# loss += my_loss(sent_probs, sent_targets).data.item()
# loss += my_loss(sent_probs, event_probs, sent_targets, event_targets).data.item()
# loss += my_loss(sent_probs, doc_probs, event_probs, sent_targets, doc_targets, event_targets).data.item()
probs = sent_probs.tolist()
ref = summary.strip()
ref_len = len(ref.split())
Expand All @@ -131,7 +147,7 @@ def evaluate(net, my_loss, vocab, data_iter, train_next): # train_next指明接
r2 += score['ROUGE-2']['r']
rl += score['ROUGE-L']['r']
rsu += score['ROUGE-SU4']['r']

f.close()
loss = loss / blog_num
r1 = r1 / blog_num
r2 = r2 / blog_num
Expand All @@ -155,15 +171,15 @@ def train():
train_data = []
fns = os.listdir(args.train_dir)
fns.sort()
for fn in fns:
for fn in tqdm(fns):
f = open(args.train_dir + fn, 'r')
train_data.append(json.load(f))
f.close()

val_data = []
fns = os.listdir(args.valid_dir)
fns.sort()
for fn in fns:
for fn in tqdm(fns):
f = open(args.valid_dir + fn, 'r')
val_data.append(json.load(f))
f.close()
Expand All @@ -178,34 +194,35 @@ def train():

for epoch in range(1, args.epochs + 1):
for i, blog in enumerate(train_data):
sents, sent_targets, doc_lens, doc_targets, events, event_targets, event_tfs, event_lens, event_sent_lens, _1, _2, = vocab.make_tensors(blog, args)
# sents, sent_targets, doc_targets, events, event_targets, event_tfs = Variable(sents), Variable(sent_targets.float()), Variable(doc_targets.float()), Variable(events), Variable(event_targets.float()), Variable(event_tfs.float())
sents, sent_targets, doc_lens, doc_targets, events, event_targets, event_tfs, event_prs, event_lens, event_sent_lens, _1, _2, = vocab.make_tensors(
blog, args)
if use_cuda:
sents = sents.cuda()
sent_targets = sent_targets.cuda()
doc_targets = doc_targets.cuda()
events = events.cuda()
event_targets = event_targets.cuda()
event_tfs = event_tfs.cuda()
# sent_probs, doc_probs = net(sents, doc_lens)
sent_probs, doc_probs, event_probs = net(sents, doc_lens, events, event_lens, event_sent_lens, event_tfs)
# loss = my_loss(sent_probs, doc_probs, sent_targets, doc_targets)
loss = my_loss(sent_probs, doc_probs, event_probs, sent_targets, doc_targets, event_targets)
# sent_probs = net(sents, doc_lens)
sent_probs, event_probs = net(sents, doc_lens, events, event_lens, event_sent_lens, event_tfs, event_targets, sent_targets, True)
# loss = my_loss(sent_probs, sent_targets)
loss = my_loss(sent_probs, event_probs, sent_targets, event_targets)
# loss = my_loss(sent_probs, doc_probs, event_probs, sent_targets, doc_targets, event_targets)
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(net.parameters(), args.max_norm)
optimizer.step()
print('EPOCH [%d/%d]: BATCH_ID=[%d/%d] loss=%f' % (epoch, args.epochs, i, len(train_data), loss))

cnt = (epoch - 1) * len(train_data) + i
if cnt % args.valid_every == 0:
if cnt % args.valid_every == 0 and cnt / args.valid_every > 0:
print('Begin valid... Epoch %d, Batch %d' % (epoch, i))
cur_loss, r1, r2, rl, rsu = evaluate(net, my_loss, vocab, val_data, True)
save_path = args.save_dir + args.model + '_%d_%.4f_%.4f_%.4f_%.4f_%.4f' % (
cnt / args.valid_every, cur_loss, r1, r2, rl, rsu)
net.save(save_path)
print('Epoch: %2d Cur_Val_Loss: %f Rouge-1: %f Rouge-2: %f Rouge-l: %f Rouge-SU4: %f' %
(epoch, cur_loss, r1, r2, rl, rsu))
adjust_learning_rate(optimizer, epoch)


def test():
Expand Down
Loading

0 comments on commit b6a2252

Please sign in to comment.