Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wondervictor committed Mar 15, 2024
1 parent 83601a1 commit 323386a
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions yolo_world/models/backbones/mm_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,15 @@ def __init__(self,
image_model: ConfigType,
text_model: ConfigType,
frozen_stages: int = -1,
with_text_model: bool = True,
init_cfg: OptMultiConfig = None) -> None:

super().__init__(init_cfg)

self.with_text_model = with_text_model
self.image_model = MODELS.build(image_model)
self.text_model = MODELS.build(text_model)
if self.with_text_model:
self.text_model = MODELS.build(text_model)
else:
self.text_model = None
self.frozen_stages = frozen_stages
self._freeze_stages()

Expand All @@ -225,5 +228,8 @@ def train(self, mode: bool = True):
def forward(self, image: Tensor,
text: List[List[str]]) -> Tuple[Tuple[Tensor], Tensor]:
img_feats = self.image_model(image)
txt_feats = self.text_model(text)
return img_feats, txt_feats
if self.with_text_model:
txt_feats = self.text_model(text)
return img_feats, txt_feats
else:
return img_feats, None

0 comments on commit 323386a

Please sign in to comment.