1
+ import lightning as L
2
+ import matplotlib .pyplot as plt
3
+ import pandas as pd
4
+ import torch
5
+ import torch .nn .functional as F
6
+ import torchmetrics
7
+ from torch .utils .data import DataLoader
8
+ from torch .utils .data .dataset import random_split
9
+ from torchvision import datasets , transforms
10
+
11
+
12
+ class LightningModel (L .LightningModule ):
13
+ def __init__ (self , model , learning_rate ):
14
+ super ().__init__ ()
15
+
16
+ self .learning_rate = learning_rate
17
+ self .model = model
18
+
19
+ self .save_hyperparameters (ignore = ["model" ])
20
+
21
+ self .train_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 10 )
22
+ self .val_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 10 )
23
+ self .test_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 10 )
24
+
25
+ def forward (self , x ):
26
+ return self .model (x )
27
+
28
+ def _shared_step (self , batch ):
29
+ features , true_labels = batch
30
+ logits = self (features )
31
+
32
+ loss = F .cross_entropy (logits , true_labels )
33
+ predicted_labels = torch .argmax (logits , dim = 1 )
34
+ return loss , true_labels , predicted_labels
35
+
36
+ def training_step (self , batch , batch_idx ):
37
+ loss , true_labels , predicted_labels = self ._shared_step (batch )
38
+
39
+ self .log ("train_loss" , loss )
40
+ self .train_acc (predicted_labels , true_labels )
41
+ self .log (
42
+ "train_acc" , self .train_acc , prog_bar = True , on_epoch = True , on_step = False
43
+ )
44
+ return loss
45
+
46
+ def validation_step (self , batch , batch_idx ):
47
+ loss , true_labels , predicted_labels = self ._shared_step (batch )
48
+
49
+ self .log ("val_loss" , loss , prog_bar = True )
50
+ self .val_acc (predicted_labels , true_labels )
51
+ self .log ("val_acc" , self .val_acc , prog_bar = True )
52
+
53
+ def test_step (self , batch , batch_idx ):
54
+ loss , true_labels , predicted_labels = self ._shared_step (batch )
55
+ self .test_acc (predicted_labels , true_labels )
56
+ self .log ("test_acc" , self .test_acc )
57
+
58
+ def configure_optimizers (self ):
59
+ optimizer = torch .optim .SGD (self .parameters (), lr = self .learning_rate )
60
+ return optimizer
61
+
62
+
63
+ class Cifar10DataModule (L .LightningDataModule ):
64
+ def __init__ (
65
+ self , data_path = "./" , batch_size = 64 , num_workers = 0 , height_width = (32 , 32 ),
66
+ train_transform = None , test_transform = None
67
+ ):
68
+ super ().__init__ ()
69
+ self .batch_size = batch_size
70
+ self .data_path = data_path
71
+ self .num_workers = num_workers
72
+ self .height_width = height_width
73
+ self .train_transform = train_transform
74
+ self .test_transform = test_transform
75
+
76
+ def prepare_data (self ):
77
+ datasets .CIFAR10 (root = self .data_path , download = True )
78
+
79
+ if self .train_transform is None :
80
+ self .train_transform = transforms .Compose (
81
+ [
82
+ transforms .Resize (self .height_width ),
83
+ transforms .ToTensor (),
84
+ ]
85
+ )
86
+
87
+ if self .test_transform is None :
88
+ self .test_transform = transforms .Compose (
89
+ [
90
+ transforms .Resize (self .height_width ),
91
+ transforms .ToTensor (),
92
+ ]
93
+ )
94
+ return
95
+
96
+ def setup (self , stage = None ):
97
+ train = datasets .CIFAR10 (
98
+ root = self .data_path ,
99
+ train = True ,
100
+ transform = self .train_transform ,
101
+ download = False ,
102
+ )
103
+
104
+ self .test = datasets .CIFAR10 (
105
+ root = self .data_path ,
106
+ train = False ,
107
+ transform = self .test_transform ,
108
+ download = False ,
109
+ )
110
+
111
+ self .train , self .valid = random_split (train , lengths = [45000 , 5000 ])
112
+
113
+ def train_dataloader (self ):
114
+ train_loader = DataLoader (
115
+ dataset = self .train ,
116
+ batch_size = self .batch_size ,
117
+ drop_last = True ,
118
+ shuffle = True ,
119
+ num_workers = self .num_workers ,
120
+ )
121
+ return train_loader
122
+
123
+ def val_dataloader (self ):
124
+ valid_loader = DataLoader (
125
+ dataset = self .valid ,
126
+ batch_size = self .batch_size ,
127
+ drop_last = False ,
128
+ shuffle = False ,
129
+ num_workers = self .num_workers ,
130
+ )
131
+ return valid_loader
132
+
133
+ def test_dataloader (self ):
134
+ test_loader = DataLoader (
135
+ dataset = self .test ,
136
+ batch_size = self .batch_size ,
137
+ drop_last = False ,
138
+ shuffle = False ,
139
+ num_workers = self .num_workers ,
140
+ )
141
+ return test_loader
142
+
143
+
144
+ def plot_val_acc (
145
+ log_dir , acc_ylim = (0.5 , 1.0 ), save_loss = None , save_acc = None ):
146
+
147
+ metrics = pd .read_csv (f"{ log_dir } /metrics.csv" )
148
+
149
+ aggreg_metrics = []
150
+ agg_col = "epoch"
151
+
152
+ for i , dfg in metrics .groupby (agg_col ):
153
+ agg = dict (dfg .mean ())
154
+ agg [agg_col ] = i
155
+ aggreg_metrics .append (agg )
156
+
157
+ df_metrics = pd .DataFrame (aggreg_metrics )
158
+ df_metrics [["val_acc" ]].plot (
159
+ grid = True , legend = True , xlabel = "Epoch" , ylabel = "ACC"
160
+ )
161
+
162
+ plt .ylim (acc_ylim )
163
+ if save_acc is not None :
164
+ plt .savefig (save_acc )
165
+
166
+
167
+ def plot_loss_and_acc (
168
+ log_dir , loss_ylim = (0.0 , 0.9 ), acc_ylim = (0.3 , 1.0 ), save_loss = None , save_acc = None
169
+ ):
170
+
171
+ metrics = pd .read_csv (f"{ log_dir } /metrics.csv" )
172
+
173
+ aggreg_metrics = []
174
+ agg_col = "epoch"
175
+ for i , dfg in metrics .groupby (agg_col ):
176
+ agg = dict (dfg .mean ())
177
+ agg [agg_col ] = i
178
+ aggreg_metrics .append (agg )
179
+
180
+ df_metrics = pd .DataFrame (aggreg_metrics )
181
+ df_metrics [["train_loss" ]].plot (
182
+ grid = True , legend = True , xlabel = "Epoch" , ylabel = "Loss"
183
+ )
184
+
185
+ plt .ylim (loss_ylim )
186
+ if save_loss is not None :
187
+ plt .savefig (save_loss )
188
+
189
+ df_metrics [["train_acc" , "val_acc" ]].plot (
190
+ grid = True , legend = True , xlabel = "Epoch" , ylabel = "ACC"
191
+ )
192
+
193
+ plt .ylim (acc_ylim )
194
+ if save_acc is not None :
195
+ plt .savefig (save_acc )
0 commit comments