Skip to content

Commit b2af102

Browse files
committed
train
1 parent 52bf26d commit b2af102

File tree

1 file changed

+35
-21
lines changed

1 file changed

+35
-21
lines changed

train.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,57 +9,71 @@
99
from utils import batch_variable
1010
from seqeval.metrics import accuracy_score, classification_report, f1_score
1111

12+
1213
def train(model, train_loader, dev_loader, config, vocab):
1314

14-
loss_all = np.array([], dtype=float)
15-
label_all = np.array([], dtype=float)
16-
predict_all = np.array([], dtype=float)
1715
dev_best_f1 = float('-inf')
18-
16+
avg_loss = []
1917
optimizer = optim.AdamW(params=model.parameters(), lr=config.lr)
2018
for epoch in range(0, config.epochs):
19+
train_right, train_total = 0, 0
2120
for batch_idx, batch_data in enumerate(train_loader):
2221
model.train() #训练模型
2322
word_ids, label_ids, label_mask = batch_variable(batch_data, vocab, config)
24-
loss, label_predict = model(word_ids, label_ids, label_mask)
23+
loss, predicts = model(word_ids, label_ids, label_mask)
2524

26-
loss_all = np.append(loss_all, loss.data.item())
27-
label_all = np.append(label_all, label_ids.data.cpu().numpy())
28-
predict_all = np.append(predict_all, label_predict.data.cpu().numpy())
29-
acc = accuracy_score(predict_all, label_all)
25+
avg_loss.append(loss.data.item())
26+
27+
batch_right = ((predicts == label_ids) * label_mask).sum().item()
28+
batch_total = label_mask.sum().item()
29+
train_right += batch_right
30+
train_total += batch_total
3031

3132
optimizer.zero_grad()
3233
loss.backward()
3334
optimizer.step()
3435

36+
3537
if batch_idx % 10 == 0:
36-
print("Epoch:{}--------Iter:{}--------train_loss:{:.3f}--------train_acc:{:.3f}".format(epoch+1, batch_idx+1, loss_all.mean(), acc))
38+
print("Epoch:{}--------Iter:{}--------train_loss:{:.3f}--------train_acc:{:.3f}".format(epoch+1, batch_idx+1, np.array(avg_loss).mean(), train_right/train_total))
3739
dev_loss, dev_acc, dev_f1, dev_report = evaluate(model, dev_loader, config, vocab)
38-
msg = "Dev Loss:{}--------Dev Acc:{}--------Dev F1:{}"
40+
msg = "Dev Loss:{:.3f}--------Dev Acc:{:.3f}--------Dev F1:{:.3f}"
3941
print(msg.format(dev_loss, dev_acc, dev_f1))
40-
print("Dev Report")
4142
print(dev_report)
4243

4344
if dev_best_f1 < dev_f1:
4445
dev_best_f1 = dev_f1
4546
torch.save(model.state_dict(), config.save_path)
4647
print("***************************** Save Model *****************************")
4748

48-
def evaluate(config, model, dev_loader, vocab, output_dict=False):
49+
def evaluate(model, one_loader, config, vocab, output_dict=False):
4950
model.eval() #评价模式
50-
loss_all = np.array([], dtype=float)
51+
loss_total = 0
5152
predict_all = []
5253
label_all = []
5354
with torch.no_grad():
54-
for batch_idx, batch_data in enumerate(dev_loader):
55+
for batch_idx, batch_data in enumerate(one_loader):
5556
word_ids, label_ids, label_mask = batch_variable(batch_data, vocab, config)
56-
loss, label_predict = model(word_ids, label_ids, label_mask)
57+
loss, predicts = model(word_ids, label_ids, label_mask)
58+
59+
loss_total = loss_total + loss
60+
61+
for i, sen_mask in enumerate(label_mask):
62+
for j, word_mask in enumerate(sen_mask):
63+
if word_mask.item() == False:
64+
predicts[i][j] = 0
65+
labels_list = []
66+
for index_i, ids in enumerate(label_ids):
67+
labels_list.append([config.id2label[id.cpu().item()] for index_j, id in enumerate(ids)])
68+
predicts_list = []
69+
for index_i, pres in enumerate(predicts):
70+
predicts_list.append([config.id2label[pre.cpu().item()] for index_j, pre in enumerate(pres)])
71+
72+
label_all += labels_list
73+
predict_all += predicts_list
5774

58-
loss_all = np.append(loss_all, loss.data.item())
59-
predict_all.append(label_predict.data)
60-
label_all.append(label_ids.data)
6175
acc = accuracy_score(label_all, predict_all)
62-
f1 = f1_score(label_all, predict_all, average='macro')
76+
f1 = f1_score(label_all, predict_all, average='micro')
6377
report = classification_report(label_all, predict_all, digits=3, output_dict=output_dict)
6478

65-
return loss.mean(), acc, f1, report
79+
return loss_total/len(one_loader), acc, f1, report

0 commit comments

Comments
 (0)