-
Notifications
You must be signed in to change notification settings - Fork 155
/
module.py
78 lines (67 loc) · 2.49 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import pytorch_lightning as pl
import torch
from pytorch_lightning.metrics import Accuracy
from cifar10_models.densenet import densenet121, densenet161, densenet169
from cifar10_models.googlenet import googlenet
from cifar10_models.inception import inception_v3
from cifar10_models.mobilenetv2 import mobilenet_v2
from cifar10_models.resnet import resnet18, resnet34, resnet50
from cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from schduler import WarmupCosineLR
all_classifiers = {
"vgg11_bn": vgg11_bn(),
"vgg13_bn": vgg13_bn(),
"vgg16_bn": vgg16_bn(),
"vgg19_bn": vgg19_bn(),
"resnet18": resnet18(),
"resnet34": resnet34(),
"resnet50": resnet50(),
"densenet121": densenet121(),
"densenet161": densenet161(),
"densenet169": densenet169(),
"mobilenet_v2": mobilenet_v2(),
"googlenet": googlenet(),
"inception_v3": inception_v3(),
}
class CIFAR10Module(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.criterion = torch.nn.CrossEntropyLoss()
self.accuracy = Accuracy()
self.model = all_classifiers[self.hparams.classifier]
def forward(self, batch):
images, labels = batch
predictions = self.model(images)
loss = self.criterion(predictions, labels)
accuracy = self.accuracy(predictions, labels)
return loss, accuracy * 100
def training_step(self, batch, batch_nb):
loss, accuracy = self.forward(batch)
self.log("loss/train", loss)
self.log("acc/train", accuracy)
return loss
def validation_step(self, batch, batch_nb):
loss, accuracy = self.forward(batch)
self.log("loss/val", loss)
self.log("acc/val", accuracy)
def test_step(self, batch, batch_nb):
loss, accuracy = self.forward(batch)
self.log("acc/test", accuracy)
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.model.parameters(),
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
momentum=0.9,
nesterov=True,
)
total_steps = self.hparams.max_epochs * len(self.train_dataloader())
scheduler = {
"scheduler": WarmupCosineLR(
optimizer, warmup_epochs=total_steps * 0.3, max_epochs=total_steps
),
"interval": "step",
"name": "learning_rate",
}
return [optimizer], [scheduler]