Skip to content

[Community] Implementation of the IADB community pipeline #3996

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) |
| CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | - | [Karachev Denis](https://github.com/TheDenk) |
| TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)

To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
Expand Down Expand Up @@ -1707,3 +1708,62 @@ output = pipeline(
```
![Input_Image](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/input_image.png)
![mixture_canvas_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/canvas.png)


### IADB pipeline

This pipeline is the implementation of the [α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) paper.
It is a simple and minimalist diffusion model.

The following code shows how to use the IADB pipeline to generate images using a pretrained celebahq-256 model.

```python

pipeline_iadb = DiffusionPipeline.from_pretrained("thomasc4/iadb-celebahq-256", custom_pipeline='iadb')

pipeline_iadb = pipeline_iadb.to('cuda')

output = pipeline_iadb(batch_size=4,num_inference_steps=128)
for i in range(len(output[0])):
plt.imshow(output[0][i])
plt.show()

```

Sampling with the IADB formulation is easy, and can be done in a few lines (the pipeline already implements it):

```python

def sample_iadb(model, x0, nb_step):
x_alpha = x0
for t in range(nb_step):
alpha = (t/nb_step)
alpha_next =((t+1)/nb_step)

d = model(x_alpha, torch.tensor(alpha, device=x_alpha.device))['sample']
x_alpha = x_alpha + (alpha_next-alpha)*d

return x_alpha

```

The training loop is also straightforward:

```python

# Training loop
while True:
x0 = sample_noise()
x1 = sample_dataset()

alpha = torch.rand(batch_size)

# Blend
x_alpha = (1-alpha) * x0 + alpha * x1

# Loss
loss = torch.sum((D(x_alpha, alpha)- (x1-x0))**2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
```
149 changes: 149 additions & 0 deletions examples/community/iadb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import List, Optional, Tuple, Union

import torch

from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import ImagePipelineOutput
from diffusers.schedulers.scheduling_utils import SchedulerMixin


class IADBScheduler(SchedulerMixin, ConfigMixin):
"""
IADBScheduler is a scheduler for the Iterative α-(de)Blending denoising method. It is simple and minimalist.

For more details, see the original paper: https://arxiv.org/abs/2305.03486 and the blog post: https://ggx-research.github.io/publication/2023/05/10/publication-iadb.html
"""

def step(
self,
model_output: torch.FloatTensor,
timestep: int,
x_alpha: torch.FloatTensor,
) -> torch.FloatTensor:
"""
Predict the sample at the previous timestep by reversing the ODE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).

Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. It is the direction from x0 to x1.
timestep (`float`): current timestep in the diffusion chain.
x_alpha (`torch.FloatTensor`): x_alpha sample for the current timestep

Returns:
`torch.FloatTensor`: the sample at the previous timestep

"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

alpha = timestep / self.num_inference_steps
alpha_next = (timestep + 1) / self.num_inference_steps

d = model_output

x_alpha = x_alpha + (alpha_next - alpha) * d

return x_alpha

def set_timesteps(self, num_inference_steps: int):
self.num_inference_steps = num_inference_steps

def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
alpha: torch.FloatTensor,
) -> torch.FloatTensor:
return original_samples * alpha + noise * (1 - alpha)

def __len__(self):
return self.config.num_train_timesteps


class IADBPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

Parameters:
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
[`DDPMScheduler`], or [`DDIMScheduler`].
"""

def __init__(self, unet, scheduler):
super().__init__()

self.register_modules(unet=unet, scheduler=scheduler)

@torch.no_grad()
def __call__(
self,
batch_size: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 50,
output_type: Optional[str] = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.

Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""

# Sample gaussian noise to begin loop
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)

# set step values
self.scheduler.set_timesteps(num_inference_steps)
x_alpha = image.clone()
for t in self.progress_bar(range(num_inference_steps)):
alpha = t / num_inference_steps

# 1. predict noise model_output
model_output = self.unet(x_alpha, torch.tensor(alpha, device=x_alpha.device)).sample

# 2. step
x_alpha = self.scheduler.step(model_output, t, x_alpha)

image = (x_alpha * 0.5 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)

if not return_dict:
return (image,)

return ImagePipelineOutput(images=image)