Skip to content

Commit

Permalink
add ut for value zero
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyouzhi committed Dec 21, 2023
1 parent b93b796 commit fa1d0bd
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4320,21 +4320,23 @@ def setUp(self):
self.op_type = "thresholded_relu"
self.init_dtype()
self.init_shape()
self.set_attrs()
self.python_api = paddle.nn.functional.thresholded_relu

threshold = 15
value = 5

np.random.seed(1024)
x = np.random.uniform(-20, 20, self.shape).astype(self.dtype)
x[np.abs(x) < 0.005] = 0.02
out = ref_thresholded_relu(x, threshold, value)
out = ref_thresholded_relu(x, self.threshold, self.value)

self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"threshold": threshold, "value": value}
self.attrs = {"threshold": self.threshold, "value": self.value}
self.convert_input_output()

def set_attrs(self):
self.threshold = 15
self.value = 5

def init_shape(self):
self.shape = [10, 12]

Expand All @@ -4347,6 +4349,12 @@ def test_check_output(self):
self.check_output(check_pir=True)


class TestThresholdedRelu_ZeroValue(TestThresholdedRelu):
def set_attrs(self):
self.threshold = 15
self.value = 0


class TestThresholdedRelu_ZeroDim(TestThresholdedRelu):
def init_shape(self):
self.shape = []
Expand All @@ -4355,8 +4363,7 @@ def init_shape(self):
class TestThresholdedReluAPI(unittest.TestCase):
# test paddle.nn.ThresholdedReLU, paddle.nn.functional.thresholded_relu
def setUp(self):
self.threshold = 15
self.value = 5
self.set_attrs()
np.random.seed(1024)
self.x_np = np.random.uniform(-20, 20, [10, 12]).astype(np.float64)
self.x_np[np.abs(self.x_np) < 0.005] = 0.02
Expand All @@ -4366,6 +4373,10 @@ def setUp(self):
else paddle.CPUPlace()
)

def set_attrs(self):
self.threshold = 15
self.value = 5

@test_with_pir_api
def test_static_api(self):
with static_guard():
Expand Down Expand Up @@ -4415,6 +4426,12 @@ def test_errors(self):
F.thresholded_relu(x_fp16)


class TestThresholdedReluAPI_ZeroValue(TestThresholdedReluAPI):
def set_attrs(self):
self.threshold = 15
self.value = 0


def ref_hardsigmoid(x, slope=0.166666666666667, offset=0.5):
return np.maximum(np.minimum(x * slope + offset, 1.0), 0.0).astype(x.dtype)

Expand Down

0 comments on commit fa1d0bd

Please sign in to comment.