Skip to content

Commit

Permalink
initial code release
Browse files Browse the repository at this point in the history
  • Loading branch information
pix2pixzero committed Feb 11, 2023
1 parent ddccacd commit 25b8857
Show file tree
Hide file tree
Showing 38 changed files with 1,348 additions and 6 deletions.
101 changes: 95 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# pix2pix-zero [diffusers]
# pix2pix-zero

### [website](https://pix2pixzero.github.io/)
## [**[website]**](https://pix2pixzero.github.io/)


This is author's reimplementation of "Zero-shot Image-to-Image Translation" using the diffusers library. <br>
The results in the paper are based on the [CompVis](https://github.com/CompVis/stable-diffusion) library, which will be released later.

**[New!]** Code for editing real and synthetic images released!

## Code and Demo will be released in a few days!


<br>
Expand All @@ -20,14 +25,31 @@ We propose pix2pix-zero, a diffusion-based image-to-image approach that allows u
## Results
All our results are based on [stable-diffusion-v1-4](https://github.com/CompVis/stable-diffusion) model. Please the website for more results.




<div>
<p align="center">
<img src='assets/results_teaser.jpg' align="center" width=800px>
</p>
</div>
<hr>

The top row for each of the results below show editing of real images, and the bottom row shows synthetic image editing.
<div>
<p align="center">
<img src='assets/grid_dog2cat.jpg' align="center" width=800px>
</p>
<p align="center">
<img src='assets/grid_zebra2horse.jpg' align="center" width=800px>
</p>
<p align="center">
<img src='assets/grid_cat2dog.jpg' align="center" width=800px>
</p>
<p align="center">
<img src='assets/grid_horse2zebra.jpg' align="center" width=800px>
</p>
<p align="center">
<img src='assets/grid_tree2fall.jpg' align="center" width=800px>
</p>
</div>

## Real Image Editing
<div>
Expand Down Expand Up @@ -56,6 +78,73 @@ reference cross-attention maps.
</p>
</div>


## Getting Started

**Environment Setup**
- We provide a [conda env file](environment.yml) that contains all the required dependencies
```
conda env create -f environment.yml
```
- Following this, you can activate the conda environment with the command below.
```
conda activate pix2pix-zero
```

**Real Image Translation**
- First, run the inversion command below to obtain the input noise that reconstructs the image.
The command below will save the inversion in the results folder as `output/test_cat/inversion/cat_1.pt`
and the BLIP-generated prompt as `output/test_cat/prompt/cat_1.txt`
```
python src/inversion.py \
--input_image "assets/test_images/cats/cat_1.png" \
--results_folder "output/test_cat"
```
- Next, we can perform image editing with the editing direction as shown below.
The command below will save the edited image as `output/test_cat/edit/cat_1.png`
```
python src/edit_real.py \
--inversion "output/test_cat/inversion/cat_1.pt" \
--prompt "output/test_cat/prompt/cat_1.txt" \
--task_name "cat2dog" \
--results_folder "output/test_cat/"
```
**Editing Synthetic Images**
- Similarly, we can edit the synthetic images generated by Stable Diffusion with the following command.
```
python src/edit_synthetic.py \
--results_folder "output/synth_editing" \
--prompt_str "a high resolution painting of a cat in the style of van gough" \
--task "cat2dog"
```
### **Tips and Debugging**
- **Controlling the Image Structure:**<br>
The `--xa_guidance` flag controls the amount of cross-attention guidance to be applied when performing the edit. If the output edited image does not retain the structure from the input, increasing the value will typically address the issue. We recommend changing the value in increments of 0.05.
- **Improving Image Quality:**<br>
If the output image quality is low or has some artifacts, using more steps for both the inversion and editing would be helpful.
This can be controlled with the `--num_ddim_steps` flag.
- **Reducing the VRAM Requirements:**<br>
We can reduce the VRAM requirements using lower precision and setting the flag `--use_float_16`.
<br>
**Finding Custom Edit Directions**<br>
- We provide some pre-computed directions in the assets [folder](assets/embeddings_sd_1.4).
To generate new edit directions, users can first generate two files containing a large number of sentences (~1000) and then run the command as shown below.
```
python src/make_edit_direction.py \
--file_source_sentences sentences/apple.txt \
--file_target_sentences sentences/orange.txt \
--output_folder assets/embeddings_sd_1.4
```
- After running the above command, you can set the flag `--task apple2orange` for the new edit.
## Comparison
Comparisons with different baselines, including, SDEdit + word swap, DDIM + word swap, and prompt-to-propmt. Our method successfully applies the edit, while preserving the structure of the input image.
<div>
Expand Down
Binary file added assets/embeddings_sd_1.4/cat.pt
Binary file not shown.
Binary file added assets/embeddings_sd_1.4/dog.pt
Binary file not shown.
Binary file added assets/embeddings_sd_1.4/horse.pt
Binary file not shown.
Binary file added assets/embeddings_sd_1.4/zebra.pt
Binary file not shown.
Binary file added assets/grid_cat2dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/grid_dog2cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/grid_horse2zebra.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/grid_tree2fall.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/grid_zebra2horse.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/cats/cat_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/cats/cat_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/cats/cat_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/cats/cat_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/cats/cat_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/cats/cat_6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/cats/cat_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/cats/cat_8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/cats/cat_9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/dogs/dog_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/dogs/dog_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/dogs/dog_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/dogs/dog_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/dogs/dog_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/dogs/dog_6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/dogs/dog_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/dogs/dog_8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/test_images/dogs/dog_9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 65 additions & 0 deletions src/edit_real.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os, pdb

import argparse
import numpy as np
import torch
import requests
from PIL import Image

from diffusers import DDIMScheduler
from utils.ddim_inv import DDIMInversion
from utils.edit_directions import construct_direction
from utils.edit_pipeline import EditingPipeline


if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--inversion', required=True)
parser.add_argument('--prompt', type=str, required=True)
parser.add_argument('--task_name', type=str, default='cat2dog')
parser.add_argument('--results_folder', type=str, default='output/test_cat')
parser.add_argument('--num_ddim_steps', type=int, default=50)
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
parser.add_argument('--xa_guidance', default=0.1, type=float)
parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
parser.add_argument('--use_float_16', action='store_true')

args = parser.parse_args()

os.makedirs(os.path.join(args.results_folder, "edit"), exist_ok=True)
os.makedirs(os.path.join(args.results_folder, "reconstruction"), exist_ok=True)

if args.use_float_16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32

# if the inversion is a folder, the prompt should also be a folder
assert (os.path.isdir(args.inversion)==os.path.isdir(args.prompt)), "If the inversion is a folder, the prompt should also be a folder"
if os.path.isdir(args.inversion):
l_inv_paths = sorted(glob(os.path.join(args.inversion, "*.pt")))
l_bnames = [os.path.basename(x) for x in l_inv_paths]
l_prompt_paths = [os.path.join(args.prompt, x.replace(".pt",".txt")) for x in l_bnames]
else:
l_inv_paths = [args.inversion]
l_prompt_paths = [args.prompt]

# Make the editing pipeline
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)


for inv_path, prompt_path in zip(l_inv_paths, l_prompt_paths):
prompt_str = open(prompt_path).read().strip()
rec_pil, edit_pil = pipe(prompt_str,
num_inference_steps=args.num_ddim_steps,
x_in=torch.load(inv_path).unsqueeze(0),
edit_dir=construct_direction(args.task_name),
guidance_amount=args.xa_guidance,
guidance_scale=args.negative_guidance_scale,
negative_prompt=prompt_str # use the unedited prompt for the negative prompt
)

bname = os.path.basename(args.inversion).split(".")[0]
edit_pil[0].save(os.path.join(args.results_folder, f"edit/{bname}.png"))
rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction/{bname}.png"))
52 changes: 52 additions & 0 deletions src/edit_synthetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os, pdb

import argparse
import numpy as np
import torch
import requests
from PIL import Image

from diffusers import DDIMScheduler
from utils.edit_directions import construct_direction
from utils.edit_pipeline import EditingPipeline


if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--prompt_str', type=str, required=True)
parser.add_argument('--random_seed', default=0)
parser.add_argument('--task_name', type=str, default='cat2dog')
parser.add_argument('--results_folder', type=str, default='output/test_cat')
parser.add_argument('--num_ddim_steps', type=int, default=50)
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
parser.add_argument('--xa_guidance', default=0.15, type=float)
parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
parser.add_argument('--use_float_16', action='store_true')
args = parser.parse_args()

os.makedirs(args.results_folder, exist_ok=True)

if args.use_float_16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32

# make the input noise map
torch.cuda.manual_seed(args.random_seed)
x = torch.randn((1,4,64,64), device="cuda")

# Make the editing pipeline
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

rec_pil, edit_pil = pipe(args.prompt_str,
num_inference_steps=args.num_ddim_steps,
x_in=x,
edit_dir=construct_direction(args.task_name),
guidance_amount=args.xa_guidance,
guidance_scale=args.negative_guidance_scale,
negative_prompt="" # use the empty string for the negative prompt
)

edit_pil[0].save(os.path.join(args.results_folder, f"edit.png"))
rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction.png"))
64 changes: 64 additions & 0 deletions src/inversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os, pdb

import argparse
import numpy as np
import torch
import requests
from PIL import Image

from lavis.models import load_model_and_preprocess

from utils.ddim_inv import DDIMInversion
from utils.scheduler import DDIMInverseScheduler

if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png')
parser.add_argument('--results_folder', type=str, default='output/test_cat')
parser.add_argument('--num_ddim_steps', type=int, default=50)
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
parser.add_argument('--use_float_16', action='store_true')
args = parser.parse_args()

# make the output folders
os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True)
os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True)

if args.use_float_16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32


# load the BLIP model
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
# make the DDIM inversion pipeline
pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)


# if the input is a folder, collect all the images as a list
if os.path.isdir(args.input_image):
l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png")))
else:
l_img_paths = [args.input_image]


for img_path in l_img_paths:
bname = os.path.basename(args.input_image).split(".")[0]
img = Image.open(args.input_image).resize((512,512), Image.Resampling.LANCZOS)
# generate the caption
_image = vis_processors["eval"](img).unsqueeze(0).cuda()
prompt_str = model_blip.generate({"image": _image})[0]
x_inv, x_inv_image, x_dec_img = pipe(
prompt_str,
guidance_scale=1,
num_inversion_steps=args.num_ddim_steps,
img=img,
torch_dtype=torch_dtype
)
# save the inversion
torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt"))
# save the prompt string
with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f:
f.write(prompt_str)
61 changes: 61 additions & 0 deletions src/make_edit_direction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os, pdb

import argparse
import numpy as np
import torch
import requests
from PIL import Image

from diffusers import DDIMScheduler
from utils.edit_pipeline import EditingPipeline


## convert sentences to sentence embeddings
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
with torch.no_grad():
l_embeddings = []
for sent in l_sentences:
text_inputs = tokenizer(
sent,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
l_embeddings.append(prompt_embeds)
return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)


if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--file_source_sentences', required=True)
parser.add_argument('--file_target_sentences', required=True)
parser.add_argument('--output_folder', required=True)
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
args = parser.parse_args()

# load the model
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda")
bname_src = os.path.basename(args.file_source_sentences).strip(".txt")
outf_src = os.path.join(args.output_folder, bname_src+".pt")
if os.path.exists(outf_src):
print(f"Skipping source file {outf_src} as it already exists")
else:
with open(args.file_source_sentences, "r") as f:
l_sents = [x.strip() for x in f.readlines()]
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
print(mean_emb.shape)
torch.save(mean_emb, outf_src)

bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt")
outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt")
if os.path.exists(outf_tgt):
print(f"Skipping target file {outf_tgt} as it already exists")
else:
with open(args.file_target_sentences, "r") as f:
l_sents = [x.strip() for x in f.readlines()]
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
print(mean_emb.shape)
torch.save(mean_emb, outf_tgt)
Loading

0 comments on commit 25b8857

Please sign in to comment.