From 7ef63f265bc1d167c19088f3272259fc69014308 Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Sun, 10 Apr 2022 16:35:32 -0400 Subject: [PATCH] Patch clip model for ONNX compatibility (#219) * Patch clip model for ONNX compatibility Changes to use INT32 for tokenization, since ONNX doesn't yet support ArgMax(INT64) Use explicit dimension for norm * Add compatibility fix for torch 1.7 --- clip/clip.py | 10 +++++++--- clip/model.py | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index 2c911d060..00abbc77c 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -192,7 +192,7 @@ def patch_float(module): return model, _transform(model.input_resolution.item()) -def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: """ Returns the tokenized representation of given input string(s) @@ -209,7 +209,8 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b Returns ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. """ if isinstance(texts, str): texts = [texts] @@ -217,7 +218,10 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b sot_token = _tokenizer.encoder["<|startoftext|>"] eot_token = _tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: diff --git a/clip/model.py b/clip/model.py index f7958f171..e743d2c78 100644 --- a/clip/model.py +++ b/clip/model.py @@ -356,8 +356,8 @@ def forward(self, image, text): text_features = self.encode_text(text) # normalized features - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - text_features = text_features / text_features.norm(dim=-1, keepdim=True) + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp()