Skip to content

Commit

Permalink
Patch clip model for ONNX compatibility (openai#219)
Browse files Browse the repository at this point in the history
* Patch clip model for ONNX compatibility

Changes to use INT32 for tokenization, since ONNX doesn't yet support ArgMax(INT64)
Use explicit dimension for norm

* Add compatibility fix for torch 1.7
  • Loading branch information
chajath authored Apr 10, 2022
1 parent 40f5484 commit 7ef63f2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
10 changes: 7 additions & 3 deletions clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def patch_float(module):
return model, _transform(model.input_resolution.item())


def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
Expand All @@ -209,15 +209,19 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
"""
if isinstance(texts, str):
texts = [texts]

sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)

for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
Expand Down
4 changes: 2 additions & 2 deletions clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ def forward(self, image, text):
text_features = self.encode_text(text)

# normalized features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)

# cosine similarity as logits
logit_scale = self.logit_scale.exp()
Expand Down

0 comments on commit 7ef63f2

Please sign in to comment.