diff --git a/README.md b/README.md
index 57dd4bde..8c248132 100644
--- a/README.md
+++ b/README.md
@@ -170,7 +170,19 @@ chmod +x tools/dist_test.sh
## Fine-tuning YOLO-World
-We provide the details about fine-tuning YOLO-World in [docs/fine-tuning](./docs/finetuning.md).
+
+
+
+
+
+YOLO-World supports **zero-shot inference**, and three types of **fine-tuning recipes**: **(1) normal fine-tuning**, **(2) prompt tuning**, and **(3) reparameterized fine-tuning**.
+
+
+* Normal Fine-tuning: we provide the details about fine-tuning YOLO-World in [docs/fine-tuning](./docs/finetuning.md).
+
+* Prompt Tuning: we provide more details ahout prompt tuning in [docs/prompt_yolo_world](./docs/prompt_yolo_world.md).
+
+* Reparameterized Fine-tuning: the reparameterized YOLO-World is more suitable for specific domains far from generic scenes. You can find more details in [`docs/reparameterize`](./docs/reparameterize.md).
## Deployment
diff --git a/assets/finetune_yoloworld.png b/assets/finetune_yoloworld.png
new file mode 100644
index 00000000..235230e4
Binary files /dev/null and b/assets/finetune_yoloworld.png differ
diff --git a/assets/reparameterize.png b/assets/reparameterize.png
new file mode 100644
index 00000000..56a52de9
Binary files /dev/null and b/assets/reparameterize.png differ
diff --git a/configs/finetune_coco/README.md b/configs/finetune_coco/README.md
index b70ba02c..84e06f6b 100644
--- a/configs/finetune_coco/README.md
+++ b/configs/finetune_coco/README.md
@@ -18,11 +18,13 @@ BTW, the COCO fine-tuning results are updated with higher performance (with `mas
| [YOLO-World-v2-S](./yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 37.5 | 46.1 | 62.0 | 49.9 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-492dc329.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240327_110411.log) |
| [YOLO-World-v2-M](./yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 42.8 | 51.0 | 67.5 | 55.2 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-69c27ac7.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240327_110411.log) |
| [YOLO-World-v2-L](./yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 45.1 | 53.9 | 70.9 | 58.8 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-81c701ee.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240326_160313.log) |
-| [YOLO-World-v2-L](./yolo_world_v2_l_efficient_neck_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✔️ | 45.1 | | | | [HF Checkpoints]() | [log]() |
| [YOLO-World-v2-X](./yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 46.8 | 54.7 | 71.6 | 59.6 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-76bc0cbd.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240322_181232.log) |
| [YOLO-World-v2-L](./yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetune_coco.py) 🔥 | SGD, 1e-3, 40e | ✖️ | ✖️ | 45.1 | 52.8 | 69.5 | 57.8 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetune_coco_ep80-e1288152.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetuning_coco_20240327_014902.log) |
+### Reparameterized Training
-### Reparameterized Training
+| model | Schedule | `mask-refine` | efficient neck | APZS| AP | AP50 | AP75 | weights | log |
+| :---- | :-------: | :----------: |:-------------: | :------------: | :-: | :--------------:| :-------------: |:------: | :-: |
+| [YOLO-World-v2-S](./yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 37.5 | 46.3 | 62.8 | 50.4 | [HF Checkpoints]() | [log]() |
\ No newline at end of file
diff --git a/configs/finetune_coco/yolo_world_v2_l_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py b/configs/finetune_coco/yolo_world_v2_s_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py
similarity index 82%
rename from configs/finetune_coco/yolo_world_v2_l_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py
rename to configs/finetune_coco/yolo_world_v2_s_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py
index ec275a29..49801101 100644
--- a/configs/finetune_coco/yolo_world_v2_l_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py
+++ b/configs/finetune_coco/yolo_world_v2_s_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py
@@ -1,5 +1,5 @@
_base_ = ('../../third_party/mmyolo/configs/yolov8/'
- 'yolov8_l_mask-refine_syncbn_fast_8xb16-500e_coco.py')
+ 'yolov8_s_mask-refine_syncbn_fast_8xb16-500e_coco.py')
custom_imports = dict(imports=['yolo_world'], allow_failed_imports=False)
# hyper-parameters
@@ -11,11 +11,13 @@
text_channels = 512
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
-base_lr = 2e-3
+base_lr = 2e-4
weight_decay = 0.05
train_batch_size_per_gpu = 16
-load_from = '/group/40034/adriancheng/notebooks/rep_models/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36_repconv.pth'
+load_from = '../FastDet/output_models/pretrain_yolow-v8_s_clipv2_frozen_te_noprompt_t2i_bn_2e-3adamw_scale_lr_wd_32xb16-100e_obj365v1_goldg_cc3mram250k_train_lviseval-e3592307_rep_conv.pth'
persistent_workers = False
+mixup_prob = 0.15
+copypaste_prob = 0.3
# model settings
model = dict(type='SimpleYOLOWorldDetector',
@@ -28,14 +30,12 @@
type='MultiModalYOLOBackbone',
text_model=None,
image_model={{_base_.model.backbone}},
- frozen_stages=4,
with_text_model=False),
neck=dict(type='YOLOWorldPAFPN',
- guide_channels=num_classes,
+ guide_channels=text_channels,
embed_channels=neck_embed_channels,
num_heads=neck_num_heads,
- block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
- guide_channels=num_classes)),
+ block_cfg=dict(type='EfficientCSPLayerWithTwoConv')),
bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
embed_dims=text_channels,
num_guide=num_classes,
@@ -53,7 +53,7 @@
img_scale=_base_.img_scale,
pad_val=114.0,
pre_transform=_base_.pre_transform),
- dict(type='YOLOv5CopyPaste', prob=_base_.copypaste_prob),
+ dict(type='YOLOv5CopyPaste', prob=copypaste_prob),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
@@ -69,7 +69,7 @@
train_pipeline = [
*_base_.pre_transform, *mosaic_affine_transform,
dict(type='YOLOv5MixUp',
- prob=_base_.mixup_prob,
+ prob=mixup_prob,
pre_transform=[*_base_.pre_transform, *mosaic_affine_transform]),
*_base_.last_transform[:-1], *final_transform
]
@@ -135,16 +135,6 @@
lr=base_lr,
weight_decay=weight_decay,
batch_size_per_gpu=train_batch_size_per_gpu),
- paramwise_cfg=dict(bias_decay_mult=0.0,
- norm_decay_mult=0.0,
- custom_keys={
- 'backbone.text_model':
- dict(lr_mult=0.01),
- 'logit_scale':
- dict(weight_decay=0.0),
- 'embeddings':
- dict(weight_decay=0.0)
- }),
constructor='YOLOWv5OptimizerConstructor')
# evaluation settings
@@ -153,4 +143,3 @@
proposal_nums=(100, 1, 10),
ann_file='data/coco/annotations/instances_val2017.json',
metric='bbox')
-find_unused_parameters = True
diff --git a/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py b/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_2e-3_80e_8gpus_mask-refine_finetune_coco.py
similarity index 99%
rename from configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py
rename to configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_2e-3_80e_8gpus_mask-refine_finetune_coco.py
index 261a4123..1a843973 100644
--- a/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py
+++ b/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_2e-3_80e_8gpus_mask-refine_finetune_coco.py
@@ -11,7 +11,7 @@
text_channels = 512
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
-base_lr = 1e-3
+base_lr = 2e-3
weight_decay = 0.0005
train_batch_size_per_gpu = 16
load_from = '../FastDet/output_models/yolo_world_s_clip_t2i_bn_2e-3adamw_32xb16-100e_obj365v1_goldg_train-55b943ea_rep_conv.pth'
diff --git a/configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py b/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_5e-4_80e_8gpus_mask-refine_finetune_coco.py
similarity index 96%
rename from configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py
rename to configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_5e-4_80e_8gpus_mask-refine_finetune_coco.py
index 7a1c7236..6df7c4c5 100644
--- a/configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py
+++ b/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_5e-4_80e_8gpus_mask-refine_finetune_coco.py
@@ -11,7 +11,7 @@
text_channels = 512
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
-base_lr = 1e-3
+base_lr = 5e-4
weight_decay = 0.0005
train_batch_size_per_gpu = 16
load_from = '../FastDet/output_models/yolo_world_s_clip_t2i_bn_2e-3adamw_32xb16-100e_obj365v1_goldg_train-55b943ea_rep_conv.pth'
@@ -32,11 +32,10 @@
image_model={{_base_.model.backbone}},
with_text_model=False),
neck=dict(type='YOLOWorldPAFPN',
- guide_channels=num_classes,
+ guide_channels=text_channels,
embed_channels=neck_embed_channels,
num_heads=neck_num_heads,
- block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
- guide_channels=num_classes)),
+ block_cfg=dict(type='EfficientCSPLayerWithTwoConv')),
bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
embed_dims=text_channels,
num_guide=num_classes,
@@ -140,7 +139,6 @@
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOWv5OptimizerConstructor')
-
# evaluation settings
val_evaluator = dict(_delete_=True,
type='mmdet.CocoMetric',
diff --git a/configs/finetune_coco/yolo_world_v2_x_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py b/configs/finetune_coco/yolo_world_v2_x_rep_efficient_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py
similarity index 82%
rename from configs/finetune_coco/yolo_world_v2_x_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py
rename to configs/finetune_coco/yolo_world_v2_x_rep_efficient_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py
index ac872d55..1a6643b7 100644
--- a/configs/finetune_coco/yolo_world_v2_x_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py
+++ b/configs/finetune_coco/yolo_world_v2_x_rep_efficient_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py
@@ -1,6 +1,9 @@
-_base_ = ('../../third_party/mmyolo/configs/yolov8/'
- 'yolov8_x_mask-refine_syncbn_fast_8xb16-500e_coco.py')
-custom_imports = dict(imports=['yolo_world'], allow_failed_imports=False)
+_base_ = (
+ '../../third_party/mmyolo/configs/yolov8/'
+ 'yolov8_x_mask-refine_syncbn_fast_8xb16-500e_coco.py')
+custom_imports = dict(
+ imports=['yolo_world'],
+ allow_failed_imports=False)
# hyper-parameters
num_classes = 80
@@ -11,12 +14,13 @@
text_channels = 512
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
-base_lr = 2e-3
+base_lr = 2e-4
weight_decay = 0.05
train_batch_size_per_gpu = 16
-load_from = '/group/40034/adriancheng/notebooks/rep_models/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36_repconv.pth'
+load_from = '../YOLOWorld_Master/yolo_models/'
persistent_workers = False
+
# model settings
model = dict(type='SimpleYOLOWorldDetector',
mm_neck=True,
@@ -28,14 +32,12 @@
type='MultiModalYOLOBackbone',
text_model=None,
image_model={{_base_.model.backbone}},
- frozen_stages=4,
with_text_model=False),
neck=dict(type='YOLOWorldPAFPN',
- guide_channels=num_classes,
+ guide_channels=text_channels,
embed_channels=neck_embed_channels,
num_heads=neck_num_heads,
- block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
- guide_channels=num_classes)),
+ block_cfg=dict(type='EfficientCSPLayerWithTwoConv')),
bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
embed_dims=text_channels,
num_guide=num_classes,
@@ -135,16 +137,6 @@
lr=base_lr,
weight_decay=weight_decay,
batch_size_per_gpu=train_batch_size_per_gpu),
- paramwise_cfg=dict(bias_decay_mult=0.0,
- norm_decay_mult=0.0,
- custom_keys={
- 'backbone.text_model':
- dict(lr_mult=0.01),
- 'logit_scale':
- dict(weight_decay=0.0),
- 'embeddings':
- dict(weight_decay=0.0)
- }),
constructor='YOLOWv5OptimizerConstructor')
# evaluation settings
@@ -153,4 +145,3 @@
proposal_nums=(100, 1, 10),
ann_file='data/coco/annotations/instances_val2017.json',
metric='bbox')
-find_unused_parameters = True
diff --git a/deploy/easydeploy/model/model.py b/deploy/easydeploy/model/model.py
index 02301178..21cf50f7 100644
--- a/deploy/easydeploy/model/model.py
+++ b/deploy/easydeploy/model/model.py
@@ -28,12 +28,14 @@ def __init__(self,
baseModel: nn.Module,
backend: MMYOLOBackend,
postprocess_cfg: Optional[ConfigDict] = None,
- with_nms=True):
+ with_nms=True,
+ without_bbox_decoder=False):
super().__init__()
self.baseModel = baseModel
self.baseHead = baseModel.bbox_head
self.backend = backend
self.with_nms = with_nms
+ self.without_bbox_decoder = without_bbox_decoder
if postprocess_cfg is None:
self.with_postprocess = False
else:
@@ -103,7 +105,8 @@ def pred_by_feat(self,
bbox_decoder = yolox_bbox_decoder
else:
bbox_decoder = self.bbox_decoder
-
+ print(bbox_decoder)
+
num_imgs = cls_scores[0].shape[0]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
@@ -112,7 +115,6 @@ def pred_by_feat(self,
device=device)
flatten_priors = torch.cat(mlvl_priors)
-
mlvl_strides = [
flatten_priors.new_full(
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ),
@@ -121,8 +123,6 @@ def pred_by_feat(self,
]
flatten_stride = torch.cat(mlvl_strides)
- # flatten cls_scores, bbox_preds and objectness
- # using score.shape
text_len = cls_scores[0].shape[1]
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, text_len)
@@ -145,7 +145,9 @@ def pred_by_feat(self,
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))
scores = cls_scores
-
+ bboxes = flatten_bbox_preds
+ if self.without_bbox_decoder:
+ return scores, bboxes
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
flatten_stride)
diff --git a/deploy/export_onnx.py b/deploy/export_onnx.py
index 990056f1..183c3938 100644
--- a/deploy/export_onnx.py
+++ b/deploy/export_onnx.py
@@ -37,6 +37,9 @@ def parse_args():
parser.add_argument('--without-nms',
action='store_true',
help='Expore model without NMS')
+ parser.add_argument('--without-bbox-decoder',
+ action='store_true',
+ help='Expore model without Bbox Decoder (for INT8 Quantization)')
parser.add_argument('--work-dir',
default='./work_dirs',
help='Path to save export model')
@@ -129,7 +132,8 @@ def main():
deploy_model = DeployModel(baseModel=baseModel,
backend=backend,
postprocess_cfg=postprocess_cfg,
- with_nms=not args.without_nms)
+ with_nms=not args.without_nms,
+ without_bbox_decoder=args.without_bbox_decoder)
deploy_model.eval()
fake_input = torch.randn(args.batch_size, 3,
diff --git a/docs/reparameterize.md b/docs/reparameterize.md
index 74b26ece..9115783d 100644
--- a/docs/reparameterize.md
+++ b/docs/reparameterize.md
@@ -2,6 +2,10 @@
The reparameterization incorporates text embeddings as parameters into the model. For example, in the final classification layer, text embeddings are reparameterized into a simple 1x1 convolutional layer.
+
+
+
+
### Key Advantages from Reparameterization
> Reparameterized YOLO-World still has zero-shot ability!
@@ -15,13 +19,59 @@ For example, fine-tuning the **reparameterized YOLO-World** obtains *46.3 AP* on
#### 1. Prepare cutstom text embeddings
-You need to generate the text embeddings
-
+You need to generate the text embeddings by [`toos/generate_text_prompts.py`](../tools/generate_text_prompts.py) and save it as a `numpy.array` with shape `NxD`.
#### 2. Reparameterizing
+Reparameterizing will generate a new checkpoint with text embeddings!
+
+Check those files first:
+
+* model checkpoint
+* text embeddings
+We mainly reparameterize two groups of modules:
+
+* head (`YOLOWorldHeadModule`)
+* neck (`MaxSigmoidCSPLayerWithTwoConv`)
+
+```bash
+python tools/reparameterize_yoloworld.py \
+ --model path/to/checkpoint \
+ --out-dir path/to/save/re-parameterized/ \
+ --text-embed path/to/text/embeddings \
+ --conv-neck
+```
#### 3. Prepare the model config
+Please see the sample config: [`finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py`](../configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) for reparameterized training.
+
+
+* `RepConvMaxSigmoidCSPLayerWithTwoConv`:
+
+```python
+neck=dict(type='YOLOWorldPAFPN',
+ guide_channels=num_classes,
+ embed_channels=neck_embed_channels,
+ num_heads=neck_num_heads,
+ block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
+ guide_channels=num_classes)),
+```
+
+* `RepYOLOWorldHeadModule`:
+
+```python
+bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
+ embed_dims=text_channels,
+ num_guide=num_classes,
+ num_classes=num_classes)),
+
+```
+
+#### 4. Reparameterized Training
+
+**Reparameterized YOLO-World** is easier to fine-tune and can be treated as an enhanced and pre-trained YOLOv8!
+
+You can check [`finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py`](../configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) for more details.
\ No newline at end of file
diff --git a/tools/reparameterize_yoloworld.py b/tools/reparameterize_yoloworld.py
index c28a53bf..0257f637 100644
--- a/tools/reparameterize_yoloworld.py
+++ b/tools/reparameterize_yoloworld.py
@@ -49,6 +49,8 @@ def reparameterize_head(state_dict, embeds):
def convert_neck_split_conv(input_state_dict, block_name, text_embeds,
num_heads):
+ if block_name + '.guide_fc.weight' not in input_state_dict:
+ return input_state_dict
guide_fc_weight = input_state_dict[block_name + '.guide_fc.weight']
guide_fc_bias = input_state_dict[block_name + '.guide_fc.bias']
guide = text_embeds @ guide_fc_weight.transpose(0,
@@ -77,12 +79,15 @@ def convert_neck_weight(input_state_dict, block_name, embeds, num_heads):
def reparameterize_neck(state_dict, embeds, type='conv'):
+
neck_blocks = [
'neck.top_down_layers.0.attn_block',
'neck.top_down_layers.1.attn_block',
'neck.bottom_up_layers.0.attn_block',
'neck.bottom_up_layers.1.attn_block'
]
+ if "neck.top_down_layers.0.attn_block.bias" not in state_dict:
+ return state_dict
for block in neck_blocks:
num_heads = state_dict[block + '.bias'].shape[0]
if type == 'conv':
diff --git a/yolo_world/models/layers/yolo_bricks.py b/yolo_world/models/layers/yolo_bricks.py
index ac16ed2e..0c39131c 100644
--- a/yolo_world/models/layers/yolo_bricks.py
+++ b/yolo_world/models/layers/yolo_bricks.py
@@ -542,7 +542,8 @@ def __init__(self,
def forward(self, x: Tensor, guide: Tensor) -> Tensor:
"""Forward process."""
x = self.project_conv(x)
- x = x * x.sigmoid()
+ # remove sigmoid
+ # x = x * x.sigmoid()
return x