Skip to content

Commit 1d43e2d

Browse files
authored
【PaddlePaddle Hackathon 2】8、为 Paddle 新增 nanmean API (#40472)
* Update __init__.py * Update math.py * Create test_nanmean_api.py * Update __init__.py * Update __init__.py * Update math.py * Update test_nanmean_api.py * Update __init__.py * Update math.py * Update test_nanmean_api.py * Update test_nanmean_api.py * Update test_nanmean_api.py * Update math.py * Update test_nanmean_api.py * Update math.py Update the nanmean example code * Update math.py * Update math.py * Update math.py Remove redundant code in nanmean * Update math.py change default keepdim = False * Update test_nanmean_api.py add nan into self.x * Update test_nanmean_api.py rerun CI check * Update test_nanmean_api.py * update code of nanmean in python/paddle/tensor/math.py and test_nanmean_api.py * Update test_nanmean_api.py update code format * Update test_nanmean_api.py update code format * Update test_nanmean_api.py add check grad code. * Update math.py update nanmean's describe of Args x * Update test_nanmean_api.py update format and release the test_case(self.x, keepdim=True) in check grad code. * Update test_nanmean_api.py Update gradient checking method * Update test_nanmean_api.py update code format * Update test_nanmean_api.py Update code format and copyright in test_nanmean_api.py * Update math.py update arguments describe and code example * Update math.py 修改了nanmean的axis参数的文档描述。 * Update math.py updata nanmean's sample code (:name: code-example1) * Update math.py 修改nanmean的example code 错误 * Update math.py update example code * Update math.py update example code of nanmean
1 parent 176df91 commit 1d43e2d

File tree

4 files changed

+209
-0
lines changed

4 files changed

+209
-0
lines changed

python/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@
213213
from .tensor.math import stanh # noqa: F401
214214
from .tensor.math import sum # noqa: F401
215215
from .tensor.math import nansum # noqa: F401
216+
from .tensor.math import nanmean # noqa: F401
216217
from .tensor.math import tanh # noqa: F401
217218
from .tensor.math import tanh_ # noqa: F401
218219
from .tensor.math import add_n # noqa: F401
@@ -545,6 +546,7 @@
545546
'not_equal',
546547
'sum',
547548
'nansum',
549+
'nanmean',
548550
'tile',
549551
'greater_equal',
550552
'isfinite',
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
import paddle
20+
import paddle.fluid as fluid
21+
import paddle.fluid.core as core
22+
from paddle.fluid import Program, program_guard
23+
24+
np.random.seed(10)
25+
26+
27+
class TestNanmeanAPI(unittest.TestCase):
28+
# test paddle.tensor.math.nanmean
29+
30+
def setUp(self):
31+
self.x_shape = [2, 3, 4, 5]
32+
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
33+
self.x[0, :, :, :] = np.nan
34+
self.x_grad = np.array([[np.nan, np.nan, 3.],
35+
[0., np.nan, 2.]]).astype(np.float32)
36+
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
37+
else paddle.CPUPlace()
38+
39+
def test_api_static(self):
40+
paddle.enable_static()
41+
with paddle.static.program_guard(paddle.static.Program()):
42+
x = paddle.fluid.data('X', self.x_shape)
43+
out1 = paddle.nanmean(x)
44+
out2 = paddle.tensor.nanmean(x)
45+
out3 = paddle.tensor.math.nanmean(x)
46+
axis = np.arange(len(self.x_shape)).tolist()
47+
out4 = paddle.nanmean(x, axis)
48+
out5 = paddle.nanmean(x, tuple(axis))
49+
exe = paddle.static.Executor(self.place)
50+
res = exe.run(feed={'X': self.x},
51+
fetch_list=[out1, out2, out3, out4, out5])
52+
out_ref = np.nanmean(self.x)
53+
for out in res:
54+
self.assertEqual(np.allclose(out, out_ref, rtol=1e-04), True)
55+
56+
def test_api_dygraph(self):
57+
paddle.disable_static(self.place)
58+
59+
def test_case(x, axis=None, keepdim=False):
60+
x_tensor = paddle.to_tensor(x)
61+
out = paddle.nanmean(x_tensor, axis, keepdim)
62+
if isinstance(axis, list):
63+
axis = tuple(axis)
64+
if len(axis) == 0:
65+
axis = None
66+
67+
out_ref = np.nanmean(x, axis, keepdims=keepdim)
68+
if np.isnan(out_ref).sum():
69+
nan_mask = np.isnan(out_ref)
70+
out_ref[nan_mask] = 0
71+
out_np = out.numpy()
72+
out_np[nan_mask] = 0
73+
self.assertEqual(np.allclose(out_np, out_ref, rtol=1e-04), True)
74+
else:
75+
self.assertEqual(
76+
np.allclose(
77+
out.numpy(), out_ref, rtol=1e-04), True)
78+
79+
test_case(self.x)
80+
test_case(self.x, [])
81+
test_case(self.x, -1)
82+
test_case(self.x, keepdim=True)
83+
test_case(self.x, 2, keepdim=True)
84+
test_case(self.x, [0, 2])
85+
test_case(self.x, (0, 2))
86+
test_case(self.x, [0, 1, 2, 3])
87+
paddle.enable_static()
88+
89+
def test_errors(self):
90+
paddle.enable_static()
91+
with paddle.static.program_guard(paddle.static.Program()):
92+
x = paddle.fluid.data('X', [10, 12], 'int32')
93+
self.assertRaises(TypeError, paddle.nanmean, x)
94+
95+
def test_api_dygraph_grad(self):
96+
paddle.disable_static(self.place)
97+
98+
def test_case(x, axis=None, keepdim=False):
99+
if isinstance(axis, list):
100+
axis = list(axis)
101+
if len(axis) == 0:
102+
axis = None
103+
x_tensor = paddle.to_tensor(x, stop_gradient=False)
104+
y = paddle.nanmean(x_tensor, axis, keepdim)
105+
dx = paddle.grad(y, x_tensor)[0].numpy()
106+
sum_dx_ref = np.prod(y.shape)
107+
if np.isnan(y.numpy()).sum():
108+
sum_dx_ref -= np.isnan(y.numpy()).sum()
109+
cnt = paddle.sum(~paddle.isnan(x_tensor),
110+
axis=axis,
111+
keepdim=keepdim)
112+
if (cnt == 0).sum():
113+
dx[np.isnan(dx)] = 0
114+
sum_dx = dx.sum()
115+
self.assertEqual(np.allclose(sum_dx, sum_dx_ref, rtol=1e-04), True)
116+
117+
test_case(self.x)
118+
test_case(self.x, [])
119+
test_case(self.x, -1)
120+
test_case(self.x, keepdim=True)
121+
test_case(self.x, 2, keepdim=True)
122+
test_case(self.x, [0, 2])
123+
test_case(self.x, (0, 2))
124+
test_case(self.x, [0, 1, 2, 3])
125+
126+
test_case(self.x_grad)
127+
test_case(self.x_grad, [])
128+
test_case(self.x_grad, -1)
129+
test_case(self.x_grad, keepdim=True)
130+
test_case(self.x_grad, 0, keepdim=True)
131+
test_case(self.x_grad, 1)
132+
test_case(self.x_grad, (0, 1))
133+
paddle.enable_static()
134+
135+
136+
if __name__ == "__main__":
137+
unittest.main()

python/paddle/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@
165165
from .math import stanh # noqa: F401
166166
from .math import sum # noqa: F401
167167
from .math import nansum # noqa: F401
168+
from .math import nanmean # noqa: F401
168169
from .math import tanh # noqa: F401
169170
from .math import tanh_ # noqa: F401
170171
from .math import add_n # noqa: F401
@@ -333,6 +334,7 @@
333334
'stanh',
334335
'sum',
335336
'nansum',
337+
'nanmean',
336338
'tanh',
337339
'tanh_',
338340
'add_n',

python/paddle/tensor/math.py

100755100644
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,73 @@ def nansum(x, axis=None, dtype=None, keepdim=False, name=None):
10241024
return sum(tmp_tensor, axis, dtype, keepdim, name)
10251025

10261026

1027+
def nanmean(x, axis=None, keepdim=False, name=None):
1028+
r"""
1029+
Compute the arithmetic mean along the specified axis, ignoring NaNs.
1030+
1031+
Args:
1032+
x (Tensor): The input Tensor with data type uint16, float16, float32, float64.
1033+
axis (int|list|tuple, optional):The axis along which to perform nanmean
1034+
calculations. ``axis`` should be int, list(int) or tuple(int). If
1035+
``axis`` is a list/tuple of dimension(s), nanmean is calculated along
1036+
all element(s) of ``axis`` . ``axis`` or element(s) of ``axis``
1037+
should be in range [-D, D), where D is the dimensions of ``x`` . If
1038+
``axis`` or element(s) of ``axis`` is less than 0, it works the
1039+
same way as :math:`axis + D` . If ``axis`` is None, nanmean is
1040+
calculated over all elements of ``x``. Default is None.
1041+
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
1042+
in the output Tensor. If ``keepdim`` is True, the dimensions of
1043+
the output Tensor is the same as ``x`` except in the reduced
1044+
dimensions(it is of size 1 in this case). Otherwise, the shape of
1045+
the output Tensor is squeezed in ``axis`` . Default is False.
1046+
name (str, optional): Name for the operation (optional, default is None).
1047+
For more information, please refer to :ref:`api_guide_Name`.
1048+
1049+
Returns:
1050+
Tensor, results of arithmetic mean along ``axis`` of ``x``, with the same data
1051+
type as ``x``.
1052+
1053+
Examples:
1054+
1055+
.. code-block:: python
1056+
:name: code-example1
1057+
1058+
import paddle
1059+
# x is a 2-D Tensor:
1060+
x = paddle.to_tensor([[float('nan'), 0.3, 0.5, 0.9],
1061+
[0.1, 0.2, float('-nan'), 0.7]])
1062+
out1 = paddle.nanmean(x)
1063+
# [0.44999996]
1064+
out2 = paddle.nanmean(x, axis=0)
1065+
# [0.1, 0.25, 0.5, 0.79999995]
1066+
out3 = paddle.nanmean(x, axis=0, keepdim=True)
1067+
# [[0.1, 0.25, 0.5, 0.79999995]]
1068+
out4 = paddle.nanmean(x, axis=1)
1069+
# [0.56666666 0.33333334]
1070+
out5 = paddle.nanmean(x, axis=1, keepdim=True)
1071+
# [[0.56666666]
1072+
# [0.33333334]]
1073+
1074+
# y is a 3-D Tensor:
1075+
y = paddle.to_tensor([[[1, float('nan')], [3, 4]],
1076+
[[5, 6], [float('-nan'), 8]]])
1077+
out6 = paddle.nanmean(y, axis=[1, 2])
1078+
# [2.66666675, 6.33333349]
1079+
out7 = paddle.nanmean(y, axis=[0, 1])
1080+
# [3., 6.]
1081+
"""
1082+
if isinstance(axis, int):
1083+
axis = [axis]
1084+
check_variable_and_dtype(x, 'x/input',
1085+
['uint16', 'float16', 'float32', 'float64'],
1086+
'nanmean' )
1087+
if axis is not None:
1088+
check_type(axis, 'axis/dim', (int, list, tuple), 'nanmean')
1089+
1090+
cnt = paddle.sum(~paddle.isnan(x), axis = axis,keepdim=keepdim)
1091+
return paddle.divide(paddle.nansum(x, axis=axis, keepdim=keepdim, name=name), cnt.astype(x.dtype))
1092+
1093+
10271094
@templatedoc(op_type="sum")
10281095
def add_n(inputs, name=None):
10291096
"""
@@ -3941,6 +4008,7 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
39414008
else:
39424009
out = elementwise_sub(input_back, input_front, axis=axis)
39434010
return out
4011+
39444012
else:
39454013
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'bool', 'int32', 'int64'], 'diff')
39464014
check_type(axis, 'axis', (int), 'diff')

0 commit comments

Comments
 (0)