You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: python/paddle/tensor/math.py
+8-30Lines changed: 8 additions & 30 deletions
Original file line number
Diff line number
Diff line change
@@ -321,24 +321,13 @@ def add_(x, y, name=None):
321
321
ifin_dygraph_mode():
322
322
op_type='elementwise_add_'
323
323
axis=-1
324
-
x_shape=x.shape
325
-
y_shape=y.shape
326
-
len_diff=len(x_shape) -len(y_shape)
327
-
shape_valid=True
328
-
iflen_diff<0:
329
-
shape_valid=False
330
-
foriinrange(len(y_shape)):
331
-
ifx_shape[i+len_diff] ==1andy_shape[i] !=1:
332
-
shape_valid=False
333
-
break
334
-
335
-
ifnotshape_valid:
336
-
raiseValueError("The shape of inplace tensor {} should not be changed in the Inplace operation.".format(x.shape))
324
+
325
+
out_shape=broadcast_shape(x.shape, y.shape)
326
+
ifout_shape!=x.shape:
327
+
raiseValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
337
328
338
329
out=_elementwise_op_in_dygraph(
339
330
x, y, axis=axis, op_name=op_type)
340
-
# if inplace_shape != out.shape:
341
-
# raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out.shape, inplace_shape))
@@ -413,24 +402,13 @@ def subtract_(x, y, name=None):
413
402
ifin_dygraph_mode():
414
403
axis=-1
415
404
act=None
416
-
x_shape=x.shape
417
-
y_shape=y.shape
418
-
len_diff=len(x_shape) -len(y_shape)
419
-
shape_valid=True
420
-
iflen_diff<0:
421
-
shape_valid=False
422
-
foriinrange(len(y_shape)):
423
-
ifx_shape[i+len_diff] ==1andy_shape[i] !=1:
424
-
shape_valid=False
425
-
break
426
-
427
-
ifnotshape_valid:
428
-
raiseValueError("The shape of inplace tensor {} should not be changed in the Inplace operation.".format(x.shape))
405
+
406
+
out_shape=broadcast_shape(x.shape, y.shape)
407
+
ifout_shape!=x.shape:
408
+
raiseValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
429
409
430
410
out=_elementwise_op_in_dygraph(
431
411
x, y, axis=axis, act=act, op_name='elementwise_sub_')
432
-
# if inplace_shape != out.shape:
433
-
# raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out.shape, inplace_shape))
0 commit comments