Skip to content

Commit b0a156d

Browse files
committed
Remove deprecated commands, make valid set use test transform
1 parent 6256347 commit b0a156d

File tree

2 files changed

+56
-62
lines changed

2 files changed

+56
-62
lines changed

demo.py

Lines changed: 54 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,17 @@ def train_epoch(model, loader, optimizer, epoch, n_epochs, print_freq=1):
3939
for batch_idx, (input, target) in enumerate(loader):
4040
# Create vaiables
4141
if torch.cuda.is_available():
42-
input_var = torch.autograd.Variable(input.cuda(async=True))
43-
target_var = torch.autograd.Variable(target.cuda(async=True))
44-
else:
45-
input_var = torch.autograd.Variable(input)
46-
target_var = torch.autograd.Variable(target)
42+
input = input.cuda()
43+
target = target.cuda()
4744

4845
# compute output
49-
output = model(input_var)
50-
loss = torch.nn.functional.cross_entropy(output, target_var)
46+
output = model(input)
47+
loss = torch.nn.functional.cross_entropy(output, target)
5148

5249
# measure accuracy and record loss
5350
batch_size = target.size(0)
5451
_, pred = output.data.cpu().topk(1, dim=1)
55-
error.update(torch.ne(pred.squeeze(), target.cpu()).float().sum() / batch_size, batch_size)
52+
error.update(torch.ne(pred.squeeze(), target.cpu()).float().sum().item() / batch_size, batch_size)
5653
losses.update(loss.item(), batch_size)
5754

5855
# compute gradient and do SGD step
@@ -88,70 +85,57 @@ def test_epoch(model, loader, print_freq=1, is_test=True):
8885
model.eval()
8986

9087
end = time.time()
91-
for batch_idx, (input, target) in enumerate(loader):
92-
# Create vaiables
93-
if torch.cuda.is_available():
94-
input_var = torch.autograd.Variable(input.cuda(async=True), volatile=True)
95-
target_var = torch.autograd.Variable(target.cuda(async=True), volatile=True)
96-
else:
97-
input_var = torch.autograd.Variable(input, volatile=True)
98-
target_var = torch.autograd.Variable(target, volatile=True)
99-
100-
# compute output
101-
output = model(input_var)
102-
loss = torch.nn.functional.cross_entropy(output, target_var)
103-
104-
# measure accuracy and record loss
105-
batch_size = target.size(0)
106-
_, pred = output.data.cpu().topk(1, dim=1)
107-
error.update(torch.ne(pred.squeeze(), target.cpu()).float().sum() / batch_size, batch_size)
108-
losses.update(loss.data[0], batch_size)
109-
110-
# measure elapsed time
111-
batch_time.update(time.time() - end)
112-
end = time.time()
113-
114-
# print stats
115-
if batch_idx % print_freq == 0:
116-
res = '\t'.join([
117-
'Test' if is_test else 'Valid',
118-
'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
119-
'Time %.3f (%.3f)' % (batch_time.val, batch_time.avg),
120-
'Loss %.4f (%.4f)' % (losses.val, losses.avg),
121-
'Error %.4f (%.4f)' % (error.val, error.avg),
122-
])
123-
print(res)
88+
with torch.no_grad():
89+
for batch_idx, (input, target) in enumerate(loader):
90+
# Create vaiables
91+
if torch.cuda.is_available():
92+
input = input.cuda()
93+
target = target.cuda()
94+
95+
# compute output
96+
output = model(input)
97+
loss = torch.nn.functional.cross_entropy(output, target)
98+
99+
# measure accuracy and record loss
100+
batch_size = target.size(0)
101+
_, pred = output.data.cpu().topk(1, dim=1)
102+
error.update(torch.ne(pred.squeeze(), target.cpu()).float().sum().item() / batch_size, batch_size)
103+
losses.update(loss.item(), batch_size)
104+
105+
# measure elapsed time
106+
batch_time.update(time.time() - end)
107+
end = time.time()
108+
109+
# print stats
110+
if batch_idx % print_freq == 0:
111+
res = '\t'.join([
112+
'Test' if is_test else 'Valid',
113+
'Iter: [%d/%d]' % (batch_idx + 1, len(loader)),
114+
'Time %.3f (%.3f)' % (batch_time.val, batch_time.avg),
115+
'Loss %.4f (%.4f)' % (losses.val, losses.avg),
116+
'Error %.4f (%.4f)' % (error.val, error.avg),
117+
])
118+
print(res)
124119

125120
# Return summary statistics
126121
return batch_time.avg, losses.avg, error.avg
127122

128123

129-
def train(model, train_set, test_set, save, n_epochs=300, valid_size=5000,
124+
def train(model, train_set, valid_set, test_set, save, n_epochs=300,
130125
batch_size=64, lr=0.1, wd=0.0001, momentum=0.9, seed=None):
131126
if seed is not None:
132127
torch.manual_seed(seed)
133128

134-
# Create train/valid split
135-
if valid_size:
136-
indices = torch.randperm(len(train_set))
137-
train_indices = indices[:len(indices) - valid_size]
138-
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
139-
valid_indices = indices[len(indices) - valid_size:]
140-
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indices)
141-
142129
# Data loaders
130+
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True,
131+
pin_memory=(torch.cuda.is_available()), num_workers=0)
143132
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False,
144133
pin_memory=(torch.cuda.is_available()), num_workers=0)
145-
if valid_size:
146-
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, sampler=train_sampler,
147-
pin_memory=(torch.cuda.is_available()), num_workers=0)
148-
valid_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, sampler=valid_sampler,
149-
pin_memory=(torch.cuda.is_available()), num_workers=0)
134+
if valid_set is None:
135+
valid_loader = None
150136
else:
151-
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True,
137+
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False,
152138
pin_memory=(torch.cuda.is_available()), num_workers=0)
153-
valid_loader = None
154-
155139
# Model on cuda
156140
if torch.cuda.is_available():
157141
model = model.cuda()
@@ -264,6 +248,16 @@ def demo(data, save, depth=100, growth_rate=12, efficient=True, valid_size=5000,
264248
train_set = datasets.CIFAR10(data, train=True, transform=train_transforms, download=True)
265249
test_set = datasets.CIFAR10(data, train=False, transform=test_transforms, download=False)
266250

251+
if valid_size:
252+
valid_set = datasets.CIFAR10(data, train=True, transform=test_transforms)
253+
indices = torch.randperm(len(train_set))
254+
train_indices = indices[:len(indices) - valid_size]
255+
valid_indices = indices[len(indices) - valid_size:]
256+
train_set = torch.utils.data.Subset(train_set, train_indices)
257+
valid_set = torch.utils.data.Subset(valid_set, valid_indices)
258+
else:
259+
valid_set = None
260+
267261
# Models
268262
model = DenseNet(
269263
growth_rate=growth_rate,
@@ -281,8 +275,8 @@ def demo(data, save, depth=100, growth_rate=12, efficient=True, valid_size=5000,
281275
raise Exception('%s is not a dir' % save)
282276

283277
# Train the model
284-
train(model=model, train_set=train_set, test_set=test_set, save=save,
285-
valid_size=valid_size, n_epochs=n_epochs, batch_size=batch_size, seed=seed)
278+
train(model=model, train_set=train_set, valid_set=valid_set, test_set=test_set, save=save,
279+
n_epochs=n_epochs, batch_size=batch_size, seed=seed)
286280
print('Done!')
287281

288282

models/densenet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficien
2323
super(_DenseLayer, self).__init__()
2424
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
2525
self.add_module('relu1', nn.ReLU(inplace=True)),
26-
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
27-
growth_rate, kernel_size=1, stride=1, bias=False)),
26+
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate,
27+
kernel_size=1, stride=1, bias=False)),
2828
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
2929
self.add_module('relu2', nn.ReLU(inplace=True)),
3030
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,

0 commit comments

Comments
 (0)