Skip to content

Commit c488c59

Browse files
committed
fix paddle.nn.loss.L1Loss OP, add paddle.nn.functional.l1_loss OP for API2.0, test=develop
1 parent 0cb60c7 commit c488c59

File tree

4 files changed

+295
-159
lines changed

4 files changed

+295
-159
lines changed

python/paddle/fluid/tests/unittests/test_l1_loss.py

Lines changed: 149 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -20,111 +20,161 @@
2020
import unittest
2121

2222

23-
class TestL1Loss(unittest.TestCase):
24-
def test_L1Loss_mean(self):
25-
input_np = np.random.random(size=(10, 1)).astype(np.float32)
26-
label_np = np.random.random(size=(10, 1)).astype(np.float32)
27-
prog = fluid.Program()
28-
startup_prog = fluid.Program()
29-
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
30-
) else fluid.CPUPlace()
31-
with fluid.program_guard(prog, startup_prog):
32-
input = fluid.layers.data(
33-
name='input', shape=[10, 1], dtype='float32')
34-
label = fluid.layers.data(
35-
name='label', shape=[10, 1], dtype='float32')
36-
l1_loss = paddle.nn.loss.L1Loss()
37-
ret = l1_loss(input, label)
38-
39-
exe = fluid.Executor(place)
40-
static_result = exe.run(
41-
prog,
42-
feed={"input": input_np,
43-
"label": label_np},
44-
fetch_list=[ret])
45-
46-
with fluid.dygraph.guard():
47-
l1_loss = paddle.nn.loss.L1Loss()
48-
dy_ret = l1_loss(
49-
fluid.dygraph.to_variable(input_np),
50-
fluid.dygraph.to_variable(label_np))
51-
dy_result = dy_ret.numpy()
52-
53-
expected = np.mean(np.abs(input_np - label_np))
54-
self.assertTrue(np.allclose(static_result, expected))
55-
self.assertTrue(np.allclose(static_result, dy_result))
56-
self.assertTrue(np.allclose(dy_result, expected))
23+
class TestFunctionalL1Loss(unittest.TestCase):
24+
def setUp(self):
25+
self.input_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
26+
self.label_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
27+
28+
def run_imperative(self):
29+
input = paddle.imperative.to_variable(self.input_np)
30+
label = paddle.imperative.to_variable(self.label_np)
31+
dy_result = paddle.nn.functional.l1_loss(input, label)
32+
expected = np.mean(np.abs(self.input_np - self.label_np))
33+
self.assertTrue(np.allclose(dy_result.numpy(), expected))
5734
self.assertTrue(dy_result.shape, [1])
5835

59-
def test_L1Loss_sum(self):
60-
input_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
61-
label_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
62-
prog = fluid.Program()
63-
startup_prog = fluid.Program()
64-
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
65-
) else fluid.CPUPlace()
66-
with fluid.program_guard(prog, startup_prog):
67-
input = fluid.layers.data(
36+
dy_result = paddle.nn.functional.l1_loss(input, label, reduction='sum')
37+
expected = np.sum(np.abs(self.input_np - self.label_np))
38+
self.assertTrue(np.allclose(dy_result.numpy(), expected))
39+
self.assertTrue(dy_result.shape, [1])
40+
41+
dy_result = paddle.nn.functional.l1_loss(input, label, reduction='none')
42+
expected = np.abs(self.input_np - self.label_np)
43+
self.assertTrue(np.allclose(dy_result.numpy(), expected))
44+
self.assertTrue(dy_result.shape, [10, 10, 5])
45+
46+
def run_static(self, use_gpu=False):
47+
input = paddle.data(name='input', shape=[10, 10, 5], dtype='float32')
48+
label = paddle.data(name='label', shape=[10, 10, 5], dtype='float32')
49+
result0 = paddle.nn.functional.l1_loss(input, label)
50+
result1 = paddle.nn.functional.l1_loss(input, label, reduction='sum')
51+
result2 = paddle.nn.functional.l1_loss(input, label, reduction='none')
52+
y = paddle.nn.functional.l1_loss(input, label, name='aaa')
53+
54+
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
55+
exe = fluid.Executor(place)
56+
exe.run(fluid.default_startup_program())
57+
static_result = exe.run(
58+
feed={"input": self.input_np,
59+
"label": self.label_np},
60+
fetch_list=[result0, result1, result2])
61+
62+
expected = np.mean(np.abs(self.input_np - self.label_np))
63+
self.assertTrue(np.allclose(static_result[0], expected))
64+
expected = np.sum(np.abs(self.input_np - self.label_np))
65+
self.assertTrue(np.allclose(static_result[1], expected))
66+
expected = np.abs(self.input_np - self.label_np)
67+
self.assertTrue(np.allclose(static_result[2], expected))
68+
69+
self.assertTrue('aaa' in y.name)
70+
71+
def test_cpu(self):
72+
with paddle.imperative.guard(paddle.fluid.CPUPlace()):
73+
self.run_imperative()
74+
75+
with fluid.program_guard(fluid.Program()):
76+
self.run_static()
77+
78+
def test_gpu(self):
79+
if not fluid.core.is_compiled_with_cuda():
80+
return
81+
82+
with paddle.imperative.guard(paddle.fluid.CUDAPlace(0)):
83+
self.run_imperative()
84+
85+
with fluid.program_guard(fluid.Program()):
86+
self.run_static(use_gpu=True)
87+
88+
# test case the raise message
89+
def test_errors(self):
90+
def test_value_error():
91+
input = paddle.data(
6892
name='input', shape=[10, 10, 5], dtype='float32')
69-
label = fluid.layers.data(
93+
label = paddle.data(
7094
name='label', shape=[10, 10, 5], dtype='float32')
71-
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
72-
ret = l1_loss(input, label)
73-
74-
exe = fluid.Executor(place)
75-
static_result = exe.run(
76-
prog,
77-
feed={"input": input_np,
78-
"label": label_np},
79-
fetch_list=[ret])
80-
81-
with fluid.dygraph.guard():
82-
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
83-
dy_ret = l1_loss(
84-
fluid.dygraph.to_variable(input_np),
85-
fluid.dygraph.to_variable(label_np))
86-
dy_result = dy_ret.numpy()
87-
88-
expected = np.sum(np.abs(input_np - label_np))
89-
self.assertTrue(np.allclose(static_result, expected))
90-
self.assertTrue(np.allclose(static_result, dy_result))
91-
self.assertTrue(np.allclose(dy_result, expected))
95+
loss = paddle.nn.functional.l1_loss(
96+
input, label, reduction='reduce_mean')
97+
98+
self.assertRaises(ValueError, test_value_error)
99+
100+
101+
class TestClassL1Loss(unittest.TestCase):
102+
def setUp(self):
103+
self.input_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
104+
self.label_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
105+
106+
def run_imperative(self):
107+
input = paddle.imperative.to_variable(self.input_np)
108+
label = paddle.imperative.to_variable(self.label_np)
109+
l1_loss = paddle.nn.loss.L1Loss()
110+
dy_result = l1_loss(input, label)
111+
expected = np.mean(np.abs(self.input_np - self.label_np))
112+
self.assertTrue(np.allclose(dy_result.numpy(), expected))
113+
self.assertTrue(dy_result.shape, [1])
114+
115+
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
116+
dy_result = l1_loss(input, label)
117+
expected = np.sum(np.abs(self.input_np - self.label_np))
118+
self.assertTrue(np.allclose(dy_result.numpy(), expected))
92119
self.assertTrue(dy_result.shape, [1])
93120

94-
def test_L1Loss_none(self):
95-
input_np = np.random.random(size=(10, 5)).astype(np.float32)
96-
label_np = np.random.random(size=(10, 5)).astype(np.float32)
97-
prog = fluid.Program()
98-
startup_prog = fluid.Program()
99-
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
100-
) else fluid.CPUPlace()
101-
with fluid.program_guard(prog, startup_prog):
102-
input = fluid.layers.data(
103-
name='input', shape=[10, 5], dtype='float32')
104-
label = fluid.layers.data(
105-
name='label', shape=[10, 5], dtype='float32')
106-
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
107-
ret = l1_loss(input, label)
108-
109-
exe = fluid.Executor(place)
110-
static_result = exe.run(
111-
prog,
112-
feed={"input": input_np,
113-
"label": label_np},
114-
fetch_list=[ret])
115-
116-
with fluid.dygraph.guard():
117-
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
118-
dy_ret = l1_loss(
119-
fluid.dygraph.to_variable(input_np),
120-
fluid.dygraph.to_variable(label_np))
121-
dy_result = dy_ret.numpy()
122-
123-
expected = np.abs(input_np - label_np)
124-
self.assertTrue(np.allclose(static_result, expected))
125-
self.assertTrue(np.allclose(static_result, dy_result))
126-
self.assertTrue(np.allclose(dy_result, expected))
127-
self.assertTrue(dy_result.shape, input.shape)
121+
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
122+
dy_result = l1_loss(input, label)
123+
expected = np.abs(self.input_np - self.label_np)
124+
self.assertTrue(np.allclose(dy_result.numpy(), expected))
125+
self.assertTrue(dy_result.shape, [10, 10, 5])
126+
127+
def run_static(self, use_gpu=False):
128+
input = paddle.data(name='input', shape=[10, 10, 5], dtype='float32')
129+
label = paddle.data(name='label', shape=[10, 10, 5], dtype='float32')
130+
l1_loss = paddle.nn.loss.L1Loss()
131+
result0 = l1_loss(input, label)
132+
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
133+
result1 = l1_loss(input, label)
134+
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
135+
result2 = l1_loss(input, label)
136+
l1_loss = paddle.nn.loss.L1Loss(name='aaa')
137+
result3 = l1_loss(input, label)
138+
139+
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
140+
exe = fluid.Executor(place)
141+
exe.run(fluid.default_startup_program())
142+
static_result = exe.run(
143+
feed={"input": self.input_np,
144+
"label": self.label_np},
145+
fetch_list=[result0, result1, result2])
146+
147+
expected = np.mean(np.abs(self.input_np - self.label_np))
148+
self.assertTrue(np.allclose(static_result[0], expected))
149+
expected = np.sum(np.abs(self.input_np - self.label_np))
150+
self.assertTrue(np.allclose(static_result[1], expected))
151+
expected = np.abs(self.input_np - self.label_np)
152+
self.assertTrue(np.allclose(static_result[2], expected))
153+
self.assertTrue('aaa' in result3.name)
154+
155+
def test_cpu(self):
156+
with paddle.imperative.guard(paddle.fluid.CPUPlace()):
157+
self.run_imperative()
158+
159+
with fluid.program_guard(fluid.Program()):
160+
self.run_static()
161+
162+
def test_gpu(self):
163+
if not fluid.core.is_compiled_with_cuda():
164+
return
165+
166+
with paddle.imperative.guard(paddle.fluid.CUDAPlace(0)):
167+
self.run_imperative()
168+
169+
with fluid.program_guard(fluid.Program()):
170+
self.run_static(use_gpu=True)
171+
172+
# test case the raise message
173+
def test_errors(self):
174+
def test_value_error():
175+
loss = paddle.nn.loss.L1Loss(reduction="reduce_mean")
176+
177+
self.assertRaises(ValueError, test_value_error)
128178

129179

130180
if __name__ == "__main__":

python/paddle/nn/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
from .loss import huber_loss #DEFINE_ALIAS
128128
from .loss import iou_similarity #DEFINE_ALIAS
129129
from .loss import kldiv_loss #DEFINE_ALIAS
130+
from .loss import l1_loss #DEFINE_ALIAS
130131
from .loss import log_loss #DEFINE_ALIAS
131132
from .loss import margin_rank_loss #DEFINE_ALIAS
132133
from .loss import mse_loss #DEFINE_ALIAS

python/paddle/nn/functional/loss.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
# limitations under the License.
1414

1515
# TODO: define loss functions of neural network
16+
import paddle
17+
import paddle.fluid as fluid
18+
from ...fluid.framework import core, in_dygraph_mode
19+
from ...fluid.layers.nn import _elementwise_op_in_dygraph
1620
from ...fluid.layers import bpr_loss #DEFINE_ALIAS
1721
from ...fluid.layers import center_loss #DEFINE_ALIAS
1822
from ...fluid.layers import cross_entropy #DEFINE_ALIAS
@@ -45,6 +49,7 @@
4549
'huber_loss',
4650
'iou_similarity',
4751
'kldiv_loss',
52+
'l1_loss',
4853
'log_loss',
4954
'margin_rank_loss',
5055
'mse_loss',
@@ -60,3 +65,93 @@
6065
'ssd_loss',
6166
'teacher_student_sigmoid_loss'
6267
]
68+
69+
70+
def l1_loss(x, label, reduction='mean', name=None):
71+
"""
72+
This operator computes the L1 Loss of Tensor ``x`` and ``label`` as follows.
73+
74+
If :attr:`reduction` set to ``'none'``, the loss is:
75+
76+
.. math::
77+
Out = \lvert x - label\rvert
78+
79+
If :attr:`reduction` set to ``'mean'``, the loss is:
80+
81+
.. math::
82+
Out = MEAN(\lvert x - label\rvert)
83+
84+
If :attr:`reduction` set to ``'sum'``, the loss is:
85+
86+
.. math::
87+
Out = SUM(\lvert x - label\rvert)
88+
89+
90+
Parameters:
91+
x (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
92+
label (Tensor): label. The shapes is [N, *], same shape as ``x`` . It's data type should be float32, float64, int32, int64.
93+
reduction (str, optional): Indicate the reduction to apply to the loss,
94+
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
95+
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
96+
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
97+
If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
98+
Default is ``'mean'``.
99+
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
100+
Returns:
101+
Tensor, the L1 Loss of Tensor ``x`` and ``label``.
102+
If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``x`` .
103+
If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar.
104+
Examples:
105+
.. code-block:: python
106+
import paddle
107+
import numpy as np
108+
from paddle.imperative import to_variable
109+
110+
paddle.enable_imperative()
111+
x_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32")
112+
label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32")
113+
x = to_variable(x_data)
114+
label = to_variable(label_data)
115+
116+
l1_loss = paddle.nn.functional.l1_loss(x, label)
117+
print(l1_loss.numpy())
118+
# [0.35]
119+
120+
l1_loss = paddle.nn.functional.l1_loss(x, label, reduction='none')
121+
print(l1_loss.numpy())
122+
# [[0.20000005 0.19999999]
123+
# [0.2 0.79999995]]
124+
125+
l1_loss = paddle.nn.functional.l1_loss(x, label, reduction='sum')
126+
print(l1_loss.numpy())
127+
# [1.4]
128+
"""
129+
if reduction not in ['sum', 'mean', 'none']:
130+
raise ValueError(
131+
"The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
132+
"received %s, which is not allowed." % reduction)
133+
134+
if in_dygraph_mode():
135+
unreduced = _elementwise_op_in_dygraph(
136+
x, label, axis=-1, act='abs', op_name='elementwise_sub')
137+
if reduction == 'mean':
138+
return core.ops.mean(unreduced)
139+
elif reduction == 'sum':
140+
return core.ops.reduce_sum(unreduced, 'dim', [0], 'keep_dim', False,
141+
'reduce_all', True)
142+
else:
143+
return unreduced
144+
145+
fluid.data_feeder.check_variable_and_dtype(
146+
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
147+
fluid.data_feeder.check_variable_and_dtype(
148+
label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
149+
150+
if reduction == 'sum':
151+
unreduced = paddle.elementwise_sub(x, label, act='abs')
152+
return paddle.sum(unreduced, name=name)
153+
elif reduction == 'mean':
154+
unreduced = paddle.elementwise_sub(x, label, act='abs')
155+
return paddle.mean(unreduced, name=name)
156+
else:
157+
return paddle.elementwise_sub(x, label, act='abs', name=name)

0 commit comments

Comments
 (0)