-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Varun Kelkar
authored and
Varun Kelkar
committed
Jan 2, 2024
1 parent
b831c3d
commit daabf76
Showing
52 changed files
with
8,321 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
results | ||
**/__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# AmbientFlow: Invertible generative models from incomplete, noisy measurements | ||
|
||
PyTorch implementation | ||
|
||
**Paper**: https://arxiv.org/abs/2309.04856 | ||
|
||
Varun A. Kelkar, Rucha Deshpande, Arindam Banerjee, Mark A. Anastasio <br /> | ||
University of Illinois at Urbana-Champaign, Urbana, IL - 61801, USA | ||
|
||
**Contact**: vak2@illinois.edu, maa@illinois.edu | ||
|
||
**Abstract**: Generative models have gained popularity for their potential applications in imaging science, such as image reconstruction, posterior sampling and data sharing. Flow-based generative models are particularly attractive due to their ability to tractably provide exact density estimates along with fast, inexpensive and diverse samples. Training such models, however, requires a large, high quality dataset of objects. In applications such as computed imaging, it is often difficult to acquire such data due to requirements such as long acquisition time or high radiation dose, while acquiring noisy or partially observed measurements of these objects is more feasible. In this work, we propose AmbientFlow, a framework for learning flow-based generative models directly from noisy and incomplete data. Using variational Bayesian methods, a novel framework for establishing flow-based generative models from noisy, incomplete data is proposed. Extensive numerical studies demonstrate the effectiveness of AmbientFlow in learning the object distribution. The utility of AmbientFlow in a downstream inference task of image reconstruction is demonstrated. | ||
|
||
<p align="center"> | ||
<img src="./docs/schematic.png" alt="AmbientFlow MRI images" width="500"/> | ||
<img src="./docs/AmbientFlow_generated_images_mri.png" alt="AmbientFlow MRI images" width="1000"/> | ||
</p> | ||
|
||
|
||
## System Requirements | ||
- Linux/Unix-based systems recommended. The code hasn't been tested on Windows. | ||
- 64 bit Python 3.8+. The code has been tested with Python 3.8.13 installed via Anaconda | ||
|
||
Other packages required are listed in the `requirements.txt` file. | ||
|
||
## Directory structure | ||
|
||
- The directory `toy` contains the codes corresponding to the toy 2D distribution. | ||
|
||
- The directory `scripts/training` contains the top level shell scripts for training the regular and ambient models. Within this directory, `paper_configs.txt` contains the hyperparameter configurations used for the results in the paper. The number of GPUs and batch sizes listed are according to training on a single NVIDIA Quadro RTX 8000 48 GB GPU. | ||
|
||
- The directory `src` stores all python code defining the flow networks, forward models (degradation models), data processing pipelines, and utilities. The subdirectory `src/metrics` contains codes for evaluating the trained generative models via FID score, log-odds, empirical mean and covariance estimates, and radiomic features (for the MRI study). | ||
|
||
- The directory `masks_mri` contains the undersampling masks relevant to the Fourier undersampling forward operators used for the stylized MRI study. | ||
|
||
- The directory `recon` contains the python codes for running image reconstruction/restoration algorithms including baseline classical algorithms, compressed sensing using generative models (CSGM), and posterior sampling via annealed Langevin dynamics. The details of these approaches can be found in our paper. | ||
|
||
## Data | ||
|
||
The data for the toy study is generated according to `toy/data.py`. The data for the MNIST, CelebA and MRI studies are obtained from the following resources: | ||
- MNIST : Torchvision dataset: https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html | ||
- CelebA-HQ : https://github.com/tkarras/progressive_growing_of_gans | ||
- MRI: T2-weighted brain image volumes from the FastMRI Initiative Database: https://fastmri.med.nyu.edu/ | ||
|
||
## Model weights | ||
|
||
The weights of our trained models can be found at (coming soon). | ||
|
||
## Citations | ||
``` | ||
@article{kelkar2023ambientflow, | ||
title={AmbientFlow: Invertible generative models from incomplete, noisy measurements}, | ||
author={Kelkar, Varun A and Deshpande, Rucha and Banerjee, Arindam and Anastasio, Mark A}, | ||
journal={arXiv preprint arXiv:2309.04856}, | ||
year={2023} | ||
} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" Copyright (c) 2022-2023 authors | ||
Author : Varun A. Kelkar | ||
Email : vak2@illinois.edu | ||
This source code is licensed under the MIT license found in the | ||
LICENSE file in the root directory of this source tree. | ||
""" | ||
|
||
import numpy as np | ||
import bm3d | ||
import torch | ||
import dataset_tool | ||
import degradations | ||
import ast | ||
import os | ||
import argparse | ||
import imageio as io | ||
import glob | ||
import numpy.linalg as la | ||
import sys | ||
sys.path.append('../src') | ||
|
||
parser = argparse.ArgumentParser() | ||
# recon args | ||
parser.add_argument("--num_images", type=int, default=50, help="Index of the image to be reconstructed") | ||
parser.add_argument("--num_bits", type=int, default=0) | ||
parser.add_argument("--results_dir", type=str, default='', help="Results dir") | ||
parser.add_argument("--dataset", action='store_true') | ||
|
||
# data args | ||
parser.add_argument("--input_shape", type=int, nargs='+', default=[3, 28, 28]) | ||
parser.add_argument("--data_type", type=str, default='MNISTDataset') | ||
parser.add_argument("--data_args", type=ast.literal_eval, default={'power_of_two': False}) | ||
parser.add_argument("--degradation_type", type=str, default='GaussianNoise') | ||
parser.add_argument("--degradation_args", type=ast.literal_eval, default={'mean':0., 'std':0.3}) | ||
parser.add_argument("--tune", action='store_true', help="Tune the regularization parameter") | ||
|
||
args = parser.parse_args() | ||
|
||
torch.manual_seed(0) | ||
|
||
# forward model | ||
args.data_args['input_shape'] = args.input_shape | ||
degradation = getattr(degradations, args.degradation_type)(**args.degradation_args, input_shape=args.input_shape, num_bits=args.num_bits) | ||
|
||
# data | ||
fnames = glob.glob(f"/shared/aristotle/SOMS/varun/ambientflow/data/CelebAHQDataset-GaussianNoise-0.2-image-data/*.png") | ||
fnames_gt = glob.glob(f"/shared/aristotle/SOMS/varun/ambientflow/data/CelebAHQDataset-real-image-data/*.png") | ||
|
||
for idx in range(args.num_images): | ||
print(idx) | ||
ymeas = io.imread(fnames[5000+idx]) / 255 | ||
xgt = io.imread(fnames_gt[5000+idx]) / 255 | ||
# ymeas = np.swapaxes(ymeas,0,1).T | ||
# ymeas = ymeas.reshape(1,3,64,64) | ||
# ymeas,_ = noisy_dataset[idx]; ymeas = ymeas.reshape(1, *ymeas.shape) | ||
# ymeas = degradation.rev(ymeas, mode='real', use_device=False) | ||
# print(xgt.min(), xgt.max(), ymeas.min(), ymeas.max()) | ||
|
||
# ymeas = np.squeeze(ymeas.numpy()) | ||
# ymeas = np.swapaxes(ymeas.T, 0,1) | ||
# xgt = np.squeeze(xgt.numpy()) | ||
# xgt = np.swapaxes(xgt.T, 0,1) | ||
|
||
# if args.tune: | ||
# xests = [] | ||
# for i in range(5): | ||
# reg_parameter = args.reg_parameter * ( args.reg_parameter2 / args.reg_parameter )**(i/4) | ||
xest = bm3d.bm3d(ymeas, sigma_psd=args.reg_parameter, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING) | ||
print("RMSE error : ", la.norm(xest - xgt)/np.sqrt(np.prod(xgt.shape))) | ||
|
||
# ymeas = np.clip(ymeas,-0.5,0.5) | ||
if args.dataset: | ||
os.makedirs(os.path.join(args.results_dir, 'dataset'), exist_ok=True) | ||
|
||
if args.dataset: | ||
io.imsave( os.path.join(args.results_dir, 'dataset', f'xest_{idx}_reg{args.reg_parameter}.png' ), xest) | ||
else: | ||
np.save( os.path.join(args.results_dir, f'ymeas_{idx}_reg{args.reg_parameter}.npy'), ymeas) | ||
np.save( os.path.join(args.results_dir, f'xest_{idx}_reg{args.reg_parameter}.npy' ), xest) | ||
np.save( os.path.join(args.results_dir, f'xgt_{idx}_reg{args.reg_parameter}.npy' ), xgt) | ||
io.imsave( os.path.join(args.results_dir, f'ymeas_{idx}_reg{args.reg_parameter}.png' ), ymeas) | ||
io.imsave( os.path.join(args.results_dir, f'xest_{idx}_reg{args.reg_parameter}.png' ), xest) | ||
io.imsave( os.path.join(args.results_dir, f'xgt_{idx}_reg{args.reg_parameter}.png' ), xgt) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
""" Copyright (c) 2022-2023 authors | ||
Author : Varun A. Kelkar | ||
Email : vak2@illinois.edu | ||
This source code is licensed under the MIT license found in the | ||
LICENSE file in the root directory of this source tree. | ||
""" | ||
|
||
import torch | ||
import torch.optim | ||
import numpy as np | ||
import dataset_tool | ||
from models2 import * | ||
import degradations | ||
import ast | ||
import os | ||
import argparse | ||
import imageio as io | ||
import sys | ||
sys.path.append('../src') | ||
|
||
parser = argparse.ArgumentParser() | ||
# recon args | ||
parser.add_argument("--idx", type=int, default=0, help="Index of the image to be reconstructed") | ||
parser.add_argument("--lamda", type=float, default=0, help="MAP regularization parameter") | ||
parser.add_argument("--tv", type=float, default=0, help="TV regularization parameter") | ||
parser.add_argument("--step", type=float, default=1e-03, help="Step size") | ||
parser.add_argument("--num_iter", type=int, default=10000, help="Number of iterations for recon") | ||
parser.add_argument("--num_bits", type=int, default=0) | ||
parser.add_argument("--results_dir", type=str, default='', help="Results dir") | ||
parser.add_argument("--project_on_mask", action='store_true', help="Only applicable for inpainting") | ||
|
||
# model args | ||
parser.add_argument("--input_shape", type=int, nargs='+', default=[3, 28, 28]) | ||
parser.add_argument("--model_path", type=str, default='', help="Path to model pkl file") | ||
|
||
# data args | ||
parser.add_argument("--data_type", type=str, default='MNISTDataset') | ||
parser.add_argument("--data_args", type=ast.literal_eval, default={'power_of_two': False}) | ||
parser.add_argument("--degradation_type", type=str, default='GaussianNoise') | ||
parser.add_argument("--degradation_args", type=ast.literal_eval, default={'mean':0., 'std':0.3}) | ||
|
||
args = parser.parse_args() | ||
print(args, flush=True) | ||
|
||
def total_variation(img): | ||
x = img[:,:,1:,:] - img[:,:,:-1,:] | ||
y = img[:,:,:,1:] - img[:,:,:,:-1] | ||
tvnorm = torch.sum(torch.abs(x)) + torch.sum(torch.abs(y)) | ||
return tvnorm | ||
|
||
device = f'cuda:0' if torch.cuda.is_available() else 'cpu' | ||
torch.manual_seed(0) | ||
|
||
# data | ||
# noisy_dataset = getattr(dataset_tool, args.data_type)(train=False, ambient=True, degradation=degradation, **args.data_args) | ||
clean_dataset = getattr(dataset_tool, args.data_type)(train=False, ambient=False, input_shape=args.input_shape, **args.data_args) | ||
|
||
# forward model | ||
args.data_args['input_shape'] = args.input_shape | ||
degradation = getattr(degradations, args.degradation_type)(**args.degradation_args, input_shape=args.input_shape, num_bits=args.num_bits) | ||
|
||
xgt ,_ = clean_dataset[args.idx]; xgt = xgt.reshape(1,*xgt.shape) | ||
# ymeas,_ = noisy_dataset[args.idx]; ymeas = ymeas.reshape(1, *ymeas.shape); ymeas = ymeas.to(device) | ||
ymeas = degradation(xgt).to(device) | ||
xgt = xgt.to(device) | ||
print(xgt.min(), xgt.max(), abs(ymeas).min(), abs(ymeas).max()) | ||
|
||
# model | ||
if '/ambient/' in args.model_path: | ||
model = load_model(args.model_path, ambient=True).to(device) | ||
else: model = load_model(args.model_path).to(device) | ||
model.eval() | ||
|
||
# z = torch.zeros([1, np.prod(args.input_shape)], requires_grad=True).to(device) | ||
z = torch.zeros( [1, np.prod(args.input_shape)], requires_grad=True, device=device) | ||
# z = torch.nn.Parameter(z).to(device) | ||
# model.initialize_actnorm(ymeas) | ||
optimizer = torch.optim.Adam([z], lr=args.step) | ||
|
||
for i in range(args.num_iter): | ||
|
||
optimizer.zero_grad() | ||
x,_ = model.reverse(z) | ||
map_reg = args.lamda * torch.sum(z**2) | ||
tv_reg = args.tv * total_variation(x) | ||
loss = 0.5*torch.norm(ymeas - degradation.fwd_noiseless(x, use_device=True))**2 + tv_reg + map_reg | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if (i < 10) or (i < 50 and i % 10 == 0) or (i % 50 == 0) or (i == args.num_iter-1): | ||
if args.project_on_mask: | ||
x[:, degradation.mask.astype(bool)] = xgt[:, degradation.mask.astype(bool)] | ||
recon_error = torch.mean( (x - xgt)**2 ) | ||
print(f"Idx : {i:05}, loss : {loss.cpu().detach().numpy()}, tv reg : {tv_reg}, recon. error : {recon_error}", flush=True) | ||
|
||
ymeas = np.squeeze(ymeas.cpu().detach().numpy()) | ||
ymeas = np.swapaxes(ymeas.T, 0,1) | ||
xest = np.squeeze(x.cpu().detach().numpy()) | ||
xest = np.swapaxes(xest.T, 0,1) | ||
xgt = np.squeeze(xgt.cpu().detach().numpy()) | ||
xgt = np.swapaxes(xgt.T, 0,1) | ||
np.save( os.path.join(args.results_dir, f'ymeas_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.npy'), ymeas) | ||
np.save( os.path.join(args.results_dir, f'xest_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.npy' ), xest) | ||
np.save( os.path.join(args.results_dir, f'xgt_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.npy' ), xgt) | ||
io.imsave( os.path.join(args.results_dir, f'ymeas_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.png' ), ymeas) | ||
io.imsave( os.path.join(args.results_dir, f'xest_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.png' ), (np.clip((xest+0.5)*255, 0, 255))) | ||
io.imsave( os.path.join(args.results_dir, f'xgt_{args.idx}_lam{args.lamda}_tv{args.tv}_niter{args.num_iter}_step{args.step}.png' ), xgt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
""" Copyright (c) 2022-2023 authors | ||
Author : Varun A. Kelkar | ||
Email : vak2@illinois.edu | ||
This source code is licensed under the MIT license found in the | ||
LICENSE file in the root directory of this source tree. | ||
""" | ||
|
||
import torch | ||
import torch.optim | ||
import numpy as np | ||
import dataset_tool | ||
import models | ||
import degradations | ||
import ast | ||
import os | ||
import argparse | ||
import imageio as io | ||
import sys | ||
sys.path.append('../src') | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
# recon args | ||
parser.add_argument("--idx", type=int, default=0, help="Index of the image to be reconstructed") | ||
parser.add_argument("--reg_parameter", type=float, default=0, help="TV regularization parameter") | ||
parser.add_argument("--step", type=float, default=1e-03, help="Step size") | ||
parser.add_argument("--num_iter", type=int, default=10000, help="Number of iterations for recon") | ||
parser.add_argument("--num_bits", type=int, default=0) | ||
parser.add_argument("--results_dir", type=str, default='', help="Results dir") | ||
|
||
# model args | ||
parser.add_argument("--input_shape", type=int, nargs='+', default=[3, 28, 28]) | ||
parser.add_argument("--model_type", type=str, default='ConvINN') | ||
parser.add_argument("--model_args", type=ast.literal_eval, default={'num_conv_layers': [4,12], 'num_fc_layers': [4]}) | ||
|
||
# data args | ||
parser.add_argument("--data_type", type=str, default='MNISTDataset') | ||
parser.add_argument("--data_args", type=ast.literal_eval, default={'power_of_two': False}) | ||
parser.add_argument("--degradation_type", type=str, default='GaussianNoise') | ||
parser.add_argument("--degradation_args", type=ast.literal_eval, default={'mean':0., 'std':0.3}) | ||
|
||
args = parser.parse_args() | ||
|
||
def total_variation(img): | ||
x = img[:,:,1:,:] - img[:,:,:-1,:] | ||
y = img[:,:,:,1:] - img[:,:,:,:-1] | ||
tvnorm = torch.sum(torch.abs(x)) + torch.sum(torch.abs(y)) | ||
return tvnorm | ||
|
||
device = f'cuda:0' if torch.cuda.is_available() else 'cpu' | ||
torch.manual_seed(0) | ||
|
||
# forward model | ||
args.data_args['input_shape'] = args.input_shape | ||
degradation = getattr(degradations, args.degradation_type)(**args.degradation_args, input_shape=args.input_shape, num_bits=args.num_bits) | ||
|
||
# model | ||
model = getattr(models, args.model_type)(args.input_shape, **args.model_args, device=device).to(device) | ||
|
||
# data | ||
noisy_dataset = getattr(dataset_tool, args.data_type)(train=False, ambient=True, degradation=degradation, **args.data_args) | ||
clean_dataset = getattr(dataset_tool, args.data_type)(train=False, ambient=False, **args.data_args) | ||
|
||
xgt ,_ = clean_dataset[args.idx]; xgt = xgt.reshape(1,*xgt.shape); xgt = xgt.to(device) | ||
ymeas,_ = noisy_dataset[args.idx]; ymeas = ymeas.reshape(1, *ymeas.shape); ymeas = ymeas.to(device) | ||
print(xgt.min(), xgt.max(), ymeas.min(), ymeas.max()) | ||
|
||
z = torch.randn([1, *args.input_shape]).to(device) | ||
# model.initialize_actnorm(ymeas) | ||
optimizer = torch.optim.Adam(list(model.trainable_parameters), lr=args.step) | ||
|
||
for i in range(args.num_iter): | ||
|
||
optimizer.zero_grad() | ||
x = model(z) | ||
tv_reg = args.reg_parameter * total_variation(x) | ||
loss = 0.5*torch.norm(ymeas - x)**2 + tv_reg | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if (i < 10) or (i < 50 and i % 10 == 0) or (i % 50 == 0): | ||
recon_error = torch.mean( (x - xgt)**2 ) | ||
print(f"Idx : {i:05}, loss : {loss.cpu().detach().numpy()}, tv reg : {tv_reg}, recon. error : {recon_error}") | ||
|
||
ymeas = np.squeeze(ymeas.cpu().detach().numpy()) | ||
ymeas = np.swapaxes(ymeas.T, 0,1) | ||
xest = np.squeeze(x.cpu().detach().numpy()) | ||
xest = np.swapaxes(xest.T, 0,1) | ||
xgt = np.squeeze(xgt.cpu().detach().numpy()) | ||
xgt = np.swapaxes(xgt.T, 0,1) | ||
np.save( os.path.join(args.results_dir, f'ymeas_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.npy'), ymeas) | ||
np.save( os.path.join(args.results_dir, f'xest_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.npy' ), xest) | ||
np.save( os.path.join(args.results_dir, f'xgt_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.npy' ), xgt) | ||
io.imsave( os.path.join(args.results_dir, f'ymeas_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.png' ), ymeas) | ||
io.imsave( os.path.join(args.results_dir, f'xest_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.png' ), (np.clip((xest+0.5)*255, 0, 255))) | ||
io.imsave( os.path.join(args.results_dir, f'xgt_{args.idx}_reg{args.reg_parameter}_niter{args.num_iter}_step{args.step}.png' ), xgt) |
Oops, something went wrong.