Skip to content

Commit 0dca977

Browse files
committed
add load pre-trained model
1 parent c2269a3 commit 0dca977

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def get_dataset(csv_data, text_field, label_field):
154154
test_acc.append(tmp)
155155
if tmp > max_acc:
156156
max_acc = tmp
157-
#torch.save(model.state_dict(), './result/'+classifier+'.pt')
157+
torch.save(model.state_dict(), './result/'+classifier+'.pt')
158158
print('Repeat %s times: %s' % (t, test_acc))
159159
print('average test_acc: %.1f%%' % (100*sum(test_acc)/t))
160160
print('max acc: %.1f%%' % (100*max_acc))
@@ -164,4 +164,4 @@ def get_dataset(csv_data, text_field, label_field):
164164
model.load_state_dict(torch.load('./result/'+classifier+'.pt'))
165165
model.eval()
166166
print(eval(test_iter, model, device))
167-
167+

0 commit comments

Comments
 (0)