Skip to content

Commit

Permalink
fix bug of model output layer; try to use cuda, instead of cudnn, cud…
Browse files Browse the repository at this point in the history
…nn+float cannot train well; demo example uses fixed random seed to fit ONE sample
  • Loading branch information
dragonbook committed Dec 2, 2018
1 parent 05955f2 commit 348dcc0
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 79 deletions.
6 changes: 4 additions & 2 deletions lib/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def train_epoch(model, criterion, optimizer, train_loader, device=torch.device('
optimizer.step()

train_loss += loss.item()
progress_bar(batch_idx, len(train_loader), 'Loss: {0:.7f}'.format(train_loss/(batch_idx+1)))
progress_bar(batch_idx, len(train_loader), 'Loss: {0:.4e}'.format(train_loss/(batch_idx+1)))
#print('loss: {0: .4e}'.format(train_loss/(batch_idx+1)))


def val_epoch(model, criterion, val_loader, device=torch.device('cuda'), dtype=torch.float):
Expand All @@ -30,4 +31,5 @@ def val_epoch(model, criterion, val_loader, device=torch.device('cuda'), dtype=t
loss = criterion(outputs, targets)

val_loss += loss.item()
progress_bar(batch_idx, len(val_loader), 'Loss: {0:.7f}'.format(val_loss/(batch_idx+1)))
progress_bar(batch_idx, len(val_loader), 'Loss: {0:.4e}'.format(val_loss/(batch_idx+1)))
#print('loss: {0: .4e}'.format(val_loss/(batch_idx+1)))
170 changes: 106 additions & 64 deletions src/v2v_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import sampler

from lib.solver import train_epoch, val_epoch
from lib.mesh_util import read_mesh_vertices
Expand All @@ -13,14 +14,21 @@
import numpy as np


#torch.random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)


# Basic configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float
#dtype = torch.double


# Data configuration
print('==> Preparing data ..')
data_dir = r'/home/maiqi/yalong/dataset/cases-tooth-keypoints/D-11-15-aug/v2v-23-same-ori/split1'
data_dir = r'/home/yalong/yalong/project/KeyPointsEstimation/V2V-PoseNet-pytorch/experiments/tooth/exp1/split1/'
dataset_scale = 10
keypoints_num = 7

Expand All @@ -38,11 +46,24 @@ def to_tensor(x):
return torch.from_numpy(x)


voxelization_train = V2VVoxelization(augmentation=True)
voxelization_val = voxelization_train
voxelization_train = V2VVoxelization(augmentation=False)
voxelization_val = V2VVoxelization(augmentation=False)


class ChunkSampler(sampler.Sampler):
def __init__(self, num_samples, start=0):
self.num_samples = num_samples
self.start = start

def __iter__(self):
return iter(range(self.start, self.start + self.num_samples))

def __len__(self):
return self.num_samples


def transform_train(sample):
vertices, keypoints, refpoint = sample['vertices'], sample['keypoints'], sample['refpoint']
vertices, keypoints, refpoint = sample['vertices'].copy(), sample['keypoints'].copy(), sample['refpoint'].copy()
assert(keypoints.shape[0] == keypoints_num)

vertices, keypoints, refpoint = apply_dataset_scale((vertices, keypoints, refpoint))
Expand All @@ -52,7 +73,8 @@ def transform_train(sample):


def transform_val(sample):
vertices, keypoints, refpoint = sample['vertices'], sample['keypoints'], sample['refpoint']
#vertices, keypoints, refpoint = sample['vertices'], sample['keypoints'], sample['refpoint']
vertices, keypoints, refpoint = sample['vertices'].copy(), sample['keypoints'].copy(), sample['refpoint'].copy()
assert(keypoints.shape[0] == keypoints_num)

vertices, keypoints, refpoint = apply_dataset_scale((vertices, keypoints, refpoint))
Expand All @@ -61,105 +83,125 @@ def transform_val(sample):
return (to_tensor(input), to_tensor(heatmap))


# Datasets
train_set = Tooth13Dataset(root=data_dir, mode='train', transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True, num_workers=6)
train_num = 1
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False, num_workers=6,sampler=ChunkSampler(train_num, 0))
#train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True, num_workers=6)

val_set = Tooth13Dataset(root=data_dir, mode='val', transform=transform_val)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=True, num_workers=6)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=6)


# Model, criterion and optimizer
net = V2VModel(input_channels=1, output_channels=keypoints_num)

net = net.to(device)
net = net.to(device, dtype)
if device == 'cuda':
cudnn.benchmark = True
torch.backends.cudnn.enabled = False
#cudnn.benchmark = True
#cudnn.deterministic = True
print('backends: ', torch.backends.cudnn.enabled)
print('version: ', torch.backends.cudnn.version())


class Criterion(nn.Module):
def __init__(self):
super(Criterion, self).__init__()

def forward(self, outputs, targets):
# Assume batch = 1
return ((outputs - targets)**2).mean()


criterion = nn.MSELoss()
optimizer = optim.RMSprop(net.parameters(), lr=2.5e-4)
#criterion = Criterion()
#optimizer = optim.RMSprop(net.parameters(), lr=2.5e-4)
optimizer = optim.Adam(net.parameters())
#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)


## Train and validate
print('Start train ..')
for epoch in range(200):
print('Epoch: {}'.format(epoch))
train_epoch(net, criterion, optimizer, train_loader, device=device, dtype=dtype)
val_epoch(net, criterion, val_loader, device=device, dtype=dtype)
#val_epoch(net, criterion, val_loader, device=device, dtype=dtype)


## Test
def test(model, test_loader, output_transform, device=torch.device('cuda'), dtype=torch.float):
model.eval()
# Test
# def test(model, test_loader, output_transform, device=torch.device('cuda'), dtype=torch.float):
# model.eval()

samples_num = len(test_loader)
keypoints = None
idx = 0
# samples_num = len(test_loader)
# keypoints = None
# idx = 0

with torch.no_grad():
for batch_idx, (inputs, refpoints) in enumerate(test_loader):
outputs = model(inputs.to(device, dtype))
# with torch.no_grad():
# for batch_idx, (inputs, refpoints) in enumerate(test_loader):
# outputs = model(inputs.to(device, dtype))

outputs = outputs.cpu().numpy()
refpoints = refpoints.cpu().numpy()
# outputs = outputs.cpu().numpy()
# refpoints = refpoints.cpu().numpy()

# (batch, keypoints_num, 3)
keypoints_batch = output_transform((outputs, refpoints))
# # (batch, keypoints_num, 3)
# keypoints_batch = output_transform((outputs, refpoints))

if keypoints is None:
# Initialize keypoints until dimensions awailable now
keypoints = np.zeros((samples_num, *keypoints_batch.shape[1:]))
# if keypoints is None:
# # Initialize keypoints until dimensions awailable now
# keypoints = np.zeros((samples_num, *keypoints_batch.shape[1:]))

batch_size = keypoints_batch.shape[0]
keypoints[idx:idx+batch_size] = keypoints_batch
idx += batch_size
# batch_size = keypoints_batch.shape[0]
# keypoints[idx:idx+batch_size] = keypoints_batch
# idx += batch_size


return keypoints
# return keypoints


def remove_dataset_scale(x):
if isinstance(x, tuple):
for e in x: e /= dataset_scale
else: x /= dataset_scale
# def remove_dataset_scale(x):
# if isinstance(x, tuple):
# for e in x: e /= dataset_scale
# else: x /= dataset_scale

return x
# return x


voxelization_test = voxelization_train
# voxelization_test = voxelization_train

def output_transform(x):
heatmaps, refpoints = x
keypoints = voxelization_test.evaluate(heatmaps, refpoints)
return remove_dataset_scale(keypoints)
# def output_transform(x):
# heatmaps, refpoints = x
# keypoints = voxelization_test.evaluate(heatmaps, refpoints)
# return remove_dataset_scale(keypoints)


def transform_test(sample):
vertices, refpoint = sample['vertices'], sample['refpoint']
vertices, refpoint = apply_dataset_scale((vertices, refpoint))
input = voxelization_test.voxelize(vertices, refpoint)
return to_tensor(input), to_tensor(refpoint.reshape((1, -1)))
# def transform_test(sample):
# vertices, refpoint = sample['vertices'], sample['refpoint']
# vertices, refpoint = apply_dataset_scale((vertices, refpoint))
# input = voxelization_test.voxelize(vertices, refpoint)
# return to_tensor(input), to_tensor(refpoint.reshape((1, -1)))


test_set = Tooth13Dataset(root=data_dir, mode='test', transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=6)
# test_set = Tooth13Dataset(root=data_dir, mode='test', transform=transform_test)
# test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=6)

print('Start test ..')
keypoints_estimate = test(net, test_loader, output_transform, device, dtype)
# print('Start test ..')
# keypoints_estimate = test(net, test_loader, output_transform, device, dtype)

test_res_filename = r'./test_res.txt'
print('Write result to ', test_res_filename)
# Reshape one sample keypoints in one line
result = keypoints_estimate.reshape(keypoints_estimate.shape[0], -1)
np.savetxt(test_res_filename, result, fmt='%0.4f')
# test_res_filename = r'./test_res.txt'
# print('Write result to ', test_res_filename)
# # Reshape one sample keypoints in one line
# result = keypoints_estimate.reshape(keypoints_estimate.shape[0], -1)
# np.savetxt(test_res_filename, result, fmt='%0.4f')


print('Start save fit ..')
fit_set = Tooth13Dataset(root=data_dir, mode='train', transform=transform_test)
fit_loader = torch.utils.data.DataLoader(fit_set, batch_size=1, shuffle=False, num_workers=6)
keypoints_fit = test(net, fit_loader, output_transform)
fit_res_filename = r'./fit_res.txt'
print('Write fit result to ', fit_res_filename)
fit_result = keypoints_fit.reshape(keypoints_fit.shape[0], -1)
np.savetxt(fit_res_filename, fit_result, fmt='%0.4f')
# print('Start save fit ..')
# fit_set = Tooth13Dataset(root=data_dir, mode='train', transform=transform_test)
# fit_loader = torch.utils.data.DataLoader(fit_set, batch_size=1, shuffle=False, num_workers=6)
# keypoints_fit = test(net, fit_loader, output_transform, device=device, dtype=dtype)
# fit_res_filename = r'./fit_res.txt'
# print('Write fit result to ', fit_res_filename)
# fit_result = keypoints_fit.reshape(keypoints_fit.shape[0], -1)
# np.savetxt(fit_res_filename, fit_result, fmt='%0.4f')

print('All done ..')
# print('All done ..')
41 changes: 28 additions & 13 deletions src/v2v_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,38 @@ def __init__(self, pool_size):
self.pool_size = pool_size

def forward(self, x):
return F.max_pool3d(x, self.pool_size, self.pool_size)
return F.max_pool3d(x, kernel_size=self.pool_size, stride=self.pool_size)


# class Upsample3DBlock(nn.Module):
# '''
# Note, the original torch implementation can be implemented in pytorch as:
# 'upsample = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2, padding=0.5, output_padding=1'
# But the padding value is float, and there will be an error in environment pytorch-0.4.1 + python36. Note, It
# may work with python2.7?
# So, I impelment it with kernel = 3 instead.
# '''
# def __init__(self, in_planes, out_planes, kernel_size, stride):
# super(Upsample3DBlock, self).__init__()
# assert(kernel_size == 3)
# assert(stride == 2)
# self.block = nn.Sequential(
# nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=1, output_padding=1),
# nn.BatchNorm3d(out_planes),
# nn.ReLU(True)
# )

# def forward(self, x):
# return self.block(x)


class Upsample3DBlock(nn.Module):
'''
Note, the original torch implementation can be implemented in pytorch as:
'upsample = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2, padding=0.5, output_padding=1'
But the padding value is float, and there will be an error in environment pytorch-0.4.1 + python36. Note, It
may work with python2.7?
So, I impelment it with kernel = 3 instead.
'''
def __init__(self, in_planes, out_planes, kernel_size, stride):
super(Upsample3DBlock, self).__init__()
assert(kernel_size == 3)
assert(kernel_size == 2)
assert(stride == 2)
self.block = nn.Sequential(
nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=1, output_padding=1),
nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0, output_padding=0),
nn.BatchNorm3d(out_planes),
nn.ReLU(True)
)
Expand All @@ -88,9 +103,9 @@ def __init__(self):
self.mid_res = Res3DBlock(128, 128)

self.decoder_res2 = Res3DBlock(128, 128)
self.decoder_upsample2 = Upsample3DBlock(128, 64, 3, 2)
self.decoder_upsample2 = Upsample3DBlock(128, 64, 2, 2)
self.decoder_res1 = Res3DBlock(64, 64)
self.decoder_upsample1 = Upsample3DBlock(64, 32, 3, 2)
self.decoder_upsample1 = Upsample3DBlock(64, 32, 2, 2)

self.skip_res1 = Res3DBlock(32, 32)
self.skip_res2 = Res3DBlock(64, 64)
Expand Down Expand Up @@ -135,7 +150,7 @@ def __init__(self, input_channels, output_channels):
Basic3DBlock(32, 32, 1),
)

self.output_layer = Basic3DBlock(32, output_channels, 1)
self.output_layer = nn.Conv3d(32, output_channels, kernel_size=1, stride=1, padding=0)

self._initialize_weights()

Expand Down
5 changes: 5 additions & 0 deletions src/v2v_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def voxelize(self, points, refpoint):
input = generate_cubic_input(points, refpoint, new_size, angle, trans, self.sizes)
return input.reshape((1, *input.shape))

def generate_heatmap(self, keypoints, refpoint):
new_size, angle, trans = 100, 0, self.original_size/2 - self.cropped_size/2 + 1
heatmap = generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.d3outputs, self.pool_factor, self.std)
return heatmap

def evaluate(self, heatmaps, refpoints):
coords = extract_coord_from_output(heatmaps)
coords *= self.pool_factor
Expand Down

0 comments on commit 348dcc0

Please sign in to comment.