-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2f2778a
commit 34cc587
Showing
490 changed files
with
138,586 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
train: | ||
bash train.sh $(shell hostname -i) train.py ${config} $(num_nodes) $(num_gpus) $(local_rank) > logs/$(shell date +"%Y-%m-%d-%T" ).log | ||
|
||
eval: | ||
CUDA_VISIBLE_DEVICES=${devices} python3 eval_fix.py --config ${config} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
## 更新逻辑 | ||
尽可能把可复用的定义与方法扔到animateddiff文件夹中,仅使用yaml文件进行配置 | ||
|
||
## Requirements | ||
|
||
see 'environment.yaml' | ||
|
||
此版本支持最新版的diffusers库 | ||
需要按照以下步骤配置环境 | ||
``` | ||
conda env create -f environment.yaml | ||
pip install -U openmim | ||
# 自行安装torchvision cuda版本 | ||
# https://download.pytorch.org/whl/torch_stable.html | ||
# wget https://download.pytorch.org/whl/cu117/torchvision-0.15.1%2Bcu117-cp310-cp310-linux_x86_64.whl | ||
pip install torchvision-0.15.1%2Bcu117-cp310-cp310-linux_x86_64.whl | ||
mim install mmengine | ||
mim install "mmcv>=2.0.1" | ||
mim install "mmdet>=3.1.0" | ||
mim install "mmpose>=1.1.0" | ||
自行安装cudatoolkit和deepspeed | ||
conda install -c conda-forge cudatoolkit-dev -y | ||
pip3 install deepspeed | ||
``` | ||
|
||
|
||
## Training example | ||
使用make train命令开始训练 | ||
```bash | ||
make train num_nodes={node数量} num_gpus={每个node的gpu数量} local_rank={当前node的rank} config={config路径} | ||
``` | ||
训练2D部分 | ||
```bash | ||
make train num_nodes=1 num_gpus=8 local_rank=0 config=configs/training_me/train1_magic_catnoise_codeback_apt0_ZT_OffNoi.yaml | ||
``` | ||
训练3D部分 | ||
```bash | ||
make train num_nodes=1 num_gpus=8 local_rank=0 config=configs/training_me/train12_magic_catnoise_codeback_apt0_from2D25000step_warp05_DiffMotion_ZT_OffNoi.yaml | ||
``` | ||
|
||
## Evalutaion | ||
```bash | ||
make eval devices={用于eval的gpu ranks} config={config路径} | ||
``` | ||
如 | ||
```bash | ||
make eval devices=0 config=./configs/prompts_me/infer12_magic_catnoise_CrossRefTemp_codeback_apt0_ZT_OffNoi.yaml | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import os | ||
import numpy as np | ||
import torchvision.transforms as transforms | ||
from torch.utils.data.dataset import Dataset | ||
from animatediff.utils.util import zero_rank_print | ||
import json | ||
from PIL import Image | ||
import torch | ||
from decord import VideoReader | ||
|
||
from einops import rearrange | ||
|
||
import io, csv, math, random | ||
# from util import zero_rank_print | ||
|
||
|
||
class WebVid10M(Dataset): | ||
def __init__( | ||
self, | ||
csv_path, video_folder, | ||
sample_size=256, sample_stride=4, sample_n_frames=16, | ||
is_image=False, | ||
): | ||
zero_rank_print(f"loading annotations from {csv_path} ...") | ||
with open(csv_path, 'r') as csvfile: | ||
self.dataset = list(csv.DictReader(csvfile)) | ||
self.length = len(self.dataset) | ||
zero_rank_print(f"data scale: {self.length}") | ||
|
||
self.video_folder = video_folder | ||
self.sample_stride = sample_stride | ||
self.sample_n_frames = sample_n_frames | ||
self.is_image = is_image | ||
|
||
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) | ||
self.pixel_transforms = transforms.Compose([ | ||
transforms.RandomHorizontalFlip(), | ||
transforms.Resize(sample_size[0]), | ||
transforms.CenterCrop(sample_size), | ||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | ||
]) | ||
|
||
def get_batch(self, idx): | ||
video_dict = self.dataset[idx] | ||
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] | ||
|
||
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") | ||
video_reader = VideoReader(video_dir) | ||
video_length = len(video_reader) | ||
|
||
if not self.is_image: | ||
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) | ||
start_idx = random.randint(0, video_length - clip_length) | ||
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) | ||
else: | ||
batch_index = [random.randint(0, video_length - 1)] | ||
|
||
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() | ||
pixel_values = pixel_values / 255. | ||
del video_reader | ||
|
||
if self.is_image: | ||
pixel_values = pixel_values[0] | ||
|
||
return pixel_values, name | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __getitem__(self, idx): | ||
while True: | ||
try: | ||
pixel_values, name = self.get_batch(idx) | ||
break | ||
|
||
except Exception as e: | ||
idx = random.randint(0, self.length - 1) | ||
|
||
pixel_values = self.pixel_transforms(pixel_values) | ||
sample = dict(pixel_values=pixel_values, text=name) | ||
return sample | ||
|
||
|
||
class PexelsDataset(Dataset): | ||
""" | ||
load video-only data, and get dwpose condition | ||
""" | ||
|
||
def __init__( | ||
self, | ||
json_path, | ||
sample_size=(768, 512), sample_stride=1, sample_n_frames=16, is_test=False | ||
): | ||
print("load video-only data, and get dwpose condition") | ||
if not isinstance(json_path, list): | ||
zero_rank_print(f"loading annotations from {json_path} ...") | ||
self.dataset = json.load(open(json_path)) | ||
else: | ||
self.dataset = json_path | ||
|
||
self.length = len(self.dataset) | ||
zero_rank_print(f"data scale: {self.length}") | ||
|
||
self.sample_stride = sample_stride | ||
self.sample_n_frames = sample_n_frames | ||
|
||
self.sample_size = sample_size | ||
|
||
if not is_test: | ||
self.pixel_transforms = transforms.Compose([ | ||
transforms.RandomHorizontalFlip(), | ||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | ||
]) | ||
else: | ||
self.pixel_transforms = transforms.Compose([ | ||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | ||
]) | ||
|
||
def get_batch(self, idx): | ||
video_dir = self.dataset[idx] | ||
video_reader = VideoReader(video_dir) | ||
video_length = len(video_reader) | ||
|
||
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) | ||
start_idx = random.randint(0, video_length - clip_length) | ||
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) | ||
|
||
image_np = video_reader.get_batch(batch_index).asnumpy() | ||
pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous() | ||
pixel_values = (pixel_values / 255.0 - 0.5) * 2 | ||
del video_reader | ||
|
||
return pixel_values, image_np | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __getitem__(self, idx): | ||
|
||
while True: | ||
try: | ||
pixel_values, image_np = self.get_batch(idx) | ||
break | ||
|
||
except Exception as e: | ||
idx = random.randint(0, self.length - 1) | ||
|
||
pixel_values = resize_and_crop(image_np, sample_size=self.sample_size) | ||
pixel_values = self.pixel_transforms(pixel_values) | ||
|
||
sample = dict(pixel_values=pixel_values) | ||
return sample | ||
|
||
|
||
def resize_and_crop(images, sample_size=(768, 512)): | ||
image_np = [] | ||
|
||
for image in images: | ||
image = Image.fromarray(image) | ||
# Determine if width is larger than height or vice versa | ||
if image.width > image.height: | ||
aspect_ratio = image.width / image.height | ||
new_width = int(sample_size[0] * aspect_ratio) | ||
resize = transforms.Resize((sample_size[0], new_width)) | ||
else: | ||
aspect_ratio = image.height / image.width | ||
new_height = int(sample_size[1] * aspect_ratio) | ||
resize = transforms.Resize((new_height, sample_size[1])) | ||
|
||
# Apply the resize transformation | ||
image = resize(image) | ||
|
||
# Calculate padding | ||
pad_left = (sample_size[1] - image.width) // 2 | ||
pad_right = sample_size[1] - image.width - pad_left | ||
pad_top = (sample_size[0] - image.height) // 2 | ||
pad_bottom = sample_size[0] - image.height - pad_top | ||
|
||
# Apply padding | ||
padding = transforms.Pad((pad_left, pad_top, pad_right, pad_bottom), fill=0) | ||
image = padding(image) | ||
|
||
image_np.append(np.array(image)) | ||
|
||
image_np = np.stack(image_np) | ||
|
||
pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous() | ||
pixel_values = pixel_values / 255. | ||
|
||
return pixel_values | ||
|
||
|
||
def get_pose_conditions(image_np, dwpose_model=None): | ||
dwpose = dwpose_model | ||
|
||
num_frames = image_np.shape[0] | ||
dwpose_conditions = [] | ||
|
||
for frame_id in range(num_frames): | ||
pil_image = Image.fromarray(image_np[0]) | ||
dwpose_image = dwpose(pil_image, output_type='np') | ||
dwpose_image = torch.tensor(dwpose_image).unsqueeze(0) | ||
dwpose_conditions.append(dwpose_image) | ||
|
||
return torch.cat(dwpose_conditions, dim=0) | ||
|
||
|
||
if __name__ == "__main__": | ||
from util import save_videos_grid | ||
|
||
dataset = PexelsDataset( | ||
json_path="/work00/magic_animate_unofficial/fashion_dataset.json") | ||
|
||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16) | ||
for idx, batch in enumerate(dataloader): | ||
print(batch['pixel_values'].shape) | ||
for i in range(batch['pixel_values'].shape[0]): | ||
save_videos_grid(batch['pixel_values'][i:i + 1].permute(0, 2, 1, 3, 4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True) | ||
|
||
break |
Oops, something went wrong.