Skip to content

Commit

Permalink
[torch.onnx] support torch.nn.functional.grid_sample
Browse files Browse the repository at this point in the history
summary

- Adds `F.grid_sample` support
- Adds a test case

Fixes pytorch#27212
Pull Request resolved: pytorch#76159
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
  • Loading branch information
crcrpar authored and pytorchmergebot committed May 2, 2022
1 parent e14f533 commit 0ae3aa6
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 4 deletions.
2 changes: 1 addition & 1 deletion scripts/onnx/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ fi

if [[ "${SHARD_NUMBER}" == "2" ]]; then
# Update the loop for new opsets
for i in $(seq 10 15); do
for i in $(seq 10 16); do
pytest "${args[@]}" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i"
done
Expand Down
25 changes: 25 additions & 0 deletions test/onnx/test_onnx_opset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import onnx

import io
import itertools

from torch.onnx.symbolic_helper import _export_onnx_opset_version
from torch.onnx import producer_name, producer_version
Expand Down Expand Up @@ -369,6 +370,30 @@ def forward(self, x):
x = torch.randn(20, 16, 50)
check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10])

def test_grid_sample(self):
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
ops = {16: [{"op_name": "GridSample"}]}

class MyModule(Module):
def forward(self, x, grid, mode, padding_mode, align_corers):
return torch.nn.functional.grid_sample(x, grid, mode, padding_mode, align_corners)

for mode, padding_mode, align_corners in itertools.product(
("bilinear", "nearest", "bicubic"),
("zeros", "border", "reflection"),
(True, False),
):

args = (
torch.randn(n, c, h_in, w_in), # x
torch.randn(n, h_out, w_out, 2), # grid,
mode,
padding_mode,
align_corners,
)
check_onnx_opsets_operator(MyModule(), args, ops, opset_versions=[16], training=torch.onnx.TrainingMode.TRAINING)
check_onnx_opsets_operator(MyModule(), args, ops, opset_versions=[16], training=torch.onnx.TrainingMode.EVAL)


if __name__ == "__main__":
run_tests()
32 changes: 32 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9183,13 +9183,15 @@ def forward(self, boxes, size):
dynamic_axes={"size": [0, 1]},
test_with_inputs=[(boxes, size), (boxes, size_2)])

@skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch
@skipIfUnsupportedMinOpsetVersion(11)
def test_roi_align(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1., 2)
self.run_test(model, (x, single_roi))

@skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch
@skipIfUnsupportedMinOpsetVersion(11)
def test_roi_align_aligned(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
Expand Down Expand Up @@ -9295,6 +9297,7 @@ def forward(self, images, features: Dict[str, torch.Tensor]):
test_with_inputs=[(images, features), (images2, test_features)],
dict_check=False)

@skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_multi_scale_roi_align(self):
Expand Down Expand Up @@ -10986,6 +10989,33 @@ def forward(self, x):
self.run_test(Module(False), x, rtol=1e-3, atol=1e-6)
self.run_test(Module(True), x, rtol=1e-3, atol=1e-6)

@skipIfUnsupportedMinOpsetVersion(16)
def test_grid_sample(self):
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4

class GridSampleModule(torch.nn.Module):

def __init__(self, mode, padding_mode, align_corners) -> None:
super().__init__()
self.mode, self.padding_mode, self.align_corners = mode, padding_mode, align_corners

def forward(self, input, grid):
return torch.nn.functional.grid_sample(input, grid, self.mode, self.padding_mode, self.align_corners)

for mode, padding_mode, align_corners in itertools.product(
("bilinear", "nearest", "bicubic"),
("zeros", "border", "reflection"),
(True, False),
):
atol_rtol = {}
if (mode, padding_mode) == ("bicubic", "border"):
if align_corners:
atol_rtol.update({"atol": 0.3, "rtol": 0.4})
else:
atol_rtol.update({"atol": 0.02, "rtol": 0.02})
input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2)
self.run_test(GridSampleModule(mode, padding_mode, align_corners), (input, grid), **atol_rtol)


def make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout, script_test_min_opset_version,
Expand Down Expand Up @@ -11123,6 +11153,8 @@ def MakeTestCase(opset_version: int, keep_initializers_as_inputs: bool = True) -

TestONNXRuntime_opset15 = MakeTestCase(15, keep_initializers_as_inputs=False)

TestONNXRuntime_opset16 = MakeTestCase(16, keep_initializers_as_inputs=False)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions torch/csrc/jit/passes/onnx/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ static const int OPSET_VERSION_12 = 12;
static const int OPSET_VERSION_13 = 13;
static const int OPSET_VERSION_14 = 14;
static const int OPSET_VERSION_15 = 15;
static const int OPSET_VERSION_16 = 16;

using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>;

Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/serialization/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ namespace onnx = ::ONNX_NAMESPACE;
const static int kInvalidOpsetVersion = -1;
// Based on OP_SET_ID_VERSION_MAP in
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
constexpr static std::array<int64_t, 16> kOpsetVersionToIRVersion = {
constexpr static std::array<int64_t, 17> kOpsetVersionToIRVersion = {
kInvalidOpsetVersion,
3,
kInvalidOpsetVersion,
Expand All @@ -75,6 +75,7 @@ constexpr static std::array<int64_t, 16> kOpsetVersionToIRVersion = {
7,
7,
7,
8,
8};

std::string getNodeStackTraceString(const Node* n) {
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
opset_version (int, default 13): The version of the
`default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
to target. Must be >= 7 and <= 15.
to target. Must be >= 7 and <= 16.
do_constant_folding (bool, default True): Apply the constant-folding optimization.
Constant-folding will replace some of the ops that have all constant inputs
with pre-computed constant nodes.
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ def args_have_same_dtype(args):
return has_same_dtype

_default_onnx_opset_version = 13
_onnx_main_opset = 15
_onnx_main_opset = 16
_onnx_stable_opsets = list(range(7, _onnx_main_opset))
_export_onnx_opset_version = _default_onnx_opset_version
_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1))
Expand Down
46 changes: 46 additions & 0 deletions torch/onnx/symbolic_opset16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py

# This file exports ONNX ops for opset 16

# Note [ONNX Operators that are added/updated in opset 16]
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
# New operators:
# GridSample https://github.com/onnx/onnx/pull/3557
#
# Updated operators:
# Identity
# If
# LeakyRelu
# Loop
# PRelu
# RoiAlign
# Scan
# ScatterElemenets
# ScatterND
# Where
# GreaterOrEqual
# LessOrEqual
# SequenceMap

from torch.onnx.symbolic_helper import parse_args

from torch.nn.functional import GRID_SAMPLE_INTERPOLATION_MODES, GRID_SAMPLE_PADDING_MODES


# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
@parse_args("v", "v", "i", "i", "b")
def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
return g.op(
"GridSample",
input,
grid,
align_corners_i=int(align_corners),
mode_s=mode_s,
padding_mode_s=padding_mode_s,
)

0 comments on commit 0ae3aa6

Please sign in to comment.