-
Notifications
You must be signed in to change notification settings - Fork 449
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
bugfix: image demo & support image and text prompts
- Loading branch information
1 parent
323386a
commit ee57525
Showing
13 changed files
with
475 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
161 changes: 161 additions & 0 deletions
161
...mpt_tuning_coco/yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_prompt_tuning_coco.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
_base_ = ('../../third_party/mmyolo/configs/yolov8/' | ||
'yolov8_l_mask-refine_syncbn_fast_8xb16-500e_coco.py') | ||
custom_imports = dict(imports=['yolo_world'], allow_failed_imports=False) | ||
|
||
# hyper-parameters | ||
num_classes = 80 | ||
num_training_classes = 80 | ||
max_epochs = 80 # Maximum training epochs | ||
close_mosaic_epochs = 10 | ||
save_epoch_intervals = 5 | ||
text_channels = 512 | ||
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2] | ||
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32] | ||
base_lr = 2e-3 | ||
weight_decay = 0.05 | ||
train_batch_size_per_gpu = 16 | ||
load_from = 'pretrained_models/yolo_world_l_clip_t2i_bn_2e-3adamw_32xb16-100e_obj365v1_goldg_cc3mlite_train-ca93cd1f.pth' | ||
persistent_workers = False | ||
|
||
# model settings | ||
model = dict(type='YOLOWorldPromptDetector', | ||
mm_neck=True, | ||
num_train_classes=num_training_classes, | ||
num_test_classes=num_classes, | ||
embedding_path='embeddings/clip_vit_b32_coco_80_embeddings.npy', | ||
prompt_dim=text_channels, | ||
num_prompts=80, | ||
data_preprocessor=dict(type='YOLOv5DetDataPreprocessor'), | ||
backbone=dict(_delete_=True, | ||
type='MultiModalYOLOBackbone', | ||
text_model=None, | ||
image_model={{_base_.model.backbone}}, | ||
frozen_stages=4, | ||
with_text_model=False), | ||
neck=dict(type='YOLOWorldPAFPN', | ||
freeze_all=True, | ||
guide_channels=text_channels, | ||
embed_channels=neck_embed_channels, | ||
num_heads=neck_num_heads, | ||
block_cfg=dict(type='MaxSigmoidCSPLayerWithTwoConv')), | ||
bbox_head=dict(type='YOLOWorldHead', | ||
head_module=dict( | ||
type='YOLOWorldHeadModule', | ||
freeze_all=True, | ||
use_bn_head=True, | ||
embed_dims=text_channels, | ||
num_classes=num_training_classes)), | ||
train_cfg=dict(assigner=dict(num_classes=num_training_classes))) | ||
|
||
# dataset settings | ||
final_transform = [ | ||
dict(type='mmdet.PackDetInputs', | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip', | ||
'flip_direction')) | ||
] | ||
mosaic_affine_transform = [ | ||
dict(type='Mosaic', | ||
img_scale=_base_.img_scale, | ||
pad_val=114.0, | ||
pre_transform=_base_.pre_transform), | ||
dict(type='YOLOv5CopyPaste', prob=_base_.copypaste_prob), | ||
dict( | ||
type='YOLOv5RandomAffine', | ||
max_rotate_degree=0.0, | ||
max_shear_degree=0.0, | ||
max_aspect_ratio=100., | ||
scaling_ratio_range=(1 - _base_.affine_scale, 1 + _base_.affine_scale), | ||
# img_scale is (width, height) | ||
border=(-_base_.img_scale[0] // 2, -_base_.img_scale[1] // 2), | ||
border_val=(114, 114, 114), | ||
min_area_ratio=_base_.min_area_ratio, | ||
use_mask_refine=_base_.use_mask2refine) | ||
] | ||
train_pipeline = [ | ||
*_base_.pre_transform, *mosaic_affine_transform, | ||
dict(type='YOLOv5MixUp', | ||
prob=_base_.mixup_prob, | ||
pre_transform=[*_base_.pre_transform, *mosaic_affine_transform]), | ||
*_base_.last_transform[:-1], *final_transform | ||
] | ||
|
||
train_pipeline_stage2 = [*_base_.train_pipeline_stage2[:-1], *final_transform] | ||
|
||
coco_train_dataset = dict(type='YOLOv5CocoDataset', | ||
data_root='data/coco', | ||
ann_file='annotations/instances_train2017.json', | ||
data_prefix=dict(img='train2017/'), | ||
filter_cfg=dict(filter_empty_gt=False, min_size=32), | ||
pipeline=train_pipeline) | ||
|
||
train_dataloader = dict(persistent_workers=persistent_workers, | ||
batch_size=train_batch_size_per_gpu, | ||
collate_fn=dict(type='yolow_collate'), | ||
dataset=coco_train_dataset) | ||
|
||
train_dataloader = dict(persistent_workers=persistent_workers, | ||
batch_size=train_batch_size_per_gpu, | ||
collate_fn=dict(type='yolow_collate'), | ||
dataset=coco_train_dataset) | ||
test_pipeline = [ | ||
*_base_.test_pipeline[:-1], | ||
dict(type='mmdet.PackDetInputs', | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
'scale_factor', 'pad_param')) | ||
] | ||
coco_val_dataset = dict(type='YOLOv5CocoDataset', | ||
data_root='data/coco', | ||
ann_file='annotations/instances_val2017.json', | ||
data_prefix=dict(img='val2017/'), | ||
filter_cfg=dict(filter_empty_gt=False, min_size=32), | ||
pipeline=test_pipeline) | ||
|
||
val_dataloader = dict(dataset=coco_val_dataset) | ||
test_dataloader = val_dataloader | ||
# training settings | ||
default_hooks = dict(param_scheduler=dict(scheduler_type='linear', | ||
lr_factor=0.01, | ||
max_epochs=max_epochs), | ||
checkpoint=dict(max_keep_ckpts=-1, | ||
save_best=None, | ||
interval=save_epoch_intervals)) | ||
custom_hooks = [ | ||
dict(type='EMAHook', | ||
ema_type='ExpMomentumEMA', | ||
momentum=0.0001, | ||
update_buffers=True, | ||
strict_load=False, | ||
priority=49), | ||
dict(type='mmdet.PipelineSwitchHook', | ||
switch_epoch=max_epochs - close_mosaic_epochs, | ||
switch_pipeline=train_pipeline_stage2) | ||
] | ||
train_cfg = dict(max_epochs=max_epochs, | ||
val_interval=5, | ||
dynamic_intervals=[((max_epochs - close_mosaic_epochs), | ||
_base_.val_interval_stage2)]) | ||
optim_wrapper = dict(optimizer=dict( | ||
_delete_=True, | ||
type='AdamW', | ||
lr=base_lr, | ||
weight_decay=weight_decay, | ||
batch_size_per_gpu=train_batch_size_per_gpu), | ||
paramwise_cfg=dict(bias_decay_mult=0.0, | ||
norm_decay_mult=0.0, | ||
custom_keys={ | ||
'backbone.text_model': | ||
dict(lr_mult=0.01), | ||
'logit_scale': | ||
dict(weight_decay=0.0), | ||
'embeddings': | ||
dict(weight_decay=0.0) | ||
}), | ||
constructor='YOLOWv5OptimizerConstructor') | ||
|
||
# evaluation settings | ||
val_evaluator = dict(_delete_=True, | ||
type='mmdet.CocoMetric', | ||
proposal_nums=(100, 1, 10), | ||
ann_file='data/coco/annotations/instances_val2017.json', | ||
metric='bbox') | ||
find_unused_parameters = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
## Prompt YOLO-World | ||
|
||
|
||
### 1. Simple YOLO-World with Embeddings | ||
|
||
For simplifying YOLO-World and get rid of the language model, we define a new basic detector `YOLOWorldPromptDetector`: | ||
|
||
The `YOLOWorldPromptDetector` supports prompt embeddings as the input and doesn't not contain a language model anymore! | ||
Now, YOLO-World adopts `embeddings` as language inputs, and the embeddings support several kinds: (1) text embeddings from the language model, e.g., CLIP language encoder, (2) image embeddings from a vision model, e.g., CLIP vision encoder, and (3) image-text fused embeddings, and (4) random embeddings. | ||
The (1)(2)(3) supports zero-shot inference and (4), including (1)(2)(3) are designed for prompt tuning on your custom data. | ||
|
||
The basic detector is defined as follows: | ||
|
||
```python | ||
class YOLOWorldPromptDetector(YOLODetector): | ||
"""Implementation of YOLO World Series""" | ||
|
||
def __init__(self, | ||
*args, | ||
mm_neck: bool = False, | ||
num_train_classes=80, | ||
num_test_classes=80, | ||
prompt_dim=512, | ||
num_prompts=80, | ||
embedding_path='', | ||
freeze_prompt=False, | ||
use_mlp_adapter=False, | ||
**kwargs) | ||
``` | ||
|
||
To use it in a zero-shot manner, you need to pre-compute the text embeddings (image embeddings) and save it as a `numpy array (*.npy)` with a `NxD` shape (N is the number of prompts and D is the dimension of the embeddings). Currently, we only support one prompt for one class. You can use several prompts for one class but you need to merge the results in the post-processing steps. | ||
|
||
|
||
### 2. Prompt Tuning YOLO-World | ||
|
||
We introduce prompt tuning for YOLO-World to maintain the zero-shot ability while improve the performance on your custom datasets. | ||
|
||
For more details about writing configs for prompt tuning, you can refer to [`prompt tuning for COCO data`](./../configs/prompt_tuning_coco/yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_prompt_tuning_coco.py). | ||
|
||
1. Use random prompts | ||
|
||
```python | ||
dict(type='YOLOWorldPromptDetector', | ||
mm_neck=True, | ||
num_train_classes=num_training_classes, | ||
num_test_classes=num_classes, | ||
prompt_dim=text_channels, | ||
num_prompts=80, | ||
...) | ||
``` | ||
|
||
2. Use CLIP embeddings (text, image, or text-image embeddings) | ||
|
||
the `clip_vit_b32_coco_80_embeddings.npy` can be downloaded at [HuggingFace](https://huggingface.co/wondervictor/YOLO-World/blob/main/clip_vit_b32_coco_80_embeddings.npy). | ||
|
||
```python | ||
dict(type='YOLOWorldPromptDetector', | ||
mm_neck=True, | ||
num_train_classes=num_training_classes, | ||
num_test_classes=num_classes, | ||
embedding_path='embeddings/clip_vit_b32_coco_80_embeddings.npy', | ||
prompt_dim=text_channels, | ||
num_prompts=80, | ||
...) | ||
``` | ||
|
||
Using CLIP model to obtains the image and text embeddings will maintain the zero-shot performace. | ||
|
||
|
||
| Model | Config | AP | AP50 | AP75 | APS | APM | APL | | ||
| :---- | :----: | :--: | :--: | :---: | :-: | :-: | :-: | | ||
| YOLO-World-v2-L | Zero-shot | 45.7 | 61.6 | 49.8 | 29.9 | 50.0 | 60.8 | | ||
| [YOLO-World-v2-L](./../configs/prompt_tuning_coco/yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_prompt_tuning_coco.py) | Prompt tuning | 47.9 | 64.3 | 52.5 | 31.9 | 52.6 | 61.3 | |
Oops, something went wrong.