Skip to content

Commit

Permalink
Set figsize and increase dpi for loss plots
Browse files Browse the repository at this point in the history
  • Loading branch information
ChesterHuynh committed May 7, 2021
1 parent 65e062a commit 95936e4
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/visualization/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def plot_batch_loss(fpath: Path, dst: Path):
test_losses_long = pd.melt(test_losses, id_vars='index',
value_vars=list(np.arange(n_epochs)),
var_name='Epoch', value_name='Loss')
fig, ax = plt.subplots(dpi=100)

fig, ax = plt.subplots(dpi=200, figsize=(6.4, 4.8))
sns.lineplot(data=train_losses_long, x='Epoch', y='Loss', label='Train loss', ax=ax)
sns.lineplot(data=test_losses_long, x='Epoch', y='Loss', label='Test loss', ax=ax)
ax.legend(frameon=False)
Expand All @@ -55,7 +56,7 @@ def plot_batch_loss(fpath: Path, dst: Path):
def plot_epoch_loss(fpath: Path, dst: Path):
train_losses, test_losses = parse_epoch_loss(fpath)

fig, ax = plt.subplots(dpi=100)
fig, ax = plt.subplots(dpi=200, figsize=(6.4, 4.8))
ax.plot(train_losses.sum(axis=1), label = "Train Loss")
ax.plot(test_losses.sum(axis=1), label="Test Loss")
ax.set(xlabel="Epoch", ylabel="Loss")
Expand Down

0 comments on commit 95936e4

Please sign in to comment.