The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
below is the 6 random images with their respective label:
There is a package of python called torchvision, that has data loaders for CIFAR10 and data transformers for images using torch.utils.data.DataLoader.
Below an example of how to load CIFAR10 dataset using torchvision:
import torch
import torchvision
## load data CIFAR10
train_dataset = torchvision.datasets.CIFAR10(root='./train_data', train=True, download=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)- Python>=3.6
- PyTorch >=1.4
- Library are mentioned in
requirenments.txt
I used pretrained resnet18 for model training. you can use any other pretrained model according to you problem.
import torchvision.models as models
alexnet = models.alexnet()
vgg16 = models.vgg16()
densenet = models.densenet161()
inception = models.inception_v3()There are two things for pytorch model training:
- Notebook - you can just download and play with it
- python scripts:
# Start training with: python main.py # You can manually pass the attributes for the training: python main.py --lr=0.01 --epoch 20 --model_path './cifar_model.pth' # Start infrence with: python3.6 prediction.py --model_path './cifar_model.pth'


