Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added link to arxiv.org paper #19

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 Johanna Karras

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
59 changes: 58 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,59 @@
# DreamPose
Official implementation of "DreamPose: Fashion Image-to-Video Synthesis via Stable Diffusion"
Official implementation of "DreamPose: Fashion Image-to-Video Synthesis via Stable Diffusion" by Johanna Karras, Aleksander Holynski, Ting-Chun Wang, and Ira Kemelmacher-Shlizerman.

* [Project Page](https://grail.cs.washington.edu/projects/dreampose)
* [Paper](https://arxiv.org/abs/2304.06025)

![Teaser Image](media/Teaser.png "Teaser")

## Demo

You can generate a video using DreamPose using our pretrained models.

1. [Download](https://drive.google.com/drive/folders/15SaT3kZFRIjxuHT6UrGr6j0183clTK_D?usp=share_link) and unzip the pretrained models inside demo/custom-chkpts.zip
2. [Download](https://drive.google.com/drive/folders/1CjzcOp_ZUt-dyrzNAFE0T8bS3cbKTsVG?usp=share_link) and unzip the input poses inside demo/sample/poses.zip
3. Run demo.py using the command below:
```
python test.py --epoch 499 --folder demo/custom-chkpts --pose_folder demo/sample/poses --key_frame_path demo/sample/key_frame.png --s1 8 --s2 3 --n_steps 100 --output_dir demo/sample/results --custom_vae demo/custom-chkpts/vae_1499.pth
```
## Data Preparation

To prepare a sample for finetuning, create a directory containing train and test subdirectories containing the train frames (desired subject) and test frames (desired pose sequence), respectively. Note that the test frames are not expected to be of the same subject. See demo/sample for an example.

Then, run [DensePose](https://github.com/facebookresearch/detectron2/tree/main/projects/DensePose) using the "densepose_rcnn_R_50_FPN_s1x" checkpoint on all images in the sample directory. Finally, reformat the pickled DensePose output using utils/densepose.py. You need to change the "outpath" filepath to point to the pickled DensePose output.

## Download or Finetune Base Model

DreamPose is finetuned on the UBC Fashion Dataset from a pretrained Stable Diffusion checkpoint. You can download our pretrained base model from [Google Drive](https://drive.google.com/file/d/10JjayW2mMqGxhUyM9ds_GHEvuqCTDaH3/view?usp=share_link), or finetune pretrained Stable Diffusion on your own image dataset. We train on 2 NVIDIA A100 GPUs.

```
accelerate launch --num_processes=4 train.py --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" --instance_data_dir=../path/to/dataset --output_dir=checkpoints --resolution=512 --train_batch_size=2 --gradient_accumulation_steps=4 --learning_rate=5e-6 --lr_scheduler="constant" --lr_warmup_steps=0 --num_train_epochs=300 --run_name dreampose --dropout_rate=0.15 --revision "ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c"
```

## Finetune on Sample

In this next step, we finetune DreamPose on a one or more input frames to create a subject-specific model.

1. Finetune the UNet

```
accelerate launch finetune-unet.py --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" --instance_data_dir=demo/sample/train --output_dir=demo/custom-chkpts --resolution=512 --train_batch_size=1 --gradient_accumulation_steps=1 --learning_rate=1e-5 --num_train_epochs=500 --dropout_rate=0.0 --custom_chkpt=checkpoints/unet_epoch_20.pth --revision "ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c"
```

2. Finetune the VAE decoder

```
accelerate launch --num_processes=1 finetune-vae.py --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" --instance_data_dir=demo/sample/train --output_dir=demo/custom-chkpts --instance_prompt="" --resolution=512 --train_batch_size=4 --gradient_accumulation_steps=4 --learning_rate=5e-5 --num_train_epochs=1500 --run_name finetuning/ubc-vae --revision "ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c"
```

## Testing

Once you have finetuned your custom, subject-specific DreamPose model, you can generate frames using the following command:

```
python test.py --epoch 499 --folder demo/custom-chkpts --pose_folder demo/sample/poses --key_frame_path demo/sample/key_frame.png --s1 8 --s2 3 --n_steps 100 --output_dir results --custom_vae demo/custom-chkpts/vae_1499.pth
```

### Acknowledgment

This code is largely adapted from the [HuggingFace diffusers repo](https://github.com/huggingface/diffusers).
188 changes: 188 additions & 0 deletions datasets/dreampose_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from torch.utils.data import Dataset
from pathlib import Path
from torchvision import transforms
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import os, cv2, glob

class DreamPoseDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""

def __init__(
self,
instance_data_root,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
train=True,
p_jitter=0.9
):
self.size = (640, 512)
self.center_crop = center_crop
self.train = train

self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")

# Load UBC Fashion Dataset
self.instance_images_path = glob.glob(instance_data_root+'/*png')

self.num_instance_images = len(self.instance_images_path)
self._length = self.num_instance_images

if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None

self.image_transforms = transforms.Compose(
[
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
#transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.3, hue=0.3),
#transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
#transforms.Normalize([0.5], [0.5]),
]
)

self.tensor_transforms = transforms.Compose(
[
transforms.Normalize([0.5], [0.5]),
]
)

def __len__(self):
return self._length

# resize sparse uv flow to size
def resize_pose(self, pose):
h1, w1 = pose.shape
h2, w2 = self.size, self.size
resized_pose = np.zeros((h2, w2))
x_vals = np.where(pose != 0)[0]
y_vals = np.where(pose != 0)[1]
for (x, y) in list(zip(x_vals, y_vals)):
# find new coordinates
x2, y2 = int(x * h2 / h1), int(y * w2 / w1)
resized_pose[x2, y2] = pose[x, y]
return resized_pose

def __getitem__(self, index):
example = {}

frame_path = self.instance_images_path[index % self.num_instance_images]
frame_folder = frame_path.replace(os.path.basename(frame_path), '')
#frame_number = int(os.path.basename(frame_path).split('frame_')[-1].replace('.png', ''))

# load frame i
instance_image = Image.open(frame_path)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["frame_i"] = self.image_transforms(instance_image)
example["frame_prev"] = self.image_transforms(instance_image)

assert example["frame_i"].shape == (3, 640, 512)

# Select other frame in this folder
frame_paths = glob.glob(frame_folder+'/*png')
frame_paths = [p for p in frame_paths if os.path.exists(p.replace('.png', '_densepose.npy'))]
frame_j_path = np.random.choice(frame_paths)

# load frame j
frame_j_path = np.random.choice(frame_paths)
instance_image = Image.open(frame_j_path)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["frame_j"] = self.image_transforms(instance_image)


# construct 5 input poses
poses = []
h, w = 640, 512
for pose_number in range(5):
dp_path = frame_j_path.replace('.png', '_densepose.npy')
dp_i = F.interpolate(torch.from_numpy(np.load(dp_path).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0)
poses.append(self.tensor_transforms(dp_i))
input_pose = torch.cat(poses, 0)
example["pose_j"] = input_pose

''' Data Augmentation '''
key_frame = example["frame_i"]
frame = example["frame_j"]
prev_frame = example["frame_prev"]

#dp = transforms.ToPILImage()(dp)

# Get random transforms to target 70% of the time
p = np.random.randint(0, 100)
if p < 70:
ang = np.random.randint(-15, 15) # rotation angle
distort = np.random.rand(0, 1)
top, left = np.random.randint(0, 25), np.random.randint(0, 25)
h_ = np.random.randint(self.size[0]-25, self.size[0]-top)
w_ = int(h_ / h * w)

t = transforms.Compose([transforms.ToPILImage(),\
transforms.Resize((h,w), interpolation=transforms.InterpolationMode.BILINEAR), \
transforms.ToTensor(),\
transforms.Normalize([0.5], [0.5])])

# Apply transforms
frame = transforms.functional.crop(frame, top, left, h_, w_) # random crop

example["frame_i"] = transforms.Normalize([0.5], [0.5])(key_frame) #t(key_frame)
example["frame_j"] = t(frame)

for pose_id in range(5):
start, end = 2*pose_id, 2*pose_id+2
# convert dense pose to PIL image
dp = example['pose_j'][start:end]
c, h, w = dp.shape
dp = torch.cat((dp, torch.zeros(1, h, w)), 0)
dp = transforms.functional.crop(dp, top, left, h_, w_) # random crop
dp = t(dp)[0:2] # Remove extra channel from input pose
example["pose_j"][start:end] = dp.clone()

# slightly perturb transforms to previous frame, to prevent copy/paste
top += np.random.randint(0, 5)
left += np.random.randint(0, 5)
h_ += np.random.randint(0, 5)
w_ += np.random.randint(0, 5)
prev_frame = transforms.functional.crop(prev_frame, top, left, h_, w_) # random crop
example["frame_prev"] = t(prev_frame)
else:
# slightly perturb transforms to previous frame, to prevent copy/paste
top, left = np.random.randint(0, 5), np.random.randint(0, 5)
h_ = np.random.randint(self.size[0]-5, self.size[0]-top)
w_ = int(h_ / h * w)

t = transforms.Compose([transforms.ToPILImage(),\
transforms.Resize((h,w), interpolation=transforms.InterpolationMode.BILINEAR), \
transforms.ToTensor(),\
transforms.Normalize([0.5], [0.5])])

prev_frame = transforms.functional.crop(prev_frame, top, left, h_, w_) # random crop
example["frame_prev"] = t(prev_frame)

example["frame_i"] = transforms.Normalize([0.5], [0.5])(key_frame) #t(key_frame)
example["frame_j"] = transforms.Normalize([0.5], [0.5])(frame)

for pose_id in range(5):
start, end = 2*pose_id, 2*pose_id+2
dp = example['pose_j'][start:end]
dp = transforms.Normalize([0.5], [0.5])(dp)[0:2]
example["pose_j"][start:end] = dp.clone()

return example
Loading