Skip to content

Commit 852a2de

Browse files
Add files via upload
1 parent b49501e commit 852a2de

File tree

1 file changed

+223
-0
lines changed

1 file changed

+223
-0
lines changed
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import copy
18+
import unittest
19+
import numpy as np
20+
from op_test import OpTest, skip_check_grad_ci
21+
22+
23+
def rrelu_np(x: np.array,
24+
lower_bound: float=0.125,
25+
upper_bound: float=0.3333,
26+
is_test: bool=False):
27+
"""
28+
29+
"""
30+
x = x.astype(np.float32)
31+
if is_test:
32+
middle_value = (lower_bound + upper_bound) / 2.0
33+
mask = copy.deepcopy(x)
34+
mask[x >= 0.0] = 1.0
35+
mask[x < 0.0] = middle_value
36+
else:
37+
x_shape = x.shape
38+
x = x.reshape(-1)
39+
mask = copy.deepcopy(x)
40+
for i in range(x.shape[0]):
41+
if x[i].item() >= 0.0:
42+
mask[i] = 1.0
43+
else:
44+
mask[i] = np.random.uniform(lower_bound, upper_bound)
45+
x = x.reshape(x_shape)
46+
mask = mask.reshape(x_shape)
47+
48+
out = x * mask
49+
return out, mask
50+
51+
52+
class TestRReLUOp(OpTest):
53+
def setUp(self):
54+
self.op_type = "rrelu"
55+
X = np.random.uniform(low=-100, high=10, size=(32, )).astype("float32")
56+
lower_bound = 0.0
57+
upper_bound = 0.5
58+
fix_seed = True
59+
seed = 100
60+
is_test = False
61+
self.inputs = {'X': X}
62+
self.attrs = {
63+
'lower_bound': lower_bound,
64+
'upper_bound': upper_bound,
65+
'fix_seed': fix_seed,
66+
'seed': seed,
67+
'is_test': is_test
68+
}
69+
np.random.seed(seed)
70+
Out, Mask = rrelu_np(
71+
x=X,
72+
lower_bound=lower_bound,
73+
upper_bound=upper_bound,
74+
is_test=is_test)
75+
self.outputs = {'Out': Out, 'Mask': Mask}
76+
77+
def test_check_output(self):
78+
self.check_output()
79+
80+
def test_check_grad_normal(self):
81+
self.check_grad(['X'], 'Out')
82+
83+
84+
class TestRReLUOp2(TestRReLUOp):
85+
def setUp(self):
86+
self.op_type = "rrelu"
87+
X = np.random.uniform(low=-100, high=10, size=(8, 16)).astype("float32")
88+
lower_bound = 0.4
89+
upper_bound = 0.99
90+
fix_seed = True
91+
seed = 3
92+
is_test = False
93+
self.inputs = {'X': X}
94+
self.attrs = {
95+
'lower_bound': lower_bound,
96+
'upper_bound': upper_bound,
97+
'fix_seed': fix_seed,
98+
'seed': seed,
99+
'is_test': is_test
100+
}
101+
np.random.seed(seed)
102+
Out, Mask = rrelu_np(
103+
x=X,
104+
lower_bound=lower_bound,
105+
upper_bound=upper_bound,
106+
is_test=is_test)
107+
self.outputs = {'Out': Out, 'Mask': Mask}
108+
109+
110+
class TestRReLUOp3(TestRReLUOp):
111+
def setUp(self):
112+
self.op_type = "rrelu"
113+
X = np.random.uniform(
114+
low=-100, high=10, size=(8, 16, 32)).astype("float32")
115+
lower_bound = 0.5
116+
upper_bound = 0.51
117+
fix_seed = True
118+
seed = 5
119+
is_test = False
120+
self.inputs = {'X': X}
121+
self.attrs = {
122+
'lower_bound': lower_bound,
123+
'upper_bound': upper_bound,
124+
'fix_seed': fix_seed,
125+
'seed': seed,
126+
'is_test': is_test
127+
}
128+
np.random.seed(seed)
129+
Out, Mask = rrelu_np(
130+
x=X,
131+
lower_bound=lower_bound,
132+
upper_bound=upper_bound,
133+
is_test=is_test)
134+
self.outputs = {'Out': Out, 'Mask': Mask}
135+
136+
137+
@skip_check_grad_ci(reason="For inference, check_grad is not required.")
138+
class TestRReLUOp4(OpTest):
139+
def setUp(self):
140+
self.op_type = "rrelu"
141+
X = np.random.uniform(low=-100, high=10, size=(32, )).astype("float32")
142+
lower_bound = 0.0
143+
upper_bound = 0.3
144+
fix_seed = True
145+
seed = 11
146+
is_test = True
147+
self.inputs = {'X': X}
148+
self.attrs = {
149+
'lower_bound': lower_bound,
150+
'upper_bound': upper_bound,
151+
'fix_seed': fix_seed,
152+
'seed': seed,
153+
'is_test': is_test
154+
}
155+
Out, Mask = rrelu_np(
156+
x=X,
157+
lower_bound=lower_bound,
158+
upper_bound=upper_bound,
159+
is_test=is_test)
160+
self.outputs = {'Out': Out}
161+
162+
def test_check_output(self):
163+
self.check_output()
164+
165+
166+
@skip_check_grad_ci(reason="For inference, check_grad is not required.")
167+
class TestRReLUOp5(OpTest):
168+
def setUp(self):
169+
self.op_type = "rrelu"
170+
X = np.random.uniform(
171+
low=-100, high=10, size=(32, 16, 8)).astype("float32")
172+
lower_bound = 0.0
173+
upper_bound = 0.3
174+
is_test = True
175+
self.inputs = {'X': X}
176+
self.attrs = {
177+
'lower_bound': lower_bound,
178+
'upper_bound': upper_bound,
179+
'is_test': is_test
180+
}
181+
Out, Mask = rrelu_np(
182+
x=X,
183+
lower_bound=lower_bound,
184+
upper_bound=upper_bound,
185+
is_test=is_test)
186+
self.outputs = {'Out': Out}
187+
188+
def test_check_output(self):
189+
self.check_output()
190+
191+
192+
class TestRReLUOpWithSeed(OpTest):
193+
def setUp(self):
194+
self.op_type = "rrelu"
195+
X = np.random.uniform(
196+
low=-100, high=10, size=(32, 16)).astype("float32")
197+
Seed = np.asarray([125], dtype="int32")
198+
lower_bound = 0.0
199+
upper_bound = 0.3
200+
is_test = False
201+
self.inputs = {'X': X, 'Seed': Seed}
202+
self.attrs = {
203+
'lower_bound': lower_bound,
204+
'upper_bound': upper_bound,
205+
'is_test': is_test
206+
}
207+
np.random.seed(125)
208+
Out, Mask = rrelu_np(
209+
x=X,
210+
lower_bound=lower_bound,
211+
upper_bound=upper_bound,
212+
is_test=is_test)
213+
self.outputs = {'Out': Out, 'Mask': Mask}
214+
215+
def test_check_output(self):
216+
self.check_output()
217+
218+
def test_check_grad_normal(self):
219+
self.check_grad(['X'], 'Out', max_relative_error=0.05)
220+
221+
222+
if __name__ == "__main__":
223+
unittest.main()

0 commit comments

Comments
 (0)