-
-
Notifications
You must be signed in to change notification settings - Fork 66
/
vit.jl
109 lines (91 loc) · 4.43 KB
/
vit.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.)
Transformer as used in the base ViT architecture.
([reference](https://arxiv.org/abs/2010.11929)).
# Arguments
- `planes`: number of input channels
- `depth`: number of attention blocks
- `nheads`: number of attention heads
- `mlp_ratio`: ratio of MLP layers to the number of input channels
- `dropout`: dropout rate
"""
function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.0)
layers = [Chain(SkipConnection(prenorm(planes,
MHAttention(planes, nheads; attn_drop = dropout,
proj_drop = dropout)), +),
SkipConnection(prenorm(planes,
mlp_block(planes, floor(Int, mlp_ratio * planes);
dropout)), +))
for _ in 1:depth]
return Chain(layers)
end
"""
vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16),
embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1,
emb_dropout = 0.1, pool = :class, nclasses = 1000)
Creates a Vision Transformer (ViT) model.
([reference](https://arxiv.org/abs/2010.11929)).
# Arguments
- `imsize`: image size
- `inchannels`: number of input channels
- `patch_size`: size of the patches
- `embedplanes`: the number of channels after the patch embedding
- `depth`: number of blocks in the transformer
- `nheads`: number of attention heads in the transformer
- `mlpplanes`: number of hidden channels in the MLP block in the transformer
- `dropout`: dropout rate
- `emb_dropout`: dropout rate for the positional embedding layer
- `pool`: pooling type, either :class or :mean
- `nclasses`: number of classes in the output
"""
function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16),
embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1,
emb_dropout = 0.1, pool = :class, nclasses = 1000)
@assert pool in [:class, :mean]
"Pool type must be either :class (class token) or :mean (mean pooling)"
npatches = prod(imsize .÷ patch_size)
return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes),
ClassTokens(embedplanes),
ViPosEmbedding(embedplanes, npatches + 1),
Dropout(emb_dropout),
transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout),
(pool == :class) ? x -> x[:, 1, :] : seconddimmean),
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast)))
end
vit_configs = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3),
:small => (depth = 12, embedplanes = 384, nheads = 6),
:base => (depth = 12, embedplanes = 768, nheads = 12),
:large => (depth = 24, embedplanes = 1024, nheads = 16),
:huge => (depth = 32, embedplanes = 1280, nheads = 16),
:giant => (depth = 40, embedplanes = 1408, nheads = 16,
mlp_ratio = 48 // 11),
:gigantic => (depth = 48, embedplanes = 1664, nheads = 16,
mlp_ratio = 64 // 13))
"""
ViT(mode::Symbol = base; imsize::Dims{2} = (256, 256), inchannels = 3,
patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000)
Creates a Vision Transformer (ViT) model.
([reference](https://arxiv.org/abs/2010.11929)).
# Arguments
- `mode`: the model configuration, one of [:tiny, :small, :base, :large, :huge, :giant, :gigantic]
- `imsize`: image size
- `inchannels`: number of input channels
- `patch_size`: size of the patches
- `pool`: pooling type, either :class or :mean
- `nclasses`: number of classes in the output
See also [`Metalhead.vit`](#).
"""
struct ViT
layers::Any
end
function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3,
patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000)
@assert mode in keys(vit_configs) "`mode` must be one of $(keys(vit_configs))"
kwargs = vit_configs[mode]
layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...)
return ViT(layers)
end
(m::ViT)(x) = m.layers(x)
backbone(m::ViT) = m.layers[1]
classifier(m::ViT) = m.layers[2]
@functor ViT