|
21 | 21 | #include <utility> |
22 | 22 | #include "paddle/fluid/framework/eigen.h" |
23 | 23 | #include "paddle/fluid/framework/op_registry.h" |
| 24 | +#include "paddle/fluid/framework/tensor_util.h" |
24 | 25 | #include "paddle/fluid/operators/assign_value_op.h" |
25 | 26 | #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" |
26 | 27 | #include "paddle/fluid/platform/enforce.h" |
@@ -83,6 +84,10 @@ class SetValueKernel : public framework::OpKernel<T> { |
83 | 84 | public: |
84 | 85 | void Compute(const framework::ExecutionContext& ctx) const { |
85 | 86 | const int rank = ctx.Output<framework::LoDTensor>("Out")->dims().size(); |
| 87 | + |
| 88 | + // TODO(liym27): A more elegent code to do this. C++ has to make template |
| 89 | + // integer as constant, but we had better have alternative writing in the |
| 90 | + // future. |
86 | 91 | switch (rank) { |
87 | 92 | case 1: |
88 | 93 | SetValueCompute<1>(ctx); |
@@ -127,7 +132,18 @@ class SetValueKernel : public framework::OpKernel<T> { |
127 | 132 | auto& eigen_place = |
128 | 133 | *ctx.template device_context<DeviceContext>().eigen_device(); |
129 | 134 |
|
130 | | - out->ShareDataWith(*in); |
| 135 | + // Here copy data from input to avoid data loss at PE and Graph level. |
| 136 | + // TODO(liym27): Speed up in the future version. |
| 137 | + // - Q: Why don't call ShareDataWith to speed up? |
| 138 | + // - A: Because it's not supported to ShareDataWith on OP's input and output |
| 139 | + // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP |
| 140 | + // - Q: Why don't delete Input, after all, the input and output are the same |
| 141 | + // Tensor at program level? |
| 142 | + // - A: If deleting Input, the graph will be complex, such as there will |
| 143 | + // be two ops points to the output in graph: op1 -> output <- set_value. |
| 144 | + // In this case, we have to find a way to handle the running order of |
| 145 | + // set_value is what we want. |
| 146 | + TensorCopy(*in, place, out); |
131 | 147 |
|
132 | 148 | Tensor slice_t(dtype), pad_t(dtype); |
133 | 149 | slice_t.mutable_data<T>(slice_dims, place); |
|
0 commit comments