Skip to content

Commit f7bfcd3

Browse files
committed
Update 14_cnn.py
1 parent 6b55eca commit f7bfcd3

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

14_cnn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44
import torchvision
5-
import torchvision.transforms as transforms
5+
import torchvision.transforms.v2 as transforms
66
import matplotlib.pyplot as plt
77
import numpy as np
88
import ssl
@@ -22,13 +22,15 @@
2222

2323
# Hyper-parameters
2424
num_epochs = 20
25+
num_epochs = 5
2526
batch_size = 8
2627
learning_rate = 0.001
2728

2829
# dataset has PILImage images of range [0, 1].
2930
# We transform them to Tensors of normalized range [-1, 1]
3031
transform = transforms.Compose(
31-
[transforms.ToTensor(),
32+
[transforms.ToImage(),
33+
transforms.ToDtype(torch.float32, scale=True),
3234
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
3335

3436
# CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class
@@ -59,7 +61,7 @@ def imshow(img):
5961

6062
# show images
6163
# imshow(torchvision.utils.make_grid(images))
62-
64+
6365
class ConvNet(nn.Module):
6466
def __init__(self):
6567
super(ConvNet, self).__init__()

0 commit comments

Comments
 (0)