Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dragonbook committed Nov 30, 2018
1 parent 672a166 commit c65e73b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
30 changes: 26 additions & 4 deletions src/v2v_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def transform_val(sample):

## Train and validate
print('Start train ..')
for epoch in range(2):
for epoch in range(200):
print('Epoch: {}'.format(epoch))
train_epoch(net, criterion, optimizer, train_loader, device=device, dtype=dtype)
val_epoch(net, criterion, val_loader, device=device, dtype=dtype)
Expand All @@ -103,7 +103,7 @@ def test(model, test_loader, output_transform, device=torch.device('cuda'), dtyp
refpoints = refpoints.cpu().numpy()

# (batch, keypoints_num, 3)
keypoints_batch = output_transform(outputs, refpoints)
keypoints_batch = output_transform((outputs, refpoints))

if keypoints is None:
# Initialize keypoints until dimensions awailable now
Expand All @@ -117,8 +117,22 @@ def test(model, test_loader, output_transform, device=torch.device('cuda'), dtyp
return keypoints


def remove_dataset_scale(x):
if isinstance(x, tuple):
for e in x: e /= dataset_scale
else: x /= dataset_scale

return x


voxelization_test = voxelization_train

def output_transform(x):
heatmaps, refpoints = x
keypoints = voxelization_test.evaluate(heatmaps, refpoints)
return remove_dataset_scale(keypoints)


def transform_test(sample):
vertices, refpoint = sample['vertices'], sample['refpoint']
vertices, refpoint = apply_dataset_scale((vertices, refpoint))
Expand All @@ -129,8 +143,6 @@ def transform_test(sample):
test_set = Tooth13Dataset(root=data_dir, mode='test', transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=6)

output_transform = voxelization_test.evaluate

print('Start test ..')
keypoints_estimate = test(net, test_loader, output_transform, device, dtype)

Expand All @@ -140,4 +152,14 @@ def transform_test(sample):
result = keypoints_estimate.reshape(keypoints_estimate.shape[0], -1)
np.savetxt(test_res_filename, result, fmt='%0.4f')


print('Start save fit ..')
fit_set = Tooth13Dataset(root=data_dir, mode='train', transform=transform_test)
fit_loader = torch.utils.data.DataLoader(fit_set, batch_size=1, shuffle=False, num_workers=6)
keypoints_fit = test(net, fit_loader, output_transform)
fit_res_filename = r'./fit_res.txt'
print('Write fit result to ', fit_res_filename)
fit_result = keypoints_fit.reshape(keypoints_fit.shape[0], -1)
np.savetxt(fit_res_filename, fit_result, fmt='%0.4f')

print('All done ..')
2 changes: 1 addition & 1 deletion src/v2v_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def discretize(coord, cropped_size):

def warp2continuous(coord, refpoint, cubic_size, cropped_size):
'''
Map coordinates in set [0, 1, .., cropped_size-1] to original range [-cropped_size/2+refpoint, cropped_size/2 + refpoint]
Map coordinates in set [0, 1, .., cropped_size-1] to original range [-cubic_size/2+refpoint, cubic_size/2 + refpoint]
'''
min_normalized = -1
max_normalized = 1
Expand Down

0 comments on commit c65e73b

Please sign in to comment.