Skip to content

Commit

Permalink
ViT implementation (with patchify and conv stem) (facebookresearch#153)
Browse files Browse the repository at this point in the history
Summary:
References:
    https://arxiv.org/abs/2010.11929 (An Image is Worth 16x16 Words)
    https://arxiv.org/abs/2106.14881 (Early Convolutions Help Transformers See Better)
    https://github.com/google-research/vision_transformer

Pull Request resolved: facebookresearch#153

Reviewed By: pdollar

Differential Revision: D29593362

Pulled By: Tete-Xiao

fbshipit-source-id: caaeb5f912173b618ba0fa1d219445c06d2de6e8
  • Loading branch information
Tete Xiao authored and facebook-github-bot committed Jul 7, 2021
1 parent b0316d8 commit d74b462
Show file tree
Hide file tree
Showing 5 changed files with 416 additions and 7 deletions.
10 changes: 9 additions & 1 deletion pycls/core/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@
from pycls.models.effnet import EffNet
from pycls.models.regnet import RegNet
from pycls.models.resnet import ResNet
from pycls.models.vit import ViT


# Supported models
_models = {"anynet": AnyNet, "effnet": EffNet, "resnet": ResNet, "regnet": RegNet}
_models = {
"anynet": AnyNet,
"effnet": EffNet,
"resnet": ResNet,
"regnet": RegNet,
"vit": ViT,
}


# Supported loss functions
_loss_funs = {"cross_entropy": SoftCrossEntropyLoss}
Expand Down
34 changes: 34 additions & 0 deletions pycls/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,40 @@
_C.EN.DROPOUT_RATIO = 0.0


# ---------------------------- Vision Transformer options ---------------------------- #
_C.VIT = CfgNode()

# Patch Size (TRAIN.IM_SIZE must be divisible by PATCH_SIZE)
_C.VIT.PATCH_SIZE = 16

# Type of stem select from {'patchify', 'conv'}
_C.VIT.STEM_TYPE = "patchify"

# C-stem conv kernel sizes (https://arxiv.org/abs/2106.14881)
_C.VIT.C_STEM_KERNELS = []

# C-stem conv strides (the product of which must equal PATCH_SIZE)
_C.VIT.C_STEM_STRIDES = []

# C-stem conv output dims (last dim must equal HIDDEN_DIM)
_C.VIT.C_STEM_DIMS = []

# Number of layers in the encoder
_C.VIT.NUM_LAYERS = 12

# Number of self attention heads
_C.VIT.NUM_HEADS = 12

# Hidden dimension
_C.VIT.HIDDEN_DIM = 768

# Dimension of the MLP in the encoder
_C.VIT.MLP_DIM = 3072

# Type of classifier select from {'token', 'pooled'}
_C.VIT.CLASSIFIER_TYPE = "token"


# -------------------------------- Batch norm options -------------------------------- #
_C.BN = CfgNode()

Expand Down
74 changes: 69 additions & 5 deletions pycls/models/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def conv2d(w_in, w_out, k, *, stride=1, groups=1, bias=False):
return nn.Conv2d(w_in, w_out, k, stride=s, padding=p, groups=g, bias=b)


def patchify2d(w_in, w_out, k, *, bias=True):
"""Helper for building a patchify layer as used by ViT models."""
return nn.Conv2d(w_in, w_out, k, stride=k, padding=0, bias=bias)


def norm2d(w_in):
"""Helper for building a norm2d layer."""
return nn.BatchNorm2d(num_features=w_in, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
Expand All @@ -40,21 +45,28 @@ def gap2d(_w_in):
return nn.AdaptiveAvgPool2d((1, 1))


def layernorm(w_in):
"""Helper for building a layernorm layer."""
return nn.LayerNorm(w_in, eps=cfg.LN.EPS)


def linear(w_in, w_out, *, bias=False):
"""Helper for building a linear layer."""
return nn.Linear(w_in, w_out, bias=bias)


def activation():
def activation(activation_fun=None):
"""Helper for building an activation layer."""
activation_fun = cfg.MODEL.ACTIVATION_FUN.lower()
activation_fun = (activation_fun or cfg.MODEL.ACTIVATION_FUN).lower()
if activation_fun == "relu":
return nn.ReLU(inplace=cfg.MODEL.ACTIVATION_INPLACE)
elif activation_fun == "silu" or activation_fun == "swish":
try:
return torch.nn.SiLU()
except AttributeError:
return SiLU()
elif activation_fun == "gelu":
return torch.nn.GELU()
else:
raise AssertionError("Unknown MODEL.ACTIVATION_FUN: " + activation_fun)

Expand All @@ -73,6 +85,18 @@ def conv2d_cx(cx, w_in, w_out, k, *, stride=1, groups=1, bias=False):
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}


def patchify2d_cx(cx, w_in, w_out, k, *, bias=True):
"""Accumulates complexity of patchify2d into cx = (h, w, flops, params, acts)."""
err_str = "Only kernel sizes divisible by the input size are supported."
assert cx["h"] % k == 0 and cx["w"] % k == 0, err_str
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
h, w = h // k, w // k
flops += k * k * w_in * w_out * h * w + (w_out * h * w if bias else 0)
params += k * k * w_in * w_out + (w_out if bias else 0)
acts += w_out * h * w
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}


def norm2d_cx(cx, w_in):
"""Accumulates complexity of norm2d into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
Expand All @@ -95,12 +119,19 @@ def gap2d_cx(cx, _w_in):
return {"h": 1, "w": 1, "flops": flops, "params": params, "acts": acts}


def linear_cx(cx, w_in, w_out, *, bias=False):
def layernorm_cx(cx, w_in):
"""Accumulates complexity of layernorm into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
params += 2 * w_in
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}


def linear_cx(cx, w_in, w_out, *, bias=False, num_locations=1):
"""Accumulates complexity of linear into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
flops += w_in * w_out + (w_out if bias else 0)
flops += w_in * w_out * num_locations + (w_out * num_locations if bias else 0)
params += w_in * w_out + (w_out if bias else 0)
acts += w_out
acts += w_out * num_locations
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}


Expand Down Expand Up @@ -145,6 +176,39 @@ def complexity(cx, w_in, w_se):
return cx


class MultiheadAttention(Module):
"""Multi-head Attention block from Transformer models."""

def __init__(self, hidden_d, n_heads):
super(MultiheadAttention, self).__init__()
self.block = nn.MultiheadAttention(hidden_d, n_heads)

def forward(self, query, key, value, need_weights=False):
return self.block(query=query, key=key, value=value, need_weights=need_weights)

@staticmethod
def complexity(cx, hidden_d, n_heads, seq_len):
# See https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
h, w = cx["h"], cx["w"]
flops, params, acts = cx["flops"], cx["params"], cx["acts"]
# q, k, v = linear(input).chunk(3)
flops += seq_len * (hidden_d * hidden_d * 3 + hidden_d * 3)
params += hidden_d * hidden_d * 3 + hidden_d * 3
acts += hidden_d * 3 * seq_len
# attn_output_weights = torch.bmm(q, k.transpose)
head_d = hidden_d // n_heads
flops += n_heads * (seq_len * head_d * seq_len)
acts += n_heads * seq_len * seq_len
# attn_output = torch.bmm(attn_output_weights, v)
flops += n_heads * (seq_len * seq_len * head_d)
acts += n_heads * seq_len * head_d
# attn_output = linear(attn_output)
flops += seq_len * (hidden_d * hidden_d + hidden_d)
params += hidden_d * hidden_d + hidden_d
acts += hidden_d * seq_len
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}


# ---------------------------------- Miscellaneous ----------------------------------- #


Expand Down
Loading

0 comments on commit d74b462

Please sign in to comment.