Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Position aware circular convolution #6288

Open
hkzhang-git opened this issue Jul 20, 2022 · 5 comments
Open

Position aware circular convolution #6288

hkzhang-git opened this issue Jul 20, 2022 · 5 comments

Comments

@hkzhang-git
Copy link

🚀 The feature

Recently, we propose a new basic operation position aware global circular convolution (ParC). Differing from previous convolution operations, the proposed ParC has global receptive field. Experimental results show that ParC uniformly
improves performance of various typical models.

This work has been accpted by ECCV 2022, we hope the proposed ParC can be used by other researchers. Please refer to refer to https://arxiv.org/abs/2203.03952 to find more details.

Motivation, pitch

We are suggested to reach out to torchvision when we post this issue in pytorch.

pytorch/pytorch#80932 (comment)

Alternatives

We will prepare a CUDA optimized version ParC following tutorial in https://pytorch.org/tutorials/advanced/cpp_extension.html.

We also want to propose a pull request to merge our code into torchvision. Will such a pull request be accpted ? If it is acceptable, is there anyone can provide us a demo of proposing this kind of pull request?

Additional context

No response

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jul 20, 2022

@hkzhang91 thanks for the feature request !

I wonder if you have an implementation of ParC convolution that you could share here to inspect ? => found it on arxiv: https://github.com/hkzhang91/ParC-Net but haven't yet spotted the convolution implementation

As for integrating new ops into torchvision, it can depend how useful it is and citations of the original paper. Looks like your paper is very recent and has only 1 citation. This can be discussed and decided with other project leaders.

For contributions you can check https://github.com/pytorch/vision/blob/main/CONTRIBUTING.md

@hkzhang-git
Copy link
Author

hkzhang-git commented Jul 21, 2022

class ParC(nn.Module):
    def __init__(self, dim, type, global_kernel_size, use_pe=True):
        super().__init__()
        self.type = type  # H or W
        self.dim = dim
        self.use_pe = use_pe
        self.global_kernel_size = global_kernel_size
        self.kernel_size = (global_kernel_size, 1) if self.type == 'H' else (1, global_kernel_size)
        self.gcc_conv = nn.Conv2d(dim, dim, kernel_size=self.kernel_size, groups=dim)
        if use_pe:
            if self.type=='H':
                self.pe = nn.Parameter(torch.randn(1, dim, self.global_kernel_size, 1))                    
            elif self.type=='W':
                self.pe = nn.Parameter(torch.randn(1, dim, 1, self.global_kernel_size))
            trunc_normal_(self.pe, std=.02)

    def forward(self, x):
        if self.use_pe:
            x = x + self.pe.expand(1, self.dim, self.global_kernel_size, self.global_kernel_size)

        x_cat = torch.cat((x, x[:, :, :-1, :]), dim=2) if self.type == 'H' else torch.cat((x, x[:, :, :, :-1]), dim=3)
        x = self.gcc_conv(x_cat)

        return x

Code presented above is the implementation of ParC in Pytorch. We can see that this implementation is really simpe.
Experimental results in our paper show that ParC universally improves performances of typical light weight models.

This paper is published very recently. If our proposed ParC can be integrated into torchvision as an new ops, we will provide cuda optimized verison of ParC for high efficient inference. Our ParC is one kind of Large Kernel convolution, which is not supported very well(RepLKNet, we are working on this problem now.

@datumbox
Copy link
Contributor

@hkzhang91 Thanks a lot for pinging us. We definitely welcome contributions from research teams; it's something we would like to get more on the future.

I had a quick look on the paper and the works looks interesting. As @vfdev-5 said, in order to add a new operator in TorchVision it has to have strong adoption in research (this is to avoid bloating the library). Given this was just released, I recommend waiting for a couple of months to see how it is received by the research community. If it picks up steam, we can discuss the details on how to structure the PR to minimize the back and forth. Does this make sense to you?

@hkzhang-git
Copy link
Author

@datumbox Thanks for your reply. It sounds reasonable.

@CAM-FSS
Copy link

CAM-FSS commented Jan 12, 2023

position embedding mentioned in ParC-Net The original text is "peV is instance position embedding (PE) and it is generated from a base embedding epeV“,Is there any kind person to help me explain this C × B × 1 ,How did come from?thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants