Skip to content

Commit 3b5b6e1

Browse files
committed
autoaugment experiment
1 parent 6f3a3d2 commit 3b5b6e1

File tree

4 files changed

+8398
-0
lines changed

4 files changed

+8398
-0
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,13 @@ Please note that the following notebooks below provide reference implementations
288288

289289

290290

291+
## Data Augmentation
292+
293+
| Title | Dataset | Description | Notebooks |
294+
| -------------------------- | ------- | ----------- | ------------------------------------------------------------ |
295+
| AutoAugment for Image Data | CIFAR-10 | Trains a ResNet-18 using AutoAugment | [![PyTorch Lightning](https://img.shields.io/badge/PyTorch-Lightning-blueviolet)](pytorch-lightning_ipynb/data-augumentation/autoaugment/) |
296+
297+
291298

292299

293300
## Tips and Tricks
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import lightning as L
2+
import matplotlib.pyplot as plt
3+
import pandas as pd
4+
import torch
5+
import torch.nn.functional as F
6+
import torchmetrics
7+
from torch.utils.data import DataLoader
8+
from torch.utils.data.dataset import random_split
9+
from torchvision import datasets, transforms
10+
11+
12+
class LightningModel(L.LightningModule):
13+
def __init__(self, model, learning_rate):
14+
super().__init__()
15+
16+
self.learning_rate = learning_rate
17+
self.model = model
18+
19+
self.save_hyperparameters(ignore=["model"])
20+
21+
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
22+
self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
23+
self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
24+
25+
def forward(self, x):
26+
return self.model(x)
27+
28+
def _shared_step(self, batch):
29+
features, true_labels = batch
30+
logits = self(features)
31+
32+
loss = F.cross_entropy(logits, true_labels)
33+
predicted_labels = torch.argmax(logits, dim=1)
34+
return loss, true_labels, predicted_labels
35+
36+
def training_step(self, batch, batch_idx):
37+
loss, true_labels, predicted_labels = self._shared_step(batch)
38+
39+
self.log("train_loss", loss)
40+
self.train_acc(predicted_labels, true_labels)
41+
self.log(
42+
"train_acc", self.train_acc, prog_bar=True, on_epoch=True, on_step=False
43+
)
44+
return loss
45+
46+
def validation_step(self, batch, batch_idx):
47+
loss, true_labels, predicted_labels = self._shared_step(batch)
48+
49+
self.log("val_loss", loss, prog_bar=True)
50+
self.val_acc(predicted_labels, true_labels)
51+
self.log("val_acc", self.val_acc, prog_bar=True)
52+
53+
def test_step(self, batch, batch_idx):
54+
loss, true_labels, predicted_labels = self._shared_step(batch)
55+
self.test_acc(predicted_labels, true_labels)
56+
self.log("test_acc", self.test_acc)
57+
58+
def configure_optimizers(self):
59+
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
60+
return optimizer
61+
62+
63+
class Cifar10DataModule(L.LightningDataModule):
64+
def __init__(
65+
self, data_path="./", batch_size=64, num_workers=0, height_width=(32, 32),
66+
train_transform=None, test_transform=None
67+
):
68+
super().__init__()
69+
self.batch_size = batch_size
70+
self.data_path = data_path
71+
self.num_workers = num_workers
72+
self.height_width = height_width
73+
self.train_transform = train_transform
74+
self.test_transform = test_transform
75+
76+
def prepare_data(self):
77+
datasets.CIFAR10(root=self.data_path, download=True)
78+
79+
if self.train_transform is None:
80+
self.train_transform = transforms.Compose(
81+
[
82+
transforms.Resize(self.height_width),
83+
transforms.ToTensor(),
84+
]
85+
)
86+
87+
if self.test_transform is None:
88+
self.test_transform = transforms.Compose(
89+
[
90+
transforms.Resize(self.height_width),
91+
transforms.ToTensor(),
92+
]
93+
)
94+
return
95+
96+
def setup(self, stage=None):
97+
train = datasets.CIFAR10(
98+
root=self.data_path,
99+
train=True,
100+
transform=self.train_transform,
101+
download=False,
102+
)
103+
104+
self.test = datasets.CIFAR10(
105+
root=self.data_path,
106+
train=False,
107+
transform=self.test_transform,
108+
download=False,
109+
)
110+
111+
self.train, self.valid = random_split(train, lengths=[45000, 5000])
112+
113+
def train_dataloader(self):
114+
train_loader = DataLoader(
115+
dataset=self.train,
116+
batch_size=self.batch_size,
117+
drop_last=True,
118+
shuffle=True,
119+
num_workers=self.num_workers,
120+
)
121+
return train_loader
122+
123+
def val_dataloader(self):
124+
valid_loader = DataLoader(
125+
dataset=self.valid,
126+
batch_size=self.batch_size,
127+
drop_last=False,
128+
shuffle=False,
129+
num_workers=self.num_workers,
130+
)
131+
return valid_loader
132+
133+
def test_dataloader(self):
134+
test_loader = DataLoader(
135+
dataset=self.test,
136+
batch_size=self.batch_size,
137+
drop_last=False,
138+
shuffle=False,
139+
num_workers=self.num_workers,
140+
)
141+
return test_loader
142+
143+
144+
def plot_val_acc(
145+
log_dir, acc_ylim=(0.5, 1.0), save_loss=None, save_acc=None):
146+
147+
metrics = pd.read_csv(f"{log_dir}/metrics.csv")
148+
149+
aggreg_metrics = []
150+
agg_col = "epoch"
151+
152+
for i, dfg in metrics.groupby(agg_col):
153+
agg = dict(dfg.mean())
154+
agg[agg_col] = i
155+
aggreg_metrics.append(agg)
156+
157+
df_metrics = pd.DataFrame(aggreg_metrics)
158+
df_metrics[["val_acc"]].plot(
159+
grid=True, legend=True, xlabel="Epoch", ylabel="ACC"
160+
)
161+
162+
plt.ylim(acc_ylim)
163+
if save_acc is not None:
164+
plt.savefig(save_acc)
165+
166+
167+
def plot_loss_and_acc(
168+
log_dir, loss_ylim=(0.0, 0.9), acc_ylim=(0.3, 1.0), save_loss=None, save_acc=None
169+
):
170+
171+
metrics = pd.read_csv(f"{log_dir}/metrics.csv")
172+
173+
aggreg_metrics = []
174+
agg_col = "epoch"
175+
for i, dfg in metrics.groupby(agg_col):
176+
agg = dict(dfg.mean())
177+
agg[agg_col] = i
178+
aggreg_metrics.append(agg)
179+
180+
df_metrics = pd.DataFrame(aggreg_metrics)
181+
df_metrics[["train_loss"]].plot(
182+
grid=True, legend=True, xlabel="Epoch", ylabel="Loss"
183+
)
184+
185+
plt.ylim(loss_ylim)
186+
if save_loss is not None:
187+
plt.savefig(save_loss)
188+
189+
df_metrics[["train_acc", "val_acc"]].plot(
190+
grid=True, legend=True, xlabel="Epoch", ylabel="ACC"
191+
)
192+
193+
plt.ylim(acc_ylim)
194+
if save_acc is not None:
195+
plt.savefig(save_acc)

0 commit comments

Comments
 (0)