Skip to content

Commit ab5059c

Browse files
committed
remove plotting code
1 parent 8ea4d1a commit ab5059c

File tree

2 files changed

+1
-55
lines changed

2 files changed

+1
-55
lines changed

1_pytorch-distilbert.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import time
44

55
from datasets import load_dataset
6-
import matplotlib.pyplot as plt
7-
import pandas as pd
86
import torch
97
from torch.utils.data import DataLoader
108
import torchmetrics
@@ -24,28 +22,6 @@ def tokenize_text(batch):
2422
return tokenizer(batch["text"], truncation=True, padding=True)
2523

2624

27-
def plot_logs(log_dir):
28-
metrics = pd.read_csv(op.join(log_dir, "metrics.csv"))
29-
30-
aggreg_metrics = []
31-
agg_col = "epoch"
32-
for i, dfg in metrics.groupby(agg_col):
33-
agg = dict(dfg.mean())
34-
agg[agg_col] = i
35-
aggreg_metrics.append(agg)
36-
37-
df_metrics = pd.DataFrame(aggreg_metrics)
38-
df_metrics[["train_loss", "val_loss"]].plot(
39-
grid=True, legend=True, xlabel="Epoch", ylabel="Loss"
40-
)
41-
plt.savefig(op.join(log_dir, "loss.pdf"))
42-
43-
df_metrics[["train_acc", "val_acc"]].plot(
44-
grid=True, legend=True, xlabel="Epoch", ylabel="Accuracy"
45-
)
46-
plt.savefig(op.join(log_dir, "acc.pdf"))
47-
48-
4925
def train(num_epochs, model, optimizer, train_loader, val_loader, device):
5026
for epoch in range(num_epochs):
5127
train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)

2_pytorch-with-trainer.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import lightning as L
77
from lightning.pytorch.callbacks import ModelCheckpoint
88
from lightning.pytorch.loggers import CSVLogger
9-
import matplotlib.pyplot as plt
10-
import pandas as pd
119
import torch
1210
from torch.utils.data import DataLoader
1311
import torchmetrics
@@ -27,28 +25,6 @@ def tokenize_text(batch):
2725
return tokenizer(batch["text"], truncation=True, padding=True)
2826

2927

30-
def plot_logs(log_dir):
31-
metrics = pd.read_csv(op.join(log_dir, "metrics.csv"))
32-
33-
aggreg_metrics = []
34-
agg_col = "epoch"
35-
for i, dfg in metrics.groupby(agg_col):
36-
agg = dict(dfg.mean())
37-
agg[agg_col] = i
38-
aggreg_metrics.append(agg)
39-
40-
df_metrics = pd.DataFrame(aggreg_metrics)
41-
df_metrics[["train_loss", "val_loss"]].plot(
42-
grid=True, legend=True, xlabel="Epoch", ylabel="Loss"
43-
)
44-
plt.savefig(op.join(log_dir, "loss.pdf"))
45-
46-
df_metrics[["train_acc", "val_acc"]].plot(
47-
grid=True, legend=True, xlabel="Epoch", ylabel="Accuracy"
48-
)
49-
plt.savefig(op.join(log_dir, "acc.pdf"))
50-
51-
5228
class LightningModel(L.LightningModule):
5329
def __init__(self, model, learning_rate=5e-5):
5430
super().__init__()
@@ -207,10 +183,4 @@ def configure_optimizers(self):
207183

208184
with open(op.join(trainer.logger.log_dir, "outputs.txt"), "w") as f:
209185
f.write((f"Time elapsed {elapsed/60:.2f} min\n"))
210-
f.write(f"Test acc: {test_acc}")
211-
212-
#########################################
213-
### 6 Plot logs
214-
#########################################
215-
216-
plot_logs(trainer.logger.log_dir)
186+
f.write(f"Test acc: {test_acc}")

0 commit comments

Comments
 (0)