forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[torch.onnx] support
torch.nn.functional.grid_sample
summary - Adds `F.grid_sample` support - Adds a test case Fixes pytorch#27212 Pull Request resolved: pytorch#76159 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
- Loading branch information
1 parent
e14f533
commit 0ae3aa6
Showing
8 changed files
with
109 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# EDITING THIS FILE? READ THIS FIRST! | ||
# see Note [Edit Symbolic Files] in symbolic_helper.py | ||
|
||
# This file exports ONNX ops for opset 16 | ||
|
||
# Note [ONNX Operators that are added/updated in opset 16] | ||
# | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set | ||
# New operators: | ||
# GridSample https://github.com/onnx/onnx/pull/3557 | ||
# | ||
# Updated operators: | ||
# Identity | ||
# If | ||
# LeakyRelu | ||
# Loop | ||
# PRelu | ||
# RoiAlign | ||
# Scan | ||
# ScatterElemenets | ||
# ScatterND | ||
# Where | ||
# GreaterOrEqual | ||
# LessOrEqual | ||
# SequenceMap | ||
|
||
from torch.onnx.symbolic_helper import parse_args | ||
|
||
from torch.nn.functional import GRID_SAMPLE_INTERPOLATION_MODES, GRID_SAMPLE_PADDING_MODES | ||
|
||
|
||
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`? | ||
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`. | ||
@parse_args("v", "v", "i", "i", "b") | ||
def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners): | ||
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg] | ||
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg] | ||
return g.op( | ||
"GridSample", | ||
input, | ||
grid, | ||
align_corners_i=int(align_corners), | ||
mode_s=mode_s, | ||
padding_mode_s=padding_mode_s, | ||
) |