diff --git a/pytorch_pretrained_vit/__init__.py b/pytorch_pretrained_vit/__init__.py index a303bf7..8767bc1 100644 --- a/pytorch_pretrained_vit/__init__.py +++ b/pytorch_pretrained_vit/__init__.py @@ -1,5 +1,5 @@ -__version__ = "0.0.6" +__version__ = "0.0.7" from .model import ViT from .configs import * -from .utils import load_pretrained_weights \ No newline at end of file +from .utils import load_pretrained_weights diff --git a/pytorch_pretrained_vit/utils.py b/pytorch_pretrained_vit/utils.py index 0ef76fc..da40561 100755 --- a/pytorch_pretrained_vit/utils.py +++ b/pytorch_pretrained_vit/utils.py @@ -1,6 +1,7 @@ """utils.py - Helper functions """ +import numpy as np import torch from torch.utils import model_zoo @@ -107,6 +108,7 @@ def resize_positional_embedding_(posemb, posemb_new, has_class_token=True): zoom_factor = (gs_new / gs_old, gs_new / gs_old, 1) posemb_grid = zoom(posemb_grid, zoom_factor, order=1) posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb_grid = torch.from_numpy(posemb_grid) # Deal with class token and return posemb = torch.cat([posemb_tok, posemb_grid], dim=1) diff --git a/setup.py b/setup.py index 24f6112..0baab1a 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ EMAIL = 'luke.melas@gmail.com' AUTHOR = 'Luke' REQUIRES_PYTHON = '>=3.5.0' -VERSION = '0.0.6' +VERSION = '0.0.7' # What packages are required for this module to be executed? REQUIRED = [