Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
19b4e1c
add relu6_ hardsigmoid_ leaky_relu_ Inplace APIs
pangyoki Jan 20, 2021
cd325a4
add softmax_with_cross_entropy_ Inplace API
pangyoki Jan 20, 2021
7b71976
add clip_ scale_ add_ subtract_ Inplace APIs
pangyoki Jan 20, 2021
8a3878c
add wlist
pangyoki Jan 20, 2021
71db93a
fix parameter of scale api
pangyoki Jan 21, 2021
ee118ec
add add_n_ Inplace API and remove log_ Inplace API
pangyoki Jan 21, 2021
188960f
fix elementwise_add_ and elementwise_sub_ broadcast problem
pangyoki Jan 22, 2021
0698f48
elementwise inplace api give error message before run the op
pangyoki Jan 26, 2021
a2cec25
use broadcast_shape in elementwise inplace op
pangyoki Jan 27, 2021
ec71854
add 8 inplace apis that is auto generated
pangyoki Jan 27, 2021
234f7d8
add unittest for all inplace apis
pangyoki Jan 27, 2021
6c20097
add decorator for inplace apis in static mode
pangyoki Jan 27, 2021
46ba29c
fix windows blas fail of exp inplace api, change array_equal to allclose
pangyoki Jan 28, 2021
ae14bbe
add flatten inplace api
pangyoki Jan 28, 2021
ea0777a
add flatten unittest
pangyoki Jan 28, 2021
40ccfe9
fix flatten unittest
pangyoki Jan 28, 2021
52190d5
merge 18 inplace apis
pangyoki Jan 29, 2021
af6a388
add decorator
pangyoki Feb 1, 2021
dc4ad2c
solve conflict
pangyoki Apr 26, 2021
43fc847
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
pangyoki Apr 26, 2021
d330959
fix grad.numpy in test_pylayer_op
pangyoki Apr 26, 2021
5a357c0
solve conflict of __init__
pangyoki Apr 27, 2021
388afb1
solve conflict of __init__ v2
pangyoki Apr 28, 2021
29737d4
unsupport softmax_with_cross_entropy_
pangyoki Apr 29, 2021
bc90468
add test_inplace_softmax_with_cross_entropy to static_mode_white_list
pangyoki Apr 29, 2021
a5293e6
delete __all__ in inplace_utils
pangyoki Apr 29, 2021
94b2713
solve conflict
pangyoki Apr 29, 2021
eb11b3b
delete activation inplace function and add Tensor.inplace_func
pangyoki Apr 29, 2021
1001d9e
change paddle.inplace_ to Tensor.inplace_
pangyoki Apr 29, 2021
4faff99
fix little problem
pangyoki Apr 29, 2021
f46d3fe
add paddle in inplace_utils
pangyoki Apr 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,8 @@ void BasicEngine::Execute() {
VLOG(10) << "create temporary var of " << var->Name()
<< " for sum gradient within this graph!";
} else if (!inplace_grad_name_map.empty() &&
inplace_grad_name_map.count(pair.first)) {
inplace_grad_name_map.count(pair.first) &&
bwd_ins.count(inplace_grad_name_map.at(pair.first))) {
// When calculate Inplace grad op, create a new output var.
// If a tmp var has been created, there is no need to create it
// again.
Expand Down
37 changes: 1 addition & 36 deletions paddle/fluid/operators/flatten_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,51 +120,16 @@ template <typename DeviceContext, typename T>
class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto &start_axis = context.Attr<int>("start_axis");
auto &stop_axis = context.Attr<int>("stop_axis");

auto *in = context.Input<framework::LoDTensor>("X");
auto x_dims = in->dims();
int in_dims_size = x_dims.size();
int real_start_axis = start_axis, real_stop_axis = stop_axis;
if (start_axis < 0) {
real_start_axis = start_axis + in_dims_size;
}
if (stop_axis < 0) {
real_stop_axis = stop_axis + in_dims_size;
}
auto *out = context.Output<framework::LoDTensor>("Out");

auto out_dims = framework::make_ddim(
GetOutputShape(real_start_axis, real_stop_axis, x_dims));
auto out_dims = out->dims();

out->mutable_data(context.GetPlace(), in->type());
framework::TensorCopy(
*in, context.GetPlace(),
context.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
}
static std::vector<int32_t> GetOutputShape(const int start_axis,
const int stop_axis,
const framework::DDim &in_dims) {
int64_t outer = 1;
std::vector<int32_t> out_shape;
int in_dims_size = in_dims.size();
out_shape.reserve(in_dims_size - stop_axis + start_axis);

for (int i = 0; i < start_axis; ++i) {
out_shape.push_back(in_dims[i]);
}
for (int i = start_axis; i <= stop_axis; i++) {
outer *= in_dims[i];
}
out_shape.push_back(outer);
for (int i = stop_axis + 1; i < in_dims_size; i++) {
out_shape.push_back(in_dims[i]);
}

return out_shape;
}
};

template <typename DeviceContext, typename T>
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/dygraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@

from .math_op_patch import monkey_patch_math_varbase

from .inplace_utils import inplace_apis_in_dygraph_only

__all__ = []
__all__ += layers.__all__
__all__ += base.__all__
Expand Down
38 changes: 38 additions & 0 deletions python/paddle/fluid/dygraph/inplace_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..wrapped_decorator import wrap_decorator
from ..framework import in_dygraph_mode
import warnings
import paddle


# NOTE(pangyoki): The Inplace APIs with underline(`_`) is only valid for the method of calling `core.ops`
# in dygraph mode. If static mode is used, the inplace mechanism will not be used, and the static method
# of the original API will be called.
def _inplace_apis_in_dygraph_only_(func):
def __impl__(*args, **kwargs):
if not in_dygraph_mode():
origin_api_name = func.__name__[:-1]
warnings.warn(
"In static mode, {}() is the same as {}() and does not perform inplace operation.".
format(func.__name__, origin_api_name))
origin_func = "{}.{}".format(func.__module__, origin_api_name)
return eval(origin_func)(*args, **kwargs)
return func(*args, **kwargs)

return __impl__


inplace_apis_in_dygraph_only = wrap_decorator(_inplace_apis_in_dygraph_only_)
32 changes: 31 additions & 1 deletion python/paddle/fluid/layers/layer_function_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from ..data_feeder import check_variable_and_dtype

__all__ = [
'generate_layer_fn', 'generate_activation_fn', 'autodoc', 'templatedoc'
'generate_layer_fn', 'generate_activation_fn', 'generate_inplace_fn',
'autodoc', 'templatedoc'
]


Expand Down Expand Up @@ -283,6 +284,35 @@ def func(x, name=None):
return func


def generate_inplace_fn(inplace_op_type):
"""Register the Python layer for an Inplace Operator without Attribute.

Args:
inplace_op_type: The name of the inplace operator to be created.

This function takes in the inplace operator type (exp_ , ceil_ etc) and
creates the operator functionality.
"""
origin_op_type = inplace_op_type[:-1]

def func(x, name=None):
if in_dygraph_mode():
op = getattr(core.ops, inplace_op_type)
return op(x)
warnings.warn(
"In static mode, {}() is the same as {}() and does not perform inplace operation.".
format(inplace_op_type, origin_op_type))
return generate_activation_fn(origin_op_type)(x, name)

func.__name__ = inplace_op_type
func.__doc__ = """
Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_fluid_layers_{1}`.
""".format(origin_op_type, origin_op_type)

return func


def autodoc(comment=""):
def __impl__(func):
func.__doc__ = _generate_doc_string_(OpProtoHolder.instance(
Expand Down
21 changes: 20 additions & 1 deletion python/paddle/fluid/layers/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import print_function
import os
from .layer_function_generator import generate_layer_fn, generate_activation_fn, add_sample_code
from .layer_function_generator import generate_layer_fn, generate_activation_fn, generate_inplace_fn, add_sample_code
from .. import core
from ..framework import convert_np_dtype_to_dtype_, Variable
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
Expand Down Expand Up @@ -55,6 +55,16 @@
'square',
]

__inplace_unary_func__ = [
'exp_',
'sqrt_',
'rsqrt_',
'ceil_',
'floor_',
'round_',
'reciprocal_',
]

__all__ = []

for _OP in set(__all__):
Expand All @@ -69,6 +79,7 @@

__all__ += __activations_noattr__
__all__ += __unary_func__
__all__ += __inplace_unary_func__

for _OP in set(__activations_noattr__):
_new_OP = _OP
Expand All @@ -87,6 +98,14 @@
func = deprecated(since="2.0.0", update_to="paddle.%s" % (_new_OP))(func)
globals()[_OP] = func

for _OP in set(__inplace_unary_func__):
_new_OP = _OP
if _OP in __deprecated_func_name__:
_new_OP = __deprecated_func_name__[_OP]
func = generate_inplace_fn(_OP)
func = deprecated(since="2.0.0", update_to="paddle.%s" % (_new_OP))(func)
globals()[_OP] = func

add_sample_code(globals()["sigmoid"], r"""
Examples:
.. code-block:: python
Expand Down
48 changes: 31 additions & 17 deletions python/paddle/fluid/tests/unittests/test_clip_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def test_dtype():


class TestClipAPI(unittest.TestCase):
def _executed_api(self, x, min=None, max=None):
return paddle.clip(x, min, max)

def test_clip(self):
paddle.enable_static()
data_shape = [1, 9, 9, 4]
Expand All @@ -136,18 +139,20 @@ def test_clip(self):
) else fluid.CPUPlace()
exe = fluid.Executor(place)

out_1 = paddle.clip(images, min=min, max=max)
out_2 = paddle.clip(images, min=0.2, max=0.9)
out_3 = paddle.clip(images, min=0.3)
out_4 = paddle.clip(images, max=0.7)
out_5 = paddle.clip(images, min=min)
out_6 = paddle.clip(images, max=max)
out_7 = paddle.clip(images, max=-1.)
out_8 = paddle.clip(images)
out_9 = paddle.clip(paddle.cast(images, 'float64'), min=0.2, max=0.9)

out_10 = paddle.clip(paddle.cast(images * 10, 'int32'), min=2, max=8)
out_11 = paddle.clip(paddle.cast(images * 10, 'int64'), min=2, max=8)
out_1 = self._executed_api(images, min=min, max=max)
out_2 = self._executed_api(images, min=0.2, max=0.9)
out_3 = self._executed_api(images, min=0.3)
out_4 = self._executed_api(images, max=0.7)
out_5 = self._executed_api(images, min=min)
out_6 = self._executed_api(images, max=max)
out_7 = self._executed_api(images, max=-1.)
out_8 = self._executed_api(images)
out_9 = self._executed_api(
paddle.cast(images, 'float64'), min=0.2, max=0.9)
out_10 = self._executed_api(
paddle.cast(images * 10, 'int32'), min=2, max=8)
out_11 = self._executed_api(
paddle.cast(images * 10, 'int64'), min=2, max=8)

res1, res2, res3, res4, res5, res6, res7, res8, res9, res10, res11 = exe.run(
fluid.default_main_program(),
Expand Down Expand Up @@ -188,12 +193,16 @@ def test_clip_dygraph(self):
v_min = paddle.to_tensor(np.array([0.2], dtype=np.float32))
v_max = paddle.to_tensor(np.array([0.8], dtype=np.float32))

out_1 = paddle.clip(images, min=0.2, max=0.8)
out_2 = paddle.clip(images, min=0.2, max=0.9)
out_3 = paddle.clip(images, min=v_min, max=v_max)
out_1 = self._executed_api(images, min=0.2, max=0.8)
images = paddle.to_tensor(data, dtype='float32')
out_2 = self._executed_api(images, min=0.2, max=0.9)
images = paddle.to_tensor(data, dtype='float32')
out_3 = self._executed_api(images, min=v_min, max=v_max)

out_4 = paddle.clip(paddle.cast(images * 10, 'int32'), min=2, max=8)
out_5 = paddle.clip(paddle.cast(images * 10, 'int64'), min=2, max=8)
out_4 = self._executed_api(
paddle.cast(images * 10, 'int32'), min=2, max=8)
out_5 = self._executed_api(
paddle.cast(images * 10, 'int64'), min=2, max=8)

self.assertTrue(np.allclose(out_1.numpy(), data.clip(0.2, 0.8)))
self.assertTrue(np.allclose(out_2.numpy(), data.clip(0.2, 0.9)))
Expand All @@ -212,5 +221,10 @@ def test_errors(self):
paddle.disable_static()


class TestInplaceClipAPI(TestClipAPI):
def _executed_api(self, x, min=None, max=None):
return x.clip_(min, max)


if __name__ == '__main__':
unittest.main()
74 changes: 70 additions & 4 deletions python/paddle/fluid/tests/unittests/test_elementwise_add_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,16 @@ def test_errors(self):
self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2)


class TestAddOp(unittest.TestCase):
class TestAddApi(unittest.TestCase):
def _executed_api(self, x, y, name=None):
return paddle.add(x, y, name)

def test_name(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[2, 3], dtype="float32")
y = fluid.data(name='y', shape=[2, 3], dtype='float32')

y_1 = paddle.add(x, y, name='add_res')
y_1 = self._executed_api(x, y, name='add_res')
self.assertEqual(('add_res' in y_1.name), True)

def test_declarative(self):
Expand All @@ -428,7 +431,7 @@ def gen_data():

x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = paddle.add(x, y)
z = self._executed_api(x, y)

place = fluid.CPUPlace()
exe = fluid.Executor(place)
Expand All @@ -442,12 +445,75 @@ def test_dygraph(self):
np_y = np.array([1, 5, 2]).astype('float64')
x = fluid.dygraph.to_variable(np_x)
y = fluid.dygraph.to_variable(np_y)
z = paddle.add(x, y)
z = self._executed_api(x, y)
np_z = z.numpy()
z_expected = np.array([3., 8., 6.])
self.assertEqual((np_z == z_expected).all(), True)


class TestAddInplaceApi(TestAddApi):
def _executed_api(self, x, y, name=None):
return x.add_(y, name)


class TestAddInplaceBroadcastSuccess(unittest.TestCase):
def init_data(self):
self.x_numpy = np.random.rand(2, 3, 4).astype('float')
self.y_numpy = np.random.rand(3, 4).astype('float')

def test_broadcast_success(self):
paddle.disable_static()
self.init_data()
x = paddle.to_tensor(self.x_numpy)
y = paddle.to_tensor(self.y_numpy)
inplace_result = x.add_(y)
numpy_result = self.x_numpy + self.y_numpy
self.assertEqual((inplace_result.numpy() == numpy_result).all(), True)
paddle.enable_static()


class TestAddInplaceBroadcastSuccess2(TestAddInplaceBroadcastSuccess):
def init_data(self):
self.x_numpy = np.random.rand(1, 2, 3, 1).astype('float')
self.y_numpy = np.random.rand(3, 1).astype('float')


class TestAddInplaceBroadcastSuccess3(TestAddInplaceBroadcastSuccess):
def init_data(self):
self.x_numpy = np.random.rand(2, 3, 1, 5).astype('float')
self.y_numpy = np.random.rand(1, 3, 1, 5).astype('float')


class TestAddInplaceBroadcastError(unittest.TestCase):
def init_data(self):
self.x_numpy = np.random.rand(3, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')

def test_broadcast_errors(self):
paddle.disable_static()
self.init_data()
x = paddle.to_tensor(self.x_numpy)
y = paddle.to_tensor(self.y_numpy)

def broadcast_shape_error():
x.add_(y)

self.assertRaises(ValueError, broadcast_shape_error)
paddle.enable_static()


class TestAddInplaceBroadcastError2(TestAddInplaceBroadcastError):
def init_data(self):
self.x_numpy = np.random.rand(2, 1, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')


class TestAddInplaceBroadcastError3(TestAddInplaceBroadcastError):
def init_data(self):
self.x_numpy = np.random.rand(5, 2, 1, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')


class TestComplexElementwiseAddOp(OpTest):
def setUp(self):
self.op_type = "elementwise_add"
Expand Down
Loading