Skip to content

Commit

Permalink
add leffa (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
franciszzj committed Dec 6, 2024
1 parent ca8322d commit 511469f
Show file tree
Hide file tree
Showing 48 changed files with 20,852 additions and 2 deletions.
66 changes: 66 additions & 0 deletions leffa/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# VTON

Based on `genads/imagenet`.


## Methods

- IDM-VTON
- CatVTON
- SimpleVTON
- Learning Flow Fields in Attention


## Prerequisites

- **Genie CLI**. Jobs are launched with `genie`. Install the CLI on your devserver with `feature install genie --persist`. For more information, visit https://www.internalfb.com/intern/wiki/Genie/Genie_101/Genie_CLI/

- **ACL permissions**. To launch jobs, you need to be added to [`oncall_ai_genads`](https://www.internalfb.com/omh/view/ai_genads/oncall_management_settings/members).

- **Hive permissions**. Once you have your dataset you will need it

### Starting a New project:
Create a fbpkg:
`fbpkg create genads.train.idm_vton --oncall-team=ai_genads --pkg-desc='idm_vton for genads usage' --acl-name ai_genads_fbpkg`
Create a new model type:
https://www.internalfb.com/mlhub/models/model_type


## Train

### Local training

You need a local machine with at least 1 GPUs of 80G.

```
CUDA_VISIBLE_DEVICES=7 genie launch genads/idm_vton --config pkg://genads/idm_vton/launcher/launch_train.yaml launcher=local launcher.torchx_options.num_processes=1 conf@app.run_fn.cfg=train_local
```

### Cluster training (MAST, Distributed)
If you are granted access to the pool

```
genie launch genads/idm_vton --config pkg://genads/idm_vton/launcher/launch_train.yaml launcher=mast launcher.torchx_options.num_hosts=4 launcher.torchx_options.num_processes=8 conf@app.run_fn.cfg=train.yaml ++launcher.torchx_options.job_name_suffix=vton_v0_0
genie launch genads/idm_vton --config pkg://genads/idm_vton/launcher/launch_train.yaml launcher=mast launcher.torchx_options.host_type=grandteton launcher.torchx_options.num_hosts=4 launcher.torchx_options.num_processes=8 conf@app.run_fn.cfg=train.yaml ++launcher.torchx_options.job_name_suffix=vton_v0_0
```


## Predict

First:
```
cd /path/to/genads/idm_vton/
```

### Local predicting
```
CUDA_VISIBLE_DEVICES=7 torchx run --scheduler local_penv fb.dist.hpc -m leffa.predict -j 1x1 -- -cn predict.yaml max_steps_per_epoch=2
```

### Cluster predicting
More information, please see this [doc](https://docs.google.com/document/d/1PW-ABvpjtUiwghXz6ZqL_sobjFCMinRfLN6QICS37-E/edit).

```
torchx run --scheduler mast fb.dist.hpc -m leffa.predict -j 2x8 -- -cn predict.yaml
```
Empty file added leffa/__init__.py
Empty file.
17 changes: 17 additions & 0 deletions leffa/conf/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
load("//gen_ai/genie/components:macros.bzl", "genie_hydra_config_bundle")

oncall("genads_infra")

# All configs including dataloading, torchtnt, and profiling.
# Note you need all dependencies for hydra instantiation here.
genie_hydra_config_bundle(
name = "idm_vton_hydra_configs",
srcs = glob(["**/*.yaml"]),
deps = [
"//caffe2:torch",
"//genads/common/data:transforms",
"//genads/idm_vton:idm_vton_lib",
"//media_dataloader/api:api",
"//torchmultimodal/fb/genai/transforms:transforms",
],
)
31 changes: 31 additions & 0 deletions leffa/conf/constants/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# for virtual try-on
# height: 512
# width: 384
# batch_size: 8
height: 1024
width: 768
batch_size: 2

# for pose transfer
# height: 256
# width: 176
# batch_size: 8
# height: 512
# width: 352
# batch_size: 4
# height: 1024
# width: 704
# batch_size: 1

precision: bf16

max_steps: null
max_epochs: 200
max_train_steps_per_epoch: null

evaluate_every_n_train_steps: null
evaluate_every_n_train_epochs: null
max_eval_steps_per_eval_epoch: null

use_torchsnapshot: false
checkpoint_every_n_steps: 500
47 changes: 47 additions & 0 deletions leffa/conf/datasets/deepfashion_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
deepfashion_test:
dataset:
_target_: media_dataloader.api.EnrichingDataset
datasource:
_target_: media_dataloader.api.LazyHiveDataSource
namespace: mgenai
table: deepfashion_pose_transfer
partition_filter_predicate_list: ["ds = '2024-08-15' AND set_name = 'val'"]
enrichments:
- _target_: media_dataloader.api.media_lookups.ManifoldLookups
lookup_handle_to_media_columns:
to_img_manifold_path: "image"
from_img_manifold_path: "cloth"
to_img_iuv_manifold_path: "image_densepose"
from_img_iuv_manifold_path: "cloth_densepose"
collate_fn:
- _target_: media_dataloader.api.Collate
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image
blob_field: image
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: cloth
blob_field: cloth
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image_densepose
blob_field: image_densepose
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: cloth_densepose
blob_field: cloth_densepose
- _target_: leffa.datasets.transform.VtonTransform
height: ${constants.height}
width: ${constants.width}
is_train: false
dataset: deepfashion
aug_garment_ratio: 0.0
get_garment_from_person_ratio: 0.0
aug_mask_ratio: 0.0

dataloader:
_target_: media_dataloader.api.StatefulDataLoader
dataset: ${datasets.deepfashion_test.dataset}
batch_size: ${constants.batch_size}
num_workers: 0
prefetch_factor: null
pin_memory: true
persistent_workers: false
multiprocessing_context: null
47 changes: 47 additions & 0 deletions leffa/conf/datasets/deepfashion_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
deepfashion_train:
dataset:
_target_: media_dataloader.api.EnrichingDataset
datasource:
_target_: media_dataloader.api.LazyHiveDataSource
namespace: mgenai
table: deepfashion_pose_transfer
partition_filter_predicate_list: ["ds = '2024-08-15' AND set_name = 'train'"]
enrichments:
- _target_: media_dataloader.api.media_lookups.ManifoldLookups
lookup_handle_to_media_columns:
to_img_manifold_path: "image"
from_img_manifold_path: "cloth"
to_img_iuv_manifold_path: "image_densepose"
from_img_iuv_manifold_path: "cloth_densepose"
collate_fn:
- _target_: media_dataloader.api.Collate
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image
blob_field: image
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: cloth
blob_field: cloth
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image_densepose
blob_field: image_densepose
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: cloth_densepose
blob_field: cloth_densepose
- _target_: leffa.datasets.transform.VtonTransform
height: ${constants.height}
width: ${constants.width}
is_train: true
dataset: deepfashion
aug_garment_ratio: 0.0
get_garment_from_person_ratio: 0.0
aug_mask_ratio: 0.0

dataloader:
_target_: media_dataloader.api.StatefulDataLoader
dataset: ${datasets.deepfashion_train.dataset}
batch_size: ${constants.batch_size}
num_workers: 4
prefetch_factor: 2
pin_memory: true
persistent_workers: true
multiprocessing_context: forkserver
56 changes: 56 additions & 0 deletions leffa/conf/datasets/dress_code_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
dress_code_test:
dataset:
_target_: media_dataloader.api.EnrichingDataset
datasource:
_target_: media_dataloader.api.LazyHiveDataSource
namespace: ad_metrics
table: vton_public_dataset_dress_code_test_paired_v2
# table: vton_public_dataset_dress_code_test_unpaired_v2
# table: vton_public_dataset_dress_code_test_upper_body_paired_v2
# table: vton_public_dataset_dress_code_test_upper_body_unpaired_v2
partition_filter_predicate_list: ["ds = '2024-09-14'"]
# table: vton_public_dataset_dress_code_test_lower_body_paired_v2
# table: vton_public_dataset_dress_code_test_lower_body_unpaired_v2
# table: vton_public_dataset_dress_code_test_dresses_paired_v2
# table: vton_public_dataset_dress_code_test_dresses_unpaired_v2
# partition_filter_predicate_list: ["ds = '2024-09-16'"]
enrichments:
- _target_: media_dataloader.api.media_lookups.ManifoldLookups
lookup_handle_to_media_columns:
image_manifold_path: "image"
cloth_manifold_path: "cloth"
agnostic_mask_manifold_path: "agnostic_mask"
dense_manifold_path: "image_densepose"
label_map_manifold_path: "image_parse"
collate_fn:
- _target_: media_dataloader.api.Collate
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image
blob_field: image
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: cloth
blob_field: cloth
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: agnostic_mask
blob_field: agnostic_mask
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image_densepose
blob_field: image_densepose
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image_parse
blob_field: image_parse
- _target_: leffa.datasets.transform.VtonTransform
height: ${constants.height}
width: ${constants.width}
is_train: false
dataset: dress_code

dataloader:
_target_: media_dataloader.api.StatefulDataLoader
dataset: ${datasets.dress_code_test.dataset}
batch_size: ${constants.batch_size}
num_workers: 0
prefetch_factor: null
pin_memory: true
persistent_workers: false
multiprocessing_context: null
55 changes: 55 additions & 0 deletions leffa/conf/datasets/dress_code_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
dress_code_train:
dataset:
_target_: media_dataloader.api.EnrichingDataset
datasource:
_target_: media_dataloader.api.LazyHiveDataSource
namespace: ad_metrics
table: vton_public_dataset_dress_code_train_v2
# table: vton_public_dataset_dress_code_train_upper_body_v2
partition_filter_predicate_list: ["ds = '2024-09-14'"]
# table: vton_public_dataset_dress_code_train_lower_body_v2
# table: vton_public_dataset_dress_code_train_dresses_v2
# partition_filter_predicate_list: ["ds = '2024-09-15'"]
enrichments:
- _target_: media_dataloader.api.media_lookups.ManifoldLookups
lookup_handle_to_media_columns:
image_manifold_path: "image"
cloth_manifold_path: "cloth"
agnostic_mask_manifold_path: "agnostic_mask"
dense_manifold_path: "image_densepose"
label_map_manifold_path: "image_parse"
collate_fn:
- _target_: media_dataloader.api.Collate
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image
blob_field: image
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: cloth
blob_field: cloth
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: agnostic_mask
blob_field: agnostic_mask
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image_densepose
blob_field: image_densepose
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image_parse
blob_field: image_parse
- _target_: leffa.datasets.transform.VtonTransform
height: ${constants.height}
width: ${constants.width}
is_train: true
dataset: dress_code
aug_garment_ratio: 0.0
get_garment_from_person_ratio: 0.0
aug_mask_ratio: 0.0

dataloader:
_target_: media_dataloader.api.StatefulDataLoader
dataset: ${datasets.dress_code_train.dataset}
batch_size: ${constants.batch_size}
num_workers: 4
prefetch_factor: 2
pin_memory: true
persistent_workers: true
multiprocessing_context: forkserver
54 changes: 54 additions & 0 deletions leffa/conf/datasets/viton_hd_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
viton_hd_test:
dataset:
_target_: media_dataloader.api.EnrichingDataset
datasource:
_target_: media_dataloader.api.LazyHiveDataSource
namespace: ad_metrics
table: vton_public_dataset_viton_hd_test_paired_v2
partition_filter_predicate_list: ["ds = '2024-10-30'"]
# table: vton_public_dataset_viton_hd_test_unpaired_v1
# partition_filter_predicate_list: ["ds = '2024-09-12'"]
enrichments:
- _target_: media_dataloader.api.media_lookups.ManifoldLookups
lookup_handle_to_media_columns:
image_manifold_path: "image"
cloth_manifold_path: "cloth"
agnostic_mask_manifold_path: "agnostic_mask"
image_densepose_manifold_path: "image_densepose"
cloth_mask_manifold_path: "cloth_mask"
image_parse_v3_manifold_path: "image_parse"
collate_fn:
- _target_: media_dataloader.api.Collate
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image
blob_field: image
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: cloth
blob_field: cloth
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: agnostic_mask
blob_field: agnostic_mask
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image_densepose
blob_field: image_densepose
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: cloth_mask
blob_field: cloth_mask
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image_parse
blob_field: image_parse
- _target_: leffa.datasets.transform.VtonTransform
height: ${constants.height}
width: ${constants.width}
is_train: false
dataset: viton_hd

dataloader:
_target_: media_dataloader.api.StatefulDataLoader
dataset: ${datasets.viton_hd_test.dataset}
batch_size: ${constants.batch_size}
num_workers: 0
prefetch_factor: null
pin_memory: true
persistent_workers: false
multiprocessing_context: null
Loading

0 comments on commit 511469f

Please sign in to comment.