-
Notifications
You must be signed in to change notification settings - Fork 1
/
draw.py
38 lines (33 loc) · 1.48 KB
/
draw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""
Plot on Windows using Matplotlib
"""
import re
import os
import numpy as np
import matplotlib.pyplot as plt
for tag in ['HGG_WT', 'LGG_WT']:
output_path = os.path.join('results', f'{tag}.out')
assert os.path.exists(output_path), f'Error: file {output_path} not found.'
with open(output_path, 'r') as file:
HGG_WT_output = file.read()
lines = HGG_WT_output.split('\n')
train_losses = []
valid_losses = []
for line in lines:
train_loss = re.search(r'Train Loss=([\d.]+)', line)
valid_loss = re.search(r'Valid Loss=([\d.]+)', line)
if train_loss and valid_loss:
train_losses.append(float(train_loss.group(1)[:-1]))
valid_losses.append(float(valid_loss.group(1)[:-1]))
assert len(train_losses) == len(valid_losses), 'Train and valid loss should have the same length.'
x = np.arange(len(train_losses))
train_y = np.array(train_losses, dtype=np.float32)
valid_y = np.array(valid_losses, dtype=np.float32)
plt.title('Train and Valid Losses', fontdict={'family':'Times New Roman','size':20})
plt.xlabel('Epoch', fontdict={'family':'Times New Roman','size':16})
plt.ylabel('Loss', fontdict={'family':'Times New Roman','size':16})
plt.plot(x, train_y, label='Train Loss')
plt.plot(x, valid_y, label='Valid Loss')
plt.legend(loc='upper right', prop={'family':'Times New Roman','size':16})
plt.savefig(os.path.join('results', f'{tag}_train_valid_loss.svg'), bbox_inches='tight')
plt.close()