Skip to content

lumo-tech/TokenLearner-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TokenLearner-pytorch

Unofficial reimplementation of TokenLearner by Google AI.

Only ViT version.

usage

import torch
from models.vit_tokenlearner import ViT

v = ViT(
    image_size=256,
    num_tokens=8,
    fuse=False,
    v11=True,
    tokenlearner_loc=3,
    patch_size=16,
    hidden_size=768,
    depth=6,
    heads=16,
    mlp_dim=2048,
    dropout=0.1,
    emb_dropout=0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img)  # (1, 1000)
print(preds.shape)

Reference

About

Unofficial reimplementation of TokenLearner by Google AI

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages