Skip to content

Commit

Permalink
update app
Browse files Browse the repository at this point in the history
  • Loading branch information
franciszzj committed Dec 16, 2024
1 parent 4680a01 commit 0e32e3f
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 106 deletions.
249 changes: 143 additions & 106 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,90 +6,111 @@
from leffa.inference import LeffaInference
from utils.garment_agnostic_mask_predictor import AutoMasker
from utils.densepose_predictor import DensePosePredictor
from utils.utils import resize_and_center
from utils.utils import resize_and_center, list_dir

import gradio as gr

# Download checkpoints
snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")

mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)

densepose_predictor = DensePosePredictor(
config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
weights_path="./ckpts/densepose/model_final_162be9.pkl",
)

vt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
pretrained_model="./ckpts/virtual_tryon.pth",
)
vt_inference = LeffaInference(model=vt_model)

pt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
pretrained_model="./ckpts/pose_transfer.pth",
)
pt_inference = LeffaInference(model=pt_model)


def leffa_predict(src_image_path, ref_image_path, control_type):
assert control_type in [
"virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
src_image = Image.open(src_image_path)
ref_image = Image.open(ref_image_path)
src_image = resize_and_center(src_image, 768, 1024)
ref_image = resize_and_center(ref_image, 768, 1024)

src_image_array = np.array(src_image)
ref_image_array = np.array(ref_image)

# Mask
if control_type == "virtual_tryon":
src_image = src_image.convert("RGB")
mask = mask_predictor(src_image, "upper")["mask"]
elif control_type == "pose_transfer":
mask = Image.fromarray(np.ones_like(src_image_array) * 255)

# DensePose
src_image_iuv_array = densepose_predictor.predict_iuv(src_image_array)
src_image_seg_array = densepose_predictor.predict_seg(src_image_array)
src_image_iuv = Image.fromarray(src_image_iuv_array)
src_image_seg = Image.fromarray(src_image_seg_array)
if control_type == "virtual_tryon":
densepose = src_image_seg
elif control_type == "pose_transfer":
densepose = src_image_iuv

# Leffa
transform = LeffaTransform()

data = {
"src_image": [src_image],
"ref_image": [ref_image],
"mask": [mask],
"densepose": [densepose],
}
data = transform(data)
if control_type == "virtual_tryon":
inference = vt_inference
elif control_type == "pose_transfer":
inference = pt_inference
output = inference(data)
gen_image = output["generated_image"][0]
# gen_image.save("gen_image.png")
return np.array(gen_image)


def leffa_predict_vt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")


def leffa_predict_pt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "pose_transfer")

class LeffaPredictor(object):
def __init__(self):
self.mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)

self.densepose_predictor = DensePosePredictor(
config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
weights_path="./ckpts/densepose/model_final_162be9.pkl",
)

vt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
pretrained_model="./ckpts/virtual_tryon.pth",
)
self.vt_inference = LeffaInference(model=vt_model)
self.vt_model_type = "viton_hd"

pt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
pretrained_model="./ckpts/pose_transfer.pth",
)
self.pt_inference = LeffaInference(model=pt_model)

def change_vt_model(self, vt_model_type):
if vt_model_type == self.vt_model_type:
return
if vt_model_type == "viton_hd":
pretrained_model = "./ckpts/virtual_tryon.pth"
elif vt_model_type == "dress_code":
pretrained_model = "./ckpts/virtual_tryon_dc.pth"
vt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
pretrained_model=pretrained_model,
)
self.vt_inference = LeffaInference(model=vt_model)
self.vt_model_type = vt_model_type

def leffa_predict(self, src_image_path, ref_image_path, control_type, step=50, scale=2.5, seed=42):
assert control_type in [
"virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
src_image = Image.open(src_image_path)
ref_image = Image.open(ref_image_path)
src_image = resize_and_center(src_image, 768, 1024)
ref_image = resize_and_center(ref_image, 768, 1024)

src_image_array = np.array(src_image)

# Mask
if control_type == "virtual_tryon":
src_image = src_image.convert("RGB")
mask = self.mask_predictor(src_image, "upper")["mask"]
elif control_type == "pose_transfer":
mask = Image.fromarray(np.ones_like(src_image_array) * 255)

# DensePose
if control_type == "virtual_tryon":
src_image_seg_array = self.densepose_predictor.predict_seg(
src_image_array)
src_image_seg = Image.fromarray(src_image_seg_array)
densepose = src_image_seg
elif control_type == "pose_transfer":
src_image_iuv_array = self.densepose_predictor.predict_iuv(
src_image_array)
src_image_iuv = Image.fromarray(src_image_iuv_array)
densepose = src_image_iuv

# Leffa
transform = LeffaTransform()

data = {
"src_image": [src_image],
"ref_image": [ref_image],
"mask": [mask],
"densepose": [densepose],
}
data = transform(data)
if control_type == "virtual_tryon":
inference = self.vt_inference
elif control_type == "pose_transfer":
inference = self.pt_inference
output = inference(
data,
num_inference_steps=step,
guidance_scale=scale,
seed=seed,)
gen_image = output["generated_image"][0]
# gen_image.save("gen_image.png")
return np.array(gen_image)

def leffa_predict_vt(self, src_image_path, ref_image_path, step, scale, seed, vt_model_type="viton_hd"):
self.change_vt_model(vt_model_type)
return self.leffa_predict(src_image_path, ref_image_path, step, scale, seed, "virtual_tryon")

def leffa_predict_pt(self, src_image_path, ref_image_path, step, scale, seed):
return self.leffa_predict(src_image_path, ref_image_path, step, scale, seed, "pose_transfer")


if __name__ == "__main__":
Expand All @@ -100,6 +121,12 @@ def leffa_predict_pt(src_image_path, ref_image_path):
# control_type = sys.argv[3]
# leffa_predict(src_image_path, ref_image_path, control_type)

leffa_predictor = LeffaPredictor()
example_dir = "./ckpts/examples"
person1_images = list_dir(f"{example_dir}/person1")
person2_images = list_dir(f"{example_dir}/person2")
garment_images = list_dir(f"{example_dir}/garment")

title = "## Leffa: Learning Flow Fields in Attention for Controllable Person Image Generation"
link = "[📚 Paper](https://arxiv.org/abs/2412.08486) - [🤖 Code](https://github.com/franciszzj/Leffa) - [🔥 Demo](https://huggingface.co/spaces/franciszzj/Leffa) - [🤗 Model](https://huggingface.co/franciszzj/Leffa)"
news = """## News
Expand Down Expand Up @@ -130,12 +157,8 @@ def leffa_predict_pt(src_image_path, ref_image_path):

gr.Examples(
inputs=vt_src_image,
examples_per_page=5,
examples=["./ckpts/examples/person1/01350_00.jpg",
"./ckpts/examples/person1/01376_00.jpg",
"./ckpts/examples/person1/01416_00.jpg",
"./ckpts/examples/person1/05976_00.jpg",
"./ckpts/examples/person1/06094_00.jpg",],
examples_per_page=10,
examples=person1_images,
)

with gr.Column():
Expand All @@ -150,12 +173,8 @@ def leffa_predict_pt(src_image_path, ref_image_path):

gr.Examples(
inputs=vt_ref_image,
examples_per_page=5,
examples=["./ckpts/examples/garment/01449_00.jpg",
"./ckpts/examples/garment/01486_00.jpg",
"./ckpts/examples/garment/01853_00.jpg",
"./ckpts/examples/garment/02070_00.jpg",
"./ckpts/examples/garment/03553_00.jpg",],
examples_per_page=10,
examples=garment_images,
)

with gr.Column():
Expand All @@ -169,8 +188,24 @@ def leffa_predict_pt(src_image_path, ref_image_path):
with gr.Row():
vt_gen_button = gr.Button("Generate")

vt_gen_button.click(fn=leffa_predict_vt, inputs=[
vt_src_image, vt_ref_image], outputs=[vt_gen_image])
with gr.Accordion("Advanced Options", open=False):
vt_step = gr.Number(
label="Inference Steps", minimum=30, maximum=100, step=1, value=50)

vt_scale = gr.Number(
label="Guidance Scale", minimum=0.1, maximum=5.0, step=0.1, value=2.5)

vt_seed = gr.Number(
label="Random Seed", minimum=-1, maximum=2147483647, step=1, value=42)

vt_model_type = gr.Radio(
choices=["viton_hd", "dress_code"],
value="viton_hd",
label="Model Type",
)

vt_gen_button.click(fn=leffa_predictor.leffa_predict_vt, inputs=[
vt_src_image, vt_ref_image, vt_step, vt_scale, vt_seed, vt_model_type], outputs=[vt_gen_image])

with gr.Tab("Control Pose (Pose Transfer)"):
with gr.Row():
Expand All @@ -186,12 +221,8 @@ def leffa_predict_pt(src_image_path, ref_image_path):

gr.Examples(
inputs=pt_ref_image,
examples_per_page=5,
examples=["./ckpts/examples/person1/01350_00.jpg",
"./ckpts/examples/person1/01376_00.jpg",
"./ckpts/examples/person1/01416_00.jpg",
"./ckpts/examples/person1/05976_00.jpg",
"./ckpts/examples/person1/06094_00.jpg",],
examples_per_page=10,
examples=person1_images,
)

with gr.Column():
Expand All @@ -206,12 +237,8 @@ def leffa_predict_pt(src_image_path, ref_image_path):

gr.Examples(
inputs=pt_src_image,
examples_per_page=5,
examples=["./ckpts/examples/person2/01850_00.jpg",
"./ckpts/examples/person2/01875_00.jpg",
"./ckpts/examples/person2/02532_00.jpg",
"./ckpts/examples/person2/02902_00.jpg",
"./ckpts/examples/person2/05346_00.jpg",],
examples_per_page=10,
examples=person2_images,
)

with gr.Column():
Expand All @@ -225,8 +252,18 @@ def leffa_predict_pt(src_image_path, ref_image_path):
with gr.Row():
pose_transfer_gen_button = gr.Button("Generate")

pose_transfer_gen_button.click(fn=leffa_predict_pt, inputs=[
pt_src_image, pt_ref_image], outputs=[pt_gen_image])
with gr.Accordion("Advanced Options", open=False):
pt_step = gr.Number(
label="Inference Steps", minimum=30, maximum=100, step=1, value=50)

pt_scale = gr.Number(
label="Guidance Scale", minimum=0.1, maximum=5.0, step=0.1, value=2.5)

pt_seed = gr.Number(
label="Random Seed", minimum=-1, maximum=2147483647, step=1, value=42)

pose_transfer_gen_button.click(fn=leffa_predictor.leffa_predict_pt, inputs=[
pt_src_image, pt_ref_image, pt_step, pt_scale, pt_seed], outputs=[pt_gen_image])

gr.Markdown(note)

Expand Down
12 changes: 12 additions & 0 deletions utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import cv2
import numpy as np
from PIL import Image
Expand Down Expand Up @@ -29,3 +30,14 @@ def resize_and_center(image, target_width, target_height):
padded_img[top:top + new_height, left:left + new_width] = resized_img

return Image.fromarray(padded_img)


def list_dir(folder_path):
# Collect all file paths within the directory
file_paths = []
for root, _, files in os.walk(folder_path):
for file in files:
file_paths.append(os.path.join(root, file))

file_paths = sorted(file_paths)
return file_paths

0 comments on commit 0e32e3f

Please sign in to comment.