Skip to content

Commit

Permalink
FFT: disable dimension wrapping for scalar tensors (pytorch#89234)
Browse files Browse the repository at this point in the history
Fixes pytorch#88985

By default, `maybe_wrap_dim` allows through `dim=0` or `dim=-1`
for scalar tensors which leads to an invalid dimension being used to
index into `tensor.sizes()` as in the code sample from the issue.

Pull Request resolved: pytorch#89234
Approved by: https://github.com/mruberry
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Nov 23, 2022
1 parent 50e2e4f commit ac19c5b
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 25 deletions.
40 changes: 32 additions & 8 deletions aten/src/ATen/WrapDimUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,29 @@ inline int64_t maybe_wrap_dim(
return maybe_wrap_dim(dim, tensor_sizes[0].size());
}

// wrap each dim in the dims array, taking dim_post_expr as the true number of
// dimensions
// Given an array of dimensions `dims` of length `ndims`, this function "Wraps"
// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
// specified using negative indices.
//
// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
// dimensions not in the range [-dim_post_expr, dim_post_expr).
inline void maybe_wrap_dims_n(
int64_t* dims,
int64_t ndims,
int64_t dim_post_expr) {
int64_t dim_post_expr,
bool wrap_scalars = true) {
if (dim_post_expr <= 0) {
dim_post_expr = 1; // this will make range [-1, 0]
if (wrap_scalars) {
dim_post_expr = 1; // this will make range [-1, 0]
} else {
TORCH_CHECK_INDEX(
ndims == 0,
"Dimension specified as ",
dims[0],
" but tensor has no dimensions");
return;
}
}
int64_t min = -dim_post_expr;
int64_t max = dim_post_expr - 1;
Expand All @@ -67,11 +82,20 @@ inline void maybe_wrap_dims_n(
}
}

// Wrap each dim in a contiguous container, taking dim_post_expr as the true
// number of dimensions E.g. could also be std::array or c10::SmallVector
// Given a contiguous container of dimensions `dims`, this function "Wraps"
// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
// specified using negative indices.
//
// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
// dimensions not in the range [-dim_post_expr, dim_post_expr).
template <typename Container>
inline void maybe_wrap_dims(Container& dims, int64_t dim_post_expr) {
return maybe_wrap_dims_n(dims.data(), dims.size(), dim_post_expr);
inline void maybe_wrap_dims(
Container& dims,
int64_t dim_post_expr,
bool wrap_scalars = true) {
return maybe_wrap_dims_n(
dims.data(), dims.size(), dim_post_expr, wrap_scalars);
}

// previously, size [0] tensors were the only possible empty tensors; thus, it
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ Tensor fft_c2r(c10::string_view function_name,
" expects a floating point output tensor, but got ", out.scalar_type());
input = promote_tensor_fft(input, /*require_complex=*/true);
const auto input_dim = input.dim();
const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim);
const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim, /*wrap_scalar=*/false);
const auto n = n_opt.value_or(2*(input.sizes()[dim] - 1));
TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified");
if (n_opt) {
Expand All @@ -225,7 +225,7 @@ Tensor fft_r2c(c10::string_view function_name,
" expects a complex output tensor, but got ", out.scalar_type());
input = promote_tensor_fft(input);
const auto input_dim = input.dim();
const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim);
const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim, /*wrap_scalar=*/false);
const auto n = n_opt.value_or(input.sizes()[dim]);
TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified");
if (n_opt) {
Expand Down Expand Up @@ -257,7 +257,7 @@ Tensor fft_c2c(c10::string_view function_name,
TORCH_CHECK(input.is_complex(), function_name,
" expects a complex input tensor, but got ", input.scalar_type());
const auto input_dim = input.dim();
const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim);
const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim, /*wrap_scalar=*/false);
const auto n = n_opt.value_or(input.sizes()[dim]);
TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified");
if (n_opt) {
Expand All @@ -284,7 +284,7 @@ ShapeAndDims canonicalize_fft_shape_and_dim_args(
if (dim) {
ret.dim.resize(dim->size());
std::copy(dim->begin(), dim->end(), ret.dim.begin());
maybe_wrap_dims(ret.dim, input_dim);
maybe_wrap_dims(ret.dim, input_dim, /*wrap_scalars=*/false);

// Check dims are unique
DimVector copy = ret.dim;
Expand Down Expand Up @@ -750,7 +750,7 @@ DimVector default_alldims(const Tensor& self, at::OptionalIntArrayRef dim_opt) {
IntArrayRef dim_unwrapped = *dim_opt;
dim.resize(dim_unwrapped.size());
for (const auto i : c10::irange(dim.size())) {
dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim());
dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim(), /*wrap_scalars=*/false);
}
} else {
dim.resize(self.dim());
Expand Down Expand Up @@ -1182,7 +1182,7 @@ void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) {
const auto input_strides = input.strides();
TORCH_CHECK(dim_.size() > 0);
DimVector dim(dim_.begin(), dim_.end());
at::maybe_wrap_dims(dim, input_strides.size());
at::maybe_wrap_dims(dim, input_strides.size(), /*wrap_scalars=*/false);

if (input.numel() == 0 || input_sizes[dim.back()] <= 2) {
return; // No elements need writing
Expand Down
15 changes: 8 additions & 7 deletions torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,9 @@ def validate_exclusive_idx(rank: int, ex_idx: int):


# "Wraps" a dim (up to one time) for the given rank, allowing dims to be
# specified using negative indices. For scalar tensors with rank 0, then idx
# must be in the range [-1, 0]. Otherwise, idx should be in the range [-rank, rank-1].
# specified using negative indices. If `wrap_scalar` is true then scalar
# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise,
# idx should be in the range [-rank, rank-1].
def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
if rank < 0:
msg = f"Rank cannot be negative but got {rank}"
Expand Down Expand Up @@ -507,20 +508,20 @@ def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
# Takes a dimension or sequence of dimensions and "wraps" them,
# mapping negative offsets to positive ones
@overload
def canonicalize_dims(rank: int, indices: Sequence[int]) -> Tuple[int, ...]:
def canonicalize_dims(rank: int, indices: Sequence[int], wrap_scalar: bool = True) -> Tuple[int, ...]:
pass


@overload
def canonicalize_dims(rank: int, indices: int) -> int:
def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int:
pass


def canonicalize_dims(rank, indices):
def canonicalize_dims(rank, indices, wrap_scalar=True):
if isinstance(indices, Dim):
return canonicalize_dim(rank, indices)
return canonicalize_dim(rank, indices, wrap_scalar)

return tuple(canonicalize_dim(rank, x) for x in indices)
return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices)


def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool:
Expand Down
8 changes: 4 additions & 4 deletions torch/_refs/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _fft_c2r(
) -> TensorLikeType:
"""Common code for performing any complex to real FFT (irfft or hfft)"""
input = _maybe_promote_tensor_fft(input, require_complex=True)
dims = (utils.canonicalize_dim(input.ndim, dim),)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified")

Expand Down Expand Up @@ -144,7 +144,7 @@ def _fft_r2c(
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
)
input = _maybe_promote_tensor_fft(input)
dims = (utils.canonicalize_dim(input.ndim, dim),)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)

if n is not None:
input = _resize_fft_input(input, dims, (n,))
Expand All @@ -167,7 +167,7 @@ def _fft_c2c(
input.dtype.is_complex,
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
)
dims = (utils.canonicalize_dim(input.ndim, dim),)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)

if n is not None:
input = _resize_fft_input(input, dims, (n,))
Expand Down Expand Up @@ -263,7 +263,7 @@ def _canonicalize_fft_shape_and_dim_args(
if dim is not None:
if not isinstance(dim, Sequence):
dim = (dim,)
ret_dims = utils.canonicalize_dims(input_dim, dim)
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)

# Check dims are unique
check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
Expand Down
40 changes: 40 additions & 0 deletions torch/testing/_internal/opinfo/definitions/fft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from functools import partial
from typing import List

import numpy as np
Expand All @@ -15,6 +16,7 @@
from torch.testing._internal.common_utils import TEST_SCIPY, TEST_WITH_ROCM
from torch.testing._internal.opinfo.core import (
DecorateInfo,
ErrorInput,
OpInfo,
SampleInput,
SpectralFuncInfo,
Expand Down Expand Up @@ -65,6 +67,26 @@ def __init__(
super().__init__(**ukwargs)


def error_inputs_fft(op_info, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32)
# Zero-dimensional tensor has no dimension to take FFT of
yield ErrorInput(
SampleInput(make_arg()),
error_type=IndexError,
error_regex="Dimension specified as -1 but tensor has no dimensions",
)


def error_inputs_fftn(op_info, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32)
# Specifying a dimension on a zero-dimensional tensor
yield ErrorInput(
SampleInput(make_arg(), dim=(0,)),
error_type=IndexError,
error_regex="Dimension specified as 0 but tensor has no dimensions",
)


def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs):
def mt(shape, **kwargs):
return make_tensor(
Expand Down Expand Up @@ -97,6 +119,7 @@ def mt(shape, **kwargs):
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fft,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand All @@ -121,6 +144,7 @@ def mt(shape, **kwargs):
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand All @@ -146,6 +170,7 @@ def mt(shape, **kwargs):
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand All @@ -171,6 +196,7 @@ def mt(shape, **kwargs):
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fft,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -205,6 +231,7 @@ def mt(shape, **kwargs):
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -245,6 +272,7 @@ def mt(shape, **kwargs):
else (torch.half, torch.complex32)
),
),
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -280,6 +308,7 @@ def mt(shape, **kwargs):
dtypesIfCUDA=all_types_and(
torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
),
error_inputs_func=error_inputs_fft,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand All @@ -300,6 +329,7 @@ def mt(shape, **kwargs):
dtypesIfCUDA=all_types_and(
torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
),
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand All @@ -322,6 +352,7 @@ def mt(shape, **kwargs):
dtypesIfCUDA=all_types_and(
torch.bool, *(() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,))
),
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand All @@ -338,6 +369,7 @@ def mt(shape, **kwargs):
decomp_aten_name="_fft_c2c",
ref=np.fft.ifft,
ndimensional=SpectralFuncType.OneD,
error_inputs_func=error_inputs_fft,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand All @@ -362,6 +394,7 @@ def mt(shape, **kwargs):
decomp_aten_name="_fft_c2c",
ref=np.fft.ifft2,
ndimensional=SpectralFuncType.TwoD,
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -393,6 +426,7 @@ def mt(shape, **kwargs):
decomp_aten_name="_fft_c2c",
ref=np.fft.ifftn,
ndimensional=SpectralFuncType.ND,
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -424,6 +458,7 @@ def mt(shape, **kwargs):
decomp_aten_name="_fft_r2c",
ref=np.fft.ihfft,
ndimensional=SpectralFuncType.OneD,
error_inputs_func=error_inputs_fft,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
Expand All @@ -443,6 +478,7 @@ def mt(shape, **kwargs):
decomp_aten_name="_fft_r2c",
ref=scipy.fft.ihfftn if has_scipy_fft else None,
ndimensional=SpectralFuncType.TwoD,
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -474,6 +510,7 @@ def mt(shape, **kwargs):
decomp_aten_name="_fft_r2c",
ref=scipy.fft.ihfftn if has_scipy_fft else None,
ndimensional=SpectralFuncType.ND,
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -504,6 +541,7 @@ def mt(shape, **kwargs):
decomp_aten_name="_fft_c2r",
ref=np.fft.irfft,
ndimensional=SpectralFuncType.OneD,
error_inputs_func=error_inputs_fft,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand All @@ -529,6 +567,7 @@ def mt(shape, **kwargs):
decomp_aten_name="_fft_c2r",
ref=np.fft.irfft2,
ndimensional=SpectralFuncType.TwoD,
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -561,6 +600,7 @@ def mt(shape, **kwargs):
decomp_aten_name="_fft_c2r",
ref=np.fft.irfftn,
ndimensional=SpectralFuncType.ND,
error_inputs_func=error_inputs_fftn,
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
Expand Down

0 comments on commit ac19c5b

Please sign in to comment.