forked from pix2pixzero/pix2pix-zero
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ddccacd
commit 25b8857
Showing
38 changed files
with
1,348 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os, pdb | ||
|
||
import argparse | ||
import numpy as np | ||
import torch | ||
import requests | ||
from PIL import Image | ||
|
||
from diffusers import DDIMScheduler | ||
from utils.ddim_inv import DDIMInversion | ||
from utils.edit_directions import construct_direction | ||
from utils.edit_pipeline import EditingPipeline | ||
|
||
|
||
if __name__=="__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--inversion', required=True) | ||
parser.add_argument('--prompt', type=str, required=True) | ||
parser.add_argument('--task_name', type=str, default='cat2dog') | ||
parser.add_argument('--results_folder', type=str, default='output/test_cat') | ||
parser.add_argument('--num_ddim_steps', type=int, default=50) | ||
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') | ||
parser.add_argument('--xa_guidance', default=0.1, type=float) | ||
parser.add_argument('--negative_guidance_scale', default=5.0, type=float) | ||
parser.add_argument('--use_float_16', action='store_true') | ||
|
||
args = parser.parse_args() | ||
|
||
os.makedirs(os.path.join(args.results_folder, "edit"), exist_ok=True) | ||
os.makedirs(os.path.join(args.results_folder, "reconstruction"), exist_ok=True) | ||
|
||
if args.use_float_16: | ||
torch_dtype = torch.float16 | ||
else: | ||
torch_dtype = torch.float32 | ||
|
||
# if the inversion is a folder, the prompt should also be a folder | ||
assert (os.path.isdir(args.inversion)==os.path.isdir(args.prompt)), "If the inversion is a folder, the prompt should also be a folder" | ||
if os.path.isdir(args.inversion): | ||
l_inv_paths = sorted(glob(os.path.join(args.inversion, "*.pt"))) | ||
l_bnames = [os.path.basename(x) for x in l_inv_paths] | ||
l_prompt_paths = [os.path.join(args.prompt, x.replace(".pt",".txt")) for x in l_bnames] | ||
else: | ||
l_inv_paths = [args.inversion] | ||
l_prompt_paths = [args.prompt] | ||
|
||
# Make the editing pipeline | ||
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda") | ||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | ||
|
||
|
||
for inv_path, prompt_path in zip(l_inv_paths, l_prompt_paths): | ||
prompt_str = open(prompt_path).read().strip() | ||
rec_pil, edit_pil = pipe(prompt_str, | ||
num_inference_steps=args.num_ddim_steps, | ||
x_in=torch.load(inv_path).unsqueeze(0), | ||
edit_dir=construct_direction(args.task_name), | ||
guidance_amount=args.xa_guidance, | ||
guidance_scale=args.negative_guidance_scale, | ||
negative_prompt=prompt_str # use the unedited prompt for the negative prompt | ||
) | ||
|
||
bname = os.path.basename(args.inversion).split(".")[0] | ||
edit_pil[0].save(os.path.join(args.results_folder, f"edit/{bname}.png")) | ||
rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction/{bname}.png")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os, pdb | ||
|
||
import argparse | ||
import numpy as np | ||
import torch | ||
import requests | ||
from PIL import Image | ||
|
||
from diffusers import DDIMScheduler | ||
from utils.edit_directions import construct_direction | ||
from utils.edit_pipeline import EditingPipeline | ||
|
||
|
||
if __name__=="__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--prompt_str', type=str, required=True) | ||
parser.add_argument('--random_seed', default=0) | ||
parser.add_argument('--task_name', type=str, default='cat2dog') | ||
parser.add_argument('--results_folder', type=str, default='output/test_cat') | ||
parser.add_argument('--num_ddim_steps', type=int, default=50) | ||
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') | ||
parser.add_argument('--xa_guidance', default=0.15, type=float) | ||
parser.add_argument('--negative_guidance_scale', default=5.0, type=float) | ||
parser.add_argument('--use_float_16', action='store_true') | ||
args = parser.parse_args() | ||
|
||
os.makedirs(args.results_folder, exist_ok=True) | ||
|
||
if args.use_float_16: | ||
torch_dtype = torch.float16 | ||
else: | ||
torch_dtype = torch.float32 | ||
|
||
# make the input noise map | ||
torch.cuda.manual_seed(args.random_seed) | ||
x = torch.randn((1,4,64,64), device="cuda") | ||
|
||
# Make the editing pipeline | ||
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda") | ||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | ||
|
||
rec_pil, edit_pil = pipe(args.prompt_str, | ||
num_inference_steps=args.num_ddim_steps, | ||
x_in=x, | ||
edit_dir=construct_direction(args.task_name), | ||
guidance_amount=args.xa_guidance, | ||
guidance_scale=args.negative_guidance_scale, | ||
negative_prompt="" # use the empty string for the negative prompt | ||
) | ||
|
||
edit_pil[0].save(os.path.join(args.results_folder, f"edit.png")) | ||
rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction.png")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os, pdb | ||
|
||
import argparse | ||
import numpy as np | ||
import torch | ||
import requests | ||
from PIL import Image | ||
|
||
from lavis.models import load_model_and_preprocess | ||
|
||
from utils.ddim_inv import DDIMInversion | ||
from utils.scheduler import DDIMInverseScheduler | ||
|
||
if __name__=="__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png') | ||
parser.add_argument('--results_folder', type=str, default='output/test_cat') | ||
parser.add_argument('--num_ddim_steps', type=int, default=50) | ||
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') | ||
parser.add_argument('--use_float_16', action='store_true') | ||
args = parser.parse_args() | ||
|
||
# make the output folders | ||
os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True) | ||
os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True) | ||
|
||
if args.use_float_16: | ||
torch_dtype = torch.float16 | ||
else: | ||
torch_dtype = torch.float32 | ||
|
||
|
||
# load the BLIP model | ||
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda")) | ||
# make the DDIM inversion pipeline | ||
pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda") | ||
pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) | ||
|
||
|
||
# if the input is a folder, collect all the images as a list | ||
if os.path.isdir(args.input_image): | ||
l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png"))) | ||
else: | ||
l_img_paths = [args.input_image] | ||
|
||
|
||
for img_path in l_img_paths: | ||
bname = os.path.basename(args.input_image).split(".")[0] | ||
img = Image.open(args.input_image).resize((512,512), Image.Resampling.LANCZOS) | ||
# generate the caption | ||
_image = vis_processors["eval"](img).unsqueeze(0).cuda() | ||
prompt_str = model_blip.generate({"image": _image})[0] | ||
x_inv, x_inv_image, x_dec_img = pipe( | ||
prompt_str, | ||
guidance_scale=1, | ||
num_inversion_steps=args.num_ddim_steps, | ||
img=img, | ||
torch_dtype=torch_dtype | ||
) | ||
# save the inversion | ||
torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt")) | ||
# save the prompt string | ||
with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f: | ||
f.write(prompt_str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os, pdb | ||
|
||
import argparse | ||
import numpy as np | ||
import torch | ||
import requests | ||
from PIL import Image | ||
|
||
from diffusers import DDIMScheduler | ||
from utils.edit_pipeline import EditingPipeline | ||
|
||
|
||
## convert sentences to sentence embeddings | ||
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"): | ||
with torch.no_grad(): | ||
l_embeddings = [] | ||
for sent in l_sentences: | ||
text_inputs = tokenizer( | ||
sent, | ||
padding="max_length", | ||
max_length=tokenizer.model_max_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
) | ||
text_input_ids = text_inputs.input_ids | ||
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] | ||
l_embeddings.append(prompt_embeds) | ||
return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0) | ||
|
||
|
||
if __name__=="__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--file_source_sentences', required=True) | ||
parser.add_argument('--file_target_sentences', required=True) | ||
parser.add_argument('--output_folder', required=True) | ||
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') | ||
args = parser.parse_args() | ||
|
||
# load the model | ||
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda") | ||
bname_src = os.path.basename(args.file_source_sentences).strip(".txt") | ||
outf_src = os.path.join(args.output_folder, bname_src+".pt") | ||
if os.path.exists(outf_src): | ||
print(f"Skipping source file {outf_src} as it already exists") | ||
else: | ||
with open(args.file_source_sentences, "r") as f: | ||
l_sents = [x.strip() for x in f.readlines()] | ||
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda") | ||
print(mean_emb.shape) | ||
torch.save(mean_emb, outf_src) | ||
|
||
bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt") | ||
outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt") | ||
if os.path.exists(outf_tgt): | ||
print(f"Skipping target file {outf_tgt} as it already exists") | ||
else: | ||
with open(args.file_target_sentences, "r") as f: | ||
l_sents = [x.strip() for x in f.readlines()] | ||
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda") | ||
print(mean_emb.shape) | ||
torch.save(mean_emb, outf_tgt) |
Oops, something went wrong.