Skip to content

AgamChopra/TorchRegister

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

67 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TorchRegister

Common 2D and 3D image registration methods such as rigid, affine, and flow field for PyTorch.


Fig. Example visualizations 2D image registration.


Fig. Example visualizations of deep learning based flow-field brain MRI registration.


Fig. Example visualizations of PyTorch based Affine brain MRI registration.


Fig. Example visualizations of PyTorch based Rigid brain MRI registration.


Fig. Example visualizations of loss curve for flow-field registration(left), affine registration(middle), and rigid registration(right).

Example:


    import torch
    from torchio.transforms import RandomAffine
    from numpy import load
    from matplotlib import pyplot as plt
    import TorchRegister as tr


    def rand_augment(x):
        affine = RandomAffine(image_interpolation='bspline',
                              degrees=25, translation=4, scales=(0.8, 1.2))
        y = affine(x[0])
        return y.view(x.shape)


    device = 'cuda'

    # loading data
    path = 'R:/img (%d).pkl' % (1)
    data = load(path, allow_pickle=True)
    moving = torch.from_numpy(data[0])
    moving = moving.view(1, 1, moving.shape[0], moving.shape[1], moving.shape[2]).to(
        dtype=torch.float, device=device)
    target = torch.from_numpy(data[0])
    target = rand_augment(target.view(1, 1, target.shape[0], target.shape[1], target.shape[2])).to(
        dtype=torch.float, device=device)

    plt.imshow(torch.squeeze(moving[:, :, :, :, 60]
                             ).detach().cpu().numpy(), cmap='gray')
    plt.title('Moving')
    plt.show()

    plt.imshow(torch.squeeze(target[:, :, :, :, 60]
                             ).detach().cpu().numpy(), cmap='gray')
    plt.title('Target')
    plt.show()

    # Rigid registration
    warping = tr.Register(mode='rigid', device=device, debug=False)
    warping.optim(moving, target, max_epochs=500, lr=1E-5)
    warped = warping(moving)

    plt.imshow(torch.squeeze(warped[:, :, :, :, 60]
                             ).detach().cpu().numpy(), cmap='gray')
    plt.title('Warped Moving 1')
    plt.show()

    # Affine registration
    moving = warped.detach()
    warping = tr.Register(mode='affine', device=device, debug=False)
    warping.optim(moving, target, max_epochs=200, lr=1E-5)
    warped = warping(moving)

    plt.imshow(torch.squeeze(warped[:, :, :, :, 60]
                             ).detach().cpu().numpy(), cmap='gray')
    plt.title('Warped Moving 2')
    plt.show()

    # Flow field based registration
    moving = warped.detach()
    warping = tr.Register(mode='flow', device=device, debug=False)
    warping.optim(moving, target, lr=1E-3, max_epochs=100)
    warped = warping(moving)

    plt.imshow(torch.squeeze(warped[:, :, :, :, 60]
                             ).detach().cpu().numpy(), cmap='gray')
    plt.title('Warped Moving 3')
    plt.show()

    plt.imshow(torch.moveaxis(torch.squeeze(tr.norm(
        torch.abs(warping.theta[:, :, :, :, 60]))), 0, -1).detach().cpu().numpy())
    plt.title('Flow Field')
    plt.show()

[GNU AGPL3 License]

About

Common 2D and 3D image registration methods such as rigid, affine, and flow field for PyTorch.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages