Skip to content

Commit bd98227

Browse files
authored
Add test_dropout_nd (#1673)
* add test_dropout_nd * PR comment
1 parent ed7ce0b commit bd98227

File tree

2 files changed

+131
-0
lines changed

2 files changed

+131
-0
lines changed

api/tests/dropout_nd.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from common_import import *
16+
import paddle
17+
from paddle import _legacy_C_ops
18+
19+
20+
# no dropout_nd in pytorch
21+
def dropout_nd(x,
22+
p=0.5,
23+
axis=None,
24+
training=True,
25+
mode="upscale_in_train",
26+
name=None):
27+
drop_axes = [axis] if isinstance(axis, int) else list(axis)
28+
seed = None
29+
mode = ('downgrade_in_infer'
30+
if mode == 'downscale_in_infer' else mode) # semantic transfer
31+
out = _legacy_C_ops.dropout_nd(
32+
x,
33+
'dropout_prob',
34+
p,
35+
'is_test',
36+
not training,
37+
'fix_seed',
38+
seed is not None,
39+
'seed',
40+
seed if seed is not None else 0,
41+
'dropout_implementation',
42+
mode,
43+
'axis',
44+
drop_axes, )
45+
return out
46+
47+
48+
@benchmark_registry.register("dropout_nd")
49+
class PaddleDropoutNdConfig(APIConfig):
50+
def __init__(self):
51+
super(PaddleDropoutNdConfig, self).__init__('dropout_nd')
52+
self.run_torch = False
53+
54+
55+
@benchmark_registry.register("dropout_nd")
56+
class PaddleDropoutNd(PaddleOpBenchmarkBase):
57+
def build_graph(self, config):
58+
x = self.variable(name='x', shape=config.x_shape, dtype=config.x_dtype)
59+
result = dropout_nd(
60+
x=x, p=config.p, axis=config.axis, mode=config.mode)
61+
62+
self.feed_list = [x]
63+
self.fetch_list = [result]
64+
if config.backward:
65+
self.append_gradients(result[0], self.feed_list)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
[{
2+
"op": "dropout_nd",
3+
"param_info": {
4+
"mode": {
5+
"type": "string",
6+
"value": "upscale_in_train"
7+
},
8+
"p": {
9+
"type": "float",
10+
"value": "0.5"
11+
},
12+
"x": {
13+
"dtype": "float32",
14+
"shape": "[-1L, 16L, -1L]",
15+
"type": "Variable"
16+
},
17+
"axis": {
18+
"type": "list",
19+
"value": "[1]"
20+
}
21+
},
22+
"repeat": 2000
23+
}, {
24+
"op": "dropout_nd",
25+
"param_info": {
26+
"mode": {
27+
"type": "string",
28+
"value": "downscale_in_infer"
29+
},
30+
"p": {
31+
"type": "float",
32+
"value": "0.5"
33+
},
34+
"x": {
35+
"dtype": "float32",
36+
"shape": "[-1L, 16L, -1L, -1L]",
37+
"type": "Variable"
38+
},
39+
"axis": {
40+
"type": "list",
41+
"value": "[0,1]"
42+
}
43+
},
44+
"repeat": 2000
45+
}, {
46+
"op": "dropout_nd",
47+
"param_info": {
48+
"mode": {
49+
"type": "string",
50+
"value": "upscale_in_train"
51+
},
52+
"p": {
53+
"type": "float",
54+
"value": "0.1"
55+
},
56+
"x": {
57+
"dtype": "float32",
58+
"shape": "[32L, 128L, 768L]",
59+
"type": "Variable"
60+
},
61+
"axis": {
62+
"type": "list",
63+
"value": "[0]"
64+
}
65+
}
66+
}]

0 commit comments

Comments
 (0)