Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hmm #13

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@


def line_number(info):
# linkcode gets angry about namedtuple lol
if "ICONLoss" in info["fullname"]:
return 22
if "BendingLoss" in info["fullname"]:
return 235
mod = icon_registration
for elem in info["module"].split(".")[1:]:
mod = getattr(mod, elem)
Expand Down
104 changes: 104 additions & 0 deletions src/icon_registration/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,111 @@ def forward(self, image_A, image_B) -> ICONLoss:
transform_magnitude,
flips(self.phi_BA_vectorfield),
)

BendingLoss = namedtuple(
"BendingLoss",
"all_loss bending_energy_loss similarity_loss transform_magnitude flips",
)

class BendingEnergyNet(network_wrappers.RegistrationModule):
def __init__(self, network, similarity, lmbda):
super().__init__()

self.regis_net = network
self.lmbda = lmbda
self.similarity = similarity

def compute_bending_energy_loss(self, phi_AB_vectorfield):
if len(self.identity_map.shape) == 3:
bending_energy = torch.mean((
- phi_AB_vectorfield[:, :, 2:]
+ 2*phi_AB_vectorfield[:, :, 1:-1]
- phi_AB_vectorfield[:, :, :-2]
)**2)

elif len(self.identity_map.shape) == 4:
bending_energy = torch.mean((
- phi_AB_vectorfield[:, :, 2:]
+ 2*phi_AB_vectorfield[:, :, 1:-1]
- phi_AB_vectorfield[:, :, :-2]
)**2) + torch.mean((
- phi_AB_vectorfield[:, :, :, 2:]
+ 2*phi_AB_vectorfield[:, :, :, 1:-1]
- phi_AB_vectorfield[:, :, :, :-2]
)**2)
elif len(self.identity_map.shape) == 5:
bending_energy = torch.mean((
- phi_AB_vectorfield[:, :, 2:]
+ 2*phi_AB_vectorfield[:, :, 1:-1]
- phi_AB_vectorfield[:, :, :-2]
)**2) + torch.mean((
- phi_AB_vectorfield[:, :, :, 2:]
+ 2*phi_AB_vectorfield[:, :, :, 1:-1]
- phi_AB_vectorfield[:, :, :, :-2]
)**2) + torch.mean((
- phi_AB_vectorfield[:, :, :, :, 2:]
+ 2*phi_AB_vectorfield[:, :, :, :, 1:-1]
- phi_AB_vectorfield[:, :, :, :, :-2]
)**2)


return bending_energy

def compute_similarity_measure(self, phi_AB_vectorfield, image_A, image_B):

# tag images during warping so that the similarity measure
# can use information about whether a sample is interpolated
# or extrapolated

inbounds_tag = torch.zeros(tuple(image_A.shape), device=image_A.device)
if len(self.input_shape) - 2 == 3:
inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0
elif len(self.input_shape) - 2 == 2:
inbounds_tag[:, :, 1:-1, 1:-1] = 1.0
else:
inbounds_tag[:, :, 1:-1] = 1.0

self.warped_image_A = self.as_function(
torch.cat([image_A, inbounds_tag], axis=1)
)(phi_AB_vectorfield)

similarity_loss = self.similarity(
self.warped_image_A, image_B
)
return similarity_loss

def forward(self, image_A, image_B) -> ICONLoss:

assert self.identity_map.shape[2:] == image_A.shape[2:]
assert self.identity_map.shape[2:] == image_B.shape[2:]

# Tag used elsewhere for optimization.
# Must be set at beginning of forward b/c not preserved by .cuda() etc
self.identity_map.isIdentity = True

self.phi_AB = self.regis_net(image_A, image_B)
self.phi_AB_vectorfield = self.phi_AB(self.identity_map)

similarity_loss = self.compute_similarity_measure(
self.phi_AB_vectorfield, image_A, image_B
)

bending_energy_loss = self.compute_bending_energy_loss(
self.phi_AB_vectorfield
)

all_loss = self.lmbda * bending_energy_loss + similarity_loss

transform_magnitude = torch.mean(
(self.identity_map - self.phi_AB_vectorfield) ** 2
)
return ICONLoss(
all_loss,
bending_energy_loss,
similarity_loss,
transform_magnitude,
flips(self.phi_BA_vectorfield),
)

def normalize(image):
dimension = len(image.shape) - 2
Expand Down
61 changes: 36 additions & 25 deletions src/icon_registration/network_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def adjust_batch_size(self, size):

def forward(image_A, image_B):
"""Register a pair of images:
return a python function that warps a tensor of coordinates such that
return a python function phi_AB that warps a tensor of coordinates such that

.. code-block:: python

Expand Down Expand Up @@ -109,17 +109,25 @@ def __init__(self, net):
self.net = net

def forward(self, image_A, image_B):
vectorfield_phi = self.net(image_A, image_B)

def ret(input_):
if hasattr(input_, "isIdentity") and vectorfield_phi.shape == input_.shape:
return input_ + vectorfield_phi
else:
return input_ + compute_warped_image_multiNC(
vectorfield_phi, input_, self.spacing, 1
)

return ret
displacement_field = self.as_function(self.net(image_A, image_B))
return lambda coordinates: coordinates + displacement_field(coordinates)

class SquaringVelocityField(RegistrationModule):
def __init__(self, net):
super().__init__()
self.net = net
self.n_steps = 256

def forward(self, image_A, image_B):
velocityfield_delta = self.net(image_A, image_B) / self.n_steps

for _ in range(8):
velocityfield_delta = velocityfield_delta + self.as_function(
velocityfield_delta)(velocityfield_delta + self.identity_map)
def transform(coordinate_tensor):
coordinate_tensor = coordinate_tensor + self.as_function(velocityfield_delta)(coordinate_tensor)
return coordinate_tensor
return transform


def multiply_matrix_vectorfield(matrix, vectorfield):
Expand All @@ -145,15 +153,15 @@ def __init__(self, net):
def forward(self, image_A, image_B):
matrix_phi = self.net(image_A, image_B)

def ret(input_):
shape = list(input_.shape)
def transform(tensor_of_coordinates):
shape = list(tensor_of_coordinates.shape)
shape[1] = 1
input_homogeneous = torch.cat(
[input_, torch.ones(shape, device=input_.device)], axis=1
coordinates_homogeneous = torch.cat(
[tensor_of_coordinates, torch.ones(shape, device=tensor_of_coordinates.device)], axis=1
)
return multiply_matrix_vectorfield(matrix_phi, input_homogeneous)[:, :-1]
return multiply_matrix_vectorfield(matrix_phi, coordinates_homogeneous)[:, :-1]

return ret
return transform


class RandomShift(RegistrationModule):
Expand Down Expand Up @@ -189,15 +197,18 @@ def __init__(self, netPhi, netPsi):
self.netPsi = netPsi

def forward(self, image_A, image_B):
# Tag for optimization. Must be set at the beginning of forward because it is not preserved by .to(config.device)

# Tag for shortcutting hack. Must be set at the beginning of
# forward because it is not preserved by .to(config.device)
self.identity_map.isIdentity = True

phi = self.netPhi(image_A, image_B)
phi_vectorfield = phi(self.identity_map)
self.image_A_comp_phi = self.as_function(image_A)(phi_vectorfield)
psi = self.netPsi(self.image_A_comp_phi, image_B)

ret = lambda input_: phi(psi(input_))
return ret
psi = self.netPsi(
self.as_function(image_A)(phi(self.identity_map)),
image_B
)
return lambda tensor_of_coordinates: phi(psi(tensor_of_coordinates))



class DownsampleRegistration(RegistrationModule):
Expand Down
39 changes: 39 additions & 0 deletions src/icon_registration/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,45 @@
from icon_registration import config


class ConvNet(nn.Module):
def __init__(self, dimension=2, output_dim=100):
super().__init__()
self.dimension = dimension

if dimension == 2:
self.Conv = nn.Conv2d
self.avg_pool = F.avg_pool2d
else:
self.Conv = nn.Conv3d
self.avg_pool = F.avg_pool3d

self.features = [2, 16, 32, 64, 128, 128, 256]
self.convs = nn.ModuleList([])
for depth in range(len(self.features) - 1):
self.convs.append(
self.Conv(
self.features[depth],
self.features[depth + 1],
kernel_size=3,
padding=1,
)
)
self.dense2 = nn.Linear(256, 300)
self.dense3 = nn.Linear(300, output_dim)

def forward(self, x, y):
x = torch.cat([x, y], 1)
for depth in range(len(self.features) - 1):
x = F.relu(x)
x = self.convs[depth](x)
x = self.avg_pool(x, 2, ceil_mode=True)
x = self.avg_pool(x, x.shape[2:], ceil_mode=True)
x = torch.reshape(x, (-1, 256))
x = F.relu(self.dense2(x))
x = self.dense3(x)
return x


class Autoencoder(nn.Module):
def __init__(self, num_layers, channels):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion src/icon_registration/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def train_datasets(net, optimizer, d1, d2, epochs=400):
loss_object.all_loss.backward()
optimizer.step()

loss_history.append(to_floats(loss_object))
loss_history.append(to_floats(loss_object))
return loss_history


Expand Down