Skip to content

Commit

Permalink
[Torch] Add aten::roll support for Swin Transformer (apache#9371)
Browse files Browse the repository at this point in the history
* add test

* first impl

* basic example working

* all test cases working

* support adaptive avg and max pool

* cleanup

* axes transpose logic fixed for roll

* pylint

* fixed roll dim indexing
  • Loading branch information
masahi authored and ylc committed Jan 7, 2022
1 parent a9da24e commit 210b4c9
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 21 deletions.
78 changes: 58 additions & 20 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,35 +849,23 @@ def hard_swish(self, inputs, input_types):
data = inputs[0]
return data * self.hard_sigmoid(inputs, input_types)

def adaptive_avg_pool_2d(self, inputs, input_types):
def adaptive_avg_pool(self, op, inputs, input_types):
data = inputs[0]
output_size = inputs[1]

def func(x):
return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
return op(x, output_size=output_size)

if self.is_quantized_tensor(data):
return qnn_torch.apply_with_upcast(data, func)

return func(data)

def adaptive_max_pool_2d(self, inputs, input_types):
def adaptive_max_pool(self, op, inputs, input_types):
data = inputs[0]
output_size = inputs[1]

# returns dummy indices too
return _op.nn.adaptive_max_pool2d(data, output_size=output_size), None

def adaptive_max_pool_3d(self, inputs, input_types):
data = inputs[0]
output_size = inputs[1]
# returns dummy indices too
return _op.nn.adaptive_max_pool3d(data, output_size=output_size), None

def adaptive_avg_pool_3d(self, inputs, input_types):
data = inputs[0]
output_size = inputs[1]
return _op.nn.adaptive_avg_pool3d(data, output_size=output_size)
return op(data, output_size=output_size), None

@staticmethod
def convert_const_list(data):
Expand Down Expand Up @@ -2794,6 +2782,39 @@ def searchsorted(self, inputs, input_types):
def bucketize(self, inputs, input_types):
return self.searchsorted_common(inputs[1], inputs[0], inputs[2], inputs[3])

def roll(self, inputs, input_types):
def slide_axes(inp, shape, ax):
axes = list(range(len(shape)))
axes = axes[:ax] + [-1] + axes[ax:-1]
return _op.transpose(inp, axes)

x = inputs[0]
shifts = inputs[1]
dims = inputs[2]
shape = self.infer_shape(x)
start = _expr.const(0, "int64")
step = _expr.const(1, "int64")

out = x
for i, dim in enumerate(dims):
roll_dim = _expr.const(shape[dim], "int64")
indices_1d = _op.mod(
_op.transform.arange(start, roll_dim, step, "int64")
- _expr.const(shifts[i], "int64")
+ roll_dim,
roll_dim,
)
# First fill in the last axis with roll indices, and then do transpose to
# bring the roll indices into the desired axis.
indices = slide_axes(
_op.tile(indices_1d, shape[:dim] + shape[dim + 1 :] + (1,)),
shape,
dim,
)
out = _op.gather(out, dim, indices)

return out

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -2851,9 +2872,26 @@ def create_convert_map(self):
"aten::gelu": self.gelu,
"aten::selu": self.selu,
"aten::silu": self.silu,
"aten::silu_": self.silu,
"aten::log_sigmoid": self.log_sigmoid,
"aten::adaptive_avg_pool2d": self.adaptive_avg_pool_2d,
"aten::adaptive_max_pool2d": self.adaptive_max_pool_2d,
"aten::adaptive_avg_pool1d": functools.partial(
self.adaptive_avg_pool, _op.nn.adaptive_avg_pool1d
),
"aten::adaptive_avg_pool2d": functools.partial(
self.adaptive_avg_pool, _op.nn.adaptive_avg_pool2d
),
"aten::adaptive_avg_pool3d": functools.partial(
self.adaptive_avg_pool, _op.nn.adaptive_avg_pool3d
),
"aten::adaptive_max_pool1d": functools.partial(
self.adaptive_max_pool, _op.nn.adaptive_max_pool1d
),
"aten::adaptive_max_pool2d": functools.partial(
self.adaptive_max_pool, _op.nn.adaptive_max_pool2d
),
"aten::adaptive_max_pool3d": functools.partial(
self.adaptive_max_pool, _op.nn.adaptive_max_pool3d
),
"aten::max_pool2d": self.maxpool_2d,
"aten::max_pool2d_with_indices": self.maxpool_2d_with_indices,
"aten::max_pool1d": self.maxpool_1d,
Expand Down Expand Up @@ -2939,6 +2977,7 @@ def create_convert_map(self):
"aten::rsqrt": self.make_unary("rsqrt"),
"aten::ceil": self.make_unary("ceil"),
"aten::floor": self.make_unary("floor"),
"aten::floor_": self.make_unary("floor"),
"aten::round": self.make_unary("round"),
"aten::isfinite": self.make_unary("isfinite"),
"aten::isinf": self.make_unary("isinf"),
Expand All @@ -2964,8 +3003,6 @@ def create_convert_map(self):
"aten::bitwise_xor": self.bitwise_xor,
"aten::Bool": self.Bool,
"aten::Float": self.Float,
"aten::adaptive_avg_pool3d": self.adaptive_avg_pool_3d,
"aten::adaptive_max_pool3d": self.adaptive_max_pool_3d,
"aten::rsub": self.rsub,
"aten::embedding": self.embedding,
"aten::one_hot": self.one_hot,
Expand Down Expand Up @@ -3021,6 +3058,7 @@ def create_convert_map(self):
"aten::any": functools.partial(self.all_any_common, _op.any),
"aten::searchsorted": self.searchsorted,
"aten::bucketize": self.bucketize,
"aten::roll": self.roll,
}

def update_convert_map(self, custom_map):
Expand Down
30 changes: 29 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,13 +735,30 @@ def test_forward_log_sigmoid():


@tvm.testing.uses_gpu
def test_forward_adaptiveavgpool():
def test_forward_adaptive_avgpool():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AdaptiveAvgPool2d([1, 1]).eval(), input_data=input_data)
verify_model(torch.nn.AdaptiveAvgPool2d([10, 10]).eval(), input_data=input_data)

input_data = torch.rand([1, 3, 10]).float()
verify_model(torch.nn.AdaptiveAvgPool1d([1]).eval(), input_data=input_data)
verify_model(torch.nn.AdaptiveAvgPool1d([5]).eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_adaptive_maxpool():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AdaptiveMaxPool2d([1, 1]).eval(), input_data=input_data)
verify_model(torch.nn.AdaptiveMaxPool2d([10, 10]).eval(), input_data=input_data)

input_data = torch.rand([1, 3, 10]).float()
verify_model(torch.nn.AdaptiveMaxPool1d([1]).eval(), input_data=input_data)
verify_model(torch.nn.AdaptiveMaxPool1d([5]).eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_maxpool2d():
Expand Down Expand Up @@ -3992,5 +4009,16 @@ def test_fn(out_int32=False, right=False):
verify_model(test_fn(out_int32=True, right=True), [values, boundaries])


@tvm.testing.uses_gpu
def test_roll():
def test_fn(shifts, dims):
return lambda x: torch.roll(x, shifts, dims)

x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
verify_model(test_fn(1, 0), [x])
verify_model(test_fn(-1, 0), [x])
verify_model(test_fn(shifts=(2, 1), dims=(0, 1)), [x])


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 210b4c9

Please sign in to comment.