Authors: Wooseok Gwak
This is the repository that implements Vision Transformer (Alexey Dosovitskiy et al, 2021) using tensorflow 2. The paper can be found at here.
The official Jax repository is here
I recommend using python3 and conda virtual environment.
conda create -n myenv python=3.7
conda activate myenv
conda install --yes --file requirements.txt
After making a virtual environment, download the git repository and use the model for your own project. When you're done working on the project, deactivate the virtual environment with conda deactivate
.
import tensorflow as tf
from model.model import model
vit = ViT(
d_model = 50
mlp_dim = 100,
num_heads = 10,
dropout_rate = 0.1,
num_layers = 3,
patch_size = 32,
num_classes = 102
)
img = np.randn(1, 3, 256, 256)
preds = vit(img)
Because of dependeny problem for anaconda packages, I use tensorflow 2.3 and write the code for multi head attention. (the code can be found from here) I recommend to use tf.keras.layers.MultiHeadAttention from tensorflow 2.5~.
python train.py
train.py is sample training code to verify whether it performs the desired operation. You can change the file to train the model on specific dataset.
- 2021.11.30 : WARNING:tensorflow:'gradients do not exist for variables' is not resolved!