Skip to content

Commit 5edca0f

Browse files
committed
Fix code and add unittests according to reviews
1 parent 7fd9557 commit 5edca0f

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

paddle/fluid/operators/set_value_op.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <utility>
2222
#include "paddle/fluid/framework/eigen.h"
2323
#include "paddle/fluid/framework/op_registry.h"
24+
#include "paddle/fluid/framework/tensor_util.h"
2425
#include "paddle/fluid/operators/assign_value_op.h"
2526
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
2627
#include "paddle/fluid/platform/enforce.h"
@@ -83,6 +84,10 @@ class SetValueKernel : public framework::OpKernel<T> {
8384
public:
8485
void Compute(const framework::ExecutionContext& ctx) const {
8586
const int rank = ctx.Output<framework::LoDTensor>("Out")->dims().size();
87+
88+
// TODO(liym27): A more elegent code to do this. C++ has to make template
89+
// integer as constant, but we had better have alternative writing in the
90+
// future.
8691
switch (rank) {
8792
case 1:
8893
SetValueCompute<1>(ctx);
@@ -127,7 +132,18 @@ class SetValueKernel : public framework::OpKernel<T> {
127132
auto& eigen_place =
128133
*ctx.template device_context<DeviceContext>().eigen_device();
129134

130-
out->ShareDataWith(*in);
135+
// Here copy data from input to avoid data loss at PE and Graph level.
136+
// TODO(liym27): Speed up in the future version.
137+
// - Q: Why don't call ShareDataWith to speed up?
138+
// - A: Because it's not supported to ShareDataWith on OP's input and output
139+
// https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
140+
// - Q: Why don't delete Input, after all, the input and output are the same
141+
// Tensor at program level?
142+
// - A: If deleting Input, the graph will be complex, such as there will
143+
// be two ops points to the output in graph: op1 -> output <- set_value.
144+
// In this case, we have to find a way to handle the running order of
145+
// set_value is what we want.
146+
TensorCopy(*in, place, out);
131147

132148
Tensor slice_t(dtype), pad_t(dtype);
133149
slice_t.mutable_data<T>(slice_dims, place);

python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
np.random.seed(SEED)
2424
prog_trans = paddle.jit.ProgramTranslator()
2525

26+
2627
@paddle.jit.to_static
2728
def test_slice_without_control_flow(x):
2829
# Python slice will not be transformed.
@@ -84,7 +85,7 @@ def test_slice_in_for_loop(x, iter_num=3):
8485

8586

8687
@paddle.jit.to_static
87-
def test_setitem(x):
88+
def test_set_value(x):
8889
x = paddle.to_tensor(x)
8990
x[0] = paddle.full(shape=[1], fill_value=2, dtype="float32")
9091
x[1:2, 0:1] = 10
@@ -140,12 +141,12 @@ def init_dygraph_func(self):
140141
self.dygraph_func = test_slice_in_for_loop
141142

142143

143-
class TestSetitem(TestSliceWithoutControlFlow):
144+
class TestSetValue(TestSliceWithoutControlFlow):
144145
def init_input(self):
145-
self.input = np.full([3,4,5], 5).astype('float32')
146+
self.input = np.full([3, 4, 5], 5).astype('float32')
146147

147148
def init_dygraph_func(self):
148-
self.dygraph_func = test_setitem
149+
self.dygraph_func = test_set_value
149150

150151

151152
if __name__ == '__main__':

python/paddle/fluid/tests/unittests/test_setitem_op.py renamed to python/paddle/fluid/tests/unittests/test_set_value_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Test setitem op in static mode
15+
# Test set_value op in static mode
1616

1717
from __future__ import print_function
1818

0 commit comments

Comments
 (0)