Authors: Ismail R. Alkhouri*, Shijun Liang*, Cheng-Han Huang, Jimmy Dai, Qing Qu, Saiprasad Ravishankar, and Rongrong Wang.
Diffusion models (DMs) are a class of generative models that allow sampling from a distribution learned over a training set. When applied to solving inverse imaging problems (IPs), the reverse sampling steps of DMs are typically modified to approximately sample from a measurement-conditioned distribution in the image space. However, these modifications may be unsuitable for certain settings (such as in the presence of measurement noise) and non-linear tasks, as they often struggle to correct errors from earlier sampling steps and generally require a large number of optimization and/or sampling steps. To address these challenges, we state three conditions for achieving measurement-consistent diffusion trajectories. Building on these conditions, we propose a new optimization-based sampling method that not only enforces the standard data manifold measurement consistency and forward diffusion consistency, as seen in previous studies, but also incorporates backward diffusion consistency that maintains a diffusion trajectory by optimizing over the input of the pre-trained model at every sampling step. By enforcing these conditions, either implicitly or explicitly, our sampler requires significantly fewer reverse steps. Therefore, we refer to our accelerated sampling method as Step-wise Triple-Consistent Sampling (SITCOM). Compared to existing state-of-the-art baseline methods, under different levels of measurement noise, our extensive experiments across five linear and three non-linear image restoration tasks demonstrate that SITCOM achieves competitive or superior results in terms of standard image similarity metrics while requiring a significantly reduced run-time across all considered tasks.
-
python 3.8
-
pytorch 1.11.0
-
CUDA 11.3.1
-
nvidia-docker (if you use GPU in docker container)
It is okay to use lower version of CUDA with proper pytorch version. For example, CUDA 10.2 with pytorch 1.7.0.
### For phase retrieval with ode solver we will refer to the other director on [SITCOM-ODE](https://github.com/sjames40/SITCOM_ODE).
From the link, download the checkpoint "ffhq_10m.pt" and paste it to ./models/
mkdir models
mv {DOWNLOAD_DIR}/ffqh_10m.pt ./models/
{DOWNLOAD_DIR} is the directory that you downloaded checkpoint to.
We use the external codes for motion-blurring and non-linear deblurring, similar to DPS (see references below).
git clone https://github.com/VinAIResearch/blur-kernel-space-exploring bkse
git clone https://github.com/LeviBorodenko/motionblur motionblur
Install dependencies
conda create -n new_diffusion python=3.8
conda activate new_diffusion
pip install torch
pip install numpy
pip install diffuser
pip install matplotlib
you can run python SITCOM.py for measurement noiseless case
for the noise case like 0.05
you can run python SITCOM_with_noise
with the script
like SR
python3 SITCOM_with_noise.py\
--model_config=configs/model_config.yaml \
--diffusion_config=configs/diffusion_config.yaml \
--task_config=configs/super_resolution_config.yaml \
--gpu=0 \
--file_path=/path/to/input_file \
--save_path=/path/to/output_dir \
--device=cuda \
--learning_rate=0.02 \
--num_steps=30 \
--n_step=20 \
--noiselevel=0.05 \
--random_seed=42
The threshold may vary depending on the task; for instance, in super-resolution, it is set to 20, whereas for phase retrieval, it is 80, other image task is 50. as it depends on the image size.
For imagenet, use configs/imagenet_model_config.yaml
# Linear inverse problems
- configs/super_resolution_config.yaml
- configs/gaussian_deblur_config.yaml
- configs/motion_deblur_config.yaml
- configs/inpainting_config.yaml
# Non-linear inverse problems
- configs/nonlinear_deblur_config.yaml
- configs/phase_retrieval_config.yaml
- configs/hdr.yaml
You need to write your data directory at data.root. Default is ./data/samples which contains 10 sample images from FFHQ validation set.
data:
name: ffhq
root: ./data/samples/
measurement:
operator:
name: # check candidates in guided_diffusion/measurements.py
noise:
name: # gaussian
sigma: # if you use name: gaussian, set this.
@inproceedings{sitcomICML25,
author = {Alkhouri, Ismail and Linag, Shijun and Huang, Cheng-Han, and Dai, Jimmy, and Qu, Qing and Ravishankar, Saiprasad, and Wang, Rongrong},
title = {SITCOM: Step-wise Triple-Consistent Diffusion Sampling For Inverse Problems},
booktitle = {International Conference on Machine Learning (ICML)},
year = {2025}
}


