forked from franciszzj/Leffa
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ca8322d
commit 511469f
Showing
48 changed files
with
20,852 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# VTON | ||
|
||
Based on `genads/imagenet`. | ||
|
||
|
||
## Methods | ||
|
||
- IDM-VTON | ||
- CatVTON | ||
- SimpleVTON | ||
- Learning Flow Fields in Attention | ||
|
||
|
||
## Prerequisites | ||
|
||
- **Genie CLI**. Jobs are launched with `genie`. Install the CLI on your devserver with `feature install genie --persist`. For more information, visit https://www.internalfb.com/intern/wiki/Genie/Genie_101/Genie_CLI/ | ||
|
||
- **ACL permissions**. To launch jobs, you need to be added to [`oncall_ai_genads`](https://www.internalfb.com/omh/view/ai_genads/oncall_management_settings/members). | ||
|
||
- **Hive permissions**. Once you have your dataset you will need it | ||
|
||
### Starting a New project: | ||
Create a fbpkg: | ||
`fbpkg create genads.train.idm_vton --oncall-team=ai_genads --pkg-desc='idm_vton for genads usage' --acl-name ai_genads_fbpkg` | ||
Create a new model type: | ||
https://www.internalfb.com/mlhub/models/model_type | ||
|
||
|
||
## Train | ||
|
||
### Local training | ||
|
||
You need a local machine with at least 1 GPUs of 80G. | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES=7 genie launch genads/idm_vton --config pkg://genads/idm_vton/launcher/launch_train.yaml launcher=local launcher.torchx_options.num_processes=1 conf@app.run_fn.cfg=train_local | ||
``` | ||
|
||
### Cluster training (MAST, Distributed) | ||
If you are granted access to the pool | ||
|
||
``` | ||
genie launch genads/idm_vton --config pkg://genads/idm_vton/launcher/launch_train.yaml launcher=mast launcher.torchx_options.num_hosts=4 launcher.torchx_options.num_processes=8 conf@app.run_fn.cfg=train.yaml ++launcher.torchx_options.job_name_suffix=vton_v0_0 | ||
genie launch genads/idm_vton --config pkg://genads/idm_vton/launcher/launch_train.yaml launcher=mast launcher.torchx_options.host_type=grandteton launcher.torchx_options.num_hosts=4 launcher.torchx_options.num_processes=8 conf@app.run_fn.cfg=train.yaml ++launcher.torchx_options.job_name_suffix=vton_v0_0 | ||
``` | ||
|
||
|
||
## Predict | ||
|
||
First: | ||
``` | ||
cd /path/to/genads/idm_vton/ | ||
``` | ||
|
||
### Local predicting | ||
``` | ||
CUDA_VISIBLE_DEVICES=7 torchx run --scheduler local_penv fb.dist.hpc -m leffa.predict -j 1x1 -- -cn predict.yaml max_steps_per_epoch=2 | ||
``` | ||
|
||
### Cluster predicting | ||
More information, please see this [doc](https://docs.google.com/document/d/1PW-ABvpjtUiwghXz6ZqL_sobjFCMinRfLN6QICS37-E/edit). | ||
|
||
``` | ||
torchx run --scheduler mast fb.dist.hpc -m leffa.predict -j 2x8 -- -cn predict.yaml | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
load("//gen_ai/genie/components:macros.bzl", "genie_hydra_config_bundle") | ||
|
||
oncall("genads_infra") | ||
|
||
# All configs including dataloading, torchtnt, and profiling. | ||
# Note you need all dependencies for hydra instantiation here. | ||
genie_hydra_config_bundle( | ||
name = "idm_vton_hydra_configs", | ||
srcs = glob(["**/*.yaml"]), | ||
deps = [ | ||
"//caffe2:torch", | ||
"//genads/common/data:transforms", | ||
"//genads/idm_vton:idm_vton_lib", | ||
"//media_dataloader/api:api", | ||
"//torchmultimodal/fb/genai/transforms:transforms", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# for virtual try-on | ||
# height: 512 | ||
# width: 384 | ||
# batch_size: 8 | ||
height: 1024 | ||
width: 768 | ||
batch_size: 2 | ||
|
||
# for pose transfer | ||
# height: 256 | ||
# width: 176 | ||
# batch_size: 8 | ||
# height: 512 | ||
# width: 352 | ||
# batch_size: 4 | ||
# height: 1024 | ||
# width: 704 | ||
# batch_size: 1 | ||
|
||
precision: bf16 | ||
|
||
max_steps: null | ||
max_epochs: 200 | ||
max_train_steps_per_epoch: null | ||
|
||
evaluate_every_n_train_steps: null | ||
evaluate_every_n_train_epochs: null | ||
max_eval_steps_per_eval_epoch: null | ||
|
||
use_torchsnapshot: false | ||
checkpoint_every_n_steps: 500 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
deepfashion_test: | ||
dataset: | ||
_target_: media_dataloader.api.EnrichingDataset | ||
datasource: | ||
_target_: media_dataloader.api.LazyHiveDataSource | ||
namespace: mgenai | ||
table: deepfashion_pose_transfer | ||
partition_filter_predicate_list: ["ds = '2024-08-15' AND set_name = 'val'"] | ||
enrichments: | ||
- _target_: media_dataloader.api.media_lookups.ManifoldLookups | ||
lookup_handle_to_media_columns: | ||
to_img_manifold_path: "image" | ||
from_img_manifold_path: "cloth" | ||
to_img_iuv_manifold_path: "image_densepose" | ||
from_img_iuv_manifold_path: "cloth_densepose" | ||
collate_fn: | ||
- _target_: media_dataloader.api.Collate | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image | ||
blob_field: image | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: cloth | ||
blob_field: cloth | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image_densepose | ||
blob_field: image_densepose | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: cloth_densepose | ||
blob_field: cloth_densepose | ||
- _target_: leffa.datasets.transform.VtonTransform | ||
height: ${constants.height} | ||
width: ${constants.width} | ||
is_train: false | ||
dataset: deepfashion | ||
aug_garment_ratio: 0.0 | ||
get_garment_from_person_ratio: 0.0 | ||
aug_mask_ratio: 0.0 | ||
|
||
dataloader: | ||
_target_: media_dataloader.api.StatefulDataLoader | ||
dataset: ${datasets.deepfashion_test.dataset} | ||
batch_size: ${constants.batch_size} | ||
num_workers: 0 | ||
prefetch_factor: null | ||
pin_memory: true | ||
persistent_workers: false | ||
multiprocessing_context: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
deepfashion_train: | ||
dataset: | ||
_target_: media_dataloader.api.EnrichingDataset | ||
datasource: | ||
_target_: media_dataloader.api.LazyHiveDataSource | ||
namespace: mgenai | ||
table: deepfashion_pose_transfer | ||
partition_filter_predicate_list: ["ds = '2024-08-15' AND set_name = 'train'"] | ||
enrichments: | ||
- _target_: media_dataloader.api.media_lookups.ManifoldLookups | ||
lookup_handle_to_media_columns: | ||
to_img_manifold_path: "image" | ||
from_img_manifold_path: "cloth" | ||
to_img_iuv_manifold_path: "image_densepose" | ||
from_img_iuv_manifold_path: "cloth_densepose" | ||
collate_fn: | ||
- _target_: media_dataloader.api.Collate | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image | ||
blob_field: image | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: cloth | ||
blob_field: cloth | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image_densepose | ||
blob_field: image_densepose | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: cloth_densepose | ||
blob_field: cloth_densepose | ||
- _target_: leffa.datasets.transform.VtonTransform | ||
height: ${constants.height} | ||
width: ${constants.width} | ||
is_train: true | ||
dataset: deepfashion | ||
aug_garment_ratio: 0.0 | ||
get_garment_from_person_ratio: 0.0 | ||
aug_mask_ratio: 0.0 | ||
|
||
dataloader: | ||
_target_: media_dataloader.api.StatefulDataLoader | ||
dataset: ${datasets.deepfashion_train.dataset} | ||
batch_size: ${constants.batch_size} | ||
num_workers: 4 | ||
prefetch_factor: 2 | ||
pin_memory: true | ||
persistent_workers: true | ||
multiprocessing_context: forkserver |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
dress_code_test: | ||
dataset: | ||
_target_: media_dataloader.api.EnrichingDataset | ||
datasource: | ||
_target_: media_dataloader.api.LazyHiveDataSource | ||
namespace: ad_metrics | ||
table: vton_public_dataset_dress_code_test_paired_v2 | ||
# table: vton_public_dataset_dress_code_test_unpaired_v2 | ||
# table: vton_public_dataset_dress_code_test_upper_body_paired_v2 | ||
# table: vton_public_dataset_dress_code_test_upper_body_unpaired_v2 | ||
partition_filter_predicate_list: ["ds = '2024-09-14'"] | ||
# table: vton_public_dataset_dress_code_test_lower_body_paired_v2 | ||
# table: vton_public_dataset_dress_code_test_lower_body_unpaired_v2 | ||
# table: vton_public_dataset_dress_code_test_dresses_paired_v2 | ||
# table: vton_public_dataset_dress_code_test_dresses_unpaired_v2 | ||
# partition_filter_predicate_list: ["ds = '2024-09-16'"] | ||
enrichments: | ||
- _target_: media_dataloader.api.media_lookups.ManifoldLookups | ||
lookup_handle_to_media_columns: | ||
image_manifold_path: "image" | ||
cloth_manifold_path: "cloth" | ||
agnostic_mask_manifold_path: "agnostic_mask" | ||
dense_manifold_path: "image_densepose" | ||
label_map_manifold_path: "image_parse" | ||
collate_fn: | ||
- _target_: media_dataloader.api.Collate | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image | ||
blob_field: image | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: cloth | ||
blob_field: cloth | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: agnostic_mask | ||
blob_field: agnostic_mask | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image_densepose | ||
blob_field: image_densepose | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image_parse | ||
blob_field: image_parse | ||
- _target_: leffa.datasets.transform.VtonTransform | ||
height: ${constants.height} | ||
width: ${constants.width} | ||
is_train: false | ||
dataset: dress_code | ||
|
||
dataloader: | ||
_target_: media_dataloader.api.StatefulDataLoader | ||
dataset: ${datasets.dress_code_test.dataset} | ||
batch_size: ${constants.batch_size} | ||
num_workers: 0 | ||
prefetch_factor: null | ||
pin_memory: true | ||
persistent_workers: false | ||
multiprocessing_context: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
dress_code_train: | ||
dataset: | ||
_target_: media_dataloader.api.EnrichingDataset | ||
datasource: | ||
_target_: media_dataloader.api.LazyHiveDataSource | ||
namespace: ad_metrics | ||
table: vton_public_dataset_dress_code_train_v2 | ||
# table: vton_public_dataset_dress_code_train_upper_body_v2 | ||
partition_filter_predicate_list: ["ds = '2024-09-14'"] | ||
# table: vton_public_dataset_dress_code_train_lower_body_v2 | ||
# table: vton_public_dataset_dress_code_train_dresses_v2 | ||
# partition_filter_predicate_list: ["ds = '2024-09-15'"] | ||
enrichments: | ||
- _target_: media_dataloader.api.media_lookups.ManifoldLookups | ||
lookup_handle_to_media_columns: | ||
image_manifold_path: "image" | ||
cloth_manifold_path: "cloth" | ||
agnostic_mask_manifold_path: "agnostic_mask" | ||
dense_manifold_path: "image_densepose" | ||
label_map_manifold_path: "image_parse" | ||
collate_fn: | ||
- _target_: media_dataloader.api.Collate | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image | ||
blob_field: image | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: cloth | ||
blob_field: cloth | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: agnostic_mask | ||
blob_field: agnostic_mask | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image_densepose | ||
blob_field: image_densepose | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image_parse | ||
blob_field: image_parse | ||
- _target_: leffa.datasets.transform.VtonTransform | ||
height: ${constants.height} | ||
width: ${constants.width} | ||
is_train: true | ||
dataset: dress_code | ||
aug_garment_ratio: 0.0 | ||
get_garment_from_person_ratio: 0.0 | ||
aug_mask_ratio: 0.0 | ||
|
||
dataloader: | ||
_target_: media_dataloader.api.StatefulDataLoader | ||
dataset: ${datasets.dress_code_train.dataset} | ||
batch_size: ${constants.batch_size} | ||
num_workers: 4 | ||
prefetch_factor: 2 | ||
pin_memory: true | ||
persistent_workers: true | ||
multiprocessing_context: forkserver |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
viton_hd_test: | ||
dataset: | ||
_target_: media_dataloader.api.EnrichingDataset | ||
datasource: | ||
_target_: media_dataloader.api.LazyHiveDataSource | ||
namespace: ad_metrics | ||
table: vton_public_dataset_viton_hd_test_paired_v2 | ||
partition_filter_predicate_list: ["ds = '2024-10-30'"] | ||
# table: vton_public_dataset_viton_hd_test_unpaired_v1 | ||
# partition_filter_predicate_list: ["ds = '2024-09-12'"] | ||
enrichments: | ||
- _target_: media_dataloader.api.media_lookups.ManifoldLookups | ||
lookup_handle_to_media_columns: | ||
image_manifold_path: "image" | ||
cloth_manifold_path: "cloth" | ||
agnostic_mask_manifold_path: "agnostic_mask" | ||
image_densepose_manifold_path: "image_densepose" | ||
cloth_mask_manifold_path: "cloth_mask" | ||
image_parse_v3_manifold_path: "image_parse" | ||
collate_fn: | ||
- _target_: media_dataloader.api.Collate | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image | ||
blob_field: image | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: cloth | ||
blob_field: cloth | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: agnostic_mask | ||
blob_field: agnostic_mask | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image_densepose | ||
blob_field: image_densepose | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: cloth_mask | ||
blob_field: cloth_mask | ||
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform | ||
image_field: image_parse | ||
blob_field: image_parse | ||
- _target_: leffa.datasets.transform.VtonTransform | ||
height: ${constants.height} | ||
width: ${constants.width} | ||
is_train: false | ||
dataset: viton_hd | ||
|
||
dataloader: | ||
_target_: media_dataloader.api.StatefulDataLoader | ||
dataset: ${datasets.viton_hd_test.dataset} | ||
batch_size: ${constants.batch_size} | ||
num_workers: 0 | ||
prefetch_factor: null | ||
pin_memory: true | ||
persistent_workers: false | ||
multiprocessing_context: null |
Oops, something went wrong.