From 210b4c9dfe582ee7dcae5c795db88465cafa888d Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 27 Oct 2021 02:50:37 +0900 Subject: [PATCH] [Torch] Add aten::roll support for Swin Transformer (#9371) * add test * first impl * basic example working * all test cases working * support adaptive avg and max pool * cleanup * axes transpose logic fixed for roll * pylint * fixed roll dim indexing --- python/tvm/relay/frontend/pytorch.py | 78 ++++++++++++++----- tests/python/frontend/pytorch/test_forward.py | 30 ++++++- 2 files changed, 87 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3fc202a7cc91..13704ff7aad9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -849,35 +849,23 @@ def hard_swish(self, inputs, input_types): data = inputs[0] return data * self.hard_sigmoid(inputs, input_types) - def adaptive_avg_pool_2d(self, inputs, input_types): + def adaptive_avg_pool(self, op, inputs, input_types): data = inputs[0] output_size = inputs[1] def func(x): - return _op.nn.adaptive_avg_pool2d(x, output_size=output_size) + return op(x, output_size=output_size) if self.is_quantized_tensor(data): return qnn_torch.apply_with_upcast(data, func) return func(data) - def adaptive_max_pool_2d(self, inputs, input_types): + def adaptive_max_pool(self, op, inputs, input_types): data = inputs[0] output_size = inputs[1] - # returns dummy indices too - return _op.nn.adaptive_max_pool2d(data, output_size=output_size), None - - def adaptive_max_pool_3d(self, inputs, input_types): - data = inputs[0] - output_size = inputs[1] - # returns dummy indices too - return _op.nn.adaptive_max_pool3d(data, output_size=output_size), None - - def adaptive_avg_pool_3d(self, inputs, input_types): - data = inputs[0] - output_size = inputs[1] - return _op.nn.adaptive_avg_pool3d(data, output_size=output_size) + return op(data, output_size=output_size), None @staticmethod def convert_const_list(data): @@ -2794,6 +2782,39 @@ def searchsorted(self, inputs, input_types): def bucketize(self, inputs, input_types): return self.searchsorted_common(inputs[1], inputs[0], inputs[2], inputs[3]) + def roll(self, inputs, input_types): + def slide_axes(inp, shape, ax): + axes = list(range(len(shape))) + axes = axes[:ax] + [-1] + axes[ax:-1] + return _op.transpose(inp, axes) + + x = inputs[0] + shifts = inputs[1] + dims = inputs[2] + shape = self.infer_shape(x) + start = _expr.const(0, "int64") + step = _expr.const(1, "int64") + + out = x + for i, dim in enumerate(dims): + roll_dim = _expr.const(shape[dim], "int64") + indices_1d = _op.mod( + _op.transform.arange(start, roll_dim, step, "int64") + - _expr.const(shifts[i], "int64") + + roll_dim, + roll_dim, + ) + # First fill in the last axis with roll indices, and then do transpose to + # bring the roll indices into the desired axis. + indices = slide_axes( + _op.tile(indices_1d, shape[:dim] + shape[dim + 1 :] + (1,)), + shape, + dim, + ) + out = _op.gather(out, dim, indices) + + return out + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2851,9 +2872,26 @@ def create_convert_map(self): "aten::gelu": self.gelu, "aten::selu": self.selu, "aten::silu": self.silu, + "aten::silu_": self.silu, "aten::log_sigmoid": self.log_sigmoid, - "aten::adaptive_avg_pool2d": self.adaptive_avg_pool_2d, - "aten::adaptive_max_pool2d": self.adaptive_max_pool_2d, + "aten::adaptive_avg_pool1d": functools.partial( + self.adaptive_avg_pool, _op.nn.adaptive_avg_pool1d + ), + "aten::adaptive_avg_pool2d": functools.partial( + self.adaptive_avg_pool, _op.nn.adaptive_avg_pool2d + ), + "aten::adaptive_avg_pool3d": functools.partial( + self.adaptive_avg_pool, _op.nn.adaptive_avg_pool3d + ), + "aten::adaptive_max_pool1d": functools.partial( + self.adaptive_max_pool, _op.nn.adaptive_max_pool1d + ), + "aten::adaptive_max_pool2d": functools.partial( + self.adaptive_max_pool, _op.nn.adaptive_max_pool2d + ), + "aten::adaptive_max_pool3d": functools.partial( + self.adaptive_max_pool, _op.nn.adaptive_max_pool3d + ), "aten::max_pool2d": self.maxpool_2d, "aten::max_pool2d_with_indices": self.maxpool_2d_with_indices, "aten::max_pool1d": self.maxpool_1d, @@ -2939,6 +2977,7 @@ def create_convert_map(self): "aten::rsqrt": self.make_unary("rsqrt"), "aten::ceil": self.make_unary("ceil"), "aten::floor": self.make_unary("floor"), + "aten::floor_": self.make_unary("floor"), "aten::round": self.make_unary("round"), "aten::isfinite": self.make_unary("isfinite"), "aten::isinf": self.make_unary("isinf"), @@ -2964,8 +3003,6 @@ def create_convert_map(self): "aten::bitwise_xor": self.bitwise_xor, "aten::Bool": self.Bool, "aten::Float": self.Float, - "aten::adaptive_avg_pool3d": self.adaptive_avg_pool_3d, - "aten::adaptive_max_pool3d": self.adaptive_max_pool_3d, "aten::rsub": self.rsub, "aten::embedding": self.embedding, "aten::one_hot": self.one_hot, @@ -3021,6 +3058,7 @@ def create_convert_map(self): "aten::any": functools.partial(self.all_any_common, _op.any), "aten::searchsorted": self.searchsorted, "aten::bucketize": self.bucketize, + "aten::roll": self.roll, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 0031f4143fab..5057f0d2b6b8 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -735,13 +735,30 @@ def test_forward_log_sigmoid(): @tvm.testing.uses_gpu -def test_forward_adaptiveavgpool(): +def test_forward_adaptive_avgpool(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.AdaptiveAvgPool2d([1, 1]).eval(), input_data=input_data) verify_model(torch.nn.AdaptiveAvgPool2d([10, 10]).eval(), input_data=input_data) + input_data = torch.rand([1, 3, 10]).float() + verify_model(torch.nn.AdaptiveAvgPool1d([1]).eval(), input_data=input_data) + verify_model(torch.nn.AdaptiveAvgPool1d([5]).eval(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_adaptive_maxpool(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.AdaptiveMaxPool2d([1, 1]).eval(), input_data=input_data) + verify_model(torch.nn.AdaptiveMaxPool2d([10, 10]).eval(), input_data=input_data) + + input_data = torch.rand([1, 3, 10]).float() + verify_model(torch.nn.AdaptiveMaxPool1d([1]).eval(), input_data=input_data) + verify_model(torch.nn.AdaptiveMaxPool1d([5]).eval(), input_data=input_data) + @tvm.testing.uses_gpu def test_forward_maxpool2d(): @@ -3992,5 +4009,16 @@ def test_fn(out_int32=False, right=False): verify_model(test_fn(out_int32=True, right=True), [values, boundaries]) +@tvm.testing.uses_gpu +def test_roll(): + def test_fn(shifts, dims): + return lambda x: torch.roll(x, shifts, dims) + + x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + verify_model(test_fn(1, 0), [x]) + verify_model(test_fn(-1, 0), [x]) + verify_model(test_fn(shifts=(2, 1), dims=(0, 1)), [x]) + + if __name__ == "__main__": pytest.main([__file__])