Skip to content

Commit

Permalink
Fix casting of emb_matrix if it is not loaded.
Browse files Browse the repository at this point in the history
  • Loading branch information
kohjingyu committed Jul 17, 2023
1 parent c7de07a commit d85ad06
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion gill/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ def load_gill(model_dir: str, load_ret_embs: bool = True) -> GILL:
assert model_kwargs['share_ret_gen'], 'Model loading only supports share_ret_gen=True for now.'
model.model.input_embeddings.weight[-model_kwargs['num_tokens']:, :].copy_(img_token_embeddings)

if len(embs_paths) > 0:
if load_ret_embs and len(embs_paths) > 0:
logit_scale = model.model.logit_scale.exp()
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)
Expand Down

0 comments on commit d85ad06

Please sign in to comment.