-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Swin Transformer V2 #1150
Swin Transformer V2 #1150
Conversation
… and change docstring style to match timm
Just fixed a small but silly bug regarding the classification head, however, I did not manage to get |
@ChristophReich1996 thanks, I'll likely have a chance to look at this tomorrow (Monday). torchscript compatibility is always a pain.. I think I've gone through that hell more times than most :) |
@ChristophReich1996 as you may have gathered from tweet, preliminary testing of v2 is looking promising... I've trained to epoch 10-20 a few times with some diff variations on init and some other adjustments. I've made some pretty extensive changes to reduce code volume a bit, and remove some things that didn't seem to be working too well or need some rethink (sequential attention impl and deformable included there).... Torchscript wasn't a big problem, mostly removing some annotations that it barfed on and a few minor things. I did some reformatting and brought some of the naming closer to original swin and other timm vit based models, while keeping some aspects and fleixiblity that's unique to this model. Also, I ended up making the shift/window code a bit closer to original swin, it's actually a bit of a performance hit to keep switching BCHW <-> BLC more frequently so I move to BLC for the whole transformer block stack and BCHW in/out of the stages now (still a bit different than swin v1) I'm almost ready to add my changes to the PR, do I have permission to update your fork? |
Hi @rwightman, yeah just saw the cool news on Twitter. Yes, you have permission for edits! |
* reformat and change some naming so closer to existing timm vision transformers * remove typing that wasn't adding clarity (or causing torchscript issues) * support non-square windows * auto window size adjust from image size * post-norm + main-branch no
I think it was roughly ~10%? I've seen similar things simply using LayerNorm and doing the permute btw NCHW and NHWC for every instance, slows things down. One of the biggest slowdowns in the convnext architecture is due to that (it also messes with channels-last optimizations which isn't a factor here). As it stands right now the stages still use BCHW, but after downsample it's always BLC. Changing almost everything to BLC had minimal additional gain so I left it in between V1 and your V2 which should make it a bit more friendly for 2d feature extraction . |
…lusions for swin v2
Good to know! I hadn't thought that this has such an impact on the performance... |
Hi @rwightman, I think I have found a solution for the "novel implementation of sequential self-attention computation" that the Swin V2 paper is stating, without any more details. In December last year, Google published an arXiv paper called Self-attention Does Not Need O(n^2) Memory. I'm pretty sure the Swin V2 sequential self-attention is very similar to this! Here is a toy example of both the (cosine) self-attention and the sequential (cosine) self-attention, with and without numerically stable softmax implementation. import torch
def self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
# Compute attention matrix
attention_matrix: torch.Tensor = torch.softmax(torch.outer(query, key), dim=-1)
# Compute output
output: torch.Tensor = attention_matrix @ value
return output
def sequential_self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
# Init v
v: torch.Tensor = torch.zeros(value.shape[0])
for query_index in range(value.shape[0]): # type: int
# Reset s
s: torch.Tensor = torch.zeros(1)
for value_index in range(value.shape[0]): # type: int
s_: torch.Tensor = torch.exp(query[query_index] * key[value_index])
s: torch.Tensor = s + s_
v[query_index] = v[query_index] + s_ * value[value_index]
v[query_index] = v[query_index] / s
return v
def sequential_stable_self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
# Init v
v: torch.Tensor = torch.zeros(value.shape[0])
for query_index in range(value.shape[0]): # type: int
# Reset s and m
s: torch.Tensor = torch.zeros(1)
m: torch.Tensor = torch.tensor(float("-inf"))
for value_index in range(value.shape[0]): # type: int
s_: torch.Tensor = query[query_index] * key[value_index]
m_: torch.Tensor = torch.maximum(m, s_)
s_: torch.Tensor = torch.exp(s_ - m_)
m__: torch.Tensor = torch.exp(m - m_)
s: torch.Tensor = s * m__ + s_
v[query_index] = v[query_index] * m__ + s_ * value[value_index]
m = m_
v[query_index] = v[query_index] / s
return v
def cosine_self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, tau: float = 1.) -> torch.Tensor:
# Compute attention matrix
outer_product: torch.Tensor = torch.outer(query, key)
outer_product_normalized: torch.Tensor = outer_product / torch.outer(
torch.norm(query.view(-1, 1), dim=-1, keepdim=False),
torch.norm(key.view(-1, 1), dim=-1, keepdim=False))
attention_matrix: torch.Tensor = torch.softmax(outer_product_normalized / tau, dim=-1)
# Compute output
output: torch.Tensor = attention_matrix @ value
return output
def sequential_cosine_self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
tau: float = 1.) -> torch.Tensor:
# Init v
v: torch.Tensor = torch.zeros(query.shape[0])
for query_index in range(query.shape[0]): # type: int
# Reset s
s: torch.Tensor = torch.zeros(1)
for value_index in range(query.shape[0]): # type: int
s_: torch.Tensor = query[query_index] * key[value_index]
s_: torch.Tensor = s_ / (torch.norm(query[query_index]) * torch.norm(key[value_index]))
s_: torch.Tensor = torch.exp(s_ / tau)
s: torch.Tensor = s + s_
v[query_index] = v[query_index] + s_ * value[value_index]
v[query_index] = v[query_index] / s
return v
def sequential_stable_cosine_self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
tau: float = 1.) -> torch.Tensor:
# Init v
v: torch.Tensor = torch.zeros(value.shape[0])
for query_index in range(value.shape[0]): # type: int
# Reset s and m
s: torch.Tensor = torch.zeros(1)
m: torch.Tensor = torch.tensor(float("-inf"))
for value_index in range(value.shape[0]): # type: int
s_: torch.Tensor = query[query_index] * key[value_index]
s_: torch.Tensor = s_ / (torch.norm(query[query_index]) * torch.norm(key[value_index]))
s_: torch.Tensor = s_ / tau
m_: torch.Tensor = torch.maximum(m, s_)
s_: torch.Tensor = torch.exp(s_ - m_)
m__: torch.Tensor = torch.exp(m - m_)
s: torch.Tensor = s * m__ + s_
v[query_index] = v[query_index] * m__ + s_ * value[value_index]
m = m_
v[query_index] = v[query_index] / s
return v
if __name__ == '__main__':
query = torch.randn(6)
key = torch.randn(6)
value = torch.randn(6)
# Standard self-attention
output = self_attention(query, key, value)
output_sequential = sequential_self_attention(query, key, value)
output_sequential_stable = sequential_stable_self_attention(query, key, value)
print(output)
print(output_sequential)
print(output_sequential_stable)
# Cosine self-attention
output_cosine = cosine_self_attention(query, key, value)
output_cosine_sequential = sequential_cosine_self_attention(query, key, value)
output_cosine_sequential_stable = sequential_stable_cosine_self_attention(query, key, value)
print(output_cosine)
print(output_cosine_sequential)
print(output_cosine_sequential_stable) The "original sequential attention" paper suggests an implementation where parallel computation is balanced against memory complexity. A sophisticated PyTorch implementation of the sequential attention is available here. Do you think adopting this implementation to match the scaled cosine attention of the Swin V2 would add some value to this implementation? I could try a bit to get it to work... |
@ChristophReich1996 sorry didn't get back to this sooner, it's been a bit of a distracting week :( When I was searching for possible solutions to this I did run across that paper, did a quick scan and looked like it might be along the same lines as what was being discussed in swin v2, but hadn't dug in, so yeah, I think there is something there. Right now I'm setting up for a full train run 1k and 21k (I did some shorter 1k runs that are looking okay but need to push higher). Have you tested your impl to see if the memory savings are there with 'reasonable' slowdowns? |
Hi @rwightman, my implementation is only the naive approach to get the idea of the algorithm. However, the amazing lucidrains has already a sophisticated PyTorch implementation memory-efficient-attention-pytorch. The repo also includes a sequential version of the cosine-attention but no benchmarks are currently available. There is, however, also another similar implementation (but no cosine-attention) reporting some runtimes linear_mem_attention_pytorch. I maybe find some time to benchmark the current non-sequential cosine-attention vs. the sequential version of lucidrains. |
@ChristophReich1996 I've got some decent results for tiny so far, but still doing some comparison runs. I've managed 81.65 for tiny @224 and also a 81.6 for different init. However, the ability of the trained network to scale well to different resolution appears to diminish by the end of training. It almost seems that the position MLP overfits to the original sizes. Part way through training (when the overall accuracy isn't great, the relative boost for increasing the resolution is better, it's a net increase vs a net decrease at the end of training curve). I'm currently doing a training run w/ dropout in the pos MLP. All that said, I'm merging this now prior to a bigger merge I'm doing... |
Hi, as discussed in #1147 here the pull request for the Swin Transformer V2. I changed the docstrings to the
timm
style, cleaned up the coder, and added create functions for all model sizes (tiny to giant). Happy to receive some feedback on the code and on what still needs to be fixed :)