diff --git a/lightly/models/modules/ijepa.py b/lightly/models/modules/ijepa.py index 3eb14a247..e07376623 100644 --- a/lightly/models/modules/ijepa.py +++ b/lightly/models/modules/ijepa.py @@ -72,7 +72,7 @@ def __init__( self.predictor_pos_embed = nn.Parameter( torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False ) - predictor_pos_embed = _get_2d_sincos_pos_embed( + predictor_pos_embed = utils.get_2d_sincos_pos_embed( self.predictor_pos_embed.shape[-1], int(num_patches**0.5), cls_token=False ) self.predictor_pos_embed.data.copy_( diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 7215395a3..bd7c4e26d 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -7,9 +7,11 @@ import warnings from typing import Iterable, List, Optional, Tuple, Union +import numpy as np import torch import torch.distributed as dist import torch.nn as nn +from numpy.typing import NDArray from torch.nn import Module from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.parameter import Parameter @@ -596,3 +598,119 @@ def repeat_interleave_batch(x, B, repeat): dim=0, ) return x + + +def get_2d_sincos_pos_embed( + embed_dim: int, grid_size: int, cls_token: bool = False +) -> NDArray[np.float_]: + """Returns 2D sin-cos embeddings. Code from [0]. + + - [0]: https://github.com/facebookresearch/ijepa + + Args: + embed_dim: + Embedding dimension. + grid_size: + Grid height and width. Should usually be set to sqrt(sequence length). + cls_token: + If True, a positional embedding for the class token is prepended to the returned + embeddings. + + Returns: + Positional embeddings array with size (grid_size * grid_size, embed_dim) if cls_token is False. + If cls_token is True, a (1 + grid_size * grid_size, embed_dim) array is returned. + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: NDArray[np.int_] +) -> NDArray[np.float_]: + """Returns 2D sin-cos embeddings grid. Code from [0]. + + - [0]: https://github.com/facebookresearch/ijepa + + Args: + embed_dim: + Embedding dimension. + grid: + 2-dimensional grid to embed. + + Returns: + Positional embeddings array with size (grid_size * grid_size, embed_dim). + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed( + embed_dim: int, grid_size: int, cls_token: bool = False +) -> NDArray[np.float_]: + """Returns 1D sin-cos embeddings. Code from [0]. + + - [0]: https://github.com/facebookresearch/ijepa + + Args: + embed_dim: + Embedding dimension. + grid_size: + Grid height and width. Should usually be set to sqrt(sequence length). + cls_token: + If True, a positional embedding for the class token is prepended to the returned + embeddings. + + Returns: + Positional embeddings array with size (grid_size, embed_dim) if cls_token is False. + If cls_token is True, a (1 + grid_size, embed_dim) array is returned. + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: NDArray[np.int_] +) -> NDArray[np.float_]: + """Returns 1D sin-cos embeddings grid. Code from [0]. + + - [0]: https://github.com/facebookresearch/ijepa + + Args: + embed_dim: + Embedding dimension. + pos: + 1-dimensional grid to embed. + + Returns: + Positional embeddings array with size (grid_size, embed_dim). + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb