Skip to content

Commit

Permalink
Add loss visualization and update log parser
Browse files Browse the repository at this point in the history
  • Loading branch information
ChesterHuynh committed May 6, 2021
1 parent 726fca4 commit 31bca34
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 38 deletions.
76 changes: 38 additions & 38 deletions src/logparser.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,13 @@
import re
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
from matplotlib import cycler

plt.rcParams.update({
"axes.spines.right" : False,
"axes.spines.top" : False,
"axes.labelsize" : "medium",
"axes.titlesize" : "x-large",
"font.size" : 10,
"axes.prop_cycle": cycler(color=[
"#348ABD",
"#A60628",
"#7A68A6",
"#467821",
"#CF4457",
"#188487",
"#E24A33"
])
})


def parse_epoch_loss(fpath):
fpath = Path("/Users/ChesterHuynh/OneDrive - Johns Hopkins/classes/dl482/Wavenet-CPC-Music-Translation/checkpoints/umtcpc-pretrained/main_0.log")
train_losses = []
test_losses = []
with open(fpath,'r') as f:
for line in f:
if not (line.startswith("INFO") and "Epoch" in line and "loss" in line):
continue
# _, lead_end = re.match("^[^Epoch]*", line).span()
# line_ = line[lead_end:]
# _, train_loss, test_loss = re.split(' loss', line_)

s = line.strip()

Expand All @@ -49,20 +24,45 @@ def parse_epoch_loss(fpath):
return train_losses, test_losses


# TODO: Batch losses are not present in current data log
def parse_batch_loss(fpath):
return
def parse_batch_loss(fpath, train_epoch_len=1000, test_epoch_len=100):
if not (train_epoch_len > 0 and test_epoch_len > 0):
raise ValueError('train_epoch_len and test_epoch_len must be positive')
if not (isinstance(train_epoch_len, int) and isinstance(test_epoch_len, int)):
raise TypeError('train_epoch_len and test_epoch_len must be int')
train_losses = []
test_losses = []
with open(fpath, 'r') as f:
cur_epoch = 0
batch_train_losses = []
batch_test_losses = []
for line in f:
if not(line.startswith("INFO") and "epoch" in line and "(loss: " in line):
continue

s = line.strip()
epoch = int(s[s.find("epoch ") + len("epoch ") : ])
if cur_epoch == 0:
cur_epoch = epoch

elif cur_epoch > 0 and epoch != cur_epoch:
assert len(batch_train_losses) >= train_epoch_len, len(batch_train_losses)
assert len(batch_test_losses) >= test_epoch_len, len(batch_test_losses)

if __name__ == "__main__":
repo_path = Path(__file__).parents[1]
fpath = repo_path / "checkpoints/umtcpc-pretrained/main_0.log"
cur_epoch = epoch
train_losses.append(batch_train_losses[-train_epoch_len:])
test_losses.append(batch_test_losses[-test_epoch_len:])
batch_train_losses.clear()
batch_test_losses.clear()

train_losses, test_losses = parse_epoch_loss(fpath)
if "Train" in line:
train_loss = float(s[s.find("Train (loss: ") + len("Train (loss: ") : s.find(")")])
batch_train_losses.append(train_loss)
elif "Test" in line:
test_loss = float(s[s.find("Test (loss: ") + len("Test (loss: ") : s.find(")")])
batch_test_losses.append(test_loss)
else:
raise NotImplementedError('Reached an unexpected line')

fig, ax = plt.subplots(dpi=100)
ax.plot(train_losses.sum(axis=1), label = "Train Loss")
ax.plot(test_losses.sum(axis=1), label="Test Loss")
ax.set(xlabel="Epoch", ylabel="Loss")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
fig.tight_layout()
train_losses = np.array(train_losses)
test_losses = np.array(test_losses)
return train_losses, test_losses
89 changes: 89 additions & 0 deletions src/visualization/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import argparse
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns
from pathlib import Path
from matplotlib import cycler

from src.logparser import parse_epoch_loss, parse_batch_loss

plt.rcParams.update({
"axes.spines.right" : False,
"axes.spines.top" : False,
"axes.labelsize" : "medium",
"axes.titlesize" : "x-large",
"font.size" : 10,
"axes.prop_cycle": cycler(color=[
"#348ABD",
"#A60628",
"#7A68A6",
"#467821",
"#CF4457",
"#188487",
"#E24A33"
])
})

def plot_batch_loss(fpath: Path, dst: Path):
train_losses, test_losses = parse_batch_loss(fpath)
n_epochs = train_losses.shape[0]
train_losses = pd.DataFrame(train_losses.T)
test_losses = pd.DataFrame(test_losses.T)

# Convert DataFrames from wide to long form
train_losses = train_losses.reset_index()
train_losses_long = pd.melt(train_losses, id_vars='index',
value_vars=list(np.arange(n_epochs)),
var_name='Epoch', value_name='Loss')
test_losses = test_losses.reset_index()
test_losses_long = pd.melt(test_losses, id_vars='index',
value_vars=list(np.arange(n_epochs)),
var_name='Epoch', value_name='Loss')
fig, ax = plt.subplots(dpi=100)
sns.lineplot(data=train_losses_long, x='Epoch', y='Loss', label='Train loss', ax=ax)
sns.lineplot(data=test_losses_long, x='Epoch', y='Loss', label='Test loss', ax=ax)
# ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False)
ax.legend(frameon=False)
fig.tight_layout()

if not os.path.exists(dst):
os.makedirs(dst)
plt.savefig(dst / "batch_loss.png")


def plot_epoch_loss(fpath: Path, dst: Path):
train_losses, test_losses = parse_epoch_loss(fpath)

fig, ax = plt.subplots(dpi=100)
ax.plot(train_losses.sum(axis=1), label = "Train Loss")
ax.plot(test_losses.sum(axis=1), label="Test Loss")
ax.set(xlabel="Epoch", ylabel="Loss")
# ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False)
ax.legend(frameon=False)
fig.tight_layout()

if not os.path.exists(dst):
os.makedirs(dst)
plt.savefig(dst / "epoch_loss.png")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Visualize train and test loss from log file')
parser.add_argument('--file', type=Path, required=True,
help="Path to log file")
parser.add_argument('--dst', type=Path, required=True,
help="Path to save plots")
parser.add_argument('--how', type=str, choices=["batch", "epoch"], default="batch",
help="Loss values to use")

args = parser.parse_args()
fpath = args.file
dst = args.dst
how = args.how

if how == "batch":
plot_batch_loss(fpath, dst)
elif how == "epoch":
plot_epoch_loss(fpath, dst)

0 comments on commit 31bca34

Please sign in to comment.