Skip to content

[numpy] Add torch.nan_to_num #44592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f28d58b
add nan_to_num kernel impl
kshitij12345 Sep 12, 2020
1473c8a
Merge branch 'master' into develop/numpy/nan_to_num
kshitij12345 Sep 12, 2020
039001b
extra machinery for c10::Half support
kshitij12345 Sep 12, 2020
f2b4aba
add docs
kshitij12345 Sep 12, 2020
5ec18c5
update overrides.py
kshitij12345 Sep 12, 2020
90083ca
update signature
kshitij12345 Sep 13, 2020
9a0dcb4
support gradient
kshitij12345 Sep 13, 2020
c14529f
add test vs numpy
kshitij12345 Sep 13, 2020
d159a0d
rename pos_inf -> posinf, neg_inf -> neginf
kshitij12345 Sep 13, 2020
bab7127
Merge branch 'master' into develop/numpy/nan_to_num
kshitij12345 Sep 21, 2020
d9b2e59
remove merge stray
kshitij12345 Sep 21, 2020
cecd74a
update argument names
kshitij12345 Sep 21, 2020
0746035
address comment
kshitij12345 Sep 21, 2020
0a66921
replace randn < 0.2 -> rand < 0.2
kshitij12345 Sep 21, 2020
1345137
address comment
kshitij12345 Sep 22, 2020
6408886
address comment
kshitij12345 Sep 22, 2020
5f5af08
update UnaryUfuncInfo DB
kshitij12345 Sep 22, 2020
ed94f3c
Merge branch 'master' into develop/numpy/nan_to_num
kshitij12345 Sep 22, 2020
823bfa3
update types for nan_to_num
kshitij12345 Sep 22, 2020
e771f8a
try fixing docs warnings
kshitij12345 Sep 22, 2020
7df69bb
try fixing doc warnings
kshitij12345 Sep 23, 2020
f8fc57d
try fixing docs - esacpe `s`
kshitij12345 Sep 23, 2020
06b23e0
update gradient formula
kshitij12345 Sep 24, 2020
3dc297b
fix kernel name
kshitij12345 Sep 24, 2020
c260821
update docs
kshitij12345 Sep 24, 2020
a7ba1c9
replace dispatch for intergers with inplace copy_
kshitij12345 Sep 24, 2020
81258a9
kMerge branch 'develop/numpy/nan_to_num' into develop/numpy/nan_to_num
kshitij12345 Sep 25, 2020
d675bd7
move test to test_unary_funcs
kshitij12345 Sep 25, 2020
79428e4
Merge branch 'master' into develop/numpy/nan_to_num
kshitij12345 Sep 25, 2020
5288932
update kernel impl
kshitij12345 Sep 26, 2020
534f40b
remove unused _isfinite
kshitij12345 Sep 26, 2020
dffa5a9
address comment
kshitij12345 Sep 29, 2020
6141370
Merge branch 'master' into develop/numpy/nan_to_num
kshitij12345 Sep 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/ATen/NumericUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ inline bool _isnan(T val) {
template <typename T,
typename std::enable_if<std::is_same<T, at::Half>::value, int>::type = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return at::_isnan(float(val));
return at::_isnan(static_cast<float>(val));
}


inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
return at::_isnan(float(val));
return at::_isnan(static_cast<float>(val));
}

template <typename T>
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ _(aten, multinomial) \
_(aten, mv) \
_(aten, mvlgamma) \
_(aten, nansum) \
_(aten, nan_to_num) \
_(aten, narrow) \
_(aten, narrow_copy) \
_(aten, native_batch_norm) \
Expand Down
36 changes: 36 additions & 0 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,41 @@ Tensor& logit_(Tensor& self, c10::optional<double> eps) {
return at::logit_out(self, self, eps);
}

Tensor& nan_to_num_out(
Tensor& result,
const Tensor& self,
c10::optional<double> nan,
c10::optional<double> pos_inf,
c10::optional<double> neg_inf) {

if (c10::isIntegralType(self.scalar_type())) {
result.resize_as_(self);
result.copy_(self);
return result;
}

auto iter = TensorIterator::unary_op(result, self);
Copy link
Collaborator

@mruberry mruberry Sep 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ezyang @anjali411 this is another case where we could consider returning self (for integer tensors)

nan_to_num_stub(iter.device_type(), iter, nan, pos_inf, neg_inf);
return result;
}

Tensor nan_to_num(
const Tensor& self,
c10::optional<double> nan,
c10::optional<double> pos_inf,
c10::optional<double> neg_inf) {
auto result = at::empty_like(self);
return at::nan_to_num_out(result, self, nan, pos_inf, neg_inf);
}

Tensor& nan_to_num_(
Tensor& self,
c10::optional<double> nan,
c10::optional<double> pos_inf,
c10::optional<double> neg_inf) {
return at::nan_to_num_out(self, self, nan, pos_inf, neg_inf);
}

Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tanh_stub); }
Tensor tanh(const Tensor& self) { return unary_op_impl(self, at::tanh_out); }
Tensor& tanh_(Tensor& self) { return unary_op_impl_(self, at::tanh_out); }
Expand Down Expand Up @@ -645,6 +680,7 @@ DEFINE_DISPATCH(log1p_stub);
DEFINE_DISPATCH(log2_stub);
DEFINE_DISPATCH(logical_not_stub);
DEFINE_DISPATCH(neg_stub);
DEFINE_DISPATCH(nan_to_num_stub);
DEFINE_DISPATCH(polygamma_stub);
DEFINE_DISPATCH(reciprocal_stub);
DEFINE_DISPATCH(round_stub);
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/native/UnaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional<Generator>), random_stub
DECLARE_DISPATCH(void(*)(TensorIterator&, const int64_t), polygamma_stub);
DECLARE_DISPATCH(void(*)(TensorIterator&, Scalar a, Scalar b), clamp_stub);
DECLARE_DISPATCH(void(*)(Tensor&, const Tensor&, int64_t, bool, c10::optional<Generator>), multinomial_stub);
DECLARE_DISPATCH(
void (*)(
TensorIterator&,
c10::optional<double>,
c10::optional<double>,
c10::optional<double>),
nan_to_num_stub);

// Missing unary functions
// digamma
Expand Down
28 changes: 28 additions & 0 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,33 @@ static void polygamma_kernel(TensorIterator& iter, int64_t n) {
}
}

static void nan_to_num_kernel(
TensorIterator& iter,
c10::optional<double> nan,
c10::optional<double> pos_inf,
c10::optional<double> neg_inf) {
AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "nan_to_num", [&]() {
scalar_t nan_replacement = static_cast<scalar_t>(nan.value_or(0.));
scalar_t pos_inf_replacement = pos_inf.has_value()
? static_cast<scalar_t>(pos_inf.value())
: std::numeric_limits<scalar_t>::max();
scalar_t neg_inf_replacement = neg_inf.has_value()
? static_cast<scalar_t>(neg_inf.value())
: std::numeric_limits<scalar_t>::lowest();

cpu_kernel(iter, [=](scalar_t a) -> scalar_t {
return (
at::_isnan(a)
? nan_replacement
: (a == std::numeric_limits<scalar_t>::infinity()
? pos_inf_replacement
: (a == -std::numeric_limits<scalar_t>::infinity()
? neg_inf_replacement
: a)));
});
});
}

static void clamp_kernel(TensorIterator& iter, Scalar min_scalar, Scalar max_scalar) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "clamp_cpu", [&]() {
c10::scalar_value_type<scalar_t>::type (*zabs_)(scalar_t) = zabs;
Expand Down Expand Up @@ -648,6 +675,7 @@ REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel);
REGISTER_DISPATCH(logical_not_stub, &logical_not_kernel);
REGISTER_DISPATCH(frac_stub, &frac_kernel);
REGISTER_DISPATCH(reciprocal_stub, &reciprocal_kernel);
REGISTER_DISPATCH(nan_to_num_stub, &nan_to_num_kernel);
REGISTER_DISPATCH(neg_stub, &neg_kernel);
REGISTER_DISPATCH(sign_stub, &sign_kernel);
REGISTER_DISPATCH(signbit_stub, &signbit_kernel);
Expand Down
28 changes: 28 additions & 0 deletions aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/NumericUtils.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/complex.h>

Expand Down Expand Up @@ -180,6 +181,32 @@ void clamp_max_kernel_cuda(TensorIterator& iter, Scalar max_value) {
});
}

void nan_to_num_kernel_cuda(
TensorIterator& iter,
c10::optional<double> nan,
c10::optional<double> pos_inf,
c10::optional<double> neg_inf) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "nan_to_num_cuda", [&]() {
scalar_t nan_replacement = static_cast<scalar_t>(nan.value_or(0.));
scalar_t pos_inf_replacement = pos_inf.has_value()
? static_cast<scalar_t>(pos_inf.value())
: std::numeric_limits<scalar_t>::max();
scalar_t neg_inf_replacement = neg_inf.has_value()
? static_cast<scalar_t>(neg_inf.value())
: std::numeric_limits<scalar_t>::lowest();
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t {
return (
at::_isnan(a)
? nan_replacement
: (a == std::numeric_limits<scalar_t>::infinity()
? pos_inf_replacement
: (a == -std::numeric_limits<scalar_t>::infinity()
? neg_inf_replacement
: a)));
});
});
}

void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta){
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "kaiser_window_cuda", [&] {
Expand All @@ -206,6 +233,7 @@ REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
REGISTER_DISPATCH(clamp_stub, &clamp_kernel_cuda);
REGISTER_DISPATCH(clamp_min_stub, &clamp_min_kernel_cuda);
REGISTER_DISPATCH(clamp_max_stub, &clamp_max_kernel_cuda);
REGISTER_DISPATCH(nan_to_num_stub, &nan_to_num_kernel_cuda);
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);

} // namespace native
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1981,6 +1981,16 @@
CPU: layer_norm_backward_cpu
CUDA: layer_norm_backward_cuda

- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
use_c10_dispatcher: full
variants: function, method

- func: nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!)
use_c10_dispatcher: full
variants: function, method

- func: nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!)

- func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
use_c10_dispatcher: full
python_module: nn
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: narrow
.. automethod:: narrow_copy
.. automethod:: ndimension
.. automethod:: nan_to_num
.. automethod:: nan_to_num_
.. automethod:: ne
.. automethod:: ne_
.. automethod:: not_equal
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ Pointwise Ops
mul
multiply
mvlgamma
nan_to_num
neg
negative
nextafter
Expand Down
27 changes: 27 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4635,6 +4635,33 @@ def test(inp, inp_dtype, out_dtype):
test(inp, torch.float, torch.double)
test(inp, torch.double, torch.float)

def test_nan_to_num(self):
a = torch.randn(3, 3, 3, 3)
with torch.no_grad():
a[torch.rand_like(a) < 0.2] = float('nan')
a[torch.rand_like(a) < 0.2] = float('inf')
a[torch.rand_like(a) < 0.2] = -float('inf')

a.requires_grad = True

gradcheck(lambda x: x.nan_to_num(), a)
gradgradcheck(lambda x: x.nan_to_num(), a)

gradcheck(lambda x: x.nan_to_num(nan=1.2), a)
gradgradcheck(lambda x: x.nan_to_num(nan=1.2), a)

gradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0), a)
gradgradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0), a)

gradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0, neginf=-2.0), a)
gradgradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0, neginf=-2.0), a)

gradcheck(lambda x: x.nan_to_num(posinf=2.0, neginf=-2.0), a)
gradgradcheck(lambda x: x.nan_to_num(posinf=2.0, neginf=-2.0), a)

gradcheck(lambda x: x.nan_to_num(neginf=-2.0), a)
gradgradcheck(lambda x: x.nan_to_num(neginf=-2.0), a)

def test_custom_function_error(self):
class BadFw(Function):
@staticmethod
Expand Down
36 changes: 36 additions & 0 deletions test/test_unary_ufuncs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from itertools import product, chain
from numbers import Number
import random

import unittest

Expand Down Expand Up @@ -377,6 +378,41 @@ def test_batch_vs_slicing(self, device, dtype, op):

self.assertEqual(actual, expected)

@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
def test_nan_to_num(self, device, dtype):
for contiguous in [False, True]:
x = make_tensor((64, 64), low=0., high=100., dtype=dtype, device=device)

if dtype.is_floating_point:
# Add extremal values.
extremals = [float('nan'), float('inf'), -float('inf')]
for idx, extremal in zip(torch.randint(0, 63, (3,)), extremals):
x[idx, :] = extremal

if not contiguous:
x = x.T

# With args
nan = random.random()
posinf = random.random() * 5
neginf = random.random() * 10

self.compare_with_numpy(lambda x: x.nan_to_num(nan=nan, posinf=posinf),
lambda x: np.nan_to_num(x, nan=nan, posinf=posinf),
x)
self.compare_with_numpy(lambda x: x.nan_to_num(posinf=posinf, neginf=neginf),
lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf),
x)

# Out Variant
out = torch.empty_like(x)
result = torch.nan_to_num(x)
torch.nan_to_num(x, out=out)
self.assertEqual(result, out)

result = torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
torch.nan_to_num(x, out=out, nan=nan, posinf=posinf, neginf=neginf)
self.assertEqual(result, out)

instantiate_device_type_tests(TestUnaryUfuncs, globals())

Expand Down
3 changes: 3 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,9 @@
- name: mvlgamma(Tensor self, int p) -> Tensor
self: mvlgamma_backward(grad, self, p)

- name: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
self: grad * at::isfinite(self)

- name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"

Expand Down
4 changes: 3 additions & 1 deletion tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@
'quantize_per_tensor', 'quantize_per_channel',
# Functions that return integers should not have output that require gradients
'argmax', 'argmin', 'argsort', 'searchsorted',
'bucketize'
'bucketize',
# Functions that return booleans are not differentiable
'isnan', 'isposinf', 'isneginf', 'isinf'
}

# Some operators invalidate the grad_accumulator. Let's reset it.
Expand Down
12 changes: 12 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2342,6 +2342,18 @@ def callable(a, b) -> number
Alias for :meth:`~Tensor.dim()`
""")

add_docstr_all('nan_to_num', r"""
nan_to_num(nan=0.0, posinf=None, neginf=None) -> Tensor

See :func:`torch.nan_to_num`.
""")

add_docstr_all('nan_to_num_', r"""
nan_to_num_(nan=0.0, posinf=None, neginf=None) -> Tensor

In-place version of :meth:`~Tensor.nan_to_num`.
""")

add_docstr_all('ne', r"""
ne(other) -> Tensor

Expand Down
35 changes: 35 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5295,6 +5295,41 @@ def merge_dicts(*dicts):
[ 8, 9]])
""")

add_docstr(torch.nan_to_num,
r"""
nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None) -> Tensor

Replaces :literal:`NaN`, positive infinity, and negative infinity values in :attr:`input`
with the values specified by :attr:`nan`, :attr:`posinf`, and :attr:`neginf`, respectively.
By default, :literal:`NaN`s are replaced with zero, positive infinity is replaced with the
greatest finite value representable by :attr:`input`'s dtype, and negative infinity
is replaced with the least finite value representable by :attr:`input`'s dtype.

Args:
{input}
nan (Number, optional): the value to replace :literal:`NaN`\s with. Default is zero.
posinf (Number, optional): if a Number, the value to replace positive infinity values with.
If None, positive infinity values are replaced with the greatest finite value representable by :attr:`input`'s dtype.
Default is None.
neginf (Number, optional): if a Number, the value to replace negative infinity values with.
If None, negative infinity values are replaced with the lowest finite value representable by :attr:`input`'s dtype.
Default is None.

Keyword args:
{out}

Example::

>>> x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
>>> torch.nan_to_num(x)
tensor([ 0.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00])
>>> torch.nan_to_num(x, nan=2.0)
tensor([ 2.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00])
>>> torch.nan_to_num(x, nan=2.0, posinf=1.0)
tensor([ 2.0000e+00, 1.0000e+00, -3.4028e+38, 3.1400e+00])

""".format(**common_args))

add_docstr(torch.ne, r"""
ne(input, other, *, out=None) -> Tensor

Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.mv: lambda input, vec, out=None: -1,
torch.mvlgamma: lambda input, p: -1,
torch.narrow: lambda input, dim, start, length: -1,
torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
torch.native_layer_norm: lambda input, weight, bias, M, N, eps: -1,
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
Expand Down
Loading