66import lightning as L
77from lightning .pytorch .callbacks import ModelCheckpoint
88from lightning .pytorch .loggers import CSVLogger
9- import matplotlib .pyplot as plt
10- import pandas as pd
119import torch
1210from torch .utils .data import DataLoader
1311import torchmetrics
@@ -27,28 +25,6 @@ def tokenize_text(batch):
2725 return tokenizer (batch ["text" ], truncation = True , padding = True )
2826
2927
30- def plot_logs (log_dir ):
31- metrics = pd .read_csv (op .join (log_dir , "metrics.csv" ))
32-
33- aggreg_metrics = []
34- agg_col = "epoch"
35- for i , dfg in metrics .groupby (agg_col ):
36- agg = dict (dfg .mean ())
37- agg [agg_col ] = i
38- aggreg_metrics .append (agg )
39-
40- df_metrics = pd .DataFrame (aggreg_metrics )
41- df_metrics [["train_loss" , "val_loss" ]].plot (
42- grid = True , legend = True , xlabel = "Epoch" , ylabel = "Loss"
43- )
44- plt .savefig (op .join (log_dir , "loss.pdf" ))
45-
46- df_metrics [["train_acc" , "val_acc" ]].plot (
47- grid = True , legend = True , xlabel = "Epoch" , ylabel = "Accuracy"
48- )
49- plt .savefig (op .join (log_dir , "acc.pdf" ))
50-
51-
5228class LightningModel (L .LightningModule ):
5329 def __init__ (self , model , learning_rate = 5e-5 ):
5430 super ().__init__ ()
@@ -207,10 +183,4 @@ def configure_optimizers(self):
207183
208184 with open (op .join (trainer .logger .log_dir , "outputs.txt" ), "w" ) as f :
209185 f .write ((f"Time elapsed { elapsed / 60 :.2f} min\n " ))
210- f .write (f"Test acc: { test_acc } " )
211-
212- #########################################
213- ### 6 Plot logs
214- #########################################
215-
216- plot_logs (trainer .logger .log_dir )
186+ f .write (f"Test acc: { test_acc } " )
0 commit comments