Skip to content

Commit b047c7a

Browse files
authored
Merge pull request #3668 from zchen0211/develop
Scatter update implemented with in-place memory, gradient_check modified a little bit to suit in-place
2 parents 6e8eccb + bfeecfd commit b047c7a

File tree

10 files changed

+222
-5
lines changed

10 files changed

+222
-5
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor)
4747
op_library(gather_op SRCS gather_op.cc gather_op.cu)
4848

4949
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
50+
op_library(scatter_op SRCS scatter_op.cc scatter_op.cu)
5051

5152
cc_library(net_op SRCS net_op.cc DEPS op_registry)
5253
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)

paddle/operators/scatter_op.cc

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/scatter_op.h"
16+
#include "paddle/framework/ddim.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class ScatterOp : public framework::OperatorWithKernel {
22+
public:
23+
using framework::OperatorWithKernel::OperatorWithKernel;
24+
25+
protected:
26+
void InferShape(const framework::InferShapeContext &ctx) const override {
27+
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Index")->dims().size(), 1,
28+
"Update Index should be 1-D.");
29+
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Ref")->dims().size(),
30+
ctx.Input<Tensor>("Updates")->dims().size(),
31+
"Reference and Updates should have the same shape size");
32+
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Updates")->dims()[0],
33+
ctx.Input<Tensor>("Index")->dims()[0],
34+
"Updates and Index should have same batch-size.");
35+
framework::DDim data_dim(ctx.Input<Tensor>("Updates")->dims());
36+
for (int i = 1; i < data_dim.size(); ++i)
37+
PADDLE_ENFORCE_EQ(data_dim[i], ctx.Input<Tensor>("Updates")->dims()[i]);
38+
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("Ref")->dims());
39+
}
40+
};
41+
42+
class ScatterGradOp : public framework::OperatorWithKernel {
43+
public:
44+
using framework::OperatorWithKernel::OperatorWithKernel;
45+
46+
protected:
47+
void InferShape(const framework::InferShapeContext &ctx) const override {
48+
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
49+
auto *Updates = ctx.Input<Tensor>("Updates");
50+
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
51+
auto *Ref = ctx.Input<Tensor>("Ref");
52+
53+
dRef->Resize(Ref->dims());
54+
dUpdates->Resize(Updates->dims());
55+
}
56+
};
57+
58+
class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
59+
public:
60+
ScatterOpMaker(framework::OpProto *proto,
61+
framework::OpAttrChecker *op_checker)
62+
: OpProtoAndCheckerMaker(proto, op_checker) {
63+
AddInput("Ref", "The source input of scatter op");
64+
AddInput("Index",
65+
"The index input of scatter op where Ref will be updated");
66+
AddInput("Updates", "The updated value of updates op");
67+
AddOutput("Out", "The output of add op");
68+
AddComment(R"DOC(
69+
Scatter Operator by selecting from the first axis,
70+
71+
Out = Ref
72+
Out[Index] = Ref[Index] + Updates
73+
)DOC");
74+
}
75+
};
76+
} // namespace operators
77+
} // namespace paddle
78+
79+
namespace ops = paddle::operators;
80+
REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad,
81+
ops::ScatterGradOp);
82+
REGISTER_OP_CPU_KERNEL(scatter,
83+
ops::ScatterOpKernel<paddle::platform::CPUPlace, float>);
84+
REGISTER_OP_CPU_KERNEL(
85+
scatter_grad,
86+
ops::ScatterGradientOpKernel<paddle::platform::CPUPlace, float>);

paddle/operators/scatter_op.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/scatter_op.h"
17+
18+
namespace ops = paddle::operators;
19+
REGISTER_OP_GPU_KERNEL(scatter,
20+
ops::ScatterOpKernel<paddle::platform::GPUPlace, float>);

paddle/operators/scatter_op.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "gather.h"
17+
#include "paddle/framework/eigen.h"
18+
#include "paddle/framework/op_registry.h"
19+
#include "scatter.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using Tensor = framework::Tensor;
25+
26+
template <typename Place, typename T>
27+
class ScatterOpKernel : public framework::OpKernel {
28+
public:
29+
void Compute(const framework::ExecutionContext &ctx) const override {
30+
auto *Ref = ctx.Input<Tensor>("Ref");
31+
auto *Index = ctx.Input<Tensor>("Index");
32+
auto *Updates = ctx.Input<Tensor>("Updates");
33+
auto *Out = ctx.Output<Tensor>("Out");
34+
35+
// In place output: Out = Ref, Out[Index] += Updates
36+
Out->ShareDataWith<T>(*Ref);
37+
// Apply ScatterUpdate: Out[index] += Updates[:]
38+
ScatterUpdate<T>(ctx.GetPlace(), Updates, Index, Out);
39+
}
40+
};
41+
42+
template <typename Place, typename T>
43+
class ScatterGradientOpKernel : public framework::OpKernel {
44+
public:
45+
void Compute(const framework::ExecutionContext &ctx) const override {
46+
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
47+
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
48+
auto *Index = ctx.Input<Tensor>("Index");
49+
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
50+
51+
// In place gradient: dRef = dO
52+
dRef->ShareDataWith<T>(*dOut);
53+
dUpdates->mutable_data<T>(ctx.GetPlace());
54+
// Gradient by Gather: dUpdates += dO[Index]
55+
Gather<T>(ctx.GetPlace(), dOut, Index, dUpdates);
56+
}
57+
};
58+
59+
} // namespace operators
60+
} // namespace paddle

paddle/pybind/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ cc_library(paddle_pybind SHARED
44
DEPS pybind python backward
55
sgd_op
66
gather_op
7+
scatter_op
78
add_op
89
mul_op
910
rowwise_add_op

paddle/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ USE_OP(scale);
4747
USE_OP_ITSELF(identity);
4848
USE_OP(minus);
4949
USE_CPU_ONLY_OP(gather);
50+
USE_CPU_ONLY_OP(scatter);
5051

5152
namespace paddle {
5253
namespace framework {

python/paddle/v2/framework/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
1414
py_test(test_softmax_op SRCS test_softmax_op.py)
1515
py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py)
1616
py_test(test_gather_op SRCS test_gather_op.py)
17+
py_test(test_scatter_op SRCS test_scatter_op.py)
1718
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
1819

1920
py_test(gradient_checker SRCS gradient_checker.py)

python/paddle/v2/framework/tests/gradient_checker.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def get_numeric_gradient(op,
3232
output_name,
3333
input_to_check,
3434
delta=0.005,
35-
local_scope=None):
35+
local_scope=None,
36+
in_place=False):
3637
"""
3738
Get Numeric Gradient for an operator's input.
3839
@@ -81,6 +82,11 @@ def get_output():
8182
def product(dim):
8283
return reduce(lambda a, b: a * b, dim, 1)
8384

85+
def restore_inputs():
86+
for var_name in input_values:
87+
tensor_ = local_scope.find_var(var_name).get_tensor()
88+
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
89+
8490
# get the input tensor that we want to get it's numeric gradient.
8591
tensor_to_check = local_scope.find_var(input_to_check).get_tensor()
8692
tensor_size = product(tensor_to_check.get_dims())
@@ -90,6 +96,8 @@ def product(dim):
9096
# we only compute gradient of one element each time.
9197
# we use a for loop to compute the gradient of every element.
9298
for i in xrange(tensor_size):
99+
if in_place:
100+
restore_inputs()
93101
# get one input element throw it's index i.
94102
origin = tensor_to_check.get_float_element(i)
95103

@@ -99,6 +107,8 @@ def product(dim):
99107
y_pos = get_output()
100108

101109
# plus delta to this element, run op and get the sum of the result tensor.
110+
if in_place:
111+
restore_inputs()
102112
x_neg = origin - delta
103113
tensor_to_check.set_float_element(i, x_neg)
104114
y_neg = get_output()
@@ -251,6 +261,7 @@ def check_grad(self,
251261
output_name,
252262
no_grad_set=None,
253263
only_cpu=False,
264+
in_place=False,
254265
max_relative_error=0.005):
255266
"""
256267
:param forward_op: used to create backward_op
@@ -283,7 +294,8 @@ def check_grad(self,
283294

284295
# get numerical gradients
285296
numeric_grads = [
286-
get_numeric_gradient(forward_op, input_vars, output_name, name)
297+
get_numeric_gradient(
298+
forward_op, input_vars, output_name, name, in_place=in_place)
287299
for name in inputs_to_check
288300
]
289301

python/paddle/v2/framework/tests/test_gather_op.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,9 @@ def setUp(self):
2121

2222
class TestGatherGradOp(GradientChecker):
2323
def test_gather_grad(self):
24-
print 'creating op'
2524
op = create_op("gather")
26-
print 'creating op done'
2725
xnp = numpy.random.random((10, 20)).astype("float32")
2826
inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")}
29-
print 'correct before check gradient'
3027
self.check_grad(op, inputs, set("X"), "Out")
3128

3229

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import unittest
2+
from op_test_util import OpTestMeta
3+
from gradient_checker import GradientChecker, create_op
4+
import numpy
5+
import paddle.v2.framework.core as core
6+
from paddle.v2.framework.op import Operator
7+
8+
9+
class TestScatterOp(unittest.TestCase):
10+
__metaclass__ = OpTestMeta
11+
12+
def setUp(self):
13+
self.type = "scatter"
14+
ref_np = numpy.ones((3, 3)).astype("float32")
15+
index_np = numpy.array([1, 2]).astype("int32")
16+
updates_np = numpy.random.random((2, 3)).astype("float32")
17+
output_np = numpy.copy(ref_np)
18+
output_np[index_np] += updates_np
19+
self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
20+
self.outputs = {'Out': output_np}
21+
22+
23+
class TestScatterGradOp(GradientChecker):
24+
def test_scatter_grad(self):
25+
op = create_op("scatter")
26+
# test data setup
27+
ref_np = numpy.ones((3, 10)).astype("float32")
28+
index_np = numpy.array([1, 2]).astype("int32")
29+
updates_np = numpy.random.random((2, 10)).astype("float32")
30+
output_np = numpy.copy(ref_np)
31+
output_np[index_np] += updates_np
32+
inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
33+
self.check_grad(
34+
op, inputs, set(["Updates", "Ref"]), "Out", in_place=True)
35+
36+
37+
if __name__ == "__main__":
38+
unittest.main()

0 commit comments

Comments
 (0)