14
14
15
15
#include " paddle/phi/kernels/put_along_axis_kernel.h"
16
16
17
- #include " paddle/fluid/framework/convert_utils.h"
18
17
#include " paddle/fluid/operators/gather_scatter_kernel.h"
19
18
#include " paddle/phi/backends/cpu/cpu_context.h"
19
+ #include " paddle/phi/common/data_type.h"
20
20
#include " paddle/phi/common/place.h"
21
21
#include " paddle/phi/core/kernel_registry.h"
22
22
#include " paddle/phi/core/tensor_utils.h"
@@ -37,29 +37,28 @@ void PutAlongAxisKernel(const Context& dev_ctx,
37
37
errors::PreconditionNotMet (" PutAlongAxisOpKernel only runs on CPU." ));
38
38
39
39
phi::Copy (dev_ctx, x, dev_ctx.GetPlace (), false , out);
40
- const auto & index_type =
41
- paddle::framework::TransToProtoVarType (index .dtype ());
40
+ const auto & index_type = index .dtype ();
42
41
if (reduce == " add" ) {
43
- if (index_type == paddle::framework::proto::VarType ::INT32) {
42
+ if (index_type == DataType ::INT32) {
44
43
paddle::operators::cpu_scatter_add_kernel<T, int32_t >(
45
44
*out, axis, index , value, dev_ctx);
46
- } else if (index_type == paddle::framework::proto::VarType ::INT64) {
45
+ } else if (index_type == DataType ::INT64) {
47
46
paddle::operators::cpu_scatter_add_kernel<T, int64_t >(
48
47
*out, axis, index , value, dev_ctx);
49
48
}
50
49
} else if (reduce == " multiply" || reduce == " mul" ) {
51
- if (index_type == paddle::framework::proto::VarType ::INT32) {
50
+ if (index_type == DataType ::INT32) {
52
51
paddle::operators::cpu_scatter_mul_kernel<T, int32_t >(
53
52
*out, axis, index , value, dev_ctx);
54
- } else if (index_type == paddle::framework::proto::VarType ::INT64) {
53
+ } else if (index_type == DataType ::INT64) {
55
54
paddle::operators::cpu_scatter_mul_kernel<T, int64_t >(
56
55
*out, axis, index , value, dev_ctx);
57
56
}
58
57
} else if (reduce == " assign" ) {
59
- if (index_type == paddle::framework::proto::VarType ::INT32) {
58
+ if (index_type == DataType ::INT32) {
60
59
paddle::operators::cpu_scatter_assign_kernel<T, int32_t >(
61
60
*out, axis, index , value, dev_ctx);
62
- } else if (index_type == paddle::framework::proto::VarType ::INT64) {
61
+ } else if (index_type == DataType ::INT64) {
63
62
paddle::operators::cpu_scatter_assign_kernel<T, int64_t >(
64
63
*out, axis, index , value, dev_ctx);
65
64
}
0 commit comments