Skip to content

Commit

Permalink
code cleanup benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
shayansiddiqui committed Oct 30, 2018
1 parent 448a871 commit eb029ed
Show file tree
Hide file tree
Showing 13 changed files with 190 additions and 657 deletions.
5 changes: 5 additions & 0 deletions experiments/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#Everything
*

#Except this file
! .gitignore
25 changes: 11 additions & 14 deletions quicknat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""ClassificationCNN"""
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from nn_common_modules import modules as sm


Expand Down Expand Up @@ -60,11 +60,11 @@ def forward(self, input):

def enable_test_dropout(self):
attr_dict = self.__dict__['_modules']
for i in range(1,5):
encode_block, decode_block = attr_dict['encode'+str(i)], attr_dict['decode'+str(i)]
for i in range(1, 5):
encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)]
encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train)
decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train)

@property
def is_cuda(self):
"""
Expand All @@ -83,30 +83,27 @@ def save(self, path):
print('Saving model... %s' % path)
torch.save(self, path)

def predict(self, X, device = 0, enable_dropout = False):
def predict(self, X, device=0, enable_dropout=False):
"""
Predicts the outout after the model is trained.
Inputs:
- X: Volume to be predicted
"""
"""
self.eval()

if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad = False).type(torch.FloatTensor).cuda(device, non_blocking=True)
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)

if enable_dropout:
self.enable_test_dropout()
with torch.no_grad():

with torch.no_grad():
out = self.forward(X)
max_val, idx = torch.max(out,1)

max_val, idx = torch.max(out, 1)
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
del X, out, idx, max_val
return prediction



1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
matplotlib
numpy
torch
torchvision
Expand Down
Loading

0 comments on commit eb029ed

Please sign in to comment.