Skip to content

Commit 0698f48

Browse files
committed
elementwise inplace api give error message before run the op
1 parent 188960f commit 0698f48

File tree

3 files changed

+118
-22
lines changed

3 files changed

+118
-22
lines changed

python/paddle/fluid/tests/unittests/test_elementwise_add_op.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -449,24 +449,47 @@ def _executed_api(self):
449449
self.add = paddle.add_
450450

451451

452-
class TestAddInplaceBroadcast(unittest.TestCase):
452+
class TestAddInplaceBroadcastSuccess(unittest.TestCase):
453+
def init_data(self):
454+
self.x_numpy = np.random.rand(2, 3, 4).astype('float')
455+
self.y_numpy = np.random.rand(3, 4).astype('float')
456+
453457
def test_broadcast_success(self):
454458
paddle.disable_static()
455-
x_numpy = np.random.rand(2, 3, 4).astype('float')
456-
y_numpy = np.random.rand(3, 4).astype('float')
459+
self.init_data()
457460

458-
x = paddle.to_tensor(x_numpy)
459-
y = paddle.to_tensor(y_numpy)
461+
x = paddle.to_tensor(self.x_numpy)
462+
y = paddle.to_tensor(self.y_numpy)
460463

461464
inplace_result = paddle.add_(x, y)
462-
numpy_result = x_numpy + y_numpy
465+
numpy_result = self.x_numpy + self.y_numpy
463466
self.assertEqual((inplace_result.numpy() == numpy_result).all(), True)
464467
paddle.enable_static()
465468

469+
470+
class TestAddInplaceBroadcastSuccess2(TestAddInplaceBroadcastSuccess):
471+
def init_data(self):
472+
self.x_numpy = np.random.rand(1, 2, 3, 1).astype('float')
473+
self.y_numpy = np.random.rand(3, 1).astype('float')
474+
475+
476+
class TestAddInplaceBroadcastSuccess3(TestAddInplaceBroadcastSuccess):
477+
def init_data(self):
478+
self.x_numpy = np.random.rand(2, 3, 1, 5).astype('float')
479+
self.y_numpy = np.random.rand(1, 3, 1, 5).astype('float')
480+
481+
482+
class TestAddInplaceBroadcastError(unittest.TestCase):
483+
def init_data(self):
484+
self.x_numpy = np.random.rand(3, 4).astype('float')
485+
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
486+
466487
def test_broadcast_errors(self):
467488
paddle.disable_static()
468-
x = paddle.rand([3, 4])
469-
y = paddle.rand([2, 3, 4])
489+
self.init_data()
490+
491+
x = paddle.to_tensor(self.x_numpy)
492+
y = paddle.to_tensor(self.y_numpy)
470493

471494
def broadcast_shape_error():
472495
paddle.add_(x, y)
@@ -475,6 +498,18 @@ def broadcast_shape_error():
475498
paddle.enable_static()
476499

477500

501+
class TestAddInplaceBroadcastError2(TestAddInplaceBroadcastError):
502+
def init_data(self):
503+
self.x_numpy = np.random.rand(2, 1, 4).astype('float')
504+
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
505+
506+
507+
class TestAddInplaceBroadcastError3(TestAddInplaceBroadcastError):
508+
def init_data(self):
509+
self.x_numpy = np.random.rand(5, 2, 1, 4).astype('float')
510+
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
511+
512+
478513
class TestComplexElementwiseAddOp(OpTest):
479514
def setUp(self):
480515
self.op_type = "elementwise_add"

python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -289,24 +289,47 @@ def _executed_api(self):
289289
self.subtract = paddle.subtract_
290290

291291

292-
class TestSubtractInplaceBroadcast(unittest.TestCase):
292+
class TestSubtractInplaceBroadcastSuccess(unittest.TestCase):
293+
def init_data(self):
294+
self.x_numpy = np.random.rand(2, 3, 4).astype('float')
295+
self.y_numpy = np.random.rand(3, 4).astype('float')
296+
293297
def test_broadcast_success(self):
294298
paddle.disable_static()
295-
x_numpy = np.random.rand(2, 3, 4).astype('float')
296-
y_numpy = np.random.rand(3, 4).astype('float')
299+
self.init_data()
297300

298-
x = paddle.to_tensor(x_numpy)
299-
y = paddle.to_tensor(y_numpy)
301+
x = paddle.to_tensor(self.x_numpy)
302+
y = paddle.to_tensor(self.y_numpy)
300303

301304
inplace_result = paddle.subtract_(x, y)
302-
numpy_result = x_numpy - y_numpy
305+
numpy_result = self.x_numpy - self.y_numpy
303306
self.assertEqual((inplace_result.numpy() == numpy_result).all(), True)
304307
paddle.enable_static()
305308

309+
310+
class TestSubtractInplaceBroadcastSuccess2(TestSubtractInplaceBroadcastSuccess):
311+
def init_data(self):
312+
self.x_numpy = np.random.rand(1, 2, 3, 1).astype('float')
313+
self.y_numpy = np.random.rand(3, 1).astype('float')
314+
315+
316+
class TestSubtractInplaceBroadcastSuccess3(TestSubtractInplaceBroadcastSuccess):
317+
def init_data(self):
318+
self.x_numpy = np.random.rand(2, 3, 1, 5).astype('float')
319+
self.y_numpy = np.random.rand(1, 3, 1, 5).astype('float')
320+
321+
322+
class TestSubtractInplaceBroadcastError(unittest.TestCase):
323+
def init_data(self):
324+
self.x_numpy = np.random.rand(3, 4).astype('float')
325+
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
326+
306327
def test_broadcast_errors(self):
307328
paddle.disable_static()
308-
x = paddle.rand([3, 4])
309-
y = paddle.rand([2, 3, 4])
329+
self.init_data()
330+
331+
x = paddle.to_tensor(self.x_numpy)
332+
y = paddle.to_tensor(self.y_numpy)
310333

311334
def broadcast_shape_error():
312335
paddle.subtract_(x, y)
@@ -315,6 +338,18 @@ def broadcast_shape_error():
315338
paddle.enable_static()
316339

317340

341+
class TestSubtractInplaceBroadcastError2(TestSubtractInplaceBroadcastError):
342+
def init_data(self):
343+
self.x_numpy = np.random.rand(2, 1, 4).astype('float')
344+
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
345+
346+
347+
class TestSubtractInplaceBroadcastError3(TestSubtractInplaceBroadcastError):
348+
def init_data(self):
349+
self.x_numpy = np.random.rand(5, 2, 1, 4).astype('float')
350+
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
351+
352+
318353
if __name__ == '__main__':
319354
paddle.enable_static()
320355
unittest.main()

python/paddle/tensor/math.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,24 @@ def add_(x, y, name=None):
321321
if in_dygraph_mode():
322322
op_type = 'elementwise_add_'
323323
axis = -1
324-
inplace_shape = x.shape
324+
x_shape = x.shape
325+
y_shape = y.shape
326+
len_diff = len(x_shape) - len(y_shape)
327+
shape_valid = True
328+
if len_diff < 0:
329+
shape_valid = False
330+
for i in range(len(y_shape)):
331+
if x_shape[i+len_diff] == 1 and y_shape[i] != 1:
332+
shape_valid = False
333+
break
334+
335+
if not shape_valid:
336+
raise ValueError("The shape of inplace tensor {} should not be changed in the Inplace operation.".format(x.shape))
337+
325338
out = _elementwise_op_in_dygraph(
326339
x, y, axis=axis, op_name=op_type)
327-
if inplace_shape != out.shape:
328-
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out.shape, inplace_shape))
340+
# if inplace_shape != out.shape:
341+
# raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out.shape, inplace_shape))
329342
return out
330343
_print_warning_in_static_mode("elementwise_add")
331344
return _elementwise_op(LayerHelper('elementwise_add', **locals()))
@@ -400,11 +413,24 @@ def subtract_(x, y, name=None):
400413
if in_dygraph_mode():
401414
axis = -1
402415
act = None
403-
inplace_shape = x.shape
416+
x_shape = x.shape
417+
y_shape = y.shape
418+
len_diff = len(x_shape) - len(y_shape)
419+
shape_valid = True
420+
if len_diff < 0:
421+
shape_valid = False
422+
for i in range(len(y_shape)):
423+
if x_shape[i+len_diff] == 1 and y_shape[i] != 1:
424+
shape_valid = False
425+
break
426+
427+
if not shape_valid:
428+
raise ValueError("The shape of inplace tensor {} should not be changed in the Inplace operation.".format(x.shape))
429+
404430
out = _elementwise_op_in_dygraph(
405431
x, y, axis=axis, act=act, op_name='elementwise_sub_')
406-
if inplace_shape != out.shape:
407-
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out.shape, inplace_shape))
432+
# if inplace_shape != out.shape:
433+
# raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out.shape, inplace_shape))
408434
return out
409435
_print_warning_in_static_mode("elementwise_sub")
410436
return _elementwise_op(LayerHelper('elementwise_sub', **locals()))

0 commit comments

Comments
 (0)