diff --git a/README.md b/README.md index 57dd4bde..8c248132 100644 --- a/README.md +++ b/README.md @@ -170,7 +170,19 @@ chmod +x tools/dist_test.sh ## Fine-tuning YOLO-World -We provide the details about fine-tuning YOLO-World in [docs/fine-tuning](./docs/finetuning.md). +
+ +
+ + +YOLO-World supports **zero-shot inference**, and three types of **fine-tuning recipes**: **(1) normal fine-tuning**, **(2) prompt tuning**, and **(3) reparameterized fine-tuning**. + + +* Normal Fine-tuning: we provide the details about fine-tuning YOLO-World in [docs/fine-tuning](./docs/finetuning.md). + +* Prompt Tuning: we provide more details ahout prompt tuning in [docs/prompt_yolo_world](./docs/prompt_yolo_world.md). + +* Reparameterized Fine-tuning: the reparameterized YOLO-World is more suitable for specific domains far from generic scenes. You can find more details in [`docs/reparameterize`](./docs/reparameterize.md). ## Deployment diff --git a/assets/finetune_yoloworld.png b/assets/finetune_yoloworld.png new file mode 100644 index 00000000..235230e4 Binary files /dev/null and b/assets/finetune_yoloworld.png differ diff --git a/assets/reparameterize.png b/assets/reparameterize.png new file mode 100644 index 00000000..56a52de9 Binary files /dev/null and b/assets/reparameterize.png differ diff --git a/configs/finetune_coco/README.md b/configs/finetune_coco/README.md index b70ba02c..84e06f6b 100644 --- a/configs/finetune_coco/README.md +++ b/configs/finetune_coco/README.md @@ -18,11 +18,13 @@ BTW, the COCO fine-tuning results are updated with higher performance (with `mas | [YOLO-World-v2-S](./yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 37.5 | 46.1 | 62.0 | 49.9 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-492dc329.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240327_110411.log) | | [YOLO-World-v2-M](./yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 42.8 | 51.0 | 67.5 | 55.2 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-69c27ac7.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240327_110411.log) | | [YOLO-World-v2-L](./yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 45.1 | 53.9 | 70.9 | 58.8 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-81c701ee.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240326_160313.log) | -| [YOLO-World-v2-L](./yolo_world_v2_l_efficient_neck_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✔️ | 45.1 | | | | [HF Checkpoints]() | [log]() | | [YOLO-World-v2-X](./yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 46.8 | 54.7 | 71.6 | 59.6 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-76bc0cbd.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240322_181232.log) | | [YOLO-World-v2-L](./yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetune_coco.py) 🔥 | SGD, 1e-3, 40e | ✖️ | ✖️ | 45.1 | 52.8 | 69.5 | 57.8 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetune_coco_ep80-e1288152.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetuning_coco_20240327_014902.log) | +### Reparameterized Training -### Reparameterized Training +| model | Schedule | `mask-refine` | efficient neck | APZS| AP | AP50 | AP75 | weights | log | +| :---- | :-------: | :----------: |:-------------: | :------------: | :-: | :--------------:| :-------------: |:------: | :-: | +| [YOLO-World-v2-S](./yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 37.5 | 46.3 | 62.8 | 50.4 | [HF Checkpoints]() | [log]() | \ No newline at end of file diff --git a/configs/finetune_coco/yolo_world_v2_l_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py b/configs/finetune_coco/yolo_world_v2_s_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py similarity index 82% rename from configs/finetune_coco/yolo_world_v2_l_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py rename to configs/finetune_coco/yolo_world_v2_s_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py index ec275a29..49801101 100644 --- a/configs/finetune_coco/yolo_world_v2_l_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py +++ b/configs/finetune_coco/yolo_world_v2_s_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py @@ -1,5 +1,5 @@ _base_ = ('../../third_party/mmyolo/configs/yolov8/' - 'yolov8_l_mask-refine_syncbn_fast_8xb16-500e_coco.py') + 'yolov8_s_mask-refine_syncbn_fast_8xb16-500e_coco.py') custom_imports = dict(imports=['yolo_world'], allow_failed_imports=False) # hyper-parameters @@ -11,11 +11,13 @@ text_channels = 512 neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2] neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32] -base_lr = 2e-3 +base_lr = 2e-4 weight_decay = 0.05 train_batch_size_per_gpu = 16 -load_from = '/group/40034/adriancheng/notebooks/rep_models/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36_repconv.pth' +load_from = '../FastDet/output_models/pretrain_yolow-v8_s_clipv2_frozen_te_noprompt_t2i_bn_2e-3adamw_scale_lr_wd_32xb16-100e_obj365v1_goldg_cc3mram250k_train_lviseval-e3592307_rep_conv.pth' persistent_workers = False +mixup_prob = 0.15 +copypaste_prob = 0.3 # model settings model = dict(type='SimpleYOLOWorldDetector', @@ -28,14 +30,12 @@ type='MultiModalYOLOBackbone', text_model=None, image_model={{_base_.model.backbone}}, - frozen_stages=4, with_text_model=False), neck=dict(type='YOLOWorldPAFPN', - guide_channels=num_classes, + guide_channels=text_channels, embed_channels=neck_embed_channels, num_heads=neck_num_heads, - block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv', - guide_channels=num_classes)), + block_cfg=dict(type='EfficientCSPLayerWithTwoConv')), bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule', embed_dims=text_channels, num_guide=num_classes, @@ -53,7 +53,7 @@ img_scale=_base_.img_scale, pad_val=114.0, pre_transform=_base_.pre_transform), - dict(type='YOLOv5CopyPaste', prob=_base_.copypaste_prob), + dict(type='YOLOv5CopyPaste', prob=copypaste_prob), dict( type='YOLOv5RandomAffine', max_rotate_degree=0.0, @@ -69,7 +69,7 @@ train_pipeline = [ *_base_.pre_transform, *mosaic_affine_transform, dict(type='YOLOv5MixUp', - prob=_base_.mixup_prob, + prob=mixup_prob, pre_transform=[*_base_.pre_transform, *mosaic_affine_transform]), *_base_.last_transform[:-1], *final_transform ] @@ -135,16 +135,6 @@ lr=base_lr, weight_decay=weight_decay, batch_size_per_gpu=train_batch_size_per_gpu), - paramwise_cfg=dict(bias_decay_mult=0.0, - norm_decay_mult=0.0, - custom_keys={ - 'backbone.text_model': - dict(lr_mult=0.01), - 'logit_scale': - dict(weight_decay=0.0), - 'embeddings': - dict(weight_decay=0.0) - }), constructor='YOLOWv5OptimizerConstructor') # evaluation settings @@ -153,4 +143,3 @@ proposal_nums=(100, 1, 10), ann_file='data/coco/annotations/instances_val2017.json', metric='bbox') -find_unused_parameters = True diff --git a/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py b/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_2e-3_80e_8gpus_mask-refine_finetune_coco.py similarity index 99% rename from configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py rename to configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_2e-3_80e_8gpus_mask-refine_finetune_coco.py index 261a4123..1a843973 100644 --- a/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py +++ b/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_2e-3_80e_8gpus_mask-refine_finetune_coco.py @@ -11,7 +11,7 @@ text_channels = 512 neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2] neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32] -base_lr = 1e-3 +base_lr = 2e-3 weight_decay = 0.0005 train_batch_size_per_gpu = 16 load_from = '../FastDet/output_models/yolo_world_s_clip_t2i_bn_2e-3adamw_32xb16-100e_obj365v1_goldg_train-55b943ea_rep_conv.pth' diff --git a/configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py b/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_5e-4_80e_8gpus_mask-refine_finetune_coco.py similarity index 96% rename from configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py rename to configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_5e-4_80e_8gpus_mask-refine_finetune_coco.py index 7a1c7236..6df7c4c5 100644 --- a/configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py +++ b/configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_5e-4_80e_8gpus_mask-refine_finetune_coco.py @@ -11,7 +11,7 @@ text_channels = 512 neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2] neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32] -base_lr = 1e-3 +base_lr = 5e-4 weight_decay = 0.0005 train_batch_size_per_gpu = 16 load_from = '../FastDet/output_models/yolo_world_s_clip_t2i_bn_2e-3adamw_32xb16-100e_obj365v1_goldg_train-55b943ea_rep_conv.pth' @@ -32,11 +32,10 @@ image_model={{_base_.model.backbone}}, with_text_model=False), neck=dict(type='YOLOWorldPAFPN', - guide_channels=num_classes, + guide_channels=text_channels, embed_channels=neck_embed_channels, num_heads=neck_num_heads, - block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv', - guide_channels=num_classes)), + block_cfg=dict(type='EfficientCSPLayerWithTwoConv')), bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule', embed_dims=text_channels, num_guide=num_classes, @@ -140,7 +139,6 @@ batch_size_per_gpu=train_batch_size_per_gpu), constructor='YOLOWv5OptimizerConstructor') - # evaluation settings val_evaluator = dict(_delete_=True, type='mmdet.CocoMetric', diff --git a/configs/finetune_coco/yolo_world_v2_x_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py b/configs/finetune_coco/yolo_world_v2_x_rep_efficient_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py similarity index 82% rename from configs/finetune_coco/yolo_world_v2_x_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py rename to configs/finetune_coco/yolo_world_v2_x_rep_efficient_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py index ac872d55..1a6643b7 100644 --- a/configs/finetune_coco/yolo_world_v2_x_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py +++ b/configs/finetune_coco/yolo_world_v2_x_rep_efficient_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py @@ -1,6 +1,9 @@ -_base_ = ('../../third_party/mmyolo/configs/yolov8/' - 'yolov8_x_mask-refine_syncbn_fast_8xb16-500e_coco.py') -custom_imports = dict(imports=['yolo_world'], allow_failed_imports=False) +_base_ = ( + '../../third_party/mmyolo/configs/yolov8/' + 'yolov8_x_mask-refine_syncbn_fast_8xb16-500e_coco.py') +custom_imports = dict( + imports=['yolo_world'], + allow_failed_imports=False) # hyper-parameters num_classes = 80 @@ -11,12 +14,13 @@ text_channels = 512 neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2] neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32] -base_lr = 2e-3 +base_lr = 2e-4 weight_decay = 0.05 train_batch_size_per_gpu = 16 -load_from = '/group/40034/adriancheng/notebooks/rep_models/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36_repconv.pth' +load_from = '../YOLOWorld_Master/yolo_models/' persistent_workers = False + # model settings model = dict(type='SimpleYOLOWorldDetector', mm_neck=True, @@ -28,14 +32,12 @@ type='MultiModalYOLOBackbone', text_model=None, image_model={{_base_.model.backbone}}, - frozen_stages=4, with_text_model=False), neck=dict(type='YOLOWorldPAFPN', - guide_channels=num_classes, + guide_channels=text_channels, embed_channels=neck_embed_channels, num_heads=neck_num_heads, - block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv', - guide_channels=num_classes)), + block_cfg=dict(type='EfficientCSPLayerWithTwoConv')), bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule', embed_dims=text_channels, num_guide=num_classes, @@ -135,16 +137,6 @@ lr=base_lr, weight_decay=weight_decay, batch_size_per_gpu=train_batch_size_per_gpu), - paramwise_cfg=dict(bias_decay_mult=0.0, - norm_decay_mult=0.0, - custom_keys={ - 'backbone.text_model': - dict(lr_mult=0.01), - 'logit_scale': - dict(weight_decay=0.0), - 'embeddings': - dict(weight_decay=0.0) - }), constructor='YOLOWv5OptimizerConstructor') # evaluation settings @@ -153,4 +145,3 @@ proposal_nums=(100, 1, 10), ann_file='data/coco/annotations/instances_val2017.json', metric='bbox') -find_unused_parameters = True diff --git a/deploy/easydeploy/model/model.py b/deploy/easydeploy/model/model.py index 02301178..21cf50f7 100644 --- a/deploy/easydeploy/model/model.py +++ b/deploy/easydeploy/model/model.py @@ -28,12 +28,14 @@ def __init__(self, baseModel: nn.Module, backend: MMYOLOBackend, postprocess_cfg: Optional[ConfigDict] = None, - with_nms=True): + with_nms=True, + without_bbox_decoder=False): super().__init__() self.baseModel = baseModel self.baseHead = baseModel.bbox_head self.backend = backend self.with_nms = with_nms + self.without_bbox_decoder = without_bbox_decoder if postprocess_cfg is None: self.with_postprocess = False else: @@ -103,7 +105,8 @@ def pred_by_feat(self, bbox_decoder = yolox_bbox_decoder else: bbox_decoder = self.bbox_decoder - + print(bbox_decoder) + num_imgs = cls_scores[0].shape[0] featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] @@ -112,7 +115,6 @@ def pred_by_feat(self, device=device) flatten_priors = torch.cat(mlvl_priors) - mlvl_strides = [ flatten_priors.new_full( (featmap_size[0] * featmap_size[1] * self.num_base_priors, ), @@ -121,8 +123,6 @@ def pred_by_feat(self, ] flatten_stride = torch.cat(mlvl_strides) - # flatten cls_scores, bbox_preds and objectness - # using score.shape text_len = cls_scores[0].shape[1] flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, text_len) @@ -145,7 +145,9 @@ def pred_by_feat(self, cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1)) scores = cls_scores - + bboxes = flatten_bbox_preds + if self.without_bbox_decoder: + return scores, bboxes bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds, flatten_stride) diff --git a/deploy/export_onnx.py b/deploy/export_onnx.py index 990056f1..183c3938 100644 --- a/deploy/export_onnx.py +++ b/deploy/export_onnx.py @@ -37,6 +37,9 @@ def parse_args(): parser.add_argument('--without-nms', action='store_true', help='Expore model without NMS') + parser.add_argument('--without-bbox-decoder', + action='store_true', + help='Expore model without Bbox Decoder (for INT8 Quantization)') parser.add_argument('--work-dir', default='./work_dirs', help='Path to save export model') @@ -129,7 +132,8 @@ def main(): deploy_model = DeployModel(baseModel=baseModel, backend=backend, postprocess_cfg=postprocess_cfg, - with_nms=not args.without_nms) + with_nms=not args.without_nms, + without_bbox_decoder=args.without_bbox_decoder) deploy_model.eval() fake_input = torch.randn(args.batch_size, 3, diff --git a/docs/reparameterize.md b/docs/reparameterize.md index 74b26ece..9115783d 100644 --- a/docs/reparameterize.md +++ b/docs/reparameterize.md @@ -2,6 +2,10 @@ The reparameterization incorporates text embeddings as parameters into the model. For example, in the final classification layer, text embeddings are reparameterized into a simple 1x1 convolutional layer. +
+ +
+ ### Key Advantages from Reparameterization > Reparameterized YOLO-World still has zero-shot ability! @@ -15,13 +19,59 @@ For example, fine-tuning the **reparameterized YOLO-World** obtains *46.3 AP* on #### 1. Prepare cutstom text embeddings -You need to generate the text embeddings - +You need to generate the text embeddings by [`toos/generate_text_prompts.py`](../tools/generate_text_prompts.py) and save it as a `numpy.array` with shape `NxD`. #### 2. Reparameterizing +Reparameterizing will generate a new checkpoint with text embeddings! + +Check those files first: + +* model checkpoint +* text embeddings +We mainly reparameterize two groups of modules: + +* head (`YOLOWorldHeadModule`) +* neck (`MaxSigmoidCSPLayerWithTwoConv`) + +```bash +python tools/reparameterize_yoloworld.py \ + --model path/to/checkpoint \ + --out-dir path/to/save/re-parameterized/ \ + --text-embed path/to/text/embeddings \ + --conv-neck +``` #### 3. Prepare the model config +Please see the sample config: [`finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py`](../configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) for reparameterized training. + + +* `RepConvMaxSigmoidCSPLayerWithTwoConv`: + +```python +neck=dict(type='YOLOWorldPAFPN', + guide_channels=num_classes, + embed_channels=neck_embed_channels, + num_heads=neck_num_heads, + block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv', + guide_channels=num_classes)), +``` + +* `RepYOLOWorldHeadModule`: + +```python +bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule', + embed_dims=text_channels, + num_guide=num_classes, + num_classes=num_classes)), + +``` + +#### 4. Reparameterized Training + +**Reparameterized YOLO-World** is easier to fine-tune and can be treated as an enhanced and pre-trained YOLOv8! + +You can check [`finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py`](../configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) for more details. \ No newline at end of file diff --git a/tools/reparameterize_yoloworld.py b/tools/reparameterize_yoloworld.py index c28a53bf..0257f637 100644 --- a/tools/reparameterize_yoloworld.py +++ b/tools/reparameterize_yoloworld.py @@ -49,6 +49,8 @@ def reparameterize_head(state_dict, embeds): def convert_neck_split_conv(input_state_dict, block_name, text_embeds, num_heads): + if block_name + '.guide_fc.weight' not in input_state_dict: + return input_state_dict guide_fc_weight = input_state_dict[block_name + '.guide_fc.weight'] guide_fc_bias = input_state_dict[block_name + '.guide_fc.bias'] guide = text_embeds @ guide_fc_weight.transpose(0, @@ -77,12 +79,15 @@ def convert_neck_weight(input_state_dict, block_name, embeds, num_heads): def reparameterize_neck(state_dict, embeds, type='conv'): + neck_blocks = [ 'neck.top_down_layers.0.attn_block', 'neck.top_down_layers.1.attn_block', 'neck.bottom_up_layers.0.attn_block', 'neck.bottom_up_layers.1.attn_block' ] + if "neck.top_down_layers.0.attn_block.bias" not in state_dict: + return state_dict for block in neck_blocks: num_heads = state_dict[block + '.bias'].shape[0] if type == 'conv': diff --git a/yolo_world/models/layers/yolo_bricks.py b/yolo_world/models/layers/yolo_bricks.py index ac16ed2e..0c39131c 100644 --- a/yolo_world/models/layers/yolo_bricks.py +++ b/yolo_world/models/layers/yolo_bricks.py @@ -542,7 +542,8 @@ def __init__(self, def forward(self, x: Tensor, guide: Tensor) -> Tensor: """Forward process.""" x = self.project_conv(x) - x = x * x.sigmoid() + # remove sigmoid + # x = x * x.sigmoid() return x