|  | 
|  | 1 | +//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | 
|  | 2 | +// | 
|  | 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +// you may not use this file except in compliance with the License. | 
|  | 5 | +// You may obtain a copy of the License at | 
|  | 6 | +// | 
|  | 7 | +//     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +// | 
|  | 9 | +// Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +// distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +// See the License for the specific language governing permissions and | 
|  | 13 | +// limitations under the License. | 
|  | 14 | + | 
|  | 15 | +#pragma once | 
|  | 16 | + | 
|  | 17 | +#include <algorithm> | 
|  | 18 | +#include <string> | 
|  | 19 | +#include <vector> | 
|  | 20 | + | 
|  | 21 | +#include <utility> | 
|  | 22 | +#include "paddle/fluid/framework/eigen.h" | 
|  | 23 | +#include "paddle/fluid/framework/op_registry.h" | 
|  | 24 | +#include "paddle/fluid/framework/tensor_util.h" | 
|  | 25 | +#include "paddle/fluid/operators/assign_value_op.h" | 
|  | 26 | +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" | 
|  | 27 | +#include "paddle/fluid/platform/enforce.h" | 
|  | 28 | + | 
|  | 29 | +namespace paddle { | 
|  | 30 | +namespace operators { | 
|  | 31 | + | 
|  | 32 | +using Tensor = framework::Tensor; | 
|  | 33 | + | 
|  | 34 | +inline std::string GetValueName(framework::proto::VarType::Type data_type) { | 
|  | 35 | +  std::string value_name; | 
|  | 36 | +  switch (data_type) { | 
|  | 37 | +    case framework::proto::VarType::INT32: | 
|  | 38 | +      value_name = "int32_values"; | 
|  | 39 | +      break; | 
|  | 40 | +    case framework::proto::VarType::INT64: | 
|  | 41 | +      value_name = "int64_values"; | 
|  | 42 | +      break; | 
|  | 43 | +    case framework::proto::VarType::FP32: | 
|  | 44 | +      value_name = "fp32_values"; | 
|  | 45 | +      break; | 
|  | 46 | +    case framework::proto::VarType::BOOL: | 
|  | 47 | +      value_name = "bool_values"; | 
|  | 48 | +      break; | 
|  | 49 | +    default: | 
|  | 50 | +      PADDLE_THROW(platform::errors::Unimplemented( | 
|  | 51 | +          "Unsupported data type(code %d) for SetValue operator, only " | 
|  | 52 | +          "supports bool, int32, float32 and int64.", | 
|  | 53 | +          data_type)); | 
|  | 54 | +  } | 
|  | 55 | +  return value_name; | 
|  | 56 | +} | 
|  | 57 | + | 
|  | 58 | +inline framework::DDim GetSliceDims(const framework::DDim in_dims, | 
|  | 59 | +                                    const std::vector<int64_t> axes, | 
|  | 60 | +                                    const std::vector<int64_t> starts, | 
|  | 61 | +                                    const std::vector<int64_t> ends) { | 
|  | 62 | +  framework::DDim slice_dims(in_dims); | 
|  | 63 | + | 
|  | 64 | +  for (size_t i = 0; i < axes.size(); ++i) { | 
|  | 65 | +    int64_t axis = axes[i]; | 
|  | 66 | +    int64_t dim_value = in_dims[axis]; | 
|  | 67 | + | 
|  | 68 | +    int64_t start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i]; | 
|  | 69 | +    int64_t end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i]; | 
|  | 70 | +    start = std::max(start, static_cast<int64_t>(0)); | 
|  | 71 | +    end = std::min(end, dim_value); | 
|  | 72 | + | 
|  | 73 | +    PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument( | 
|  | 74 | +                                      "end should greater than start, but " | 
|  | 75 | +                                      "received end = %d, start = %d", | 
|  | 76 | +                                      end, start)); | 
|  | 77 | +    slice_dims[axis] = end - start; | 
|  | 78 | +  } | 
|  | 79 | +  return slice_dims; | 
|  | 80 | +} | 
|  | 81 | + | 
|  | 82 | +template <typename DeviceContext, typename T> | 
|  | 83 | +class SetValueKernel : public framework::OpKernel<T> { | 
|  | 84 | + public: | 
|  | 85 | +  void Compute(const framework::ExecutionContext& ctx) const { | 
|  | 86 | +    const int rank = ctx.Output<framework::LoDTensor>("Out")->dims().size(); | 
|  | 87 | + | 
|  | 88 | +    // TODO(liym27): A more elegent code to do this. C++ has to make template | 
|  | 89 | +    //  integer as constant, but we had better have alternative writing in the | 
|  | 90 | +    //  future. | 
|  | 91 | +    switch (rank) { | 
|  | 92 | +      case 1: | 
|  | 93 | +        SetValueCompute<1>(ctx); | 
|  | 94 | +        break; | 
|  | 95 | +      case 2: | 
|  | 96 | +        SetValueCompute<2>(ctx); | 
|  | 97 | +        break; | 
|  | 98 | +      case 3: | 
|  | 99 | +        SetValueCompute<3>(ctx); | 
|  | 100 | +        break; | 
|  | 101 | +      case 4: | 
|  | 102 | +        SetValueCompute<4>(ctx); | 
|  | 103 | +        break; | 
|  | 104 | +      case 5: | 
|  | 105 | +        SetValueCompute<5>(ctx); | 
|  | 106 | +        break; | 
|  | 107 | +      case 6: | 
|  | 108 | +        SetValueCompute<6>(ctx); | 
|  | 109 | +        break; | 
|  | 110 | +    } | 
|  | 111 | +  } | 
|  | 112 | + | 
|  | 113 | + private: | 
|  | 114 | +  template <size_t D> | 
|  | 115 | +  void SetValueCompute(const framework::ExecutionContext& ctx) const { | 
|  | 116 | +    auto* in = ctx.Input<framework::LoDTensor>("Input"); | 
|  | 117 | +    auto* out = ctx.Output<framework::LoDTensor>("Out"); | 
|  | 118 | + | 
|  | 119 | +    auto dtype = | 
|  | 120 | +        static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")); | 
|  | 121 | +    auto axes = ctx.Attr<std::vector<int64_t>>("axes"); | 
|  | 122 | +    auto starts = ctx.Attr<std::vector<int64_t>>("starts"); | 
|  | 123 | +    auto ends = ctx.Attr<std::vector<int64_t>>("ends"); | 
|  | 124 | +    auto shape = ctx.Attr<std::vector<int64_t>>("shape"); | 
|  | 125 | +    auto* value_tensor = ctx.Input<framework::LoDTensor>("ValueTensor"); | 
|  | 126 | + | 
|  | 127 | +    auto in_dims = in->dims(); | 
|  | 128 | +    auto value_dims = framework::make_ddim(shape); | 
|  | 129 | +    auto slice_dims = GetSliceDims(in_dims, axes, starts, ends); | 
|  | 130 | + | 
|  | 131 | +    auto place = ctx.GetPlace(); | 
|  | 132 | +    auto& eigen_place = | 
|  | 133 | +        *ctx.template device_context<DeviceContext>().eigen_device(); | 
|  | 134 | + | 
|  | 135 | +    // Here copy data from input to avoid data loss at PE and Graph level. | 
|  | 136 | +    // TODO(liym27): Speed up in the future version. | 
|  | 137 | +    // - Q: Why don't call ShareDataWith to speed up? | 
|  | 138 | +    // - A: Because it's not supported to ShareDataWith on OP's input and output | 
|  | 139 | +    // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP | 
|  | 140 | +    // - Q: Why don't delete Input, after all, the input and output are the same | 
|  | 141 | +    // Tensor at program level? | 
|  | 142 | +    // - A: If deleting Input, the graph will be complex, such as there will | 
|  | 143 | +    // be two ops points to the output in graph: op1 -> output <- set_value. | 
|  | 144 | +    // In this case, we have to find a way to handle the running order of | 
|  | 145 | +    // set_value is what we want. | 
|  | 146 | +    TensorCopy(*in, place, out); | 
|  | 147 | + | 
|  | 148 | +    Tensor slice_t(dtype), pad_t(dtype); | 
|  | 149 | +    slice_t.mutable_data<T>(slice_dims, place); | 
|  | 150 | +    pad_t.mutable_data<T>(in_dims, place); | 
|  | 151 | + | 
|  | 152 | +    auto pad_e = framework::EigenTensor<T, D>::From(pad_t, in_dims); | 
|  | 153 | +    auto out_e = framework::EigenTensor<T, D>::From(*out); | 
|  | 154 | +    auto slice_e = framework::EigenTensor<T, D>::From(slice_t, slice_dims); | 
|  | 155 | + | 
|  | 156 | +    // Step 1: Set the value of out at `_index` to zero | 
|  | 157 | +    // - Step 1.1 Get a slice tensor from out | 
|  | 158 | +    Eigen::array<int64_t, D> offsets, extents; | 
|  | 159 | +    Eigen::array<std::pair<int64_t, int64_t>, D> paddings; | 
|  | 160 | + | 
|  | 161 | +    for (size_t i = 0; i < D; ++i) { | 
|  | 162 | +      offsets[i] = 0; | 
|  | 163 | +      extents[i] = slice_dims[i]; | 
|  | 164 | +    } | 
|  | 165 | +    int64_t start; | 
|  | 166 | +    for (size_t i = 0; i < axes.size(); ++i) { | 
|  | 167 | +      start = starts[i] < 0 ? (starts[i] + in_dims[axes[i]]) : starts[i]; | 
|  | 168 | +      start = std::max(start, static_cast<int64_t>(0)); | 
|  | 169 | +      offsets[axes[i]] = start; | 
|  | 170 | +    } | 
|  | 171 | +    for (size_t i = 0; i < paddings.size(); ++i) { | 
|  | 172 | +      paddings[i].first = offsets[i]; | 
|  | 173 | +      paddings[i].second = (in_dims[i] - slice_dims[i]) - offsets[i]; | 
|  | 174 | +    } | 
|  | 175 | + | 
|  | 176 | +    slice_e.device(eigen_place) = out_e.slice(offsets, extents); | 
|  | 177 | + | 
|  | 178 | +    // - Step 1.2 Get paded tensor by padding 0 to slice tensor | 
|  | 179 | +    pad_e.device(eigen_place) = slice_e.pad(paddings, T(0)); | 
|  | 180 | + | 
|  | 181 | +    // - Step 1.3 Set 0 at `_index` of out tensor | 
|  | 182 | +    out_e.device(eigen_place) = out_e - pad_e; | 
|  | 183 | + | 
|  | 184 | +    // Step 2: Set a tensor with the same shape as out tensor. And its data at | 
|  | 185 | +    // '_index' is the same as value_tensor, and data out of '_index' to zero | 
|  | 186 | + | 
|  | 187 | +    // - Step 2.1 Set the data of slice tensor to 0 | 
|  | 188 | +    slice_e.device(eigen_place) = slice_e.constant(T(0)); | 
|  | 189 | + | 
|  | 190 | +    // - Step 2.2 Set slice tensor with value | 
|  | 191 | +    if (value_tensor != nullptr) { | 
|  | 192 | +      // ElementwiseComputeEx can do broadcasting | 
|  | 193 | +      ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( | 
|  | 194 | +          ctx, &slice_t, value_tensor, -1, SubFunctor<T>(), &slice_t); | 
|  | 195 | +    } else { | 
|  | 196 | +      Tensor value_t(dtype); | 
|  | 197 | +      value_t.mutable_data<T>(value_dims, place); | 
|  | 198 | +      auto value_name = GetValueName(dtype); | 
|  | 199 | +      CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx); | 
|  | 200 | +      value_t.Resize(value_dims); | 
|  | 201 | +      ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( | 
|  | 202 | +          ctx, &slice_t, &value_t, -1, SubFunctor<T>(), &slice_t); | 
|  | 203 | +    } | 
|  | 204 | + | 
|  | 205 | +    // - Step 2.3 Pad slice tensor with 0 | 
|  | 206 | +    pad_e.device(eigen_place) = slice_e.pad(paddings, T(0)); | 
|  | 207 | + | 
|  | 208 | +    // Step 3: Set out tensor with value_tensor | 
|  | 209 | +    out_e.device(eigen_place) = out_e - pad_e; | 
|  | 210 | +  } | 
|  | 211 | +}; | 
|  | 212 | + | 
|  | 213 | +}  // namespace operators | 
|  | 214 | +}  // namespace paddle | 
0 commit comments