Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swin Transformer V2 #1150

Merged
merged 11 commits into from
Mar 21, 2022
Merged

Swin Transformer V2 #1150

merged 11 commits into from
Mar 21, 2022

Conversation

ChristophReich1996
Copy link
Contributor

Hi, as discussed in #1147 here the pull request for the Swin Transformer V2. I changed the docstrings to the timm style, cleaned up the coder, and added create functions for all model sizes (tiny to giant). Happy to receive some feedback on the code and on what still needs to be fixed :)

@ChristophReich1996
Copy link
Contributor Author

Just fixed a small but silly bug regarding the classification head, however, I did not manage to get torch.jit.script to work, hope for some help here.

@rwightman
Copy link
Collaborator

@ChristophReich1996 thanks, I'll likely have a chance to look at this tomorrow (Monday). torchscript compatibility is always a pain.. I think I've gone through that hell more times than most :)

@rwightman
Copy link
Collaborator

rwightman commented Feb 23, 2022

@ChristophReich1996 as you may have gathered from tweet, preliminary testing of v2 is looking promising... I've trained to epoch 10-20 a few times with some diff variations on init and some other adjustments.

I've made some pretty extensive changes to reduce code volume a bit, and remove some things that didn't seem to be working too well or need some rethink (sequential attention impl and deformable included there)....

Torchscript wasn't a big problem, mostly removing some annotations that it barfed on and a few minor things.

I did some reformatting and brought some of the naming closer to original swin and other timm vit based models, while keeping some aspects and fleixiblity that's unique to this model. Also, I ended up making the shift/window code a bit closer to original swin, it's actually a bit of a performance hit to keep switching BCHW <-> BLC more frequently so I move to BLC for the whole transformer block stack and BCHW in/out of the stages now (still a bit different than swin v1)

I'm almost ready to add my changes to the PR, do I have permission to update your fork?

@ChristophReich1996
Copy link
Contributor Author

Hi @rwightman, yeah just saw the cool news on Twitter. Yes, you have permission for edits!
Interesting to see that omitting the switching BCHW <-> BLC improved performance, I thought it would only produce a minor overhead. Can you roughly say how much the speed improved (just curious ;)? I mainly used it because it was for me more intuitive to work with BCHW in some places.

* reformat and change some naming so closer to existing timm vision transformers
* remove typing that wasn't adding clarity (or causing torchscript issues)
* support non-square windows
* auto window size adjust from image size
* post-norm + main-branch no
@rwightman
Copy link
Collaborator

Hi @rwightman, yeah just saw the cool news on Twitter. Yes, you have permission for edits! Interesting to see that omitting the switching BCHW <-> BLC improved performance, I thought it would only produce a minor overhead. Can you roughly say how much the speed improved (just curious ;)? I mainly used it because it was for me more intuitive to work with BCHW in some places.

I think it was roughly ~10%? I've seen similar things simply using LayerNorm and doing the permute btw NCHW and NHWC for every instance, slows things down. One of the biggest slowdowns in the convnext architecture is due to that (it also messes with channels-last optimizations which isn't a factor here). As it stands right now the stages still use BCHW, but after downsample it's always BLC. Changing almost everything to BLC had minimal additional gain so I left it in between V1 and your V2 which should make it a bit more friendly for 2d feature extraction .

@ChristophReich1996
Copy link
Contributor Author

Good to know! I hadn't thought that this has such an impact on the performance...

@ChristophReich1996
Copy link
Contributor Author

ChristophReich1996 commented Mar 1, 2022

Hi @rwightman, I think I have found a solution for the "novel implementation of sequential self-attention computation" that the Swin V2 paper is stating, without any more details. In December last year, Google published an arXiv paper called Self-attention Does Not Need O(n^2) Memory. I'm pretty sure the Swin V2 sequential self-attention is very similar to this! Here is a toy example of both the (cosine) self-attention and the sequential (cosine) self-attention, with and without numerically stable softmax implementation.

import torch


def self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    # Compute attention matrix
    attention_matrix: torch.Tensor = torch.softmax(torch.outer(query, key), dim=-1)
    # Compute output
    output: torch.Tensor = attention_matrix @ value
    return output


def sequential_self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    # Init v
    v: torch.Tensor = torch.zeros(value.shape[0])
    for query_index in range(value.shape[0]):  # type: int
        # Reset s
        s: torch.Tensor = torch.zeros(1)
        for value_index in range(value.shape[0]):  # type: int
            s_: torch.Tensor = torch.exp(query[query_index] * key[value_index])
            s: torch.Tensor = s + s_
            v[query_index] = v[query_index] + s_ * value[value_index]
        v[query_index] = v[query_index] / s
    return v


def sequential_stable_self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    # Init v
    v: torch.Tensor = torch.zeros(value.shape[0])
    for query_index in range(value.shape[0]):  # type: int
        # Reset s and m
        s: torch.Tensor = torch.zeros(1)
        m: torch.Tensor = torch.tensor(float("-inf"))
        for value_index in range(value.shape[0]):  # type: int
            s_: torch.Tensor = query[query_index] * key[value_index]
            m_: torch.Tensor = torch.maximum(m, s_)
            s_: torch.Tensor = torch.exp(s_ - m_)
            m__: torch.Tensor = torch.exp(m - m_)
            s: torch.Tensor = s * m__ + s_
            v[query_index] = v[query_index] * m__ + s_ * value[value_index]
            m = m_
        v[query_index] = v[query_index] / s
    return v


def cosine_self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, tau: float = 1.) -> torch.Tensor:
    # Compute attention matrix
    outer_product: torch.Tensor = torch.outer(query, key)
    outer_product_normalized: torch.Tensor = outer_product / torch.outer(
        torch.norm(query.view(-1, 1), dim=-1, keepdim=False),
        torch.norm(key.view(-1, 1), dim=-1, keepdim=False))
    attention_matrix: torch.Tensor = torch.softmax(outer_product_normalized / tau, dim=-1)
    # Compute output
    output: torch.Tensor = attention_matrix @ value
    return output


def sequential_cosine_self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                                     tau: float = 1.) -> torch.Tensor:
    # Init v
    v: torch.Tensor = torch.zeros(query.shape[0])
    for query_index in range(query.shape[0]):  # type: int
        # Reset s
        s: torch.Tensor = torch.zeros(1)
        for value_index in range(query.shape[0]):  # type: int
            s_: torch.Tensor = query[query_index] * key[value_index]
            s_: torch.Tensor = s_ / (torch.norm(query[query_index]) * torch.norm(key[value_index]))
            s_: torch.Tensor = torch.exp(s_ / tau)
            s: torch.Tensor = s + s_
            v[query_index] = v[query_index] + s_ * value[value_index]
        v[query_index] = v[query_index] / s
    return v


def sequential_stable_cosine_self_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                                            tau: float = 1.) -> torch.Tensor:
    # Init v
    v: torch.Tensor = torch.zeros(value.shape[0])
    for query_index in range(value.shape[0]):  # type: int
        # Reset s and m
        s: torch.Tensor = torch.zeros(1)
        m: torch.Tensor = torch.tensor(float("-inf"))
        for value_index in range(value.shape[0]):  # type: int
            s_: torch.Tensor = query[query_index] * key[value_index]
            s_: torch.Tensor = s_ / (torch.norm(query[query_index]) * torch.norm(key[value_index]))
            s_: torch.Tensor = s_ / tau
            m_: torch.Tensor = torch.maximum(m, s_)
            s_: torch.Tensor = torch.exp(s_ - m_)
            m__: torch.Tensor = torch.exp(m - m_)
            s: torch.Tensor = s * m__ + s_
            v[query_index] = v[query_index] * m__ + s_ * value[value_index]
            m = m_
        v[query_index] = v[query_index] / s
    return v


if __name__ == '__main__':
    query = torch.randn(6)
    key = torch.randn(6)
    value = torch.randn(6)

    # Standard self-attention
    output = self_attention(query, key, value)
    output_sequential = sequential_self_attention(query, key, value)
    output_sequential_stable = sequential_stable_self_attention(query, key, value)
    print(output)
    print(output_sequential)
    print(output_sequential_stable)

    # Cosine self-attention
    output_cosine = cosine_self_attention(query, key, value)
    output_cosine_sequential = sequential_cosine_self_attention(query, key, value)
    output_cosine_sequential_stable = sequential_stable_cosine_self_attention(query, key, value)
    print(output_cosine)
    print(output_cosine_sequential)
    print(output_cosine_sequential_stable)

The "original sequential attention" paper suggests an implementation where parallel computation is balanced against memory complexity. A sophisticated PyTorch implementation of the sequential attention is available here. Do you think adopting this implementation to match the scaled cosine attention of the Swin V2 would add some value to this implementation? I could try a bit to get it to work...

@rwightman
Copy link
Collaborator

@ChristophReich1996 sorry didn't get back to this sooner, it's been a bit of a distracting week :(

When I was searching for possible solutions to this I did run across that paper, did a quick scan and looked like it might be along the same lines as what was being discussed in swin v2, but hadn't dug in, so yeah, I think there is something there.

Right now I'm setting up for a full train run 1k and 21k (I did some shorter 1k runs that are looking okay but need to push higher).

Have you tested your impl to see if the memory savings are there with 'reasonable' slowdowns?

@ChristophReich1996
Copy link
Contributor Author

Hi @rwightman, my implementation is only the naive approach to get the idea of the algorithm. However, the amazing lucidrains has already a sophisticated PyTorch implementation memory-efficient-attention-pytorch. The repo also includes a sequential version of the cosine-attention but no benchmarks are currently available. There is, however, also another similar implementation (but no cosine-attention) reporting some runtimes linear_mem_attention_pytorch.

I maybe find some time to benchmark the current non-sequential cosine-attention vs. the sequential version of lucidrains.

@rwightman
Copy link
Collaborator

@ChristophReich1996 I've got some decent results for tiny so far, but still doing some comparison runs. I've managed 81.65 for tiny @224 and also a 81.6 for different init.

However, the ability of the trained network to scale well to different resolution appears to diminish by the end of training. It almost seems that the position MLP overfits to the original sizes. Part way through training (when the overall accuracy isn't great, the relative boost for increasing the resolution is better, it's a net increase vs a net decrease at the end of training curve). I'm currently doing a training run w/ dropout in the pos MLP.

All that said, I'm merging this now prior to a bigger merge I'm doing...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants