Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Kelkar authored and Varun Kelkar committed Jan 2, 2024
1 parent b831c3d commit daabf76
Show file tree
Hide file tree
Showing 52 changed files with 8,321 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
results
**/__pycache__
57 changes: 57 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# AmbientFlow: Invertible generative models from incomplete, noisy measurements

PyTorch implementation

**Paper**: https://arxiv.org/abs/2309.04856

Varun A. Kelkar, Rucha Deshpande, Arindam Banerjee, Mark A. Anastasio <br />
University of Illinois at Urbana-Champaign, Urbana, IL - 61801, USA

**Contact**: vak2@illinois.edu, maa@illinois.edu

**Abstract**: Generative models have gained popularity for their potential applications in imaging science, such as image reconstruction, posterior sampling and data sharing. Flow-based generative models are particularly attractive due to their ability to tractably provide exact density estimates along with fast, inexpensive and diverse samples. Training such models, however, requires a large, high quality dataset of objects. In applications such as computed imaging, it is often difficult to acquire such data due to requirements such as long acquisition time or high radiation dose, while acquiring noisy or partially observed measurements of these objects is more feasible. In this work, we propose AmbientFlow, a framework for learning flow-based generative models directly from noisy and incomplete data. Using variational Bayesian methods, a novel framework for establishing flow-based generative models from noisy, incomplete data is proposed. Extensive numerical studies demonstrate the effectiveness of AmbientFlow in learning the object distribution. The utility of AmbientFlow in a downstream inference task of image reconstruction is demonstrated.

<p align="center">
<img src="./docs/schematic.png" alt="AmbientFlow MRI images" width="500"/>
<img src="./docs/AmbientFlow_generated_images_mri.png" alt="AmbientFlow MRI images" width="1000"/>
</p>


## System Requirements
- Linux/Unix-based systems recommended. The code hasn't been tested on Windows.
- 64 bit Python 3.8+. The code has been tested with Python 3.8.13 installed via Anaconda

Other packages required are listed in the `requirements.txt` file.

## Directory structure

- The directory `toy` contains the codes corresponding to the toy 2D distribution.

- The directory `scripts/training` contains the top level shell scripts for training the regular and ambient models. Within this directory, `paper_configs.txt` contains the hyperparameter configurations used for the results in the paper. The number of GPUs and batch sizes listed are according to training on a single NVIDIA Quadro RTX 8000 48 GB GPU.

- The directory `src` stores all python code defining the flow networks, forward models (degradation models), data processing pipelines, and utilities. The subdirectory `src/metrics` contains codes for evaluating the trained generative models via FID score, log-odds, empirical mean and covariance estimates, and radiomic features (for the MRI study).

- The directory `masks_mri` contains the undersampling masks relevant to the Fourier undersampling forward operators used for the stylized MRI study.

- The directory `recon` contains the python codes for running image reconstruction/restoration algorithms including baseline classical algorithms, compressed sensing using generative models (CSGM), and posterior sampling via annealed Langevin dynamics. The details of these approaches can be found in our paper.

## Data

The data for the toy study is generated according to `toy/data.py`. The data for the MNIST, CelebA and MRI studies are obtained from the following resources:
- MNIST : Torchvision dataset: https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html
- CelebA-HQ : https://github.com/tkarras/progressive_growing_of_gans
- MRI: T2-weighted brain image volumes from the FastMRI Initiative Database: https://fastmri.med.nyu.edu/

## Model weights

The weights of our trained models can be found at (coming soon).

## Citations
```
@article{kelkar2023ambientflow,
title={AmbientFlow: Invertible generative models from incomplete, noisy measurements},
author={Kelkar, Varun A and Deshpande, Rucha and Banerjee, Arindam and Anastasio, Mark A},
journal={arXiv preprint arXiv:2309.04856},
year={2023}
}
```
Binary file added docs/AmbientFlow_generated_images_mri.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/schematic.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added masks_mri/cartesian_2fold_128.npy
Binary file not shown.
Binary file added masks_mri/cartesian_2fold_128.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added masks_mri/cartesian_4fold_128.npy
Binary file not shown.
Binary file added masks_mri/cartesian_4fold_128.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added masks_mri/cartesian_8fold_128.npy
Binary file not shown.
Binary file added masks_mri/cartesian_8fold_128.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
86 changes: 86 additions & 0 deletions recon/recon_bm3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
""" Copyright (c) 2022-2023 authors
Author : Varun A. Kelkar
Email : vak2@illinois.edu
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import numpy as np
import bm3d
import torch
import dataset_tool
import degradations
import ast
import os
import argparse
import imageio as io
import glob
import numpy.linalg as la
import sys
sys.path.append('../src')

parser = argparse.ArgumentParser()
# recon args
parser.add_argument("--num_images", type=int, default=50, help="Index of the image to be reconstructed")
parser.add_argument("--num_bits", type=int, default=0)
parser.add_argument("--results_dir", type=str, default='', help="Results dir")
parser.add_argument("--dataset", action='store_true')

# data args
parser.add_argument("--input_shape", type=int, nargs='+', default=[3, 28, 28])
parser.add_argument("--data_type", type=str, default='MNISTDataset')
parser.add_argument("--data_args", type=ast.literal_eval, default={'power_of_two': False})
parser.add_argument("--degradation_type", type=str, default='GaussianNoise')
parser.add_argument("--degradation_args", type=ast.literal_eval, default={'mean':0., 'std':0.3})
parser.add_argument("--tune", action='store_true', help="Tune the regularization parameter")

args = parser.parse_args()

torch.manual_seed(0)

# forward model
args.data_args['input_shape'] = args.input_shape
degradation = getattr(degradations, args.degradation_type)(**args.degradation_args, input_shape=args.input_shape, num_bits=args.num_bits)

# data
fnames = glob.glob(f"/shared/aristotle/SOMS/varun/ambientflow/data/CelebAHQDataset-GaussianNoise-0.2-image-data/*.png")
fnames_gt = glob.glob(f"/shared/aristotle/SOMS/varun/ambientflow/data/CelebAHQDataset-real-image-data/*.png")

for idx in range(args.num_images):
print(idx)
ymeas = io.imread(fnames[5000+idx]) / 255
xgt = io.imread(fnames_gt[5000+idx]) / 255
# ymeas = np.swapaxes(ymeas,0,1).T
# ymeas = ymeas.reshape(1,3,64,64)
# ymeas,_ = noisy_dataset[idx]; ymeas = ymeas.reshape(1, *ymeas.shape)
# ymeas = degradation.rev(ymeas, mode='real', use_device=False)
# print(xgt.min(), xgt.max(), ymeas.min(), ymeas.max())

# ymeas = np.squeeze(ymeas.numpy())
# ymeas = np.swapaxes(ymeas.T, 0,1)
# xgt = np.squeeze(xgt.numpy())
# xgt = np.swapaxes(xgt.T, 0,1)

# if args.tune:
# xests = []
# for i in range(5):
# reg_parameter = args.reg_parameter * ( args.reg_parameter2 / args.reg_parameter )**(i/4)
xest = bm3d.bm3d(ymeas, sigma_psd=args.reg_parameter, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)
print("RMSE error : ", la.norm(xest - xgt)/np.sqrt(np.prod(xgt.shape)))

# ymeas = np.clip(ymeas,-0.5,0.5)
if args.dataset:
os.makedirs(os.path.join(args.results_dir, 'dataset'), exist_ok=True)

if args.dataset:
io.imsave( os.path.join(args.results_dir, 'dataset', f'xest_{idx}_reg{args.reg_parameter}.png' ), xest)
else:
np.save( os.path.join(args.results_dir, f'ymeas_{idx}_reg{args.reg_parameter}.npy'), ymeas)
np.save( os.path.join(args.results_dir, f'xest_{idx}_reg{args.reg_parameter}.npy' ), xest)
np.save( os.path.join(args.results_dir, f'xgt_{idx}_reg{args.reg_parameter}.npy' ), xgt)
io.imsave( os.path.join(args.results_dir, f'ymeas_{idx}_reg{args.reg_parameter}.png' ), ymeas)
io.imsave( os.path.join(args.results_dir, f'xest_{idx}_reg{args.reg_parameter}.png' ), xest)
io.imsave( os.path.join(args.results_dir, f'xgt_{idx}_reg{args.reg_parameter}.png' ), xgt)

108 changes: 108 additions & 0 deletions recon/recon_csgm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
""" Copyright (c) 2022-2023 authors
Author : Varun A. Kelkar
Email : vak2@illinois.edu
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import torch
import torch.optim
import numpy as np
import dataset_tool
from models2 import *
import degradations
import ast
import os
import argparse
import imageio as io
import sys
sys.path.append('../src')

parser = argparse.ArgumentParser()
# recon args
parser.add_argument("--idx", type=int, default=0, help="Index of the image to be reconstructed")
parser.add_argument("--lamda", type=float, default=0, help="MAP regularization parameter")
parser.add_argument("--tv", type=float, default=0, help="TV regularization parameter")
parser.add_argument("--step", type=float, default=1e-03, help="Step size")
parser.add_argument("--num_iter", type=int, default=10000, help="Number of iterations for recon")
parser.add_argument("--num_bits", type=int, default=0)
parser.add_argument("--results_dir", type=str, default='', help="Results dir")
parser.add_argument("--project_on_mask", action='store_true', help="Only applicable for inpainting")

# model args
parser.add_argument("--input_shape", type=int, nargs='+', default=[3, 28, 28])
parser.add_argument("--model_path", type=str, default='', help="Path to model pkl file")

# data args
parser.add_argument("--data_type", type=str, default='MNISTDataset')
parser.add_argument("--data_args", type=ast.literal_eval, default={'power_of_two': False})
parser.add_argument("--degradation_type", type=str, default='GaussianNoise')
parser.add_argument("--degradation_args", type=ast.literal_eval, default={'mean':0., 'std':0.3})

args = parser.parse_args()
print(args, flush=True)

def total_variation(img):
x = img[:,:,1:,:] - img[:,:,:-1,:]
y = img[:,:,:,1:] - img[:,:,:,:-1]
tvnorm = torch.sum(torch.abs(x)) + torch.sum(torch.abs(y))
return tvnorm

device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)

# data
# noisy_dataset = getattr(dataset_tool, args.data_type)(train=False, ambient=True, degradation=degradation, **args.data_args)
clean_dataset = getattr(dataset_tool, args.data_type)(train=False, ambient=False, input_shape=args.input_shape, **args.data_args)

# forward model
args.data_args['input_shape'] = args.input_shape
degradation = getattr(degradations, args.degradation_type)(**args.degradation_args, input_shape=args.input_shape, num_bits=args.num_bits)

xgt ,_ = clean_dataset[args.idx]; xgt = xgt.reshape(1,*xgt.shape)
# ymeas,_ = noisy_dataset[args.idx]; ymeas = ymeas.reshape(1, *ymeas.shape); ymeas = ymeas.to(device)
ymeas = degradation(xgt).to(device)
xgt = xgt.to(device)
print(xgt.min(), xgt.max(), abs(ymeas).min(), abs(ymeas).max())

# model
if '/ambient/' in args.model_path:
model = load_model(args.model_path, ambient=True).to(device)
else: model = load_model(args.model_path).to(device)
model.eval()

# z = torch.zeros([1, np.prod(args.input_shape)], requires_grad=True).to(device)
z = torch.zeros( [1, np.prod(args.input_shape)], requires_grad=True, device=device)
# z = torch.nn.Parameter(z).to(device)
# model.initialize_actnorm(ymeas)
optimizer = torch.optim.Adam([z], lr=args.step)

for i in range(args.num_iter):

optimizer.zero_grad()
x,_ = model.reverse(z)
map_reg = args.lamda * torch.sum(z**2)
tv_reg = args.tv * total_variation(x)
loss = 0.5*torch.norm(ymeas - degradation.fwd_noiseless(x, use_device=True))**2 + tv_reg + map_reg
loss.backward()
optimizer.step()

if (i < 10) or (i < 50 and i % 10 == 0) or (i % 50 == 0) or (i == args.num_iter-1):
if args.project_on_mask:
x[:, degradation.mask.astype(bool)] = xgt[:, degradation.mask.astype(bool)]
recon_error = torch.mean( (x - xgt)**2 )
print(f"Idx : {i:05}, loss : {loss.cpu().detach().numpy()}, tv reg : {tv_reg}, recon. error : {recon_error}", flush=True)

ymeas = np.squeeze(ymeas.cpu().detach().numpy())
ymeas = np.swapaxes(ymeas.T, 0,1)
xest = np.squeeze(x.cpu().detach().numpy())
xest = np.swapaxes(xest.T, 0,1)
xgt = np.squeeze(xgt.cpu().detach().numpy())
xgt = np.swapaxes(xgt.T, 0,1)
np.save( os.path.join(args.results_dir, f'ymeas_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.npy'), ymeas)
np.save( os.path.join(args.results_dir, f'xest_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.npy' ), xest)
np.save( os.path.join(args.results_dir, f'xgt_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.npy' ), xgt)
io.imsave( os.path.join(args.results_dir, f'ymeas_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.png' ), ymeas)
io.imsave( os.path.join(args.results_dir, f'xest_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.png' ), (np.clip((xest+0.5)*255, 0, 255)))
io.imsave( os.path.join(args.results_dir, f'xgt_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.png' ), xgt)
97 changes: 97 additions & 0 deletions recon/recon_dip_tv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
""" Copyright (c) 2022-2023 authors
Author : Varun A. Kelkar
Email : vak2@illinois.edu
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import torch
import torch.optim
import numpy as np
import dataset_tool
import models
import degradations
import ast
import os
import argparse
import imageio as io
import sys
sys.path.append('../src')


parser = argparse.ArgumentParser()
# recon args
parser.add_argument("--idx", type=int, default=0, help="Index of the image to be reconstructed")
parser.add_argument("--reg_parameter", type=float, default=0, help="TV regularization parameter")
parser.add_argument("--step", type=float, default=1e-03, help="Step size")
parser.add_argument("--num_iter", type=int, default=10000, help="Number of iterations for recon")
parser.add_argument("--num_bits", type=int, default=0)
parser.add_argument("--results_dir", type=str, default='', help="Results dir")

# model args
parser.add_argument("--input_shape", type=int, nargs='+', default=[3, 28, 28])
parser.add_argument("--model_type", type=str, default='ConvINN')
parser.add_argument("--model_args", type=ast.literal_eval, default={'num_conv_layers': [4,12], 'num_fc_layers': [4]})

# data args
parser.add_argument("--data_type", type=str, default='MNISTDataset')
parser.add_argument("--data_args", type=ast.literal_eval, default={'power_of_two': False})
parser.add_argument("--degradation_type", type=str, default='GaussianNoise')
parser.add_argument("--degradation_args", type=ast.literal_eval, default={'mean':0., 'std':0.3})

args = parser.parse_args()

def total_variation(img):
x = img[:,:,1:,:] - img[:,:,:-1,:]
y = img[:,:,:,1:] - img[:,:,:,:-1]
tvnorm = torch.sum(torch.abs(x)) + torch.sum(torch.abs(y))
return tvnorm

device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)

# forward model
args.data_args['input_shape'] = args.input_shape
degradation = getattr(degradations, args.degradation_type)(**args.degradation_args, input_shape=args.input_shape, num_bits=args.num_bits)

# model
model = getattr(models, args.model_type)(args.input_shape, **args.model_args, device=device).to(device)

# data
noisy_dataset = getattr(dataset_tool, args.data_type)(train=False, ambient=True, degradation=degradation, **args.data_args)
clean_dataset = getattr(dataset_tool, args.data_type)(train=False, ambient=False, **args.data_args)

xgt ,_ = clean_dataset[args.idx]; xgt = xgt.reshape(1,*xgt.shape); xgt = xgt.to(device)
ymeas,_ = noisy_dataset[args.idx]; ymeas = ymeas.reshape(1, *ymeas.shape); ymeas = ymeas.to(device)
print(xgt.min(), xgt.max(), ymeas.min(), ymeas.max())

z = torch.randn([1, *args.input_shape]).to(device)
# model.initialize_actnorm(ymeas)
optimizer = torch.optim.Adam(list(model.trainable_parameters), lr=args.step)

for i in range(args.num_iter):

optimizer.zero_grad()
x = model(z)
tv_reg = args.reg_parameter * total_variation(x)
loss = 0.5*torch.norm(ymeas - x)**2 + tv_reg
loss.backward()
optimizer.step()

if (i < 10) or (i < 50 and i % 10 == 0) or (i % 50 == 0):
recon_error = torch.mean( (x - xgt)**2 )
print(f"Idx : {i:05}, loss : {loss.cpu().detach().numpy()}, tv reg : {tv_reg}, recon. error : {recon_error}")

ymeas = np.squeeze(ymeas.cpu().detach().numpy())
ymeas = np.swapaxes(ymeas.T, 0,1)
xest = np.squeeze(x.cpu().detach().numpy())
xest = np.swapaxes(xest.T, 0,1)
xgt = np.squeeze(xgt.cpu().detach().numpy())
xgt = np.swapaxes(xgt.T, 0,1)
np.save( os.path.join(args.results_dir, f'ymeas_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.npy'), ymeas)
np.save( os.path.join(args.results_dir, f'xest_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.npy' ), xest)
np.save( os.path.join(args.results_dir, f'xgt_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.npy' ), xgt)
io.imsave( os.path.join(args.results_dir, f'ymeas_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.png' ), ymeas)
io.imsave( os.path.join(args.results_dir, f'xest_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.png' ), (np.clip((xest+0.5)*255, 0, 255)))
io.imsave( os.path.join(args.results_dir, f'xgt_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.png' ), xgt)
Loading

0 comments on commit daabf76

Please sign in to comment.