|  | 
|  | 1 | +#   Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +from common_import import * | 
|  | 16 | +import paddle | 
|  | 17 | +from paddle import _legacy_C_ops | 
|  | 18 | + | 
|  | 19 | + | 
|  | 20 | +# no dropout_nd in pytorch | 
|  | 21 | +def dropout_nd(x, | 
|  | 22 | +               p=0.5, | 
|  | 23 | +               axis=None, | 
|  | 24 | +               training=True, | 
|  | 25 | +               mode="upscale_in_train", | 
|  | 26 | +               name=None): | 
|  | 27 | +    drop_axes = [axis] if isinstance(axis, int) else list(axis) | 
|  | 28 | +    seed = None | 
|  | 29 | +    mode = ('downgrade_in_infer' | 
|  | 30 | +            if mode == 'downscale_in_infer' else mode)  # semantic transfer | 
|  | 31 | +    out = _legacy_C_ops.dropout_nd( | 
|  | 32 | +        x, | 
|  | 33 | +        'dropout_prob', | 
|  | 34 | +        p, | 
|  | 35 | +        'is_test', | 
|  | 36 | +        not training, | 
|  | 37 | +        'fix_seed', | 
|  | 38 | +        seed is not None, | 
|  | 39 | +        'seed', | 
|  | 40 | +        seed if seed is not None else 0, | 
|  | 41 | +        'dropout_implementation', | 
|  | 42 | +        mode, | 
|  | 43 | +        'axis', | 
|  | 44 | +        drop_axes, ) | 
|  | 45 | +    return out | 
|  | 46 | + | 
|  | 47 | + | 
|  | 48 | +@benchmark_registry.register("dropout_nd") | 
|  | 49 | +class PaddleDropoutNdConfig(APIConfig): | 
|  | 50 | +    def __init__(self): | 
|  | 51 | +        super(PaddleDropoutNdConfig, self).__init__('dropout_nd') | 
|  | 52 | +        self.run_torch = False | 
|  | 53 | + | 
|  | 54 | + | 
|  | 55 | +@benchmark_registry.register("dropout_nd") | 
|  | 56 | +class PaddleDropoutNd(PaddleOpBenchmarkBase): | 
|  | 57 | +    def build_graph(self, config): | 
|  | 58 | +        x = self.variable(name='x', shape=config.x_shape, dtype=config.x_dtype) | 
|  | 59 | +        result = dropout_nd( | 
|  | 60 | +            x=x, p=config.p, axis=config.axis, mode=config.mode) | 
|  | 61 | + | 
|  | 62 | +        self.feed_list = [x] | 
|  | 63 | +        self.fetch_list = [result] | 
|  | 64 | +        if config.backward: | 
|  | 65 | +            self.append_gradients(result[0], self.feed_list) | 
0 commit comments