Skip to content

Commit

Permalink
[PIR] D-18 Adapt randint test_errors (#62807)
Browse files Browse the repository at this point in the history
  • Loading branch information
ooooo-create authored Mar 19, 2024
1 parent e14e0cc commit 6e6b0cb
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 22 deletions.
16 changes: 11 additions & 5 deletions python/paddle/base/data_feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,18 +229,22 @@ def check_dtype(
def check_shape(
shape,
op_name,
expected_shape_type=(list, tuple, Variable),
expected_element_type=(int, Variable),
expected_shape_type=(list, tuple, Variable, Value),
expected_element_type=(int, Variable, Value),
expected_tensor_dtype=('int32', 'int64'),
):
# See NOTE [ Why skip dynamic graph check ]
if in_dygraph_mode():
return
check_type(shape, 'shape', expected_shape_type, op_name)
if expected_element_type is not None and not isinstance(shape, Variable):
if expected_element_type is not None and not isinstance(
shape, (Variable, Value)
):
for item in shape:
check_type(item, 'element of shape', expected_element_type, op_name)
if expected_tensor_dtype is not None and isinstance(item, Variable):
if expected_tensor_dtype is not None and isinstance(
item, (Variable, Value)
):
check_dtype(
item.dtype,
'element of shape',
Expand All @@ -250,7 +254,9 @@ def check_shape(
', '.join(expected_tensor_dtype)
),
)
if expected_tensor_dtype is not None and isinstance(shape, Variable):
if expected_tensor_dtype is not None and isinstance(
shape, (Variable, Value)
):
check_dtype(shape.dtype, 'shape', expected_tensor_dtype, op_name)


Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
low, high, shape, dtype, _current_expected_place()
)
elif in_pir_mode():
check_type(shape, 'shape', (list, tuple, paddle.pir.Value), 'randint')
check_shape(shape, 'randint')
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint')
if paddle.utils._contain_var(shape):
shape = paddle.utils.get_int_tensor_list(
Expand Down
31 changes: 15 additions & 16 deletions test/legacy_test/test_randint_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
from op_test import OpTest

import paddle
from paddle import base
from paddle.base import core
from paddle.static import Program, program_guard
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

Expand Down Expand Up @@ -54,8 +53,11 @@ def verify_output(self, outs):


class TestRandintOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
self.assertRaises(TypeError, paddle.randint, 5, shape=np.array([2]))
self.assertRaises(TypeError, paddle.randint, 5, dtype='float32')
self.assertRaises(ValueError, paddle.randint, 5, 5)
Expand All @@ -67,14 +69,6 @@ def test_errors(self):
TypeError, paddle.randint, 5, shape=[shape_tensor]
)

def test_pir_error(self):
with paddle.pir_utils.IrGuard():
self.assertRaises(TypeError, paddle.randint, 5, shape=np.array([2]))
self.assertRaises(TypeError, paddle.randint, 5, dtype='float32')
self.assertRaises(ValueError, paddle.randint, 5, 5)
self.assertRaises(ValueError, paddle.randint, -5)
self.assertRaises(TypeError, paddle.randint, 5, shape=['2'])


class TestRandintOp_attr_tensorlist(OpTest):
def setUp(self):
Expand Down Expand Up @@ -125,7 +119,9 @@ def verify_output(self, outs):
# Test python API
class TestRandintAPI(unittest.TestCase):
def test_api(self):
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
# results are from [0, 5).
out1 = paddle.randint(5)
# shape is a list and dtype is 'int32'
Expand Down Expand Up @@ -229,17 +225,20 @@ def test_dygraph(self):
self.assertEqual(x.shape, [])
paddle.enable_static()

@test_with_pir_api
def test_static(self):
with base.program_guard(base.Program(), base.Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.randint(-10, 10, [])

# Test compile shape
self.assertEqual(x.shape, ())
self.assertEqual(tuple(x.shape), ())

# Test runtime shape
exe = base.Executor()
exe = paddle.static.Executor()
result = exe.run(fetch_list=[x])
self.assertEqual(result[0].shape, ())
self.assertEqual(tuple(result[0].shape), ())

paddle.enable_static()

Expand Down

0 comments on commit 6e6b0cb

Please sign in to comment.