Skip to content

Commit

Permalink
Move positional embedding functions from ijepa to utils (#1322)
Browse files Browse the repository at this point in the history
* Add type hints
* Cleanup docstrings
---------

Co-authored-by: guarin <guarin@lightly.ai>
  • Loading branch information
Natyren and guarin authored Jul 17, 2023
1 parent 8ed30b7 commit 615a44f
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 1 deletion.
2 changes: 1 addition & 1 deletion lightly/models/modules/ijepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
self.predictor_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False
)
predictor_pos_embed = _get_2d_sincos_pos_embed(
predictor_pos_embed = utils.get_2d_sincos_pos_embed(
self.predictor_pos_embed.shape[-1], int(num_patches**0.5), cls_token=False
)
self.predictor_pos_embed.data.copy_(
Expand Down
118 changes: 118 additions & 0 deletions lightly/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import warnings
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from numpy.typing import NDArray
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -596,3 +598,119 @@ def repeat_interleave_batch(x, B, repeat):
dim=0,
)
return x


def get_2d_sincos_pos_embed(
embed_dim: int, grid_size: int, cls_token: bool = False
) -> NDArray[np.float_]:
"""Returns 2D sin-cos embeddings. Code from [0].
- [0]: https://github.com/facebookresearch/ijepa
Args:
embed_dim:
Embedding dimension.
grid_size:
Grid height and width. Should usually be set to sqrt(sequence length).
cls_token:
If True, a positional embedding for the class token is prepended to the returned
embeddings.
Returns:
Positional embeddings array with size (grid_size * grid_size, embed_dim) if cls_token is False.
If cls_token is True, a (1 + grid_size * grid_size, embed_dim) array is returned.
"""
grid_h = np.arange(grid_size, dtype=float)
grid_w = np.arange(grid_size, dtype=float)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)

grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed


def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: NDArray[np.int_]
) -> NDArray[np.float_]:
"""Returns 2D sin-cos embeddings grid. Code from [0].
- [0]: https://github.com/facebookresearch/ijepa
Args:
embed_dim:
Embedding dimension.
grid:
2-dimensional grid to embed.
Returns:
Positional embeddings array with size (grid_size * grid_size, embed_dim).
"""
assert embed_dim % 2 == 0

# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)

emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb


def get_1d_sincos_pos_embed(
embed_dim: int, grid_size: int, cls_token: bool = False
) -> NDArray[np.float_]:
"""Returns 1D sin-cos embeddings. Code from [0].
- [0]: https://github.com/facebookresearch/ijepa
Args:
embed_dim:
Embedding dimension.
grid_size:
Grid height and width. Should usually be set to sqrt(sequence length).
cls_token:
If True, a positional embedding for the class token is prepended to the returned
embeddings.
Returns:
Positional embeddings array with size (grid_size, embed_dim) if cls_token is False.
If cls_token is True, a (1 + grid_size, embed_dim) array is returned.
"""
grid = np.arange(grid_size, dtype=float)
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed


def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: NDArray[np.int_]
) -> NDArray[np.float_]:
"""Returns 1D sin-cos embeddings grid. Code from [0].
- [0]: https://github.com/facebookresearch/ijepa
Args:
embed_dim:
Embedding dimension.
pos:
1-dimensional grid to embed.
Returns:
Positional embeddings array with size (grid_size, embed_dim).
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)

pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product

emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)

emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb

0 comments on commit 615a44f

Please sign in to comment.