Skip to content

Commit f2c52fe

Browse files
committed
Add 12 inplace APIs including auto generated (#32573)
* add relu6_ hardsigmoid_ leaky_relu_ Inplace APIs * add softmax_with_cross_entropy_ Inplace API * add clip_ scale_ add_ subtract_ Inplace APIs * add wlist * fix parameter of scale api * add add_n_ Inplace API and remove log_ Inplace API * fix elementwise_add_ and elementwise_sub_ broadcast problem * elementwise inplace api give error message before run the op * use broadcast_shape in elementwise inplace op * add 8 inplace apis that is auto generated * add unittest for all inplace apis * add decorator for inplace apis in static mode * fix windows blas fail of exp inplace api, change array_equal to allclose * add flatten inplace api * add flatten unittest * fix flatten unittest * add decorator * fix grad.numpy in test_pylayer_op * unsupport softmax_with_cross_entropy_ * add test_inplace_softmax_with_cross_entropy to static_mode_white_list * delete __all__ in inplace_utils * delete activation inplace function and add Tensor.inplace_func * change paddle.inplace_ to Tensor.inplace_ * fix little problem * add paddle in inplace_utils
1 parent 1a417a4 commit f2c52fe

File tree

18 files changed

+997
-135
lines changed

18 files changed

+997
-135
lines changed

paddle/fluid/imperative/basic_engine.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ void BasicEngine::Execute() {
408408
VLOG(10) << "create temporary var of " << var->Name()
409409
<< " for sum gradient within this graph!";
410410
} else if (!inplace_grad_name_map.empty() &&
411-
inplace_grad_name_map.count(pair.first)) {
411+
inplace_grad_name_map.count(pair.first) &&
412+
bwd_ins.count(inplace_grad_name_map.at(pair.first))) {
412413
// When calculate Inplace grad op, create a new output var.
413414
// If a tmp var has been created, there is no need to create it
414415
// again.

paddle/fluid/operators/flatten_op.h

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -120,51 +120,16 @@ template <typename DeviceContext, typename T>
120120
class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
121121
public:
122122
void Compute(const framework::ExecutionContext &context) const override {
123-
auto &start_axis = context.Attr<int>("start_axis");
124-
auto &stop_axis = context.Attr<int>("stop_axis");
125-
126123
auto *in = context.Input<framework::LoDTensor>("X");
127-
auto x_dims = in->dims();
128-
int in_dims_size = x_dims.size();
129-
int real_start_axis = start_axis, real_stop_axis = stop_axis;
130-
if (start_axis < 0) {
131-
real_start_axis = start_axis + in_dims_size;
132-
}
133-
if (stop_axis < 0) {
134-
real_stop_axis = stop_axis + in_dims_size;
135-
}
136124
auto *out = context.Output<framework::LoDTensor>("Out");
137-
138-
auto out_dims = framework::make_ddim(
139-
GetOutputShape(real_start_axis, real_stop_axis, x_dims));
125+
auto out_dims = out->dims();
140126

141127
out->mutable_data(context.GetPlace(), in->type());
142128
framework::TensorCopy(
143129
*in, context.GetPlace(),
144130
context.template device_context<platform::DeviceContext>(), out);
145131
out->Resize(out_dims);
146132
}
147-
static std::vector<int32_t> GetOutputShape(const int start_axis,
148-
const int stop_axis,
149-
const framework::DDim &in_dims) {
150-
int64_t outer = 1;
151-
std::vector<int32_t> out_shape;
152-
int in_dims_size = in_dims.size();
153-
out_shape.reserve(in_dims_size - stop_axis + start_axis);
154-
155-
for (int i = 0; i < start_axis; ++i) {
156-
out_shape.push_back(in_dims[i]);
157-
}
158-
for (int i = start_axis; i <= stop_axis; i++) {
159-
outer *= in_dims[i];
160-
}
161-
out_shape.push_back(outer);
162-
for (int i = stop_axis + 1; i < in_dims_size; i++) {
163-
out_shape.push_back(in_dims[i]);
164-
}
165-
166-
return out_shape;
167-
}
168133
};
169134

170135
template <typename DeviceContext, typename T>

python/paddle/fluid/dygraph/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858

5959
from .math_op_patch import monkey_patch_math_varbase
6060

61+
from .inplace_utils import inplace_apis_in_dygraph_only
62+
6163
__all__ = []
6264
__all__ += layers.__all__
6365
__all__ += base.__all__
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ..wrapped_decorator import wrap_decorator
16+
from ..framework import in_dygraph_mode
17+
import warnings
18+
import paddle
19+
20+
21+
# NOTE(pangyoki): The Inplace APIs with underline(`_`) is only valid for the method of calling `core.ops`
22+
# in dygraph mode. If static mode is used, the inplace mechanism will not be used, and the static method
23+
# of the original API will be called.
24+
def _inplace_apis_in_dygraph_only_(func):
25+
def __impl__(*args, **kwargs):
26+
if not in_dygraph_mode():
27+
origin_api_name = func.__name__[:-1]
28+
warnings.warn(
29+
"In static mode, {}() is the same as {}() and does not perform inplace operation.".
30+
format(func.__name__, origin_api_name))
31+
origin_func = "{}.{}".format(func.__module__, origin_api_name)
32+
return eval(origin_func)(*args, **kwargs)
33+
return func(*args, **kwargs)
34+
35+
return __impl__
36+
37+
38+
inplace_apis_in_dygraph_only = wrap_decorator(_inplace_apis_in_dygraph_only_)

python/paddle/fluid/layers/layer_function_generator.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from ..data_feeder import check_variable_and_dtype
2626

2727
__all__ = [
28-
'generate_layer_fn', 'generate_activation_fn', 'autodoc', 'templatedoc'
28+
'generate_layer_fn', 'generate_activation_fn', 'generate_inplace_fn',
29+
'autodoc', 'templatedoc'
2930
]
3031

3132

@@ -283,6 +284,35 @@ def func(x, name=None):
283284
return func
284285

285286

287+
def generate_inplace_fn(inplace_op_type):
288+
"""Register the Python layer for an Inplace Operator without Attribute.
289+
290+
Args:
291+
inplace_op_type: The name of the inplace operator to be created.
292+
293+
This function takes in the inplace operator type (exp_ , ceil_ etc) and
294+
creates the operator functionality.
295+
"""
296+
origin_op_type = inplace_op_type[:-1]
297+
298+
def func(x, name=None):
299+
if in_dygraph_mode():
300+
op = getattr(core.ops, inplace_op_type)
301+
return op(x)
302+
warnings.warn(
303+
"In static mode, {}() is the same as {}() and does not perform inplace operation.".
304+
format(inplace_op_type, origin_op_type))
305+
return generate_activation_fn(origin_op_type)(x, name)
306+
307+
func.__name__ = inplace_op_type
308+
func.__doc__ = """
309+
Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``.
310+
Please refer to :ref:`api_fluid_layers_{1}`.
311+
""".format(origin_op_type, origin_op_type)
312+
313+
return func
314+
315+
286316
def autodoc(comment=""):
287317
def __impl__(func):
288318
func.__doc__ = _generate_doc_string_(OpProtoHolder.instance(

python/paddle/fluid/layers/ops.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import print_function
1616
import os
17-
from .layer_function_generator import generate_layer_fn, generate_activation_fn, add_sample_code
17+
from .layer_function_generator import generate_layer_fn, generate_activation_fn, generate_inplace_fn, add_sample_code
1818
from .. import core
1919
from ..framework import convert_np_dtype_to_dtype_, Variable
2020
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
@@ -55,6 +55,16 @@
5555
'square',
5656
]
5757

58+
__inplace_unary_func__ = [
59+
'exp_',
60+
'sqrt_',
61+
'rsqrt_',
62+
'ceil_',
63+
'floor_',
64+
'round_',
65+
'reciprocal_',
66+
]
67+
5868
__all__ = []
5969

6070
for _OP in set(__all__):
@@ -69,6 +79,7 @@
6979

7080
__all__ += __activations_noattr__
7181
__all__ += __unary_func__
82+
__all__ += __inplace_unary_func__
7283

7384
for _OP in set(__activations_noattr__):
7485
_new_OP = _OP
@@ -87,6 +98,14 @@
8798
func = deprecated(since="2.0.0", update_to="paddle.%s" % (_new_OP))(func)
8899
globals()[_OP] = func
89100

101+
for _OP in set(__inplace_unary_func__):
102+
_new_OP = _OP
103+
if _OP in __deprecated_func_name__:
104+
_new_OP = __deprecated_func_name__[_OP]
105+
func = generate_inplace_fn(_OP)
106+
func = deprecated(since="2.0.0", update_to="paddle.%s" % (_new_OP))(func)
107+
globals()[_OP] = func
108+
90109
add_sample_code(globals()["sigmoid"], r"""
91110
Examples:
92111
.. code-block:: python

python/paddle/fluid/tests/unittests/test_clip_op.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def test_dtype():
124124

125125

126126
class TestClipAPI(unittest.TestCase):
127+
def _executed_api(self, x, min=None, max=None):
128+
return paddle.clip(x, min, max)
129+
127130
def test_clip(self):
128131
paddle.enable_static()
129132
data_shape = [1, 9, 9, 4]
@@ -136,18 +139,20 @@ def test_clip(self):
136139
) else fluid.CPUPlace()
137140
exe = fluid.Executor(place)
138141

139-
out_1 = paddle.clip(images, min=min, max=max)
140-
out_2 = paddle.clip(images, min=0.2, max=0.9)
141-
out_3 = paddle.clip(images, min=0.3)
142-
out_4 = paddle.clip(images, max=0.7)
143-
out_5 = paddle.clip(images, min=min)
144-
out_6 = paddle.clip(images, max=max)
145-
out_7 = paddle.clip(images, max=-1.)
146-
out_8 = paddle.clip(images)
147-
out_9 = paddle.clip(paddle.cast(images, 'float64'), min=0.2, max=0.9)
148-
149-
out_10 = paddle.clip(paddle.cast(images * 10, 'int32'), min=2, max=8)
150-
out_11 = paddle.clip(paddle.cast(images * 10, 'int64'), min=2, max=8)
142+
out_1 = self._executed_api(images, min=min, max=max)
143+
out_2 = self._executed_api(images, min=0.2, max=0.9)
144+
out_3 = self._executed_api(images, min=0.3)
145+
out_4 = self._executed_api(images, max=0.7)
146+
out_5 = self._executed_api(images, min=min)
147+
out_6 = self._executed_api(images, max=max)
148+
out_7 = self._executed_api(images, max=-1.)
149+
out_8 = self._executed_api(images)
150+
out_9 = self._executed_api(
151+
paddle.cast(images, 'float64'), min=0.2, max=0.9)
152+
out_10 = self._executed_api(
153+
paddle.cast(images * 10, 'int32'), min=2, max=8)
154+
out_11 = self._executed_api(
155+
paddle.cast(images * 10, 'int64'), min=2, max=8)
151156

152157
res1, res2, res3, res4, res5, res6, res7, res8, res9, res10, res11 = exe.run(
153158
fluid.default_main_program(),
@@ -188,12 +193,16 @@ def test_clip_dygraph(self):
188193
v_min = paddle.to_tensor(np.array([0.2], dtype=np.float32))
189194
v_max = paddle.to_tensor(np.array([0.8], dtype=np.float32))
190195

191-
out_1 = paddle.clip(images, min=0.2, max=0.8)
192-
out_2 = paddle.clip(images, min=0.2, max=0.9)
193-
out_3 = paddle.clip(images, min=v_min, max=v_max)
196+
out_1 = self._executed_api(images, min=0.2, max=0.8)
197+
images = paddle.to_tensor(data, dtype='float32')
198+
out_2 = self._executed_api(images, min=0.2, max=0.9)
199+
images = paddle.to_tensor(data, dtype='float32')
200+
out_3 = self._executed_api(images, min=v_min, max=v_max)
194201

195-
out_4 = paddle.clip(paddle.cast(images * 10, 'int32'), min=2, max=8)
196-
out_5 = paddle.clip(paddle.cast(images * 10, 'int64'), min=2, max=8)
202+
out_4 = self._executed_api(
203+
paddle.cast(images * 10, 'int32'), min=2, max=8)
204+
out_5 = self._executed_api(
205+
paddle.cast(images * 10, 'int64'), min=2, max=8)
197206

198207
self.assertTrue(np.allclose(out_1.numpy(), data.clip(0.2, 0.8)))
199208
self.assertTrue(np.allclose(out_2.numpy(), data.clip(0.2, 0.9)))
@@ -212,5 +221,10 @@ def test_errors(self):
212221
paddle.disable_static()
213222

214223

224+
class TestInplaceClipAPI(TestClipAPI):
225+
def _executed_api(self, x, min=None, max=None):
226+
return x.clip_(min, max)
227+
228+
215229
if __name__ == '__main__':
216230
unittest.main()

python/paddle/fluid/tests/unittests/test_elementwise_add_op.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,13 +408,16 @@ def test_errors(self):
408408
self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2)
409409

410410

411-
class TestAddOp(unittest.TestCase):
411+
class TestAddApi(unittest.TestCase):
412+
def _executed_api(self, x, y, name=None):
413+
return paddle.add(x, y, name)
414+
412415
def test_name(self):
413416
with fluid.program_guard(fluid.Program()):
414417
x = fluid.data(name="x", shape=[2, 3], dtype="float32")
415418
y = fluid.data(name='y', shape=[2, 3], dtype='float32')
416419

417-
y_1 = paddle.add(x, y, name='add_res')
420+
y_1 = self._executed_api(x, y, name='add_res')
418421
self.assertEqual(('add_res' in y_1.name), True)
419422

420423
def test_declarative(self):
@@ -428,7 +431,7 @@ def gen_data():
428431

429432
x = fluid.data(name="x", shape=[3], dtype='float32')
430433
y = fluid.data(name="y", shape=[3], dtype='float32')
431-
z = paddle.add(x, y)
434+
z = self._executed_api(x, y)
432435

433436
place = fluid.CPUPlace()
434437
exe = fluid.Executor(place)
@@ -442,12 +445,75 @@ def test_dygraph(self):
442445
np_y = np.array([1, 5, 2]).astype('float64')
443446
x = fluid.dygraph.to_variable(np_x)
444447
y = fluid.dygraph.to_variable(np_y)
445-
z = paddle.add(x, y)
448+
z = self._executed_api(x, y)
446449
np_z = z.numpy()
447450
z_expected = np.array([3., 8., 6.])
448451
self.assertEqual((np_z == z_expected).all(), True)
449452

450453

454+
class TestAddInplaceApi(TestAddApi):
455+
def _executed_api(self, x, y, name=None):
456+
return x.add_(y, name)
457+
458+
459+
class TestAddInplaceBroadcastSuccess(unittest.TestCase):
460+
def init_data(self):
461+
self.x_numpy = np.random.rand(2, 3, 4).astype('float')
462+
self.y_numpy = np.random.rand(3, 4).astype('float')
463+
464+
def test_broadcast_success(self):
465+
paddle.disable_static()
466+
self.init_data()
467+
x = paddle.to_tensor(self.x_numpy)
468+
y = paddle.to_tensor(self.y_numpy)
469+
inplace_result = x.add_(y)
470+
numpy_result = self.x_numpy + self.y_numpy
471+
self.assertEqual((inplace_result.numpy() == numpy_result).all(), True)
472+
paddle.enable_static()
473+
474+
475+
class TestAddInplaceBroadcastSuccess2(TestAddInplaceBroadcastSuccess):
476+
def init_data(self):
477+
self.x_numpy = np.random.rand(1, 2, 3, 1).astype('float')
478+
self.y_numpy = np.random.rand(3, 1).astype('float')
479+
480+
481+
class TestAddInplaceBroadcastSuccess3(TestAddInplaceBroadcastSuccess):
482+
def init_data(self):
483+
self.x_numpy = np.random.rand(2, 3, 1, 5).astype('float')
484+
self.y_numpy = np.random.rand(1, 3, 1, 5).astype('float')
485+
486+
487+
class TestAddInplaceBroadcastError(unittest.TestCase):
488+
def init_data(self):
489+
self.x_numpy = np.random.rand(3, 4).astype('float')
490+
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
491+
492+
def test_broadcast_errors(self):
493+
paddle.disable_static()
494+
self.init_data()
495+
x = paddle.to_tensor(self.x_numpy)
496+
y = paddle.to_tensor(self.y_numpy)
497+
498+
def broadcast_shape_error():
499+
x.add_(y)
500+
501+
self.assertRaises(ValueError, broadcast_shape_error)
502+
paddle.enable_static()
503+
504+
505+
class TestAddInplaceBroadcastError2(TestAddInplaceBroadcastError):
506+
def init_data(self):
507+
self.x_numpy = np.random.rand(2, 1, 4).astype('float')
508+
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
509+
510+
511+
class TestAddInplaceBroadcastError3(TestAddInplaceBroadcastError):
512+
def init_data(self):
513+
self.x_numpy = np.random.rand(5, 2, 1, 4).astype('float')
514+
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
515+
516+
451517
class TestComplexElementwiseAddOp(OpTest):
452518
def setUp(self):
453519
self.op_type = "elementwise_add"

0 commit comments

Comments
 (0)