Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Add aten::roll support for Swin Transformer #9371

Merged
merged 9 commits into from
Oct 26, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Oct 26, 2021

Closes #9368

This PR adds missing converters necessary for importing the recent state of the art model Swin-Transoformer.

In particular, aten::roll is an interesting op to implement and generally useful op to have. It can be implemented via gather but the encoding is not obvious. Here are the references:
https://pytorch.org/docs/stable/generated/torch.roll.html
https://numpy.org/doc/stable/reference/generated/numpy.roll.html

please review @comaniac @jcf94 @junrushao1994

@Kyrie-Zhao aten::rand is not needed, it is used in dropout but if you trace the model with eval mode, dropout is gone.

Now the following script works with the error 1.4081733e-07.

import numpy as np
import tvm
import torch
from tvm import relay
from swin_transformer import SwinTransformer

net = SwinTransformer().eval()

img = torch.randn(1, 3, 224, 224)

scripted_model = torch.jit.trace(net, img).eval()
input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

with torch.no_grad():
    pt_result = net(img).numpy()

target = "llvm"

with tvm.transform.PassContext(opt_level=3):
    json, lib, params = relay.build(mod, target=target, params=params)

ctx = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.create(json, lib, ctx)
runtime.set_input(**params)
runtime.set_input("input0", img.numpy())
runtime.run()

tvm_result = runtime.get_output(0).asnumpy()

print(np.mean(np.abs(tvm_result - pt_result)))

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome! Thank you @masahi!

@junrushao junrushao merged commit 0df4edc into apache:main Oct 26, 2021
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 7, 2022
* add test

* first impl

* basic example working

* all test cases working

* support adaptive avg and max pool

* cleanup

* axes transpose logic fixed for roll

* pylint

* fixed roll dim indexing
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 13, 2022
* add test

* first impl

* basic example working

* all test cases working

* support adaptive avg and max pool

* cleanup

* axes transpose logic fixed for roll

* pylint

* fixed roll dim indexing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] A list of missing op conversion (Swin Transformer)
2 participants