diff --git a/datasets/dreampose_dataset.py b/datasets/dreampose_dataset.py new file mode 100644 index 0000000..486ec00 --- /dev/null +++ b/datasets/dreampose_dataset.py @@ -0,0 +1,180 @@ +from torch.utils.data import Dataset +from pathlib import Path +from torchvision import transforms +import torch +import torch.nn.functional as F +from PIL import Image +import numpy as np +import os, cv2, glob + +class DreamPoseDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + train=True, + p_jitter=0.9 + ): + self.size = (640, 512) + self.center_crop = center_crop + self.train = train + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + # Load UBC Fashion Dataset + self.instance_images_path = glob.glob(instance_data_root+'/*png') + + self.num_instance_images = len(self.instance_images_path) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) + + self.tensor_transforms = transforms.Compose( + [ + ] + ) + + def __len__(self): + return self._length + + # resize sparse uv flow to size + def resize_pose(self, pose): + h1, w1 = pose.shape + h2, w2 = self.size, self.size + resized_pose = np.zeros((h2, w2)) + x_vals = np.where(pose != 0)[0] + y_vals = np.where(pose != 0)[1] + for (x, y) in list(zip(x_vals, y_vals)): + # find new coordinates + x2, y2 = int(x * h2 / h1), int(y * w2 / w1) + resized_pose[x2, y2] = pose[x, y] + return resized_pose + + def __getitem__(self, index): + example = {} + + frame_path = self.instance_images_path[index % self.num_instance_images] + frame_folder = frame_path.replace(os.path.basename(frame_path), '') + #frame_number = int(os.path.basename(frame_path).split('frame_')[-1].replace('.png', '')) + + # load frame i + instance_image = Image.open(frame_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + + example["frame_i"] = self.image_transforms(instance_image) + example["frame_prev"] = self.image_transforms(instance_image) + + assert example["frame_i"].shape == (3, 640, 512) + + # Select other frame in this folder + frame_paths = glob.glob(frame_folder+'/*png') + frame_paths = [p for p in frame_paths if os.path.exists(p.replace('.png', '_densepose.npy'))] + frame_j_path = np.random.choice(frame_paths) + + # load frame j + frame_j_path = np.random.choice(frame_paths) + instance_image = Image.open(frame_j_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["frame_j"] = self.image_transforms(instance_image) + + + # construct 5 input poses + poses = [] + h, w = 640, 512 + for pose_number in range(5): + dp_path = frame_j_path.replace('.png', '_densepose.npy') + dp_i = F.interpolate(torch.from_numpy(np.load(dp_path).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) + poses.append(self.tensor_transforms(dp_i)) + input_pose = torch.cat(poses, 0) + example["pose_j"] = input_pose + + ''' Data Augmentation ''' + key_frame = example["frame_i"] + frame = example["frame_j"] + prev_frame = example["frame_prev"] + + #dp = transforms.ToPILImage()(dp) + + # Get random transforms to target 70% of the time + p = np.random.randint(0, 100) + if p < 70: + ang = np.random.randint(-15, 15) # rotation angle + distort = np.random.rand(0, 1) + top, left = np.random.randint(0, 25), np.random.randint(0, 25) + h_ = np.random.randint(self.size[0]-25, self.size[0]-top) + w_ = int(h_ / h * w) + + t = transforms.Compose([transforms.ToPILImage(),\ + transforms.Resize((h,w), interpolation=transforms.InterpolationMode.BILINEAR), \ + transforms.ToTensor(),\ + ]) + + # Apply transforms + frame = transforms.functional.crop(frame, top, left, h_, w_) # random crop + + example["frame_j"] = t(frame) + + for pose_id in range(5): + start, end = 2*pose_id, 2*pose_id+2 + # convert dense pose to PIL image + dp = example['pose_j'][start:end] + c, h, w = dp.shape + dp = torch.cat((dp, torch.zeros(1, h, w)), 0) + dp = transforms.functional.crop(dp, top, left, h_, w_) # random crop + dp = t(dp)[0:2] # Remove extra channel from input pose + example["pose_j"][start:end] = dp.clone() + + # slightly perturb transforms to previous frame, to prevent copy/paste + top += np.random.randint(0, 5) + left += np.random.randint(0, 5) + h_ += np.random.randint(0, 5) + w_ += np.random.randint(0, 5) + prev_frame = transforms.functional.crop(prev_frame, top, left, h_, w_) # random crop + example["frame_prev"] = t(prev_frame) + else: + # slightly perturb transforms to previous frame, to prevent copy/paste + top, left = np.random.randint(0, 5), np.random.randint(0, 5) + h_ = np.random.randint(self.size[0]-5, self.size[0]-top) + w_ = int(h_ / h * w) + + t = transforms.Compose([transforms.ToPILImage(),\ + transforms.Resize((h,w), interpolation=transforms.InterpolationMode.BILINEAR), \ + transforms.ToTensor(),\ + ]) + + prev_frame = transforms.functional.crop(prev_frame, top, left, h_, w_) # random crop + example["frame_prev"] = t(prev_frame) + + for pose_id in range(5): + start, end = 2*pose_id, 2*pose_id+2 + dp = example['pose_j'][start:end] + example["pose_j"][start:end] = dp.clone() + + return example diff --git a/datasets/train_vae_dataset.py b/datasets/train_vae_dataset.py new file mode 100644 index 0000000..9582ad9 --- /dev/null +++ b/datasets/train_vae_dataset.py @@ -0,0 +1,134 @@ +from torch.utils.data import Dataset +from pathlib import Path +from torchvision import transforms +import torch +import torch.nn.functional as F +from PIL import Image +import numpy as np +import os, cv2, glob + +class DreamPoseDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + train=True, + p_jitter=0.9 + ): + self.size = (640, 512) + self.center_crop = center_crop + self.train = train + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + # Load UBC Fashion Dataset + self.instance_images_path = [path for path in glob.glob(instance_data_root+'/*/*/*') if 'frame_i.png' in path] + + if len(self.instance_images_path) == 0: + self.instance_images_path = [path for path in glob.glob(instance_data_root+'/*') if 'png' in path] + + len1 = len(self.instance_images_path) + # Load Deep Fashion Dataset + #self.instance_images_path.extend([path for path in glob.glob('../Deep_Fashion_Dataset/img_highres/*/*/*/*.jpg') \ + # if os.path.exists(path.replace('.jpg', '_densepose.npy'))]) + + len2 = len(self.instance_images_path) + print(f"Train Dataset: {len1} UBC Fashion images, {len2-len1} Deep Fashion images.") + + self.num_instance_images = len(self.instance_images_path) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + #transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.3, hue=0.3), + transforms.ToTensor(), + #transforms.Normalize([0.5], [0.5]), + ] + ) + + self.tensor_transforms = transforms.Compose( + [ + #transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + # resize sparse uv flow to size + def resize_pose(self, pose): + h1, w1 = pose.shape + h2, w2 = self.size[0], self.size[1] + resized_pose = np.zeros((h2, w2)) + x_vals = np.where(pose != 0)[0] + y_vals = np.where(pose != 0)[1] + for (x, y) in list(zip(x_vals, y_vals)): + # find new coordinates + x2, y2 = int(x * h2 / h1), int(y * w2 / w1) + resized_pose[x2, y2] = pose[x, y] + return resized_pose + + def __getitem__(self, index): + example = {} + + frame_path = self.instance_images_path[index % self.num_instance_images] + + # load frame j + frame_path = frame_path.replace('frame_i', 'frame_j') + instance_image = Image.open(frame_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + frame_j = instance_image + frame_j = frame_j.resize((self.size[1], self.size[0])) + + # Load pose j + h, w = self.size[0], self.size[1] + dp_path = self.instance_images_path[index % self.num_instance_images].replace('frame_i', 'frame_j').replace('.png', '_densepose.npy') + dp_j = F.interpolate(torch.from_numpy(np.load(dp_path, allow_pickle=True).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) + + # Load joints j + #pose_path = self.instance_images_path[index % self.num_instance_images].replace('frame', 'pose').replace('.png', '_refined.npy') + #pose = np.load(pose_path).astype('float32') + #pose = self.resize_pose(pose / 32).astype('float32') + #joints_j = torch.from_numpy(pose).unsqueeze(0) + + # Apply random crops + max_crop = int(0.1*min(frame_j.size[0], frame_j.size[1])) + top, left = np.random.randint(0, max_crop), np.random.randint(0, max_crop) + h_ = np.random.randint(self.size[0]-max_crop, self.size[0]-top) + w_ = int(h_ / h * w) + #print(self.size[0]-max_crop, self.size[0]-top, h_, w_) + frame_j = transforms.functional.crop(frame_j, top, left, h_, w_) # random crop + dp_j = transforms.functional.crop(dp_j, top, left, h_, w_) # random crop + #joints_j = transforms.functional.crop(joints_j, top, left, h_, w_) # random crop + + # Apply resize and normalization + example["frame_j"] = self.image_transforms(frame_j) + dp_j = self.tensor_transforms(dp_j) + example["pose_j"] = F.interpolate(dp_j.unsqueeze(0), (h, w), mode='bilinear').squeeze(0) + + #joints_j = self.resize_pose(joints_j[0].numpy()) + #example["joints_j"] = torch.from_numpy(joints_j).unsqueeze(0) + + return example diff --git a/datasets/ubc_dataset.py b/datasets/ubc_dataset.py new file mode 100644 index 0000000..4d094d4 --- /dev/null +++ b/datasets/ubc_dataset.py @@ -0,0 +1,179 @@ +from torch.utils.data import Dataset +from pathlib import Path +from torchvision import transforms +import torch +import torch.nn.functional as F +from PIL import Image +import numpy as np +import os, cv2, glob + +''' + - Passes 5 consecutive input poses per sample + - Ensures at least one pair of consecutive frames per batch +''' +class DreamPoseDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + train=True, + p_jitter=0.9, + n_poses=5 + ): + self.size = (640, 512) + self.center_crop = center_crop + self.train = train + self.n_poses = n_poses + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + # Load UBC Fashion Dataset + self.instance_images_path = glob.glob('../UBC_Fashion_Dataset/train-frames/*/*png') + self.instance_images_path = [p for p in self.instance_images_path if os.path.exists(p.replace('.png', '_densepose.npy'))] + len1 = len(self.instance_images_path) + + # Load Deep Fashion Dataset + self.instance_images_path.extend([path for path in glob.glob('../Deep_Fashion_Dataset/img_highres/*/*/*/*.jpg') \ + if os.path.exists(path.replace('.jpg', '_densepose.npy'))]) + + len2 = len(self.instance_images_path) + print(f"Train Dataset: {len1} UBC Fashion images, {len2-len1} Deep Fashion images.") + + self.num_instance_images = len(self.instance_images_path) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + #transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.3, hue=0.3), + #transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + #transforms.Normalize([0.5], [0.5]), + ] + ) + + self.tensor_transforms = transforms.Compose( + [ + #transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + # resize sparse uv flow to size + def resize_pose(self, pose): + h1, w1 = pose.shape + h2, w2 = self.size, self.size + resized_pose = np.zeros((h2, w2)) + x_vals = np.where(pose != 0)[0] + y_vals = np.where(pose != 0)[1] + for (x, y) in list(zip(x_vals, y_vals)): + # find new coordinates + x2, y2 = int(x * h2 / h1), int(y * w2 / w1) + resized_pose[x2, y2] = pose[x, y] + return resized_pose + + # return two consecutive frames per call + def __getitem__(self, index): + example = {} + + ''' + + Prepare frame #1 + + ''' + # load frame i + frame_path = self.instance_images_path[index % self.num_instance_images] + instance_image = Image.open(frame_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["frame_i"] = self.image_transforms(instance_image) + + # Get additional frames in this folder + sample_folder = frame_path.replace(os.path.basename(frame_path), '') + samples = [path for path in glob.glob(sample_folder+'/*') if 'npy' not in path] + samples = [path for path in samples if os.path.exists(path.replace('.jpg', '_densepose.npy').replace('.png', '_densepose.npy'))] + + if 'Deep_Fashion' in frame_path: + idx = os.path.basename(frame_path).split('_')[0] + samples = [s for s in samples if os.path.basename(s).split('_')[0] == idx] + #print("Frame Path = ", frame_path) + #print("Sampels = ", samples) + + frame_j_path = samples[np.random.choice(range(len(samples)))] + pose_j_path = frame_j_path.replace('.jpg', '_densepose.npy') + + # load frame j + instance_image = Image.open(frame_j_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["frame_j"] = self.image_transforms(instance_image) + + # Load 5 poses surrounding j + _, h, w = example["frame_i"].shape + poses = [] + idx1= int(self.n_poses // 2) + idx2 = self.n_poses - idx1 + for pose_number in range(5): + dp_path = frame_j_path.replace('.jpg', '_densepose.npy').replace('.png', '_densepose.npy') + dp_i = F.interpolate(torch.from_numpy(np.load(dp_path, allow_pickle=True).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) + poses.append(self.tensor_transforms(dp_i)) + + example["pose_j"] = torch.cat(poses, 0) + + ''' + + Prepare frame #2 + + ''' + new_frame_path = samples[np.random.choice(range(len(samples)))] + frame_path = new_frame_path + + # load frame i + instance_image = Image.open(frame_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["frame_i"] = torch.stack((example["frame_i"], self.image_transforms(instance_image)), 0) + + assert example["frame_i"].shape == (2, 3, 640, 512) + + # Load frame j + frame_j_path = samples[np.random.choice(range(len(samples)))] + instance_image = Image.open(frame_j_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["frame_j"] = torch.stack((example['frame_j'], self.image_transforms(instance_image)), 0) + + # Load 5 poses surrounding j + poses = [] + for pose_number in range(5): + dp_path = frame_j_path.replace('.jpg', '_densepose.npy').replace('.png', '_densepose.npy') + dp_i = F.interpolate(torch.from_numpy(np.load(dp_path, allow_pickle=True).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) + poses.append(self.tensor_transforms(dp_i)) + + poses = torch.cat(poses, 0) + example["pose_j"] = torch.stack((example["pose_j"], poses), 0) + + #print(example["frame_i"].shape, example["frame_j"].shape, example["pose_j"].shape) + return example diff --git a/models/unet_dual_encoder.py b/models/unet_dual_encoder.py index 4dae7b8..648174c 100644 --- a/models/unet_dual_encoder.py +++ b/models/unet_dual_encoder.py @@ -13,8 +13,6 @@ from diffusers import AutoencoderKL from diffusers.models import UNet2DConditionModel -from diffusers.models import BasicTransformerBlock - def get_unet(pretrained_model_name_or_path, revision, resolution=256, n_poses=5): # Load pretrained UNet layers diff --git a/pipelines/dual_encoder_pipeline.py b/pipelines/dual_encoder_pipeline.py index 6691335..a88b041 100644 --- a/pipelines/dual_encoder_pipeline.py +++ b/pipelines/dual_encoder_pipeline.py @@ -396,7 +396,6 @@ def __call__( guidance_scale: Optional[float] = 7.5, s1: float = 1.0, # strength of input pose s2: float = 1.0, # strength of input image - s3: float = 0.0, # strength of input image negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, @@ -462,10 +461,6 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - message = "Please use `image` instead of `init_image`." - init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs) - - image = init_image or image # 1. Check inputs self.check_inputs(prompt, strength, callback_steps) diff --git a/test.py b/test.py new file mode 100644 index 0000000..3ecec67 --- /dev/null +++ b/test.py @@ -0,0 +1,190 @@ +import os +import torch +from diffusers import UNet2DConditionModel, DDIMScheduler +from pipelines.dual_encoder_pipeline import StableDiffusionImg2ImgPipeline +import argparse +from torchvision import transforms +import torch +import cv2, PIL, glob, random +import numpy as np +from torch.cuda.amp import autocast +from torchvision import transforms +from collections import OrderedDict +from torch import nn +import torch, cv2 +import torch.nn.functional as F +from models.unet_dual_encoder import get_unet, Embedding_Adapter + +parser = argparse.ArgumentParser() +parser.add_argument("--folder", default='dreampose-1', help="Path to custom pretrained checkpoints folder.",) +parser.add_argument("--pose_folder", default='../UBC_Fashion_Dataset/valid/91iZ9x8NI0S.mp4', help="Path to test frames, poses, and joints.",) +parser.add_argument("--test_poses", default=None, help="Path to test frames, poses, and joints.",) +parser.add_argument("--epoch", type=int, default=44, required=True, help="Pretrained custom model checkpoint epoch number.",) +parser.add_argument("--key_frame_path", default='../UBC_Fashion_Dataset/dreampose/91iZ9x8NI0S.mp4/key_frame.png', help="Path to key frame.",) +parser.add_argument("--pose_path", default='../UBC_Fashion_Dataset/valid/A1F1j+kNaDS.mp4/85_to_95_to_116/skeleton_i.npy', help="Pretrained model checkpoint step number.",) +parser.add_argument("--strength", type=float, default=1.0, required=False, help="How much noise to add to input image.",) +parser.add_argument("--s1", type=float, default=0.5, required=False, help="Classifier free guidance of input image.",) +parser.add_argument("--s2", type=float, default=0.5, required=False, help="Classifier free guidance of input pose.",) +parser.add_argument("--iters", default=1, type=int, help="# times to do stochastic sampling for all frames.") +parser.add_argument("--sampler", default='PNDM', help="PNDM or DDIM.") +parser.add_argument("--n_steps", default=100, type=int, help="Number of denoising steps.") +parser.add_argument("--output_dir", default=None, help="Where to save results.") +parser.add_argument("--j", type=int, default=-1, required=False, help="Specific frame number.",) +parser.add_argument("--min_j", type=int, default=0, required=False, help="Lowest predicted frame id.",) +parser.add_argument("--max_j", type=int, default=-1, required=False, help="Max predicted frame id.",) +parser.add_argument("--custom_vae", default=None, help="Path use custom VAE checkpoint.") +parser.add_argument("--batch_size", type=int, default=1, required=False, help="# frames to infer at once.",) +args = parser.parse_args() + +save_folder = args.output_dir if args.output_dir is not None else args.folder #'results-fashion/' +if not os.path.exists(save_folder): + os.mkdir(save_folder) + +# Load custom model +model_id = f"{args.folder}/checkpoint-{args.epoch}" #if args.step > 0 else "CompVis/stable-diffusion-v1-4" +device = "cuda" + +# Load UNet +unet = get_unet('CompVis/stable-diffusion-v1-4', "ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c", resolution=512) +unet_path = f"{args.folder}/unet_epoch_{args.epoch}.pth" +print("Loading ", unet_path) +unet_state_dict = torch.load(unet_path) +new_state_dict = OrderedDict() +for k, v in unet_state_dict.items(): + name = k[7:] if k[:7] == 'module' else k + new_state_dict[name] = v +unet.load_state_dict(new_state_dict) +unet = unet.cuda() + +print("Loading custom model from: ", model_id) +pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, revision="fp16") +pipe.safety_checker = lambda images, clip_input: (images, False) # disable safety check + +#pipe.unet.load_state_dict(torch.load(f'{save_folder}/unet_epoch_{args.epoch}.pth')) #'results/epoch_1/unet.pth')) +#pipe.unet = pipe.unet.cuda() + +adapter_chkpt = f'{args.folder}/adapter_{args.epoch}.pth' +print("Loading ", adapter_chkpt) +adapter_state_dict = torch.load(adapter_chkpt) +new_state_dict = OrderedDict() +for k, v in adapter_state_dict.items(): + name = k[7:] if k[:7] == 'module' else k + new_state_dict[name] = v +print(pipe.adapter.linear1.weight) +pipe.adapter = Embedding_Adapter() +pipe.adapter.load_state_dict(new_state_dict) +print(pipe.adapter.linear1.weight) +pipe.adapter = pipe.adapter.cuda() + +if args.custom_vae is not None: + vae_chkpt = args.custom_vae + print("Loading custom vae checkpoint from ", vae_chkpt, '...') + vae_state_dict = torch.load(vae_chkpt) + new_state_dict = OrderedDict() + for k, v in vae_state_dict.items(): + name = k[7:] if k[:7] == 'module' else k + new_state_dict[name] = v + pipe.vae.load_state_dict(new_state_dict) + pipe.vae = pipe.vae.cuda() + +# Change scheduler +if args.sampler == 'DDIM': + print("Default scheduler = ", pipe.scheduler) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + print("New scheduler = ", pipe.scheduler) + +def inputs2img(input): + target_images = (input / 2 + 0.5).clamp(0, 1) + target_images = target_images.detach().cpu().numpy() + target_images = (target_images * 255).round().astype("uint8") + return target_images + +def visualize_dp(im, dp): + #im = im.transpose((2, 0, 1)) + print(im.shape, dp.shape) + hsv = np.zeros(im.shape, dtype=np.uint8) + hsv[..., 1] = 255 + + dp = dp.cpu().detach().numpy() + mag, ang = cv2.cartToPolar(dp[0], dp[1]) + hsv[..., 0] = ang * 180 / np.pi / 2 + hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) + bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + + return bgr + +n_images_per_sample = 1 + +frame_numbers = sorted([int(path.split('frame_')[-1].replace('_densepose.npy', '')) for path in glob.glob(f'{args.pose_folder}/frame_*.npy')]) +frame_numbers = list(set(frame_numbers)) +pose_paths = [f'{args.pose_folder}/frame_{num}_densepose.npy' for num in frame_numbers] + +if args.max_j > -1: + pose_paths = pose_paths[args.min_j:args.max_j] +else: + pose_paths = pose_paths[args.min_j:] + +imSize = (512, 640) +image_transforms = transforms.Compose( + [ + transforms.Resize(imSize, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) +tensor_transforms = transforms.Compose( + [ + transforms.Normalize([0.5], [0.5]), + ] +) + +# Load key frame +input_image = PIL.Image.open(args.key_frame_path).resize(imSize) + +if args.j >= 0: + j = args.j + pose_paths = pose_paths[j:j+1] + +# Iterate samples +prev_image = input_image +for i, pose_path in enumerate(pose_paths): + frame_number = int(frame_numbers[i]) + h, w = imSize[1], imSize[0] + + # construct 5 input poses + poses = [] + for pose_number in range(frame_number-2, frame_number+3): + dp_path = pose_path.replace(str(frame_number), str(pose_number)) + if not os.path.exists(dp_path): + dp_path = pose_path + print(dp_path) + dp_i = F.interpolate(torch.from_numpy(np.load(dp_path).astype('float32')).unsqueeze(0), (h, w), mode='bilinear').squeeze(0) + poses.append(tensor_transforms(dp_i)) + input_pose = torch.cat(poses, 0).unsqueeze(0) + + print(pose_path.split('_')) + j = int(pose_path.split('_')[-2]) + print("j = ", j) + + with autocast(): + image = pipe(prompt="", + image=input_image, + pose=input_pose, + strength=1.0, + num_inference_steps=args.n_steps, + guidance_scale=7.5, + s1=args.s1, + s2=args.s2, + callback_steps=1, + frames=[] + )[0][0] + + + # Save pose and image + save_path = f"{save_folder}/pred_#{j}.png" + image = image.convert('RGB') + image = np.array(image) + cv2.imwrite(save_path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + +