Skip to content

Commit

Permalink
add video demo image demo
Browse files Browse the repository at this point in the history
  • Loading branch information
wondervictor committed Apr 28, 2024
1 parent da0fcb0 commit 2e0d8fb
Show file tree
Hide file tree
Showing 30 changed files with 852 additions and 1,027 deletions.
43 changes: 10 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
We recommend that everyone **use English to communicate on issues**, as this helps developers from around the world discuss, share experiences, and answer questions together.

## 🔥 Updates
`[2024-4-28]:` Long time no see! This update contains bugfixs and improvements: (1) ONNX demo; (2) image demo (support tensor input); (2) new pre-trained models; (3) image prompts; (4)simple version for fine-tuning / deployment; (5) guide for installation (include a `requirements.txt`).
`[2024-3-28]:` We provide: (1) more high-resolution pre-trained models (e.g., S, M, X) ([#142](https://github.com/AILab-CVC/YOLO-World/issues/142)); (2) pre-trained models with CLIP-Large text encoders. Most importantly, we preliminarily fix the **fine-tuning without `mask-refine`** and explore a new fine-tuning setting ([#160](https://github.com/AILab-CVC/YOLO-World/issues/160),[#76](https://github.com/AILab-CVC/YOLO-World/issues/76)). In addition, fine-tuning YOLO-World with `mask-refine` also obtains significant improvements, check more details in [configs/finetune_coco](./configs/finetune_coco/).
`[2024-3-16]:` We fix the bugs about the demo ([#110](https://github.com/AILab-CVC/YOLO-World/issues/110),[#94](https://github.com/AILab-CVC/YOLO-World/issues/94),[#129](https://github.com/AILab-CVC/YOLO-World/issues/129), [#125](https://github.com/AILab-CVC/YOLO-World/issues/125)) with visualizations of segmentation masks, and release [**YOLO-World with Embeddings**](./docs/prompt_yolo_world.md), which supports prompt tuning, text prompts and image prompts.
`[2024-3-3]:` We add the **high-resolution YOLO-World**, which supports `1280x1280` resolution with higher accuracy and better performance for small objects!
Expand Down Expand Up @@ -101,7 +102,8 @@ We've pre-trained YOLO-World-S/M/L from scratch and evaluate on the `LVIS val-1.
| [YOLO-Worldv2-L (CLIP-Large)](./configs/pretrain/yolo_world_v2_l_clip_large_vlpan_bn_2e-3_100e_4x8gpus_obj365v1_goldg_train_800ft_lvis_minival.py) 🔥 | O365+GoldG | 800🔸 | 35.5 | 28.3 | 33.2 | 38.8 | 28.6 | 22.0 | 25.1 | 35.4 | [HF Checkpoints 🤗](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_clip_large_o365v1_goldg_pretrain_800ft-9df82e55.pth)|
| [YOLO-Worldv2-L](./configs/pretrain/yolo_world_v2_l_vlpan_bn_2e-3_100e_4x8gpus_obj365v1_goldg_train_lvis_minival.py) | O365+GoldG+CC3M-Lite | 640 | 32.9 | 25.3 | 31.1 | 35.8 | 26.1 | 20.6 | 22.6 | 32.3 | [HF Checkpoints 🤗](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_obj365v1_goldg_cc3mlite_pretrain-ca93cd1f.pth)|
| [YOLO-Worldv2-X](./configs/pretrain/yolo_world_v2_x_vlpan_bn_2e-3_100e_4x8gpus_obj365v1_goldg_train_lvis_minival.py) | O365+GoldG+CC3M-Lite | 640 | 35.4 | 28.7 | 32.9 | 38.7 | 28.4 | 20.6 | 25.6 | 35.0 | [HF Checkpoints 🤗](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain-8698fbfa.pth) |
| [YOLO-Worldv2-XL](./configs/pretrain/yolo_world_v2_xl_vlpan_bn_2e-3_100e_4x8gpus_obj365v1_goldg_train_lvis_minival.py) | O365+GoldG+CC3M-Lite | 640 | 36.0 | 25.8 | 34.1 | 39.5 | 29.1 | 21.1 | 26.3 | 35.8 | [HF Checkpoints 🤗](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain-8698fbfa.pth) |
| 🔥 [YOLO-Worldv2-X]() | O365+GoldG+CC3M-Lite | 1280🔸 | 37.4 | 30.5 | 35.2 | 40.7 | 29.8 | 21.1 | 26.8 | 37.0 | [HF Checkpoints 🤗](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36.pth) |
| [YOLO-Worldv2-XL](./configs/pretrain/yolo_world_v2_xl_vlpan_bn_2e-3_100e_4x8gpus_obj365v1_goldg_train_lvis_minival.py) | O365+GoldG+CC3M-Lite | 640 | 36.0 | 25.8 | 34.1 | 39.5 | 29.1 | 21.1 | 26.3 | 35.8 | [HF Checkpoints 🤗](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_xl_obj365v1_goldg_cc3mlite_pretrain-5daf1395.pth) |

</font>
</div>
Expand Down Expand Up @@ -177,39 +179,14 @@ You can directly download the ONNX model through the online [demo](https://huggi

## Demo

### Gradio Demo

We provide the [Gradio](https://www.gradio.app/) demo for local devices:

```bash
pip install gradio==4.16.0
python demo.py path/to/config path/to/weights
```

Additionaly, you can use a Dockerfile to build an image with gradio. As a prerequisite, make sure you have respective drivers installed alongside [nvidia-container-runtime](https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime). Replace MODEL_NAME and WEIGHT_NAME with the respective values or ommit this and use default values from the [Dockerfile](Dockerfile#3)

```bash
docker build --build-arg="MODEL=MODEL_NAME" --build-arg="WEIGHT=WEIGHT_NAME" -t yolo_demo .
docker run --runtime nvidia -p 8080:8080
```

### Image Demo

We provide a simple image demo for inference on images with visualization outputs.

```bash
python image_demo.py path/to/config path/to/weights image/path/directory 'person,dog,cat' --topk 100 --threshold 0.005 --output-dir demo_outputs
```

**Notes:**
* The `image` can be a directory or a single image.
* The `texts` can be a string of categories (noun phrases) which is separated by a comma. We also support `txt` file in which each line contains a category ( noun phrases).
* The `topk` and `threshold` control the number of predictions and the confidence threshold.

### Google Golab Notebook

We sincerely thank [Onuralp](https://github.com/onuralpszr) for sharing the [Colab Demo](https://colab.research.google.com/drive/1F_7S5lSaFM06irBCZqjhbN7MpUXo6WwO?usp=sharing), you can have a try 😊!
See [`demo`](./demo) for more details

- [x] `gradio_demo.py`: Gradio demo, ONNX export
- [x] `image_demo.py`: inference with images or a directory of images
- [x] `simple_demo.py`: a simple demo of YOLO-World, using `array` (instead of path as input).
- [x] `video_demo.py`: inference YOLO-World on videos.
- [x] `inference.ipynb`: jupyter notebook for YOLO-World.
- [x] [Google Colab Notebook](https://colab.research.google.com/drive/1F_7S5lSaFM06irBCZqjhbN7MpUXo6WwO?usp=sharing): We sincerely thank [Onuralp](https://github.com/onuralpszr) for sharing the [Colab Demo](https://colab.research.google.com/drive/1F_7S5lSaFM06irBCZqjhbN7MpUXo6WwO?usp=sharing), you can have a try 😊!

## Acknowledgement

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
base_lr = 2e-3
weight_decay = 0.05 / 2
train_batch_size_per_gpu = 16
text_model_name = '../pretrained_models/clip-vit-base-patch32-projection'
text_model_name = 'openai/clip-vit-large-patch14-336'
# text_model_name = 'openai/clip-vit-base-patch32'
# model settings
model = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
weight_decay = 0.05 / 2
train_batch_size_per_gpu = 16
text_model_name = '../pretrained_models/clip-vit-base-patch32-projection'
# text_model_name = 'openai/clip-vit-base-patch32'
text_model_name = 'openai/clip-vit-base-patch32'
# model settings
model = dict(
type='YOLOWorldDetector',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
_base_ = ('../../third_party/mmyolo/configs/yolov8/'
'yolov8_x_syncbn_fast_8xb16-500e_coco.py')
custom_imports = dict(imports=['yolo_world'],
allow_failed_imports=False)

# hyper-parameters
num_classes = 1203
num_training_classes = 80
max_epochs = 100 # Maximum training epochs
close_mosaic_epochs = 2
save_epoch_intervals = 2
text_channels = 512
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
base_lr = 2e-3
weight_decay = 0.05 / 2
train_batch_size_per_gpu = 16
text_model_name = '../pretrained_models/clip-vit-base-patch32-projection'
# text_model_name = 'openai/clip-vit-base-patch32'
img_scale = (1280, 1280)

# model settings
model = dict(
type='YOLOWorldDetector',
mm_neck=True,
num_train_classes=num_training_classes,
num_test_classes=num_classes,
data_preprocessor=dict(type='YOLOWDetDataPreprocessor'),
backbone=dict(
_delete_=True,
type='MultiModalYOLOBackbone',
image_model={{_base_.model.backbone}},
text_model=dict(
type='HuggingCLIPLanguageBackbone',
model_name=text_model_name,
frozen_modules=['all'])),
neck=dict(type='YOLOWorldPAFPN',
guide_channels=text_channels,
embed_channels=neck_embed_channels,
num_heads=neck_num_heads,
block_cfg=dict(type='MaxSigmoidCSPLayerWithTwoConv')),
bbox_head=dict(type='YOLOWorldHead',
head_module=dict(type='YOLOWorldHeadModule',
use_bn_head=True,
embed_dims=text_channels,
num_classes=num_training_classes)),
train_cfg=dict(assigner=dict(num_classes=num_training_classes)))

# dataset settings
text_transform = [
dict(type='RandomLoadText',
num_neg_samples=(num_classes, num_classes),
max_num_samples=num_training_classes,
padding_to_max=True,
padding_value=''),
dict(type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction', 'texts'))
]
train_pipeline = [
*_base_.pre_transform,
dict(type='MultiModalMosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=_base_.pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(1 - _base_.affine_scale, 1 + _base_.affine_scale),
max_aspect_ratio=_base_.max_aspect_ratio,
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
*_base_.last_transform[:-1],
*text_transform,
]
train_pipeline_stage2 = [
*_base_.pre_transform,
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=True,
pad_val=dict(img=114.0)),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(1 - _base_.affine_scale, 1 + _base_.affine_scale),
max_aspect_ratio=_base_.max_aspect_ratio,
border_val=(114, 114, 114)),
*_base_.last_transform[:-1],
*text_transform
]

obj365v1_train_dataset = dict(
type='MultiModalDataset',
dataset=dict(
type='YOLOv5Objects365V1Dataset',
data_root='data/objects365v1/',
ann_file='annotations/objects365_train.json',
data_prefix=dict(img='train/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32)),
class_text_path='data/texts/obj365v1_class_texts.json',
pipeline=train_pipeline)

mg_train_dataset = dict(type='YOLOv5MixedGroundingDataset',
data_root='data/mixed_grounding/',
ann_file='annotations/final_mixed_train_no_coco.json',
data_prefix=dict(img='gqa/images/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline)

flickr_train_dataset = dict(
type='YOLOv5MixedGroundingDataset',
data_root='data/flickr/',
ann_file='annotations/final_flickr_separateGT_train.json',
data_prefix=dict(img='full_images/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline)

train_dataloader = dict(batch_size=train_batch_size_per_gpu,
collate_fn=dict(type='yolow_collate'),
dataset=dict(_delete_=True,
type='ConcatDataset',
datasets=[
obj365v1_train_dataset,
flickr_train_dataset, mg_train_dataset
],
ignore_keys=['classes', 'palette']))

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(type='LoadText'),
dict(type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param', 'texts'))
]

coco_val_dataset = dict(
_delete_=True,
type='MultiModalDataset',
dataset=dict(type='YOLOv5LVISV1Dataset',
data_root='data/coco/',
test_mode=True,
ann_file='lvis/lvis_v1_minival_inserted_image_name.json',
data_prefix=dict(img=''),
batch_shapes_cfg=None),
class_text_path='data/texts/lvis_v1_class_texts.json',
pipeline=test_pipeline)
val_dataloader = dict(dataset=coco_val_dataset)
test_dataloader = val_dataloader

val_evaluator = dict(type='mmdet.LVISMetric',
ann_file='data/coco/lvis/lvis_v1_minival_inserted_image_name.json',
metric='bbox')
test_evaluator = val_evaluator

# training settings
default_hooks = dict(param_scheduler=dict(max_epochs=max_epochs),
checkpoint=dict(interval=save_epoch_intervals,
rule='greater'))
custom_hooks = [
dict(type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
strict_load=False,
priority=49),
dict(type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - close_mosaic_epochs,
switch_pipeline=train_pipeline_stage2)
]
train_cfg = dict(max_epochs=max_epochs,
val_interval=10,
dynamic_intervals=[((max_epochs - close_mosaic_epochs),
_base_.val_interval_stage2)])
optim_wrapper = dict(optimizer=dict(
_delete_=True,
type='AdamW',
lr=base_lr,
weight_decay=weight_decay,
batch_size_per_gpu=train_batch_size_per_gpu),
paramwise_cfg=dict(bias_decay_mult=0.0,
norm_decay_mult=0.0,
custom_keys={
'backbone.text_model':
dict(lr_mult=0.01),
'logit_scale':
dict(weight_decay=0.0)
}),
constructor='YOLOWv5OptimizerConstructor')
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
weight_decay = 0.05 / 2
train_batch_size_per_gpu = 16
text_model_name = '../pretrained_models/clip-vit-base-patch32-projection'
# text_model_name = 'openai/clip-vit-base-patch32'
text_model_name = 'openai/clip-vit-base-patch32'

# scaling model from X to XL
deepen_factor = 1.0
Expand Down
Loading

0 comments on commit 2e0d8fb

Please sign in to comment.