Skip to content

Commit

Permalink
add reparameterized version
Browse files Browse the repository at this point in the history
  • Loading branch information
wondervictor committed May 9, 2024
1 parent 0f28744 commit bb9d2fd
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 59 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,19 @@ chmod +x tools/dist_test.sh

## Fine-tuning YOLO-World

We provide the details about fine-tuning YOLO-World in [docs/fine-tuning](./docs/finetuning.md).
<div align="center">
<img src="./assets/finetune_yoloworld.png" width=800px>
</div>


YOLO-World supports **zero-shot inference**, and three types of **fine-tuning recipes**: **(1) normal fine-tuning**, **(2) prompt tuning**, and **(3) reparameterized fine-tuning**.


* Normal Fine-tuning: we provide the details about fine-tuning YOLO-World in [docs/fine-tuning](./docs/finetuning.md).

* Prompt Tuning: we provide more details ahout prompt tuning in [docs/prompt_yolo_world](./docs/prompt_yolo_world.md).

* Reparameterized Fine-tuning: the reparameterized YOLO-World is more suitable for specific domains far from generic scenes. You can find more details in [`docs/reparameterize`](./docs/reparameterize.md).

## Deployment

Expand Down
Binary file added assets/finetune_yoloworld.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/reparameterize.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 4 additions & 2 deletions configs/finetune_coco/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ BTW, the COCO fine-tuning results are updated with higher performance (with `mas
| [YOLO-World-v2-S](./yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 37.5 | 46.1 | 62.0 | 49.9 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-492dc329.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240327_110411.log) |
| [YOLO-World-v2-M](./yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 42.8 | 51.0 | 67.5 | 55.2 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-69c27ac7.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240327_110411.log) |
| [YOLO-World-v2-L](./yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 45.1 | 53.9 | 70.9 | 58.8 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-81c701ee.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240326_160313.log) |
| [YOLO-World-v2-L](./yolo_world_v2_l_efficient_neck_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✔️ | 45.1 | | | | [HF Checkpoints]() | [log]() |
| [YOLO-World-v2-X](./yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 46.8 | 54.7 | 71.6 | 59.6 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-76bc0cbd.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240322_181232.log) |
| [YOLO-World-v2-L](./yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetune_coco.py) 🔥 | SGD, 1e-3, 40e | ✖️ | ✖️ | 45.1 | 52.8 | 69.5 | 57.8 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetune_coco_ep80-e1288152.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetuning_coco_20240327_014902.log) |


### Reparameterized Training


### Reparameterized Training
| model | Schedule | `mask-refine` | efficient neck | AP<sup>ZS</sup>| AP | AP<sub>50</sub> | AP<sub>75</sub> | weights | log |
| :---- | :-------: | :----------: |:-------------: | :------------: | :-: | :--------------:| :-------------: |:------: | :-: |
| [YOLO-World-v2-S](./yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 37.5 | 46.3 | 62.8 | 50.4 | [HF Checkpoints]() | [log]() |
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_base_ = ('../../third_party/mmyolo/configs/yolov8/'
'yolov8_l_mask-refine_syncbn_fast_8xb16-500e_coco.py')
'yolov8_s_mask-refine_syncbn_fast_8xb16-500e_coco.py')
custom_imports = dict(imports=['yolo_world'], allow_failed_imports=False)

# hyper-parameters
Expand All @@ -11,11 +11,13 @@
text_channels = 512
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
base_lr = 2e-3
base_lr = 2e-4
weight_decay = 0.05
train_batch_size_per_gpu = 16
load_from = '/group/40034/adriancheng/notebooks/rep_models/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36_repconv.pth'
load_from = '../FastDet/output_models/pretrain_yolow-v8_s_clipv2_frozen_te_noprompt_t2i_bn_2e-3adamw_scale_lr_wd_32xb16-100e_obj365v1_goldg_cc3mram250k_train_lviseval-e3592307_rep_conv.pth'
persistent_workers = False
mixup_prob = 0.15
copypaste_prob = 0.3

# model settings
model = dict(type='SimpleYOLOWorldDetector',
Expand All @@ -28,14 +30,12 @@
type='MultiModalYOLOBackbone',
text_model=None,
image_model={{_base_.model.backbone}},
frozen_stages=4,
with_text_model=False),
neck=dict(type='YOLOWorldPAFPN',
guide_channels=num_classes,
guide_channels=text_channels,
embed_channels=neck_embed_channels,
num_heads=neck_num_heads,
block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
guide_channels=num_classes)),
block_cfg=dict(type='EfficientCSPLayerWithTwoConv')),
bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
embed_dims=text_channels,
num_guide=num_classes,
Expand All @@ -53,7 +53,7 @@
img_scale=_base_.img_scale,
pad_val=114.0,
pre_transform=_base_.pre_transform),
dict(type='YOLOv5CopyPaste', prob=_base_.copypaste_prob),
dict(type='YOLOv5CopyPaste', prob=copypaste_prob),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
Expand All @@ -69,7 +69,7 @@
train_pipeline = [
*_base_.pre_transform, *mosaic_affine_transform,
dict(type='YOLOv5MixUp',
prob=_base_.mixup_prob,
prob=mixup_prob,
pre_transform=[*_base_.pre_transform, *mosaic_affine_transform]),
*_base_.last_transform[:-1], *final_transform
]
Expand Down Expand Up @@ -135,16 +135,6 @@
lr=base_lr,
weight_decay=weight_decay,
batch_size_per_gpu=train_batch_size_per_gpu),
paramwise_cfg=dict(bias_decay_mult=0.0,
norm_decay_mult=0.0,
custom_keys={
'backbone.text_model':
dict(lr_mult=0.01),
'logit_scale':
dict(weight_decay=0.0),
'embeddings':
dict(weight_decay=0.0)
}),
constructor='YOLOWv5OptimizerConstructor')

# evaluation settings
Expand All @@ -153,4 +143,3 @@
proposal_nums=(100, 1, 10),
ann_file='data/coco/annotations/instances_val2017.json',
metric='bbox')
find_unused_parameters = True
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
text_channels = 512
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
base_lr = 1e-3
base_lr = 2e-3
weight_decay = 0.0005
train_batch_size_per_gpu = 16
load_from = '../FastDet/output_models/yolo_world_s_clip_t2i_bn_2e-3adamw_32xb16-100e_obj365v1_goldg_train-55b943ea_rep_conv.pth'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
text_channels = 512
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
base_lr = 1e-3
base_lr = 5e-4
weight_decay = 0.0005
train_batch_size_per_gpu = 16
load_from = '../FastDet/output_models/yolo_world_s_clip_t2i_bn_2e-3adamw_32xb16-100e_obj365v1_goldg_train-55b943ea_rep_conv.pth'
Expand All @@ -32,11 +32,10 @@
image_model={{_base_.model.backbone}},
with_text_model=False),
neck=dict(type='YOLOWorldPAFPN',
guide_channels=num_classes,
guide_channels=text_channels,
embed_channels=neck_embed_channels,
num_heads=neck_num_heads,
block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
guide_channels=num_classes)),
block_cfg=dict(type='EfficientCSPLayerWithTwoConv')),
bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
embed_dims=text_channels,
num_guide=num_classes,
Expand Down Expand Up @@ -140,7 +139,6 @@
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOWv5OptimizerConstructor')


# evaluation settings
val_evaluator = dict(_delete_=True,
type='mmdet.CocoMetric',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
_base_ = ('../../third_party/mmyolo/configs/yolov8/'
'yolov8_x_mask-refine_syncbn_fast_8xb16-500e_coco.py')
custom_imports = dict(imports=['yolo_world'], allow_failed_imports=False)
_base_ = (
'../../third_party/mmyolo/configs/yolov8/'
'yolov8_x_mask-refine_syncbn_fast_8xb16-500e_coco.py')
custom_imports = dict(
imports=['yolo_world'],
allow_failed_imports=False)

# hyper-parameters
num_classes = 80
Expand All @@ -11,12 +14,13 @@
text_channels = 512
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
base_lr = 2e-3
base_lr = 2e-4
weight_decay = 0.05
train_batch_size_per_gpu = 16
load_from = '/group/40034/adriancheng/notebooks/rep_models/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36_repconv.pth'
load_from = '../YOLOWorld_Master/yolo_models/'
persistent_workers = False


# model settings
model = dict(type='SimpleYOLOWorldDetector',
mm_neck=True,
Expand All @@ -28,14 +32,12 @@
type='MultiModalYOLOBackbone',
text_model=None,
image_model={{_base_.model.backbone}},
frozen_stages=4,
with_text_model=False),
neck=dict(type='YOLOWorldPAFPN',
guide_channels=num_classes,
guide_channels=text_channels,
embed_channels=neck_embed_channels,
num_heads=neck_num_heads,
block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
guide_channels=num_classes)),
block_cfg=dict(type='EfficientCSPLayerWithTwoConv')),
bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
embed_dims=text_channels,
num_guide=num_classes,
Expand Down Expand Up @@ -135,16 +137,6 @@
lr=base_lr,
weight_decay=weight_decay,
batch_size_per_gpu=train_batch_size_per_gpu),
paramwise_cfg=dict(bias_decay_mult=0.0,
norm_decay_mult=0.0,
custom_keys={
'backbone.text_model':
dict(lr_mult=0.01),
'logit_scale':
dict(weight_decay=0.0),
'embeddings':
dict(weight_decay=0.0)
}),
constructor='YOLOWv5OptimizerConstructor')

# evaluation settings
Expand All @@ -153,4 +145,3 @@
proposal_nums=(100, 1, 10),
ann_file='data/coco/annotations/instances_val2017.json',
metric='bbox')
find_unused_parameters = True
14 changes: 8 additions & 6 deletions deploy/easydeploy/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ def __init__(self,
baseModel: nn.Module,
backend: MMYOLOBackend,
postprocess_cfg: Optional[ConfigDict] = None,
with_nms=True):
with_nms=True,
without_bbox_decoder=False):
super().__init__()
self.baseModel = baseModel
self.baseHead = baseModel.bbox_head
self.backend = backend
self.with_nms = with_nms
self.without_bbox_decoder = without_bbox_decoder
if postprocess_cfg is None:
self.with_postprocess = False
else:
Expand Down Expand Up @@ -103,7 +105,8 @@ def pred_by_feat(self,
bbox_decoder = yolox_bbox_decoder
else:
bbox_decoder = self.bbox_decoder

print(bbox_decoder)

num_imgs = cls_scores[0].shape[0]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]

Expand All @@ -112,7 +115,6 @@ def pred_by_feat(self,
device=device)

flatten_priors = torch.cat(mlvl_priors)

mlvl_strides = [
flatten_priors.new_full(
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ),
Expand All @@ -121,8 +123,6 @@ def pred_by_feat(self,
]
flatten_stride = torch.cat(mlvl_strides)

# flatten cls_scores, bbox_preds and objectness
# using score.shape
text_len = cls_scores[0].shape[1]
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, text_len)
Expand All @@ -145,7 +145,9 @@ def pred_by_feat(self,
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))

scores = cls_scores

bboxes = flatten_bbox_preds
if self.without_bbox_decoder:
return scores, bboxes
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
flatten_stride)

Expand Down
6 changes: 5 additions & 1 deletion deploy/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def parse_args():
parser.add_argument('--without-nms',
action='store_true',
help='Expore model without NMS')
parser.add_argument('--without-bbox-decoder',
action='store_true',
help='Expore model without Bbox Decoder (for INT8 Quantization)')
parser.add_argument('--work-dir',
default='./work_dirs',
help='Path to save export model')
Expand Down Expand Up @@ -129,7 +132,8 @@ def main():
deploy_model = DeployModel(baseModel=baseModel,
backend=backend,
postprocess_cfg=postprocess_cfg,
with_nms=not args.without_nms)
with_nms=not args.without_nms,
without_bbox_decoder=args.without_bbox_decoder)
deploy_model.eval()

fake_input = torch.randn(args.batch_size, 3,
Expand Down
54 changes: 52 additions & 2 deletions docs/reparameterize.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

The reparameterization incorporates text embeddings as parameters into the model. For example, in the final classification layer, text embeddings are reparameterized into a simple 1x1 convolutional layer.

<div align="center">
<img width="600" src="../assets/reparameterize.png">
</div>

### Key Advantages from Reparameterization

> Reparameterized YOLO-World still has zero-shot ability!
Expand All @@ -15,13 +19,59 @@ For example, fine-tuning the **reparameterized YOLO-World** obtains *46.3 AP* on

#### 1. Prepare cutstom text embeddings

You need to generate the text embeddings

You need to generate the text embeddings by [`toos/generate_text_prompts.py`](../tools/generate_text_prompts.py) and save it as a `numpy.array` with shape `NxD`.

#### 2. Reparameterizing

Reparameterizing will generate a new checkpoint with text embeddings!

Check those files first:

* model checkpoint
* text embeddings

We mainly reparameterize two groups of modules:

* head (`YOLOWorldHeadModule`)
* neck (`MaxSigmoidCSPLayerWithTwoConv`)

```bash
python tools/reparameterize_yoloworld.py \
--model path/to/checkpoint \
--out-dir path/to/save/re-parameterized/ \
--text-embed path/to/text/embeddings \
--conv-neck
```


#### 3. Prepare the model config

Please see the sample config: [`finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py`](../configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) for reparameterized training.


* `RepConvMaxSigmoidCSPLayerWithTwoConv`:

```python
neck=dict(type='YOLOWorldPAFPN',
guide_channels=num_classes,
embed_channels=neck_embed_channels,
num_heads=neck_num_heads,
block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
guide_channels=num_classes)),
```

* `RepYOLOWorldHeadModule`:

```python
bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
embed_dims=text_channels,
num_guide=num_classes,
num_classes=num_classes)),

```

#### 4. Reparameterized Training

**Reparameterized YOLO-World** is easier to fine-tune and can be treated as an enhanced and pre-trained YOLOv8!

You can check [`finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py`](../configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) for more details.
5 changes: 5 additions & 0 deletions tools/reparameterize_yoloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def reparameterize_head(state_dict, embeds):

def convert_neck_split_conv(input_state_dict, block_name, text_embeds,
num_heads):
if block_name + '.guide_fc.weight' not in input_state_dict:
return input_state_dict
guide_fc_weight = input_state_dict[block_name + '.guide_fc.weight']
guide_fc_bias = input_state_dict[block_name + '.guide_fc.bias']
guide = text_embeds @ guide_fc_weight.transpose(0,
Expand Down Expand Up @@ -77,12 +79,15 @@ def convert_neck_weight(input_state_dict, block_name, embeds, num_heads):


def reparameterize_neck(state_dict, embeds, type='conv'):

neck_blocks = [
'neck.top_down_layers.0.attn_block',
'neck.top_down_layers.1.attn_block',
'neck.bottom_up_layers.0.attn_block',
'neck.bottom_up_layers.1.attn_block'
]
if "neck.top_down_layers.0.attn_block.bias" not in state_dict:
return state_dict
for block in neck_blocks:
num_heads = state_dict[block + '.bias'].shape[0]
if type == 'conv':
Expand Down
Loading

0 comments on commit bb9d2fd

Please sign in to comment.