Skip to content

Commit

Permalink
[primTorch] Enable regex error testing for some refs (pytorch#87765)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#87765
Approved by: https://github.com/mruberry
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Nov 23, 2022
1 parent 3ad2a03 commit 0a1a530
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 87 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) {
auto as_nd = [&](const Tensor& t) {
TORCH_CHECK(
t.dim() == 1 || t.dim() == 0,
"prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = ", t.dim());
"prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = ", t.dim());
if (ndim >= 2) {
sizes[1] = t.dim() == 1 ? t.size(0) : 1;
strides[1] = t.dim() == 1 ? t.stride(0) : 0;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) {
Tensor result = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

TORCH_CHECK(weight_dim == 0 || weight_dim == 1,
"prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = ",
"prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = ",
weight_dim);

// case1: shared weight for all channels
Expand Down
13 changes: 0 additions & 13 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4175,19 +4175,6 @@ def test_mse_loss_size_warning(self):
self.assertEqual(len(w), 1)
self.assertIn('Please ensure they have the same size.', str(w[0]))

def test_poisson_nll_loss_reduction_modes(self):
input = torch.tensor([0.5, 1.5, 2.5])
target = torch.tensor([1., 2., 3.])
component_wise_loss = torch.exp(input) - target * input
self.assertEqual(component_wise_loss,
F.poisson_nll_loss(input, target, reduction='none'))
self.assertEqual(torch.sum(component_wise_loss),
F.poisson_nll_loss(input, target, reduction='sum'))
self.assertEqual(torch.mean(component_wise_loss),
F.poisson_nll_loss(input, target, reduction='mean'))
with self.assertRaisesRegex(ValueError, 'is not valid'):
F.poisson_nll_loss(input, target, reduction='total')

def test_gaussian_nll_loss_broadcasting(self):
input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]])
target_full = torch.tensor([[1., 2., 3.], [1., 2., 3.]])
Expand Down
8 changes: 5 additions & 3 deletions torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,12 +721,14 @@ def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
)
if dim is not None:
# Convert to list to produce a compatible error message with core
# PyTorch, which prints sequences in square brackets.
shape = list(shape)
check(
newsize != 0,
lambda: f"cannot reshape tensor fo 0 elements into shape {shape} because the "
f"unspecified dimension size -1 can be any value and is ambiguous",
lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the "
f"unspecified dimension size -1 can be any value and is ambiguous"),
)
shape = list(shape)
shape[dim] = numel // newsize
return tuple(shape)

Expand Down
57 changes: 37 additions & 20 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,9 +761,14 @@ def nan_to_num(


def _neg_meta(a: TensorLikeType):
if a.dtype is torch.bool:
msg = "neg is not supported on bool tensors."
raise RuntimeError(msg)
check(
a.dtype is not torch.bool,
lambda: (
"Negation, the `-` operator, on a bool tensor is not supported. "
"If you are trying to invert a mask, use the `~` or `logical_not()` "
"operator instead."
),
)


@_make_elementwise_unary_reference(
Expand Down Expand Up @@ -2328,11 +2333,14 @@ def mean(
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
orig_dtype = dtype
if dtype is None:
dtype = a.dtype
# can't use out wrapper because of this argument
if out is not None and out.dtype != dtype:
raise RuntimeError("expected out dtype and dtype to match")
check(
out is None or out.dtype == dtype,
lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
)
result = _reduction(
a,
prims.sum,
Expand All @@ -2342,8 +2350,14 @@ def mean(
out=None,
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
)
if utils.is_integer_dtype(dtype):
raise RuntimeError("result type should be floating point or complex")
check(
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
lambda: (
f"mean(): could not infer output dtype. "
f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either "
f"a floating point or complex dtype. Got: {dtype}"
),
)
if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment]
dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type]
Expand Down Expand Up @@ -3371,7 +3385,7 @@ def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
dim = utils.canonicalize_dim(t.ndim, dim)
check(
len(t.shape) > 0,
lambda: "dimension specified as 0 but tensor has no dimensions",
lambda: "Dimension specified as 0 but tensor has no dimensions",
IndexError,
)
return tuple(
Expand Down Expand Up @@ -3621,12 +3635,12 @@ def vsplit(
check(
(split_size != 0 and a.shape[0] % split_size == 0),
lambda: (
"torch.vsplit attempted to split along dimension 0 "
+ ", but the size of the dimension "
+ str(a.shape[0])
+ " is not divisible by the split_size "
+ str(split_size)
+ "!"
f"torch.vsplit attempted to split along dimension 0"
f", but the size of the dimension "
f"{a.shape[0]}"
f" is not divisible by the split_size "
f"{split_size}"
f"!"
),
)
return tensor_split(a, split_size, 0)
Expand Down Expand Up @@ -3792,7 +3806,7 @@ def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
)
if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
raise RuntimeError(
"torch._refs.dsplit attempted to split along dimension 2, "
"torch.dsplit attempted to split along dimension 2, "
+ f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!"
)
return tensor_split(a, sections, 2)
Expand Down Expand Up @@ -4446,12 +4460,14 @@ def movedim(
if type(destination) is int:
destination = (destination,)

# Converts to list to produce a compatible error message with core PyTorch,
# which prints sequences in square brackets.
utils.check(
len(source) == len(destination), # type: ignore[arg-type]
lambda: (
"movedim: Invalid source or destination dims: source "
f"({source} dims) should contain the same number of dims as "
f"destination ({destination} dims)"
"movedim: Invalid source or destination dims: source " # type: ignore[arg-type]
f"({list(source)} dims) should contain the same number of dims as "
f"destination ({list(destination)} dims)"
),
)

Expand All @@ -4462,13 +4478,14 @@ def movedim(
sss = set(ss)
dss = set(ds)

# See above on why this converts to list in error messages.
utils.check(
len(ss) == len(sss),
lambda: f"movedim: repeated dim in `source` {source}",
lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type]
)
utils.check(
len(ds) == len(dss),
lambda: f"movedim: repeated dim in `destination` {destination}",
lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type]
)

m = dict(zip(ds, ss))
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2764,7 +2764,7 @@ def poisson_nll_loss(
reduction = _Reduction.legacy_get_string(size_average, reduce)
if reduction != "none" and reduction != "mean" and reduction != "sum":
ret = input
raise ValueError(reduction + " is not valid")
raise ValueError(reduction + " is not a valid value for reduction")

ret = torch.poisson_nll_loss(input, target, log_input, full, eps, _Reduction.get_enum(reduction))
return ret
Expand Down
84 changes: 36 additions & 48 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def error_inputs_prelu(op, device):
inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32)
weight = make_tensor((2, 4), device=device, dtype=torch.float32)
yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}),
error_regex="prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = 2")
error_regex="prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = 2")

# src and index tensors must have the same # of dimensions
def sample_inputs_norm(op_info, device, dtype, requires_grad, **kwargs):
Expand Down Expand Up @@ -2428,7 +2428,7 @@ def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs
make_arg((S, S)), gather_variable((S, S // 2), 0, S, True, device=device))


def error_inputs_aminmax_amax_amin(op_info, device, **kwargs):
def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):

# Error Inputs for zero-dim tensors, when 'dim' arg is not provided.
shape = (S, 0, S)
Expand Down Expand Up @@ -2461,7 +2461,15 @@ def error_inputs_aminmax_amax_amin(op_info, device, **kwargs):
min_values = torch.empty(L, dtype=torch.double, device=device)
illegal_values = torch.empty(L, dtype=torch.int, device=device)

err_msg_amax_amin2 = "Expected the dtype for input and out to match"
# Unlike regular PyTorch, amax and amin refs don't require input and out
# dtypes to match exactly:
# https://github.com/pytorch/pytorch/pull/87765#pullrequestreview-1162023824
if is_ref:
err_msg_amax_amin2 = ("Attempting to cast from torch.float32 to out tensor with dtype "
"torch.int32, but this can't be cast because it is not safe!")
else:
err_msg_amax_amin2 = ("Expected the dtype for input and out to match, but got Float "
"for input's dtype and Int for out's dtype.")
err_msg_aminmax2 = "Expected out tensor to have dtype float, but got double instead"

if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
Expand Down Expand Up @@ -7336,7 +7344,7 @@ def error_inputs_poisson_nll_loss(op_info, device, **kwargs):
yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),),
kwargs={'reduction': 'abc'}),
error_type=ValueError,
error_regex='abc is not valid')
error_regex='abc is not a valid value for reduction')
# invalid input shapes
yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)),
error_regex=(r'(Attempting to broadcast a dimension of length|'
Expand Down Expand Up @@ -8138,18 +8146,28 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
numel = torch.tensor(t.shape)[kwargs.get('dim')].prod()
yield ((), {'correction': numel // 2})

def error_inputs_mean(op_info, device, **kwargs):
err_msg1 = (r"mean\(\): could not infer output dtype. "
r"Input dtype must be either a floating point or complex dtype. "
r"Got: Long")
def error_inputs_mean(op_info, device, is_ref=False, **kwargs):
if is_ref:
err_msg1 = (r"mean\(\): could not infer output dtype. "
r"Input dtype must be either a floating point or complex dtype. "
r"Got: torch.int64")
else:
err_msg1 = (r"mean\(\): could not infer output dtype. "
r"Input dtype must be either a floating point or complex dtype. "
r"Got: Long")
yield ErrorInput(
SampleInput(make_tensor((3, 4, 5), dtype=torch.int64, device=device), []),
error_regex=err_msg1,
)

err_msg2 = (r"mean\(\): could not infer output dtype. "
r"Optional dtype must be either a floating point or complex dtype. "
r"Got: Long")
if is_ref:
err_msg2 = (r"mean\(\): could not infer output dtype. "
r"Optional dtype must be either a floating point or complex dtype. "
r"Got: torch.int64")
else:
err_msg2 = (r"mean\(\): could not infer output dtype. "
r"Optional dtype must be either a floating point or complex dtype. "
r"Got: Long")
yield ErrorInput(
SampleInput(
make_tensor((3, 4, 5), dtype=torch.float32, device=device),
Expand All @@ -8158,7 +8176,10 @@ def error_inputs_mean(op_info, device, **kwargs):
error_regex=err_msg2
)

err_msg3 = "Expected out tensor to have dtype double, but got float instead"
if is_ref:
err_msg3 = "Expected out tensor to have dtype torch.float64, but got torch.float32 instead"
else:
err_msg3 = "Expected out tensor to have dtype double, but got float instead"
yield ErrorInput(
SampleInput(
make_tensor((3, 4, 5), dtype=torch.int64, device=device),
Expand Down Expand Up @@ -17125,9 +17146,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
aliases=('moveaxis',),
torch_opinfo_name="movedim",
supports_nvfuser=False,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
),
PythonRefInfo(
"_refs.bucketize",
Expand Down Expand Up @@ -17323,9 +17341,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
ElementwiseUnaryPythonRefInfo(
"_refs.neg",
torch_opinfo_name="neg",
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
),
ElementwiseUnaryPythonRefInfo(
"_refs.positive",
Expand Down Expand Up @@ -17568,16 +17583,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
PythonRefInfo(
"_refs.nn.functional.poisson_nll_loss",
torch_opinfo_name="nn.functional.poisson_nll_loss",
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
),
ElementwiseUnaryPythonRefInfo(
"_refs.nn.functional.prelu",
torch_opinfo_name="nn.functional.prelu",
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
),
ElementwiseUnaryPythonRefInfo(
"_refs.nn.functional.relu",
Expand Down Expand Up @@ -18339,9 +18348,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
"_refs.dsplit",
torch_opinfo_name="dsplit",
supports_nvfuser=False,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
),
PythonRefInfo(
"_refs.diag",
Expand Down Expand Up @@ -18465,9 +18471,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
"_refs.reshape",
torch_opinfo_name="reshape",
supports_nvfuser=False,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
),
PythonRefInfo(
"_refs.reshape_as",
Expand Down Expand Up @@ -18516,9 +18519,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
"_refs.vsplit",
torch_opinfo_name="vsplit",
supports_nvfuser=False,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
),
PythonRefInfo(
"_refs.transpose",
Expand Down Expand Up @@ -18552,9 +18552,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
"_refs.view",
torch_opinfo_name="view",
supports_nvfuser=False,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
),
PythonRefInfo(
"_refs.view_as",
Expand All @@ -18579,9 +18576,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
"_refs.unbind",
torch_opinfo_name="unbind",
supports_nvfuser=False,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
),
#
# Reduction Reference OpInfos
Expand All @@ -18593,16 +18587,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
ReductionPythonRefInfo(
"_refs.amax",
torch_opinfo_name="amax",
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True),
),
ReductionPythonRefInfo(
"_refs.amin",
torch_opinfo_name="amin",
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True),
),
ReductionPythonRefInfo(
"_refs.any",
Expand All @@ -18612,9 +18602,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
"_refs.mean",
torch_opinfo_name="mean",
supports_out=True,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
),
error_inputs_func=partial(error_inputs_mean, is_ref=True),
),
ReductionPythonRefInfo(
"_refs.std",
Expand Down

0 comments on commit 0a1a530

Please sign in to comment.