Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dragonbook committed Nov 29, 2018
1 parent 5bb0285 commit b6ffccd
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 54 deletions.
2 changes: 1 addition & 1 deletion datasets/tooth13_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class Tooth13Dataset(Dataset):
def __init__(self, root, mode, transform=None):
if not mode in ['train', 'test']: raise ValueError('Invalid mode')
if not mode in ['train', 'val', 'test']: raise ValueError('Invalid mode')

self.mesh_names_file = os.path.join(root, mode + '_names.txt')
self.keypoints_file = os.path.join(root, mode + '_keypoints.txt')
Expand Down
40 changes: 9 additions & 31 deletions src/v2v_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,34 @@
data_dir = r''
keypoints_num = 7

cubic_size, cropped_size, original_size = 140, 88, 96
data_sizes = (cubic_size, cropped_size, original_size)


# Transformation
def to_tensor(x):
return torch.from_numpy(x).to('cpu', torch.float)
return torch.from_numpy(x)


def transform_train(sample):
mesh_name, keypoints, refpoint = sample['mesh_name'], sample['keypoints'], sample['refpoint']
assert(keypoints.shape[0] == keypoints_num)

vertices = read_mesh_vertices(mesh_name)

voxelization_train = V2VVoxelization(data_sizes, pool_factor=2, std=1.7, augmentation=True)

voxelization_train = V2VVoxelization()
input, heatmap = voxelization_train({'points': vertices, 'keypoints': keypoints, 'refpoint': refpoint})

return (to_tensor(input), to_tensor(heatmap))


def transform_val(sample):
mesh_name, keypoints, refpoint = sample['mesh_name'], sample['keypoints'], sample['refpoint']
vertices = read_mesh_vertices(mesh_name)

voxelization_val = V2VVoxelization(data_sizes, pool_factor=2, std=1.7, augmentation=True)
assert(keypoints.shape[0] == keypoints_num)

input, heatmap = voxelization_val({'points': vertices, 'keypoints': keypoints, 'refpoint': refpoint})
return (to_tensor(input), to_tensor(heatmap))

def transform_test(sample):
mesh_name, keypoints, refpoint = sample['mesh_name'], sample['keypoints'], sample['refpoint']
vertices = read_mesh_vertices(mesh_name)

voxelization_test = V2VVoxelization(data_sizes, pool_factor=2, std=1.7, augmentation=False)
voxelization_val = V2VVoxelization()
input, heatmap = voxelization_val({'points': vertices, 'keypoints': keypoints, 'refpoint': refpoint})

input, heatmap = voxelization_test({'points': vertices, 'keypoints': keypoints, 'refpoint': refpoint})
return to_tensor(input)
return (to_tensor(input), to_tensor(heatmap))


train_set = Tooth13Dataset(root=data_dir, mode='train', transform=transform_train)
Expand All @@ -65,9 +57,6 @@ def transform_test(sample):
val_set = Tooth13Dataset(root=data_dir, mode='val', transform=transform_val)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=True, num_workers=6)

test_set = Tooth13Dataset(root=data_dir, mode='test', transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=6)


# Model, criterion and optimizer
net = V2VModel(input_channels=1, output_channels=keypoints_num)
Expand All @@ -86,14 +75,3 @@ def transform_test(sample):
print('Epoch: {}'.format(epoch))
train_epoch(net, criterion, optimizer, train_loader, device=device, dtype=dtype)
val_epoch(net, criterion, val_loader, device=device, dtype=dtype)


# Test
print('Start test ..')
def test_epoch(model, test_loader, device=torch.device('cuda'), dtype=torch.float):
model.eval()

with torch.no_grad():
for batch_idx, inputs in enumerate(test_loader):
inputs = inputs.to(device, dtype)
# TODO:
53 changes: 31 additions & 22 deletions src/v2v_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ def discretize(coord, cropped_size):


def scattering(coord, cropped_size):
# coord: [0, cropped_size]
# Assign range[0, 1) -> 0, [1, 2) -> 1, .. [cropped_size-1, cropped_size) -> cropped_size-1
# That is, around center 0.5 -> 0, around center 1.5 -> 1 .. around center cropped_size-0.5 -> cropped_size-1
coord = coord.astype(np.int32)

mask = (coord[:, 0] >= 0) & (coord[:, 0] < cropped_size) & \
(coord[:, 1] >= 0) & (coord[:, 1] < cropped_size) & \
(coord[:, 2] >= 0) & (coord[:, 2] < cropped_size)

coord = coord[mask, :]

cubic = np.zeros((cropped_size, cropped_size, cropped_size))

# Note, directly map point coordinate (x, y, z) to index (i, j, k), instead of (k, j, i)
# Need to be consistent with heatmap generating and coordinates extration from heatmap
cubic[coord[:, 0], coord[:, 1], coord[:, 2]] = 1
Expand All @@ -34,15 +40,17 @@ def scattering(coord, cropped_size):

def warp2continuous(coord, refpoint, cubic_size, cropped_size):
'''
[0, cropped_size] -> [-cropped_size/2+refpoint, cropped_size/2 + refpoint]
Map coordinates in set [0, 1, .., cropped_size-1] to range [-cropped_size/2+refpoint, cropped_size/2 + refpoint]
'''
# Note discrete coord can represents real range [coord, coord+1), see function scattering()
# So, move coord to range center for better fittness
coord += 0.5

min_normalized = -1
max_normalized = 1

scale = (max_normalized - min_normalized) / cropped_size
# Why need scale/2?
#coord = coord * scale + min_normalized + scale / 2 # author's implementation
coord = coord * scale + min_normalized
coord = coord * scale + min_normalized # -> [-1, 1]

coord = coord * cubic_size / 2 + refpoint

Expand Down Expand Up @@ -72,11 +80,13 @@ def generate_coord(points, refpoint, new_size, angle, trans, sizes):
# points shape: (n, 3)
coord = points

# normalize, candidates will lie in range [0, 1]
coord = (coord - refpoint) / (cubic_size/2)
# note, will consider points within range [refpoint-cubic_size/2, refpoint+cubic_size/2] as candidates

# normalize
coord = (coord - refpoint) / (cubic_size/2) # -> [-1, 1]

# discretize
coord = discretize(coord, cropped_size) # candidates in range [0, cropped_size)
coord = discretize(coord, cropped_size) # -> [0, cropped_size]
coord += (original_size / 2 - cropped_size / 2)

# resize
Expand Down Expand Up @@ -111,44 +121,43 @@ def generate_coord(points, refpoint, new_size, angle, trans, sizes):


def generate_cubic_input(points, refpoint, new_size, angle, trans, sizes):
cubic_size, cropped_size, original_size = sizes
_, cropped_size, _ = sizes
coord = generate_coord(points, refpoint, new_size, angle, trans, sizes)

# scattering
coord = coord.astype(np.int32) # [0, cropped_size)
cubic = scattering(coord, cropped_size)

return cubic


def generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, sizes, d3outputs, pool_factor, std):
cubic_size, cropped_size, original_size = sizes
_, cropped_size, _ = sizes
d3output_x, d3output_y, d3output_z = d3outputs

coord = generate_coord(keypoints, refpoint, new_size, angle, trans, sizes)
coord /= pool_factor
coord += 1
coord = generate_coord(keypoints, refpoint, new_size, angle, trans, sizes) # [0, cropped_size]
coord /= pool_factor # [0, cropped_size/pool_factor]

# heatmap generation
heatmap = np.zeros((keypoints.shape[0], cropped_size, cropped_size, cropped_size))
for i in range(coord.shape[0]):
xi, yi, zi= coord[i] - 1
heatmap[i] = np.exp(-(np.power((d3output_x-xi)/std, 2)/2 + \
np.power((d3output_y-yi)/std, 2)/2 + \
np.power((d3output_z-zi)/std, 2)/2))
xi, yi, zi= coord[i]
heatmap[i] = np.exp(-(np.power((d3output_x+0.5-xi)/std, 2)/2 + \
np.power((d3output_y+0.5-yi)/std, 2)/2 + \
np.power((d3output_z+0.5-zi)/std, 2)/2)) # +0.5, move coordinate to range center

return heatmap


class V2VVoxelization(object):
def __init__(self, sizes, pool_factor, std, augmentation=True):
self.sizes = sizes
self.cubic_size, self.cropped_size, self.original_size = self.sizes
self.pool_factor = pool_factor
self.std = std
def __init__(self, augmentation=True):
self.cubic_size, self.cropped_size, self.original_size = 140, 88, 96
self.sizes = (self.cubic_size, self.cropped_size, self.original_size)
self.pool_factor = 2
self.std = 1.7
self.augmentation = augmentation

output_size = self.cropped_size / self.pool_factor
# Note, range(size) and indexing = 'ij'
self.d3outputs = np.meshgrid(np.arange(output_size), np.arange(output_size), np.arange(output_size), indexing='ij')

def __call__(self, sample):
Expand Down

0 comments on commit b6ffccd

Please sign in to comment.