Skip to content

Commit a2cec25

Browse files
committed
use broadcast_shape in elementwise inplace op
1 parent 0698f48 commit a2cec25

File tree

1 file changed

+8
-30
lines changed

1 file changed

+8
-30
lines changed

python/paddle/tensor/math.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -321,24 +321,13 @@ def add_(x, y, name=None):
321321
if in_dygraph_mode():
322322
op_type = 'elementwise_add_'
323323
axis = -1
324-
x_shape = x.shape
325-
y_shape = y.shape
326-
len_diff = len(x_shape) - len(y_shape)
327-
shape_valid = True
328-
if len_diff < 0:
329-
shape_valid = False
330-
for i in range(len(y_shape)):
331-
if x_shape[i+len_diff] == 1 and y_shape[i] != 1:
332-
shape_valid = False
333-
break
334-
335-
if not shape_valid:
336-
raise ValueError("The shape of inplace tensor {} should not be changed in the Inplace operation.".format(x.shape))
324+
325+
out_shape = broadcast_shape(x.shape, y.shape)
326+
if out_shape != x.shape:
327+
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
337328

338329
out = _elementwise_op_in_dygraph(
339330
x, y, axis=axis, op_name=op_type)
340-
# if inplace_shape != out.shape:
341-
# raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out.shape, inplace_shape))
342331
return out
343332
_print_warning_in_static_mode("elementwise_add")
344333
return _elementwise_op(LayerHelper('elementwise_add', **locals()))
@@ -413,24 +402,13 @@ def subtract_(x, y, name=None):
413402
if in_dygraph_mode():
414403
axis = -1
415404
act = None
416-
x_shape = x.shape
417-
y_shape = y.shape
418-
len_diff = len(x_shape) - len(y_shape)
419-
shape_valid = True
420-
if len_diff < 0:
421-
shape_valid = False
422-
for i in range(len(y_shape)):
423-
if x_shape[i+len_diff] == 1 and y_shape[i] != 1:
424-
shape_valid = False
425-
break
426-
427-
if not shape_valid:
428-
raise ValueError("The shape of inplace tensor {} should not be changed in the Inplace operation.".format(x.shape))
405+
406+
out_shape = broadcast_shape(x.shape, y.shape)
407+
if out_shape != x.shape:
408+
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
429409

430410
out = _elementwise_op_in_dygraph(
431411
x, y, axis=axis, act=act, op_name='elementwise_sub_')
432-
# if inplace_shape != out.shape:
433-
# raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out.shape, inplace_shape))
434412
return out
435413
_print_warning_in_static_mode("elementwise_sub")
436414
return _elementwise_op(LayerHelper('elementwise_sub', **locals()))

0 commit comments

Comments
 (0)