Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Snowfallingplum committed Oct 9, 2024
0 parents commit dc886fe
Show file tree
Hide file tree
Showing 115 changed files with 18,995 additions and 0 deletions.
437 changes: 437 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

103 changes: 103 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SHMT
[NeurIPS 2024] SHMT: Self-supervised Hierarchical Makeup Transfer via Latent Diffusion Models

# Note: This code is just a preliminary organization and may not be completely correct. I've been too busy lately, I'll update the code and upload the pre-trained model when I have time to follow, sorry.

## Requirements

A suitable [conda](https://conda.io/) environment named `ldm` can be created
and activated with:

```
conda env create -f environment.yaml
conda activate ldm
```

## The Trainning of SHMT
1. Download a pretrained autoencoding models from [LDM](https://github.com/CompVis/latent-diffusion), VQ-f4 is selected in our experiment.
2. Data preparation. Prepare data according to the following catalogue structure.
```
/MakeupData/train/
├── images # original images
│ ├── 00001.jpg
│ ├── 00002.jpg
│ ├── ...
├── segs # face parsing
│ ├── 00001.jpg
│ ├── 00002.jpg
│ ├── ...
├── 3d # 3d
│ ├── 00001_3d.jpg
│ ├── 00002_3d.jpg
│ ├── ...
```
3. Change the pre-trained model path and data path of the configuration file to your own. The configuration file is in './configs/latent-diffusion/'

4. Execution of training scripts.
```
CUDA_VISIBLE_DEVICES=0 python main.py --base configs/latent-diffusion/shmt_h0.yaml -t --gpus 0,
CUDA_VISIBLE_DEVICES=0 python main.py --base configs/latent-diffusion/shmt_h4.yaml -t --gpus 0,
```

## The Inference of SHMT
1. Data preparation. Prepare data according to the following catalogue structure.
```
/MakeupData/test/
├── images # original images
│ ├── non_makeup
│ │ ├── 00001.jpg
│ │ ├── 00002.jpg
│ │ ├── ...
│ ├── makeup
│ │ ├── 00001.jpg
│ │ ├── 00002.jpg
│ │ ├── ...
├── segs # original images
│ ├── non_makeup
│ │ ├── 00001.jpg
│ │ ├── 00002.jpg
│ │ ├── ...
│ ├── makeup
│ │ ├── 00001.jpg
│ │ ├── 00002.jpg
│ │ ├── ...
├── 3d # only the 3d image of non_makeup
│ ├── non_makeup
│ │ ├── 00001.jpg
│ │ ├── 00002.jpg
│ │ ├── ...
```
2. Execution of inference scripts.
```
CUDA_VISIBLE_DEVICES=0 python makeup_inference_h0.py \
--outdir your_output_dir \
--config configs/latent-diffusion/shmt_h0.yaml \
--ckpt your_ckpt_path \
--source_image_path your_non_makeup_images_path \
--source_seg_path your_non_makeup_segs_path \
--source_depth_path your_non_makeup_3d_path \
--ref_image_path your_makeup_images_path \
--ref_seg_path your_makeup_segs_path \
--seed 321 \
--ddim_steps 50
```
```
CUDA_VISIBLE_DEVICES=0 python makeup_inference_h4.py \
--outdir your_output_dir \
--config configs/latent-diffusion/shmt_h4.yaml \
--ckpt your_ckpt_path \
--source_image_path your_non_makeup_images_path \
--source_seg_path your_non_makeup_segs_path \
--source_depth_path your_non_makeup_3d_path \
--ref_image_path your_makeup_images_path \
--ref_seg_path your_makeup_segs_path \
--seed 321 \
--ddim_steps 50
```

## Comments

- Our code for the SHMT models builds heavily on [LDM](https://github.com/CompVis/latent-diffusion)
and [Paint-by-Example](https://github.com/Fantasy-Studio/Paint-by-Example).
Thanks for open-sourcing!

115 changes: 115 additions & 0 deletions configs/latent-diffusion/shmt_h0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
model:
base_learning_rate: 1.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
level_of_Laplace: 0
linear_start: 0.0015
linear_end: 0.0195
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: makeup
image_size: 64
channels: 3
monitor: val/loss_simple_ema

# scheduler_config: # 10000 warmup steps
# target: ldm.lr_scheduler.LambdaLinearScheduler
# params:
# warm_up_steps: [ 10000 ]
# cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
# f_start: [ 1.e-6 ]
# f_max: [ 1. ]
# f_min: [ 1. ]

corr_stage_config:
target: ldm.models.correspondence.Correspondence
params:
in_channels_list: [22, 83, 70]
model_channels: 224
num_res_blocks: 2
channel_mult:
- 1
- 2
- 3
- 4

unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64
in_channels: 121
out_channels: 3
model_channels: 224
attention_resolutions:
# note: this isn\t actually the resolution but
# the downsampling factor, i.e. this corresnponds to
# attention on spatial resolution 8,16,32, as the
# spatial reolution of the latents is 64 for f4
- 8
- 4
- 2
num_res_blocks: 2
channel_mult:
- 1
- 2
- 3
- 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
embed_dim: 3
n_embed: 8192
ckpt_path: /mnt/workspace/workgroup/sunzhaoyang/data/ldm/first_stage_models/vq-f4/model.ckpt
ddconfig:
double_z: false
z_channels: 3
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: __is_unconditional__

data:
target: main.DataModuleFromConfig
params:
batch_size: 16
num_workers: 4
wrap: false
train:
target: ldm.data.makeup_dataset.MakeupDataset
params:
is_train: True
image_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/images
seg_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/segs
depth_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/3d
validation:
target: ldm.data.makeup_dataset.MakeupDataset
params:
is_train: False
image_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/images
seg_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/segs
depth_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/3d


lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 2000
max_images: 8
increase_log_steps: False

trainer:
benchmark: True
115 changes: 115 additions & 0 deletions configs/latent-diffusion/shmt_h4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
model:
base_learning_rate: 1.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
level_of_Laplace: 4
linear_start: 0.0015
linear_end: 0.0195
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: makeup
image_size: 64
channels: 3
monitor: val/loss_simple_ema

# scheduler_config: # 10000 warmup steps
# target: ldm.lr_scheduler.LambdaLinearScheduler
# params:
# warm_up_steps: [ 10000 ]
# cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
# f_start: [ 1.e-6 ]
# f_max: [ 1. ]
# f_min: [ 1. ]

corr_stage_config:
target: ldm.models.correspondence.Correspondence
params:
in_channels_list: [22, 68, 70]
model_channels: 224
num_res_blocks: 2
channel_mult:
- 1
- 2
- 3
- 4

unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64
in_channels: 106
out_channels: 3
model_channels: 224
attention_resolutions:
# note: this isn\t actually the resolution but
# the downsampling factor, i.e. this corresnponds to
# attention on spatial resolution 8,16,32, as the
# spatial reolution of the latents is 64 for f4
- 8
- 4
- 2
num_res_blocks: 2
channel_mult:
- 1
- 2
- 3
- 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
embed_dim: 3
n_embed: 8192
ckpt_path: /mnt/workspace/workgroup/sunzhaoyang/data/ldm/first_stage_models/vq-f4/model.ckpt
ddconfig:
double_z: false
z_channels: 3
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: __is_unconditional__

data:
target: main.DataModuleFromConfig
params:
batch_size: 16
num_workers: 4
wrap: false
train:
target: ldm.data.makeup_dataset.MakeupDataset
params:
is_train: True
image_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/images
seg_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/segs
depth_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/3d
validation:
target: ldm.data.makeup_dataset.MakeupDataset
params:
is_train: False
image_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/images
seg_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/segs
depth_path: /mnt/workspace/workgroup/sunzhaoyang/data/MakeupData2/3d


lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 2000
max_images: 8
increase_log_steps: False

trainer:
benchmark: True
27 changes: 27 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: ldm
channels:
- pytorch
- defaults
dependencies:
- python=3.8.5
- pip=20.3
- cudatoolkit=11.0
- pytorch=1.7.0
- torchvision=0.8.1
- numpy=1.19.2
- pip:
- albumentations==0.4.3
- opencv-python==4.1.2.30
- pudb==2019.2
- imageio==2.9.0
- imageio-ffmpeg==0.4.2
- pytorch-lightning==1.4.2
- omegaconf==2.1.1
- test-tube>=0.7.5
- streamlit>=0.73.1
- einops==0.3.0
- torch-fidelity==0.3.0
- transformers==4.3.1
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e .
Empty file added ldm/__init__.py
Empty file.
Empty file added ldm/data/__init__.py
Empty file.
Binary file added ldm/data/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added ldm/data/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added ldm/data/__pycache__/base.cpython-38.pyc
Binary file not shown.
Binary file added ldm/data/__pycache__/base.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
23 changes: 23 additions & 0 deletions ldm/data/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset


class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size

print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')

def __len__(self):
return self.num_records

@abstractmethod
def __iter__(self):
pass
Loading

0 comments on commit dc886fe

Please sign in to comment.