Skip to content

Commit f46ddc0

Browse files
authored
[Cherry-Pick 2.0][setitem] Support Tensor setitem in static mode (#29708) (#30104)
1. Type of index: int, slice(step must be 1). 2. Type of value: (1) int32, int64, float32, bool; (2) numpy.array(int32, int64, float32, bool);<Note: float64 is not supported> (3) paddle.Tensor(int32, int64, float32, float64, bool);
1 parent b2ca2ca commit f46ddc0

File tree

6 files changed

+928
-2
lines changed

6 files changed

+928
-2
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/set_value_op.h"
16+
17+
#include <string>
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
class SetValue : public framework::OperatorWithKernel {
23+
public:
24+
SetValue(const std::string &type, const framework::VariableNameMap &inputs,
25+
const framework::VariableNameMap &outputs,
26+
const framework::AttributeMap &attrs)
27+
: OperatorWithKernel(type, inputs, outputs, attrs) {}
28+
29+
void InferShape(framework::InferShapeContext *ctx) const override {
30+
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "SetValue");
31+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SetValue");
32+
auto in_dims = ctx->GetInputDim("Input");
33+
PADDLE_ENFORCE_LT(
34+
in_dims.size(), 7,
35+
platform::errors::InvalidArgument(
36+
"The rank of input should be less than 7, but received %d.",
37+
in_dims.size()));
38+
}
39+
40+
protected:
41+
framework::OpKernelType GetExpectedKernelType(
42+
const framework::ExecutionContext &ctx) const override {
43+
return framework::OpKernelType(
44+
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
45+
ctx.GetPlace());
46+
}
47+
};
48+
49+
class SetValueMaker : public framework::OpProtoAndCheckerMaker {
50+
public:
51+
void Make() override {
52+
AddInput("Input", "(Tensor) Input tensor of set_value operator.");
53+
AddInput("ValueTensor", "(Tensor) Value tensor of set_value operator.")
54+
.AsDispensable();
55+
AddOutput("Out",
56+
"(Tensor) Output tensor of set_value operator. The output is the "
57+
"same Tensor as input");
58+
59+
AddAttr<int>("dtype", "data type of input.")
60+
.InEnum(
61+
{framework::proto::VarType::BOOL, framework::proto::VarType::INT32,
62+
framework::proto::VarType::INT64, framework::proto::VarType::FP32,
63+
framework::proto::VarType::FP64})
64+
.SetDefault(framework::proto::VarType::FP32);
65+
AddAttr<std::vector<int64_t>>(
66+
"axes", "(list<int64_t>) Axes that `starts` and `ends` apply to.");
67+
AddAttr<std::vector<int64_t>>(
68+
"starts",
69+
"(list<int64_t>) Starting indices of corresponding axis in `axes`");
70+
AddAttr<std::vector<int64_t>>(
71+
"ends",
72+
"(list<int64_t>) Ending indices of corresponding axis in `axes`.");
73+
74+
AddAttr<std::vector<int>>("bool_values", "store the bool values")
75+
.SetDefault({});
76+
AddAttr<std::vector<float>>("fp32_values", "store the float32 values")
77+
.SetDefault({});
78+
AddAttr<std::vector<int>>("int32_values", "store the int32 values")
79+
.SetDefault({});
80+
AddAttr<std::vector<int64_t>>("int64_values", "store the int64 values")
81+
.SetDefault({});
82+
83+
AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.")
84+
.SetDefault({});
85+
AddComment(R"DOC(SetValue operator.
86+
Assignment to a Tensor in static mode.
87+
)DOC");
88+
}
89+
};
90+
} // namespace operators
91+
} // namespace paddle
92+
93+
namespace ops = paddle::operators;
94+
95+
REGISTER_OPERATOR(
96+
set_value, ops::SetValue, ops::SetValueMaker,
97+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
98+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
99+
100+
REGISTER_OP_CPU_KERNEL(
101+
set_value, ops::SetValueKernel<paddle::platform::CPUDeviceContext, int>,
102+
ops::SetValueKernel<paddle::platform::CPUDeviceContext, int64_t>,
103+
ops::SetValueKernel<paddle::platform::CPUDeviceContext, float>,
104+
ops::SetValueKernel<paddle::platform::CPUDeviceContext, double>,
105+
ops::SetValueKernel<paddle::platform::CPUDeviceContext, bool>);
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/set_value_op.h"
16+
17+
namespace ops = paddle::operators;
18+
19+
REGISTER_OP_CUDA_KERNEL(
20+
set_value, ops::SetValueKernel<paddle::platform::CUDADeviceContext, int>,
21+
ops::SetValueKernel<paddle::platform::CUDADeviceContext, int64_t>,
22+
ops::SetValueKernel<paddle::platform::CUDADeviceContext, float>,
23+
ops::SetValueKernel<paddle::platform::CUDADeviceContext, double>,
24+
ops::SetValueKernel<paddle::platform::CUDADeviceContext, bool>);
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
#include <string>
19+
#include <vector>
20+
21+
#include <utility>
22+
#include "paddle/fluid/framework/eigen.h"
23+
#include "paddle/fluid/framework/op_registry.h"
24+
#include "paddle/fluid/framework/tensor_util.h"
25+
#include "paddle/fluid/operators/assign_value_op.h"
26+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
27+
#include "paddle/fluid/platform/enforce.h"
28+
29+
namespace paddle {
30+
namespace operators {
31+
32+
using Tensor = framework::Tensor;
33+
34+
inline std::string GetValueName(framework::proto::VarType::Type data_type) {
35+
std::string value_name;
36+
switch (data_type) {
37+
case framework::proto::VarType::INT32:
38+
value_name = "int32_values";
39+
break;
40+
case framework::proto::VarType::INT64:
41+
value_name = "int64_values";
42+
break;
43+
case framework::proto::VarType::FP32:
44+
value_name = "fp32_values";
45+
break;
46+
case framework::proto::VarType::BOOL:
47+
value_name = "bool_values";
48+
break;
49+
default:
50+
PADDLE_THROW(platform::errors::Unimplemented(
51+
"Unsupported data type(code %d) for SetValue operator, only "
52+
"supports bool, int32, float32 and int64.",
53+
data_type));
54+
}
55+
return value_name;
56+
}
57+
58+
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
59+
const std::vector<int64_t> axes,
60+
const std::vector<int64_t> starts,
61+
const std::vector<int64_t> ends) {
62+
framework::DDim slice_dims(in_dims);
63+
64+
for (size_t i = 0; i < axes.size(); ++i) {
65+
int64_t axis = axes[i];
66+
int64_t dim_value = in_dims[axis];
67+
68+
int64_t start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
69+
int64_t end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
70+
start = std::max(start, static_cast<int64_t>(0));
71+
end = std::min(end, dim_value);
72+
73+
PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument(
74+
"end should greater than start, but "
75+
"received end = %d, start = %d",
76+
end, start));
77+
slice_dims[axis] = end - start;
78+
}
79+
return slice_dims;
80+
}
81+
82+
template <typename DeviceContext, typename T>
83+
class SetValueKernel : public framework::OpKernel<T> {
84+
public:
85+
void Compute(const framework::ExecutionContext& ctx) const {
86+
const int rank = ctx.Output<framework::LoDTensor>("Out")->dims().size();
87+
88+
// TODO(liym27): A more elegent code to do this. C++ has to make template
89+
// integer as constant, but we had better have alternative writing in the
90+
// future.
91+
switch (rank) {
92+
case 1:
93+
SetValueCompute<1>(ctx);
94+
break;
95+
case 2:
96+
SetValueCompute<2>(ctx);
97+
break;
98+
case 3:
99+
SetValueCompute<3>(ctx);
100+
break;
101+
case 4:
102+
SetValueCompute<4>(ctx);
103+
break;
104+
case 5:
105+
SetValueCompute<5>(ctx);
106+
break;
107+
case 6:
108+
SetValueCompute<6>(ctx);
109+
break;
110+
}
111+
}
112+
113+
private:
114+
template <size_t D>
115+
void SetValueCompute(const framework::ExecutionContext& ctx) const {
116+
auto* in = ctx.Input<framework::LoDTensor>("Input");
117+
auto* out = ctx.Output<framework::LoDTensor>("Out");
118+
119+
auto dtype =
120+
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
121+
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
122+
auto starts = ctx.Attr<std::vector<int64_t>>("starts");
123+
auto ends = ctx.Attr<std::vector<int64_t>>("ends");
124+
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
125+
auto* value_tensor = ctx.Input<framework::LoDTensor>("ValueTensor");
126+
127+
auto in_dims = in->dims();
128+
auto value_dims = framework::make_ddim(shape);
129+
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends);
130+
131+
auto place = ctx.GetPlace();
132+
auto& eigen_place =
133+
*ctx.template device_context<DeviceContext>().eigen_device();
134+
135+
// Here copy data from input to avoid data loss at PE and Graph level.
136+
// TODO(liym27): Speed up in the future version.
137+
// - Q: Why don't call ShareDataWith to speed up?
138+
// - A: Because it's not supported to ShareDataWith on OP's input and output
139+
// https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
140+
// - Q: Why don't delete Input, after all, the input and output are the same
141+
// Tensor at program level?
142+
// - A: If deleting Input, the graph will be complex, such as there will
143+
// be two ops points to the output in graph: op1 -> output <- set_value.
144+
// In this case, we have to find a way to handle the running order of
145+
// set_value is what we want.
146+
TensorCopy(*in, place, out);
147+
148+
Tensor slice_t(dtype), pad_t(dtype);
149+
slice_t.mutable_data<T>(slice_dims, place);
150+
pad_t.mutable_data<T>(in_dims, place);
151+
152+
auto pad_e = framework::EigenTensor<T, D>::From(pad_t, in_dims);
153+
auto out_e = framework::EigenTensor<T, D>::From(*out);
154+
auto slice_e = framework::EigenTensor<T, D>::From(slice_t, slice_dims);
155+
156+
// Step 1: Set the value of out at `_index` to zero
157+
// - Step 1.1 Get a slice tensor from out
158+
Eigen::array<int64_t, D> offsets, extents;
159+
Eigen::array<std::pair<int64_t, int64_t>, D> paddings;
160+
161+
for (size_t i = 0; i < D; ++i) {
162+
offsets[i] = 0;
163+
extents[i] = slice_dims[i];
164+
}
165+
int64_t start;
166+
for (size_t i = 0; i < axes.size(); ++i) {
167+
start = starts[i] < 0 ? (starts[i] + in_dims[axes[i]]) : starts[i];
168+
start = std::max(start, static_cast<int64_t>(0));
169+
offsets[axes[i]] = start;
170+
}
171+
for (size_t i = 0; i < paddings.size(); ++i) {
172+
paddings[i].first = offsets[i];
173+
paddings[i].second = (in_dims[i] - slice_dims[i]) - offsets[i];
174+
}
175+
176+
slice_e.device(eigen_place) = out_e.slice(offsets, extents);
177+
178+
// - Step 1.2 Get paded tensor by padding 0 to slice tensor
179+
pad_e.device(eigen_place) = slice_e.pad(paddings, T(0));
180+
181+
// - Step 1.3 Set 0 at `_index` of out tensor
182+
out_e.device(eigen_place) = out_e - pad_e;
183+
184+
// Step 2: Set a tensor with the same shape as out tensor. And its data at
185+
// '_index' is the same as value_tensor, and data out of '_index' to zero
186+
187+
// - Step 2.1 Set the data of slice tensor to 0
188+
slice_e.device(eigen_place) = slice_e.constant(T(0));
189+
190+
// - Step 2.2 Set slice tensor with value
191+
if (value_tensor != nullptr) {
192+
// ElementwiseComputeEx can do broadcasting
193+
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
194+
ctx, &slice_t, value_tensor, -1, SubFunctor<T>(), &slice_t);
195+
} else {
196+
Tensor value_t(dtype);
197+
value_t.mutable_data<T>(value_dims, place);
198+
auto value_name = GetValueName(dtype);
199+
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx);
200+
value_t.Resize(value_dims);
201+
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
202+
ctx, &slice_t, &value_t, -1, SubFunctor<T>(), &slice_t);
203+
}
204+
205+
// - Step 2.3 Pad slice tensor with 0
206+
pad_e.device(eigen_place) = slice_e.pad(paddings, T(0));
207+
208+
// Step 3: Set out tensor with value_tensor
209+
out_e.device(eigen_place) = out_e - pad_e;
210+
}
211+
};
212+
213+
} // namespace operators
214+
} // namespace paddle

0 commit comments

Comments
 (0)