Skip to content

Commit

Permalink
Unify checks for normal (pytorch#70087)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#70087

Test Plan: Imported from OSS

Reviewed By: davidberard98

Differential Revision: D34089965

Pulled By: bdhirsh

fbshipit-source-id: 17d7eab3d8d60d03ca8ee63875ff2813bb2992c8
(cherry picked from commit 6cd1e9f)
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Mar 1, 2022
1 parent 08493e0 commit dad0e0c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
25 changes: 16 additions & 9 deletions aten/src/ATen/native/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,22 @@ static bool resize_output_for_normal(at::Tensor& output, const at::Tensor& mean,
}
}

#define CHECK_NORMAL_TENSOR_STD(std) \
do { \
TORCH_CHECK( \
!std.is_complex(), \
"normal expects standard deviation to be non-complex"); \
TORCH_CHECK( \
std.numel() == 0 || std.min().ge(0).item<bool>(), \
"normal expects all elements of std >= 0.0"); \
} while (0)

#define CHECK_NORMAL_STD(std) \
TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std);

template<template<typename> class normal_kernel, typename RNG>
Tensor& normal_impl_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
TORCH_CHECK(std >= 0.0, "normal_ expects std >= 0.0, but found std=", std);
CHECK_NORMAL_STD(std);
if (self.is_complex()) {
auto float_tensor = at::view_as_real(self);
// variance for normal distribution of the real and imaginary values
Expand All @@ -221,10 +234,7 @@ Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, c10::opt

template<template<typename> class normal_kernel, typename RNG>
Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, c10::optional<Generator> gen) {
TORCH_CHECK(!std.is_complex(), "normal expects standard deviation to be non-complex");
TORCH_CHECK(
std.min().ge(0).item<bool>(),
"normal expects all elements of std >= 0.0");
CHECK_NORMAL_TENSOR_STD(std);
normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
auto mean_tensor = at::full({}, mean, output.options());
// CUDA NB: addcmul_out copies the tensor to be added into the output.
Expand All @@ -238,10 +248,7 @@ Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, c10::opt

template<template<typename> class normal_kernel, typename RNG>
Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
TORCH_CHECK(!std.is_complex(), "normal expects standard deviation to be non-complex");
TORCH_CHECK(
std.numel() == 0 || std.min().ge(0).item<bool>(),
"normal expects all elements of std >= 0.0");
CHECK_NORMAL_TENSOR_STD(std);
bool is_deprecated_th_impl = resize_output_for_normal(output, mean, std);
normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
// CUDA NB: addcmul_out copies the tensor to be added into the output.
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ Tensor& normal_(Tensor& self, double mean, double std, c10::optional<Generator>
}

Tensor& normal_meta_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
TORCH_CHECK(std >= 0.0, "normal_ expects std >= 0.0, but found std=", std); // TODO: dedupe
CHECK_NORMAL_STD(std);
return self;
}

Expand Down
2 changes: 1 addition & 1 deletion test/test_tensor_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3360,7 +3360,7 @@ def test_normal_std_error(self, device):
std = torch.tensor(-1, dtype=torch.float32, device=device)

for input in [0, a]:
with self.assertRaisesRegex(RuntimeError, r'normal_ expects std >= 0.0'):
with self.assertRaisesRegex(RuntimeError, r'normal expects std >= 0.0, but found std'):
torch.normal(input, -1, (10,))

with self.assertRaisesRegex(RuntimeError, r'normal expects all elements of std >= 0.0'):
Expand Down

0 comments on commit dad0e0c

Please sign in to comment.