Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
bugWholesaler committed May 23, 2024
1 parent 2f2778a commit 34cc587
Show file tree
Hide file tree
Showing 490 changed files with 138,586 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
train:
bash train.sh $(shell hostname -i) train.py ${config} $(num_nodes) $(num_gpus) $(local_rank) > logs/$(shell date +"%Y-%m-%d-%T" ).log

eval:
CUDA_VISIBLE_DEVICES=${devices} python3 eval_fix.py --config ${config}
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
## 更新逻辑
尽可能把可复用的定义与方法扔到animateddiff文件夹中,仅使用yaml文件进行配置

## Requirements

see 'environment.yaml'

此版本支持最新版的diffusers库
需要按照以下步骤配置环境
```
conda env create -f environment.yaml
pip install -U openmim
# 自行安装torchvision cuda版本
# https://download.pytorch.org/whl/torch_stable.html
# wget https://download.pytorch.org/whl/cu117/torchvision-0.15.1%2Bcu117-cp310-cp310-linux_x86_64.whl
pip install torchvision-0.15.1%2Bcu117-cp310-cp310-linux_x86_64.whl
mim install mmengine
mim install "mmcv>=2.0.1"
mim install "mmdet>=3.1.0"
mim install "mmpose>=1.1.0"
自行安装cudatoolkit和deepspeed
conda install -c conda-forge cudatoolkit-dev -y
pip3 install deepspeed
```


## Training example
使用make train命令开始训练
```bash
make train num_nodes={node数量} num_gpus={每个node的gpu数量} local_rank={当前node的rank} config={config路径}
```
训练2D部分
```bash
make train num_nodes=1 num_gpus=8 local_rank=0 config=configs/training_me/train1_magic_catnoise_codeback_apt0_ZT_OffNoi.yaml
```
训练3D部分
```bash
make train num_nodes=1 num_gpus=8 local_rank=0 config=configs/training_me/train12_magic_catnoise_codeback_apt0_from2D25000step_warp05_DiffMotion_ZT_OffNoi.yaml
```

## Evalutaion
```bash
make eval devices={用于eval的gpu ranks} config={config路径}
```
```bash
make eval devices=0 config=./configs/prompts_me/infer12_magic_catnoise_CrossRefTemp_codeback_apt0_ZT_OffNoi.yaml
```
220 changes: 220 additions & 0 deletions animatediff/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import os
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from animatediff.utils.util import zero_rank_print
import json
from PIL import Image
import torch
from decord import VideoReader

from einops import rearrange

import io, csv, math, random
# from util import zero_rank_print


class WebVid10M(Dataset):
def __init__(
self,
csv_path, video_folder,
sample_size=256, sample_stride=4, sample_n_frames=16,
is_image=False,
):
zero_rank_print(f"loading annotations from {csv_path} ...")
with open(csv_path, 'r') as csvfile:
self.dataset = list(csv.DictReader(csvfile))
self.length = len(self.dataset)
zero_rank_print(f"data scale: {self.length}")

self.video_folder = video_folder
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.is_image = is_image

sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.pixel_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(sample_size[0]),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])

def get_batch(self, idx):
video_dict = self.dataset[idx]
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']

video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
video_reader = VideoReader(video_dir)
video_length = len(video_reader)

if not self.is_image:
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
else:
batch_index = [random.randint(0, video_length - 1)]

pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader

if self.is_image:
pixel_values = pixel_values[0]

return pixel_values, name

def __len__(self):
return self.length

def __getitem__(self, idx):
while True:
try:
pixel_values, name = self.get_batch(idx)
break

except Exception as e:
idx = random.randint(0, self.length - 1)

pixel_values = self.pixel_transforms(pixel_values)
sample = dict(pixel_values=pixel_values, text=name)
return sample


class PexelsDataset(Dataset):
"""
load video-only data, and get dwpose condition
"""

def __init__(
self,
json_path,
sample_size=(768, 512), sample_stride=1, sample_n_frames=16, is_test=False
):
print("load video-only data, and get dwpose condition")
if not isinstance(json_path, list):
zero_rank_print(f"loading annotations from {json_path} ...")
self.dataset = json.load(open(json_path))
else:
self.dataset = json_path

self.length = len(self.dataset)
zero_rank_print(f"data scale: {self.length}")

self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames

self.sample_size = sample_size

if not is_test:
self.pixel_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
else:
self.pixel_transforms = transforms.Compose([
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])

def get_batch(self, idx):
video_dir = self.dataset[idx]
video_reader = VideoReader(video_dir)
video_length = len(video_reader)

clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)

image_np = video_reader.get_batch(batch_index).asnumpy()
pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous()
pixel_values = (pixel_values / 255.0 - 0.5) * 2
del video_reader

return pixel_values, image_np

def __len__(self):
return self.length

def __getitem__(self, idx):

while True:
try:
pixel_values, image_np = self.get_batch(idx)
break

except Exception as e:
idx = random.randint(0, self.length - 1)

pixel_values = resize_and_crop(image_np, sample_size=self.sample_size)
pixel_values = self.pixel_transforms(pixel_values)

sample = dict(pixel_values=pixel_values)
return sample


def resize_and_crop(images, sample_size=(768, 512)):
image_np = []

for image in images:
image = Image.fromarray(image)
# Determine if width is larger than height or vice versa
if image.width > image.height:
aspect_ratio = image.width / image.height
new_width = int(sample_size[0] * aspect_ratio)
resize = transforms.Resize((sample_size[0], new_width))
else:
aspect_ratio = image.height / image.width
new_height = int(sample_size[1] * aspect_ratio)
resize = transforms.Resize((new_height, sample_size[1]))

# Apply the resize transformation
image = resize(image)

# Calculate padding
pad_left = (sample_size[1] - image.width) // 2
pad_right = sample_size[1] - image.width - pad_left
pad_top = (sample_size[0] - image.height) // 2
pad_bottom = sample_size[0] - image.height - pad_top

# Apply padding
padding = transforms.Pad((pad_left, pad_top, pad_right, pad_bottom), fill=0)
image = padding(image)

image_np.append(np.array(image))

image_np = np.stack(image_np)

pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.

return pixel_values


def get_pose_conditions(image_np, dwpose_model=None):
dwpose = dwpose_model

num_frames = image_np.shape[0]
dwpose_conditions = []

for frame_id in range(num_frames):
pil_image = Image.fromarray(image_np[0])
dwpose_image = dwpose(pil_image, output_type='np')
dwpose_image = torch.tensor(dwpose_image).unsqueeze(0)
dwpose_conditions.append(dwpose_image)

return torch.cat(dwpose_conditions, dim=0)


if __name__ == "__main__":
from util import save_videos_grid

dataset = PexelsDataset(
json_path="/work00/magic_animate_unofficial/fashion_dataset.json")

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16)
for idx, batch in enumerate(dataloader):
print(batch['pixel_values'].shape)
for i in range(batch['pixel_values'].shape[0]):
save_videos_grid(batch['pixel_values'][i:i + 1].permute(0, 2, 1, 3, 4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)

break
Loading

0 comments on commit 34cc587

Please sign in to comment.