-
Notifications
You must be signed in to change notification settings - Fork 73
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Johanna Karras
authored and
Johanna Karras
committed
Mar 21, 2023
1 parent
7bdab77
commit b9bb3f4
Showing
6 changed files
with
1,330 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
from torch.utils.data import Dataset | ||
from pathlib import Path | ||
from torchvision import transforms | ||
import torch | ||
import torch.nn.functional as F | ||
from PIL import Image | ||
import numpy as np | ||
import os, cv2, glob | ||
|
||
''' | ||
- Passes 5 consecutive input poses per sample | ||
- Ensures at least one pair of consecutive frames per batch | ||
''' | ||
class DreamPoseDataset(Dataset): | ||
""" | ||
A dataset to prepare the instance and class images with the prompts for fine-tuning the model. | ||
It pre-processes the images and the tokenizes prompts. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
instance_data_root, | ||
class_data_root=None, | ||
class_prompt=None, | ||
size=512, | ||
center_crop=False, | ||
train=True, | ||
p_jitter=0.9, | ||
n_poses=5 | ||
): | ||
self.size = (640, 512) | ||
self.center_crop = center_crop | ||
self.train = train | ||
self.n_poses = n_poses | ||
|
||
self.instance_data_root = Path(instance_data_root) | ||
if not self.instance_data_root.exists(): | ||
raise ValueError("Instance images root doesn't exists.") | ||
|
||
# Load UBC Fashion Dataset | ||
self.instance_images_path = glob.glob('../UBC_Fashion_Dataset/train-frames/*/*png') | ||
self.instance_images_path = [p for p in self.instance_images_path if os.path.exists(p.replace('.png', '_densepose.npy'))] | ||
len1 = len(self.instance_images_path) | ||
|
||
# Load Deep Fashion Dataset | ||
self.instance_images_path.extend([path for path in glob.glob('../Deep_Fashion_Dataset/img_highres/*/*/*/*.jpg') \ | ||
if os.path.exists(path.replace('.jpg', '_densepose.npy'))]) | ||
|
||
len2 = len(self.instance_images_path) | ||
print(f"Train Dataset: {len1} UBC Fashion images, {len2-len1} Deep Fashion images.") | ||
|
||
self.num_instance_images = len(self.instance_images_path) | ||
self._length = self.num_instance_images | ||
|
||
if class_data_root is not None: | ||
self.class_data_root = Path(class_data_root) | ||
self.class_data_root.mkdir(parents=True, exist_ok=True) | ||
self.class_images_path = list(self.class_data_root.iterdir()) | ||
self.num_class_images = len(self.class_images_path) | ||
self._length = max(self.num_class_images, self.num_instance_images) | ||
self.class_prompt = class_prompt | ||
else: | ||
self.class_data_root = None | ||
|
||
self.image_transforms = transforms.Compose( | ||
[ | ||
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), | ||
#transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.3, hue=0.3), | ||
#transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5], [0.5]), | ||
] | ||
) | ||
|
||
self.tensor_transforms = transforms.Compose( | ||
[ | ||
transforms.Normalize([0.5], [0.5]), | ||
] | ||
) | ||
|
||
def __len__(self): | ||
return self._length | ||
|
||
# resize sparse uv flow to size | ||
def resize_pose(self, pose): | ||
h1, w1 = pose.shape | ||
h2, w2 = self.size, self.size | ||
resized_pose = np.zeros((h2, w2)) | ||
x_vals = np.where(pose != 0)[0] | ||
y_vals = np.where(pose != 0)[1] | ||
for (x, y) in list(zip(x_vals, y_vals)): | ||
# find new coordinates | ||
x2, y2 = int(x * h2 / h1), int(y * w2 / w1) | ||
resized_pose[x2, y2] = pose[x, y] | ||
return resized_pose | ||
|
||
# return two consecutive frames per call | ||
def __getitem__(self, index): | ||
example = {} | ||
|
||
''' | ||
Prepare frame #1 | ||
''' | ||
# load frame i | ||
frame_path = self.instance_images_path[index % self.num_instance_images] | ||
instance_image = Image.open(frame_path) | ||
if not instance_image.mode == "RGB": | ||
instance_image = instance_image.convert("RGB") | ||
example["frame_i"] = self.image_transforms(instance_image) | ||
|
||
# Get additional frames in this folder | ||
sample_folder = frame_path.replace(os.path.basename(frame_path), '') | ||
samples = [path for path in glob.glob(sample_folder+'/*') if 'npy' not in path] | ||
samples = [path for path in samples if os.path.exists(path.replace('.jpg', '_densepose.npy').replace('.png', '_densepose.npy'))] | ||
|
||
if 'Deep_Fashion' in frame_path: | ||
idx = os.path.basename(frame_path).split('_')[0] | ||
samples = [s for s in samples if os.path.basename(s).split('_')[0] == idx] | ||
#print("Frame Path = ", frame_path) | ||
#print("Sampels = ", samples) | ||
|
||
frame_j_path = samples[np.random.choice(range(len(samples)))] | ||
pose_j_path = frame_j_path.replace('.jpg', '_densepose.npy') | ||
|
||
# load frame j | ||
instance_image = Image.open(frame_j_path) | ||
if not instance_image.mode == "RGB": | ||
instance_image = instance_image.convert("RGB") | ||
example["frame_j"] = self.image_transforms(instance_image) | ||
|
||
# Load 5 poses surrounding j | ||
_, h, w = example["frame_i"].shape | ||
poses = [] | ||
idx1= int(self.n_poses // 2) | ||
idx2 = self.n_poses - idx1 | ||
for pose_number in range(5): | ||
dp_path = frame_j_path.replace('.jpg', '_densepose.npy').replace('.png', '_densepose.npy') | ||
dp_i = F.interpolate(torch.from_numpy(np.load(dp_path, allow_pickle=True).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) | ||
poses.append(self.tensor_transforms(dp_i)) | ||
|
||
example["pose_j"] = torch.cat(poses, 0) | ||
|
||
''' | ||
Prepare frame #2 | ||
''' | ||
new_frame_path = samples[np.random.choice(range(len(samples)))] | ||
frame_path = new_frame_path | ||
|
||
# load frame i | ||
instance_image = Image.open(frame_path) | ||
if not instance_image.mode == "RGB": | ||
instance_image = instance_image.convert("RGB") | ||
example["frame_i"] = torch.stack((example["frame_i"], self.image_transforms(instance_image)), 0) | ||
|
||
assert example["frame_i"].shape == (2, 3, 640, 512) | ||
|
||
# Load frame j | ||
frame_j_path = samples[np.random.choice(range(len(samples)))] | ||
instance_image = Image.open(frame_j_path) | ||
if not instance_image.mode == "RGB": | ||
instance_image = instance_image.convert("RGB") | ||
example["frame_j"] = torch.stack((example['frame_j'], self.image_transforms(instance_image)), 0) | ||
|
||
# Load 5 poses surrounding j | ||
poses = [] | ||
for pose_number in range(5): | ||
dp_path = frame_j_path.replace('.jpg', '_densepose.npy').replace('.png', '_densepose.npy') | ||
dp_i = F.interpolate(torch.from_numpy(np.load(dp_path, allow_pickle=True).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) | ||
poses.append(self.tensor_transforms(dp_i)) | ||
|
||
poses = torch.cat(poses, 0) | ||
example["pose_j"] = torch.stack((example["pose_j"], poses), 0) | ||
|
||
#print(example["frame_i"].shape, example["frame_j"].shape, example["pose_j"].shape) | ||
return example |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Load pretrained 2D UNet and modify with temporal attention | ||
|
||
from dataclasses import dataclass | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch import einsum | ||
import torch.utils.checkpoint | ||
from einops import rearrange | ||
|
||
import math | ||
|
||
from diffusers import AutoencoderKL | ||
from diffusers.models import UNet2DConditionModel | ||
from diffusers.models import BasicTransformerBlock | ||
|
||
|
||
def get_unet(pretrained_model_name_or_path, revision, resolution=256, n_poses=5): | ||
# Load pretrained UNet layers | ||
unet = UNet2DConditionModel.from_pretrained( | ||
"CompVis/stable-diffusion-v1-4", | ||
subfolder="unet", | ||
revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c" | ||
) | ||
|
||
# Modify input layer to have 1 additional input channels (pose) | ||
weights = unet.conv_in.weight.clone() | ||
unet.conv_in = nn.Conv2d(4 + 2*n_poses, weights.shape[0], kernel_size=3, padding=(1, 1)) # input noise + n poses | ||
with torch.no_grad(): | ||
unet.conv_in.weight[:, :4] = weights # original weights | ||
unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape) # new weights initialized to zero | ||
|
||
return unet | ||
|
||
''' | ||
This module takes in CLIP + VAE embeddings and outputs CLIP-compatible embeddings. | ||
''' | ||
class Embedding_Adapter(nn.Module): | ||
def __init__(self, input_nc=38, output_nc=4, norm_layer=nn.InstanceNorm2d, chkpt=None): | ||
super(Embedding_Adapter, self).__init__() | ||
|
||
self.save_method_name = "adapter" | ||
|
||
self.pool = nn.MaxPool2d(2) | ||
self.vae2clip = nn.Linear(1280, 768) | ||
|
||
self.linear1 = nn.Linear(54, 50) # 50 x 54 shape | ||
|
||
# initialize weights | ||
with torch.no_grad(): | ||
self.linear1.weight = nn.Parameter(torch.eye(50, 54)) | ||
|
||
if chkpt is not None: | ||
pass | ||
|
||
def forward(self, clip, vae): | ||
|
||
vae = self.pool(vae) # 1 4 80 64 --> 1 4 40 32 | ||
vae = rearrange(vae, 'b c h w -> b c (h w)') # 1 4 20 16 --> 1 4 1280 | ||
|
||
vae = self.vae2clip(vae) # 1 4 768 | ||
|
||
# Concatenate | ||
concat = torch.cat((clip, vae), 1) | ||
|
||
# Encode | ||
|
||
concat = rearrange(concat, 'b c d -> b d c') | ||
concat = self.linear1(concat) | ||
concat = rearrange(concat, 'b d c -> b c d') | ||
|
||
return concat |
Oops, something went wrong.