Skip to content

Commit 0b4a84f

Browse files
author
Soham Pal
committed
Added image classification example.
1 parent 6abf61f commit 0b4a84f

File tree

6 files changed

+240
-1
lines changed

6 files changed

+240
-1
lines changed

Images/ml-button.png

43.7 KB
Loading
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Overview #
2+
3+
[![](/Images/Download-Button.png)](Image-Classification-with-PyTorch.tar.gz)
4+
5+
This examples runs a Python script for image classification with PyTorch. The following sections explain the parts of the script, `fmnist.py`.
6+
7+
# Prerequisites #
8+
9+
To run this script you should have PyTorch and other associated packages installed. We recommend that you set up a Python virtual environment in which you can install the packages. For an in-depth information on how to set up a virtual environment for Python 3.8+ see our [our online documentation](https://public.confluence.arizona.edu/display/UAHPC/Using+and+Installing+Python). You will also want to load the CUDA modules before you install PyTorch, so that PyTorch can use GPUs.
10+
11+
From an interactive session on El Gato:
12+
``` console
13+
[netid@gpu70 ~]$ module load python/3.8
14+
[netid@gpu70 ~]$ module load cuda11/11.8 cuda11-dnn/8.9.2 cuda11-sdk/22.11
15+
[netid@gpu70 ~]$ python3 -m venv --system-site-packages pyvenv
16+
[netid@gpu70 ~]$ source pyvenv/bin/activate
17+
(pyvenv) [netid@gpu70 ~]$ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
18+
(pyvenv) [netid@gpu70 ~]$ pip install torcheval
19+
```
20+
21+
# Python script header #
22+
23+
``` python
24+
#!/usr/bin/env python3
25+
26+
import time
27+
28+
import torch
29+
import torch.nn as nn
30+
from torchvision import datasets, models, transforms
31+
from torch.utils.data import Dataset, DataLoader, random_split
32+
from torcheval.metrics.functional import multiclass_accuracy
33+
```
34+
35+
# The data #
36+
37+
In this example we look at the [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset. Fashion-MNIST is a dataset consisting of a training set of 60000 examples and a test set of 10000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes, each corresponding to a piece of clothing.
38+
39+
The following example shows how the data looks:
40+
![fashion-mnist-data-example](fashion-mnist-sprite.png)
41+
42+
The labels are:
43+
| Label | Description |
44+
|-------|-------------|
45+
| 0 | T-shirt/top |
46+
| 1 | Trouser |
47+
| 2 | Pullover |
48+
| 3 | Dress |
49+
| 4 | Coat |
50+
| 5 | Sandal |
51+
| 6 | Shirt |
52+
| 7 | Sneaker |
53+
| 8 | Bag |
54+
| 9 | Ankle boot |
55+
56+
The `datasets` module from PyTorch provides an API to download and transform the data. We will randomly select 20% of the training set to create a validation set. We will also resize the images, convert them to PyTorch tensors, and normalize them.
57+
58+
``` python
59+
def get_dls(bs):
60+
root = "data/" # This is where the data will be downloaded
61+
dsets = datasets.FashionMNIST(root=root, download=True, train=True)
62+
63+
train_set, valid_set = random_split(dsets, [0.8, 0.2])
64+
65+
mean = dsets.data[train_set.indices].float().mean()
66+
std = dsets.data[train_set.indices].float().std()
67+
68+
tfms = transforms.Compose([
69+
transforms.Resize((224, 224)),
70+
transforms.ToTensor(),
71+
transforms.Normalize((mean / 255,), (std / 255,))
72+
])
73+
74+
# Training and validation dataloaders which will load the data in batches.
75+
train_dl = DataLoader(
76+
DatasetFromSubset(train_set, transform=tfms),
77+
batch_size=bs,
78+
shuffle=True,
79+
**kwargs,
80+
)
81+
valid_dl = DataLoader(
82+
DatasetFromSubset(valid_set, transform=tfms),
83+
batch_size=2 * bs,
84+
shuffle=False,
85+
**kwargs,
86+
)
87+
88+
# The test dataloader is created from the original test set. The model will not see
89+
# this data during training, we will use it for a final evaluation.
90+
test_dl = DataLoader(
91+
datasets.FashionMNIST(root=root, download=True, train=False, transform=tfms),
92+
batch_size=2 * bs,
93+
shuffle=True,
94+
**kwargs,
95+
)
96+
return train_dl, valid_dl, test_dl
97+
98+
# Helper class to create a PyTorch Dataset from a Subset
99+
class DatasetFromSubset(Dataset):
100+
def __init__(self, subset, transform=None):
101+
super(Dataset, self).__init__()
102+
self.subset = subset
103+
self.transform = transform
104+
105+
def __getitem__(self, index):
106+
x, y = self.subset[index]
107+
if self.transform:
108+
x = self.transform(x)
109+
return x, y
110+
111+
def __len__(self):
112+
return len(self.subset)
113+
```
114+
115+
# The model #
116+
117+
We will use the ResNet-18 model from [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385). ResNet-18 is a 18 layers deep convolutional neural network (CNN). We will not implement this from scratch, rather we will use a version available with PyTorch that has been pretrained on the [IMAGENET](https://www.image-net.org/challenges/LSVRC/index.php) dataset, which has 1000 classes. This is an example of transfer learning.
118+
119+
By default the ResNet models expect three channel RGB images. We will modify the input layer so that it will accept our single channel grayscale images. We will also modify the output layer so that it outputs 10 classes instead of 1000.
120+
121+
``` python
122+
def get_model():
123+
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
124+
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
125+
model.fc = nn.Linear(model.fc.in_features, 10)
126+
return model
127+
```
128+
129+
# Training and testing #
130+
131+
We define the following two functions to train and test the model. We use the `.to()` method to load the model and data on the GPU.
132+
133+
We also turn on the cuDNN auto-tuner so that it may select the kernel with the best performance, to compute convolutions.
134+
135+
``` python
136+
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
137+
start_time = time.time()
138+
for epoch in range(epochs):
139+
model.train()
140+
for data in train_dl:
141+
xb, yb = data[0].to(device), data[1].to(device)
142+
loss = loss_func(model(xb), yb)
143+
loss.backward()
144+
opt.step()
145+
opt.zero_grad()
146+
valid_loss, valid_acc = predict_stats(model, valid_dl)
147+
print(f"Epoch {epoch + 1}/{epochs}, Validation loss: {valid_loss:.4f}, Validation accuracy: {valid_acc:.4f}")
148+
print(f"Training time: {time.time() - start_time:.4f}s")
149+
return valid_loss, valid_acc
150+
151+
def predict_stats(model, dl):
152+
model.eval()
153+
if device == torch.device("cuda"):
154+
torch.cuda.empty_cache()
155+
with torch.no_grad():
156+
tot_loss = tot_acc = count = 0
157+
for data in dl:
158+
xb, yb = data[0].to(device), data[1].to(device)
159+
pred = model(xb)
160+
n = len(xb)
161+
count += n
162+
tot_loss += loss_func(pred, yb).item() * n
163+
tot_acc += multiclass_accuracy(pred, yb).item() * n
164+
return tot_loss / count, tot_acc / count
165+
166+
167+
torch.backends.cudnn.benchmark = True
168+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169+
model = get_model().to(device)
170+
opt = torch.optim.Adadelta(model.parameters())
171+
loss_func = nn.CrossEntropyLoss()
172+
173+
# For faster and asynchronous memory copy to the GPU
174+
kwargs = {"num_workers": 1, "pin_memory": True} if device==torch.device("cuda") else {}
175+
176+
train_dl, valid_dl, test_dl = get_dls()
177+
epochs = 5
178+
179+
fit(epochs, model, loss_func, opt, train_dl, valid_dl)
180+
test_loss, test_acc = predict_stats(model, test_dl)
181+
print(f"Test loss: {test_loss:.4f}, Test accuracy: {test_acc:.4f}")
182+
```
183+
184+
# Submission script #
185+
186+
``` bash
187+
#!/bin/bash
188+
189+
#SBATCH --job-name=fashion-mnist-test-run
190+
#SBATCH --time=00:30:00
191+
#SBATCH --ntasks-per-node=1
192+
#SBATCH --mem-per-cpu=6G
193+
#SBATCH --partition=standard
194+
#SBATCH --account=<YOUR_GROUP>
195+
#SBATCH --gres=gpu:1
196+
197+
module load python/3.8
198+
module load cuda11/11.8 cuda11-dnn/8.9.2 cuda11-sdk/22.11
199+
200+
source pyvenv/bin/activate
201+
202+
python3 -u fmnist.py
203+
```
204+
205+
# Submit the job #
206+
207+
Submit the job on Ocelote.
208+
209+
``` console
210+
[netid@junonia ~] sbatch fmnist.slurm
211+
Submitted batch 2418647
212+
```
213+
214+
# Output #
215+
216+
``` console
217+
[netid@wentletrap ~] cat slurm-2417845.out
218+
Epoch 1/5, Validation loss: 0.2412, Validation accuracy: 0.9153
219+
Epoch 2/5, Validation loss: 0.2098, Validation accuracy: 0.9277
220+
Epoch 3/5, Validation loss: 0.2232, Validation accuracy: 0.9230
221+
Epoch 4/5, Validation loss: 0.1790, Validation accuracy: 0.9393
222+
Epoch 5/5, Validation loss: 0.1955, Validation accuracy: 0.9359
223+
Training time: 461.9267s
224+
Test loss: 0.1955, Test accuracy: 0.9359
225+
Detailed performance metrics for this job will be available at https://metrics.hpc.arizona.edu/#job_viewer?action=show&realm=SUPREMM&resource_id=5&local_job_id=2417845 by 8am on 2023/09/01.
226+
```
227+
228+
With just 5 epochs we got 93% accuracy on the test set. You can try image augmentation methods, changing the output layer of the model, larger models, etc to improve the accuracy. It takes around 8 minutes to train. You can try [tuning the performance](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html) to improve the training time.
772 KB
Loading

Machine-Learning-Examples/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Machine Learning Examples #
2+
3+
> :bulb: If you have never run a batch job before, see our [Quick Start Guide](https://public.confluence.arizona.edu/display/UAHPC/Puma+Quick+Start) for a walkthrough. We also have a [video recording of our Intro to HPC workshop](https://public.confluence.arizona.edu/display/UAHPC/Training#Training-IntroductiontoHPC) that goes over system use and batch scripts. Intro to HPC also comes with a [companion page](https://ua-researchcomputing-hpc.github.io/Intro-to-HPC/).
4+
5+
This page has a collection of various machine learning examples that users may find helpful. For a basic introduction to machine learning on HPC, see our [machine learning workshop](/Intro-to-Machine-Learning).
6+
7+
8+
## Examples ##
9+
10+
### [Image Classification with PyTorch](Image-Classification-with-PyTorch) ###
11+
A basic image classification example with PyTorch to get started.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Our clusters use the scheduling software SLURM. Below are some helpful guides fo
1515
[![](/Images/life-sciences-button.png)](Life-Sciences)[![](/Images/apptainer-button.png)](Apptainer-Examples)
1616
[![](/Images/Matlab-button.png)](Matlab-Examples)[![](/Images/mpi-button.png)](MPI-Examples)
1717
[![](/Images/python_button.png)](Python-Examples) [![](/Images/r-button.png)](R-Examples)
18-
18+
[![](/Images/ml-button.png)](Machine-Learning-Examples)
1919

2020
# Workshops
2121
[![](/Images/intro-to-HPC-button.png)](Intro-to-HPC)[![](/Images/intro-to-ML-button.png)](Intro-to-Machine-Learning)

0 commit comments

Comments
 (0)