Skip to content

Commit 188960f

Browse files
committed
fix elementwise_add_ and elementwise_sub_ broadcast problem
1 parent ee118ec commit 188960f

File tree

4 files changed

+62
-6
lines changed

4 files changed

+62
-6
lines changed

python/paddle/fluid/tests/unittests/test_elementwise_add_op.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,32 @@ def _executed_api(self):
449449
self.add = paddle.add_
450450

451451

452+
class TestAddInplaceBroadcast(unittest.TestCase):
453+
def test_broadcast_success(self):
454+
paddle.disable_static()
455+
x_numpy = np.random.rand(2, 3, 4).astype('float')
456+
y_numpy = np.random.rand(3, 4).astype('float')
457+
458+
x = paddle.to_tensor(x_numpy)
459+
y = paddle.to_tensor(y_numpy)
460+
461+
inplace_result = paddle.add_(x, y)
462+
numpy_result = x_numpy + y_numpy
463+
self.assertEqual((inplace_result.numpy() == numpy_result).all(), True)
464+
paddle.enable_static()
465+
466+
def test_broadcast_errors(self):
467+
paddle.disable_static()
468+
x = paddle.rand([3, 4])
469+
y = paddle.rand([2, 3, 4])
470+
471+
def broadcast_shape_error():
472+
paddle.add_(x, y)
473+
474+
self.assertRaises(ValueError, broadcast_shape_error)
475+
paddle.enable_static()
476+
477+
452478
class TestComplexElementwiseAddOp(OpTest):
453479
def setUp(self):
454480
self.op_type = "elementwise_add"

python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,32 @@ def _executed_api(self):
289289
self.subtract = paddle.subtract_
290290

291291

292+
class TestSubtractInplaceBroadcast(unittest.TestCase):
293+
def test_broadcast_success(self):
294+
paddle.disable_static()
295+
x_numpy = np.random.rand(2, 3, 4).astype('float')
296+
y_numpy = np.random.rand(3, 4).astype('float')
297+
298+
x = paddle.to_tensor(x_numpy)
299+
y = paddle.to_tensor(y_numpy)
300+
301+
inplace_result = paddle.subtract_(x, y)
302+
numpy_result = x_numpy - y_numpy
303+
self.assertEqual((inplace_result.numpy() == numpy_result).all(), True)
304+
paddle.enable_static()
305+
306+
def test_broadcast_errors(self):
307+
paddle.disable_static()
308+
x = paddle.rand([3, 4])
309+
y = paddle.rand([2, 3, 4])
310+
311+
def broadcast_shape_error():
312+
paddle.subtract_(x, y)
313+
314+
self.assertRaises(ValueError, broadcast_shape_error)
315+
paddle.enable_static()
316+
317+
292318
if __name__ == '__main__':
293319
paddle.enable_static()
294320
unittest.main()

python/paddle/tensor/math.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,12 @@ def add_(x, y, name=None):
321321
if in_dygraph_mode():
322322
op_type = 'elementwise_add_'
323323
axis = -1
324-
return _elementwise_op_in_dygraph(
324+
inplace_shape = x.shape
325+
out = _elementwise_op_in_dygraph(
325326
x, y, axis=axis, op_name=op_type)
327+
if inplace_shape != out.shape:
328+
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out.shape, inplace_shape))
329+
return out
326330
_print_warning_in_static_mode("elementwise_add")
327331
return _elementwise_op(LayerHelper('elementwise_add', **locals()))
328332

@@ -396,8 +400,12 @@ def subtract_(x, y, name=None):
396400
if in_dygraph_mode():
397401
axis = -1
398402
act = None
399-
return _elementwise_op_in_dygraph(
403+
inplace_shape = x.shape
404+
out = _elementwise_op_in_dygraph(
400405
x, y, axis=axis, act=act, op_name='elementwise_sub_')
406+
if inplace_shape != out.shape:
407+
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out.shape, inplace_shape))
408+
return out
401409
_print_warning_in_static_mode("elementwise_sub")
402410
return _elementwise_op(LayerHelper('elementwise_sub', **locals()))
403411

tools/wlist.json

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@
7878
"name":"clip_",
7979
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
8080
},
81-
{
82-
"name":"log_",
83-
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
84-
},
8581
{
8682
"name":"scale_",
8783
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"

0 commit comments

Comments
 (0)