Skip to content

Commit 2f34fc7

Browse files
authored
rm "paddle/fluid/framework/convert_utils.h" in phi (#48001)
1 parent f365020 commit 2f34fc7

20 files changed

+138
-161
lines changed

paddle/fluid/framework/convert_utils.cc

-35
Original file line numberDiff line numberDiff line change
@@ -162,40 +162,5 @@ DataType String2DataType(const std::string& str) {
162162
}
163163
}
164164

165-
std::string DataType2String(DataType dtype) {
166-
switch (dtype) {
167-
case DataType::BOOL:
168-
return "bool";
169-
case DataType::INT8:
170-
return "int8";
171-
case DataType::UINT8:
172-
return "uint8";
173-
case DataType::INT16:
174-
return "int16";
175-
case DataType::INT32:
176-
return "int32";
177-
case DataType::INT64:
178-
return "int64";
179-
case DataType::FLOAT16:
180-
return "float16";
181-
case DataType::FLOAT32:
182-
return "float32";
183-
case DataType::FLOAT64:
184-
return "float64";
185-
case DataType::COMPLEX64:
186-
return "complex64";
187-
case DataType::COMPLEX128:
188-
return "complex128";
189-
case DataType::PSTRING:
190-
return "pstring";
191-
case DataType::BFLOAT16:
192-
return "bfloat16";
193-
default:
194-
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
195-
"Unknow phi::DataType, the int value = %d.",
196-
static_cast<int>(dtype)));
197-
return "";
198-
}
199-
}
200165
} // namespace framework
201166
} // namespace paddle

paddle/fluid/framework/convert_utils.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include "paddle/phi/core/tensor_meta.h"
2121

2222
#include "paddle/fluid/framework/data_type.h"
23+
#include "paddle/phi/core/utils/data_type.h"
2324

2425
// TODO(chenweihang): this file may need to be removed
2526

@@ -37,7 +38,8 @@ paddle::framework::proto::VarType::Type TransToProtoVarType(
3738

3839
size_t DataTypeSize(DataType dtype);
3940
DataType String2DataType(const std::string& str);
40-
std::string DataType2String(DataType dtype);
41+
42+
using phi::DataType2String;
4143

4244
} // namespace framework
4345
} // namespace paddle

paddle/fluid/operators/prune_gate_by_capacity_op.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class PruneGateByCapacityCUDAKernel : public framework::OpKernel<T> {
121121
framework::TensorCopy(*expert_count, context.GetPlace(), &expert_count_out);
122122
PruneGateByCapacityFunctor<DeviceContext, T> functor(
123123
context, gate_idx, &expert_count_out, new_gate_idx_data);
124-
VisitDataType(expert_count->type(), functor);
124+
::paddle::operators::VisitDataType(expert_count->type(), functor);
125125
}
126126
};
127127

paddle/phi/core/utils/data_type.h

+45
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ static std::map<int, phi::DataType> var_type_map{{1, phi::DataType::INT16},
4141
{6, phi::DataType::FLOAT64},
4242
{20, phi::DataType::UINT8}};
4343

44+
static std::map<phi::DataType, int> map_to_var_type{{phi::DataType::INT16, 1},
45+
{phi::DataType::INT32, 2},
46+
{phi::DataType::INT64, 3},
47+
{phi::DataType::FLOAT16, 4},
48+
{phi::DataType::FLOAT32, 5},
49+
{phi::DataType::FLOAT64, 6},
50+
{phi::DataType::UINT8, 20}};
51+
4452
#define _PhiForEachDataTypeHelper_(callback, cpp_type, data_type) \
4553
callback(cpp_type, data_type);
4654

@@ -129,4 +137,41 @@ inline DataType ToRealType(const DataType& type) {
129137
type));
130138
}
131139
}
140+
141+
inline std::string DataType2String(DataType dtype) {
142+
switch (dtype) {
143+
case DataType::BOOL:
144+
return "bool";
145+
case DataType::INT8:
146+
return "int8";
147+
case DataType::UINT8:
148+
return "uint8";
149+
case DataType::INT16:
150+
return "int16";
151+
case DataType::INT32:
152+
return "int32";
153+
case DataType::INT64:
154+
return "int64";
155+
case DataType::FLOAT16:
156+
return "float16";
157+
case DataType::FLOAT32:
158+
return "float32";
159+
case DataType::FLOAT64:
160+
return "float64";
161+
case DataType::COMPLEX64:
162+
return "complex64";
163+
case DataType::COMPLEX128:
164+
return "complex128";
165+
case DataType::PSTRING:
166+
return "pstring";
167+
case DataType::BFLOAT16:
168+
return "bfloat16";
169+
default:
170+
PADDLE_THROW(
171+
errors::InvalidArgument("Unknow phi::DataType, the int value = %d.",
172+
static_cast<int>(dtype)));
173+
return "";
174+
}
175+
}
176+
132177
} // namespace phi

paddle/phi/infermeta/unary.cc

+5-9
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ limitations under the License. */
1717
#include <algorithm>
1818
#include <set>
1919

20-
#include "paddle/fluid/framework/convert_utils.h"
2120
#include "paddle/phi/common/data_type.h"
2221
#include "paddle/phi/common/type_traits.h"
2322
#include "paddle/phi/core/enforce.h"
2423
#include "paddle/phi/core/infermeta_utils.h"
24+
#include "paddle/phi/core/utils/data_type.h"
2525
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
2626
#include "paddle/phi/kernels/funcs/pooling.h"
2727
#include "paddle/phi/kernels/funcs/slice_utils.h"
@@ -133,12 +133,9 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
133133
phi::errors::InvalidArgument(
134134
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
135135
"received [%s]",
136-
paddle::framework::DataTypeToString(
137-
paddle::framework::proto::VarType::INT32),
138-
paddle::framework::DataTypeToString(
139-
paddle::framework::proto::VarType::INT64),
140-
paddle::framework::DataTypeToString(
141-
static_cast<paddle::framework::proto::VarType::Type>(dtype))));
136+
phi::DataType2String(DataType::INT32),
137+
phi::DataType2String(DataType::INT64),
138+
phi::DataType2String(var_type_map[dtype])));
142139

143140
if (!config.is_runtime && axis.FromTensor()) {
144141
std::vector<int64_t> vec;
@@ -180,11 +177,10 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
180177
auto x_rank = x_dims.size();
181178
if (int_axis < 0) int_axis += x_rank;
182179
if (config.is_runtime) {
183-
if (dtype == paddle::framework::proto::VarType::INT32) {
180+
if (dtype == map_to_var_type[DataType::INT32]) {
184181
int64_t all_element_num = 0;
185182
if (flatten) {
186183
all_element_num = phi::product(x_dims);
187-
188184
} else {
189185
all_element_num = x_dims[int_axis];
190186
}

paddle/phi/kernels/cpu/index_sample_grad_kernel.cc

+9-13
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515
#include "paddle/phi/kernels/index_sample_grad_kernel.h"
1616

17-
#include "paddle/fluid/framework/convert_utils.h"
1817
#include "paddle/fluid/framework/tensor_util.h"
1918
#include "paddle/phi/backends/cpu/cpu_context.h"
2019
#include "paddle/phi/common/data_type.h"
2120
#include "paddle/phi/core/kernel_registry.h"
21+
#include "paddle/phi/core/utils/data_type.h"
2222
namespace phi {
2323
template <typename T, typename Context, typename IndexT = int>
2424
void IndexSampleGradInner(const Context& context,
@@ -76,18 +76,14 @@ void IndexSampleGradKernel(const Context& ctx,
7676
auto index_type = index.dtype();
7777
bool index_type_match =
7878
index_type == DataType::INT32 || index_type == DataType::INT64;
79-
PADDLE_ENFORCE_EQ(
80-
index_type_match,
81-
true,
82-
errors::InvalidArgument(
83-
"Input(Index) holds the wrong type, it holds %s, but "
84-
"desires to be %s or %s",
85-
paddle::framework::DataTypeToString(
86-
paddle::framework::TransToProtoVarType(index_type)),
87-
paddle::framework::DataTypeToString(
88-
paddle::framework::TransToProtoVarType(DataType::INT32)),
89-
paddle::framework::DataTypeToString(
90-
paddle::framework::TransToProtoVarType((DataType::INT64)))));
79+
PADDLE_ENFORCE_EQ(index_type_match,
80+
true,
81+
errors::InvalidArgument(
82+
"Input(Index) holds the wrong type, it holds %s, but "
83+
"desires to be %s or %s",
84+
phi::DataType2String(index_type),
85+
phi::DataType2String(DataType::INT32),
86+
phi::DataType2String(DataType::INT64)));
9187
if (index_type == DataType::INT32) {
9288
IndexSampleGradInner<T, Context, int>(ctx, out_grad, index, x_grad);
9389
} else if (index_type == DataType::INT64) {

paddle/phi/kernels/cpu/index_sample_kernel.cc

+9-13
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
#include <utility>
2222
#include <vector>
2323

24-
#include "paddle/fluid/framework/convert_utils.h"
2524
#include "paddle/fluid/framework/tensor_util.h"
2625
#include "paddle/phi/backends/cpu/cpu_context.h"
2726
#include "paddle/phi/common/data_type.h"
2827
#include "paddle/phi/core/kernel_registry.h"
28+
#include "paddle/phi/core/utils/data_type.h"
2929
namespace phi {
3030
template <typename T, typename Context, typename IndexT = int>
3131
void IndexSampleInner(const Context &context,
@@ -89,18 +89,14 @@ void IndexSampleKernel(const Context &ctx,
8989
auto index_type = index.dtype();
9090
bool index_type_match =
9191
index_type == DataType::INT32 || index_type == DataType::INT64;
92-
PADDLE_ENFORCE_EQ(
93-
index_type_match,
94-
true,
95-
errors::InvalidArgument(
96-
"Input(Index) holds the wrong type, it holds %s, but "
97-
"desires to be %s or %s",
98-
paddle::framework::DataTypeToString(
99-
paddle::framework::TransToProtoVarType(index_type)),
100-
paddle::framework::DataTypeToString(
101-
paddle::framework::TransToProtoVarType(DataType::INT32)),
102-
paddle::framework::DataTypeToString(
103-
paddle::framework::TransToProtoVarType((DataType::INT64)))));
92+
PADDLE_ENFORCE_EQ(index_type_match,
93+
true,
94+
errors::InvalidArgument(
95+
"Input(Index) holds the wrong type, it holds %s, but "
96+
"desires to be %s or %s",
97+
phi::DataType2String(index_type),
98+
phi::DataType2String(DataType::INT32),
99+
phi::DataType2String(DataType::INT64)));
104100
if (index_type == DataType::INT32) {
105101
IndexSampleInner<T, Context, int>(ctx, x, index, out);
106102
} else if (index_type == DataType::INT64) {

paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc

+5-6
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
1616

17-
#include "paddle/fluid/framework/convert_utils.h"
1817
#include "paddle/fluid/operators/gather_scatter_kernel.h"
1918
#include "paddle/phi/backends/cpu/cpu_context.h"
19+
#include "paddle/phi/common/data_type.h"
2020
#include "paddle/phi/common/place.h"
2121
#include "paddle/phi/core/kernel_registry.h"
2222
#include "paddle/phi/core/tensor_utils.h"
@@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
3737
true,
3838
errors::PreconditionNotMet("PutAlongAxisGradOpKernel only runs on CPU."));
3939

40-
const auto& index_type =
41-
paddle::framework::TransToProtoVarType(index.dtype());
40+
const auto& index_type = index.dtype();
4241
if (x_grad) {
4342
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
44-
if (index_type == paddle::framework::proto::VarType::INT32) {
43+
if (index_type == DataType::INT32) {
4544
paddle::operators::cpu_scatter_input_grad_kernel<T, int32_t>(
4645
// Here passing an unused argument out_grad, because it's
4746
// convenient to instantiate a bunch of template function with the
@@ -60,10 +59,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
6059
if (value_grad) {
6160
value_grad->Resize(index.dims());
6261
value_grad->mutable_data<T>(dev_ctx.GetPlace());
63-
if (index_type == paddle::framework::proto::VarType::INT32) {
62+
if (index_type == DataType::INT32) {
6463
paddle::operators::cpu_gather_kernel<T, int32_t>(
6564
out_grad, axis, index, *value_grad, dev_ctx);
66-
} else if (index_type == paddle::framework::proto::VarType::INT64) {
65+
} else if (index_type == DataType::INT64) {
6766
paddle::operators::cpu_gather_kernel<T, int64_t>(
6867
out_grad, axis, index, *value_grad, dev_ctx);
6968
}

paddle/phi/kernels/cpu/put_along_axis_kernel.cc

+8-9
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
#include "paddle/phi/kernels/put_along_axis_kernel.h"
1616

17-
#include "paddle/fluid/framework/convert_utils.h"
1817
#include "paddle/fluid/operators/gather_scatter_kernel.h"
1918
#include "paddle/phi/backends/cpu/cpu_context.h"
19+
#include "paddle/phi/common/data_type.h"
2020
#include "paddle/phi/common/place.h"
2121
#include "paddle/phi/core/kernel_registry.h"
2222
#include "paddle/phi/core/tensor_utils.h"
@@ -37,29 +37,28 @@ void PutAlongAxisKernel(const Context& dev_ctx,
3737
errors::PreconditionNotMet("PutAlongAxisOpKernel only runs on CPU."));
3838

3939
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
40-
const auto& index_type =
41-
paddle::framework::TransToProtoVarType(index.dtype());
40+
const auto& index_type = index.dtype();
4241
if (reduce == "add") {
43-
if (index_type == paddle::framework::proto::VarType::INT32) {
42+
if (index_type == DataType::INT32) {
4443
paddle::operators::cpu_scatter_add_kernel<T, int32_t>(
4544
*out, axis, index, value, dev_ctx);
46-
} else if (index_type == paddle::framework::proto::VarType::INT64) {
45+
} else if (index_type == DataType::INT64) {
4746
paddle::operators::cpu_scatter_add_kernel<T, int64_t>(
4847
*out, axis, index, value, dev_ctx);
4948
}
5049
} else if (reduce == "multiply" || reduce == "mul") {
51-
if (index_type == paddle::framework::proto::VarType::INT32) {
50+
if (index_type == DataType::INT32) {
5251
paddle::operators::cpu_scatter_mul_kernel<T, int32_t>(
5352
*out, axis, index, value, dev_ctx);
54-
} else if (index_type == paddle::framework::proto::VarType::INT64) {
53+
} else if (index_type == DataType::INT64) {
5554
paddle::operators::cpu_scatter_mul_kernel<T, int64_t>(
5655
*out, axis, index, value, dev_ctx);
5756
}
5857
} else if (reduce == "assign") {
59-
if (index_type == paddle::framework::proto::VarType::INT32) {
58+
if (index_type == DataType::INT32) {
6059
paddle::operators::cpu_scatter_assign_kernel<T, int32_t>(
6160
*out, axis, index, value, dev_ctx);
62-
} else if (index_type == paddle::framework::proto::VarType::INT64) {
61+
} else if (index_type == DataType::INT64) {
6362
paddle::operators::cpu_scatter_assign_kernel<T, int64_t>(
6463
*out, axis, index, value, dev_ctx);
6564
}

paddle/phi/kernels/cpu/take_along_axis_kernel.cc

+4-5
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
#include "paddle/phi/kernels/take_along_axis_kernel.h"
1616

17-
#include "paddle/fluid/framework/convert_utils.h"
1817
#include "paddle/fluid/operators/gather_scatter_kernel.h"
1918
#include "paddle/phi/backends/cpu/cpu_context.h"
19+
#include "paddle/phi/common/data_type.h"
2020
#include "paddle/phi/common/place.h"
2121
#include "paddle/phi/core/kernel_registry.h"
2222

@@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
3636
out->Resize(index.dims());
3737
dev_ctx.template Alloc<T>(out);
3838

39-
const auto& index_type =
40-
paddle::framework::TransToProtoVarType(index.dtype());
41-
if (index_type == paddle::framework::proto::VarType::INT32) {
39+
const auto& index_type = index.dtype();
40+
if (index_type == DataType::INT32) {
4241
paddle::operators::cpu_gather_kernel<T, int32_t>(
4342
x, axis, index, *out, dev_ctx);
44-
} else if (index_type == paddle::framework::proto::VarType::INT64) {
43+
} else if (index_type == DataType::INT64) {
4544
paddle::operators::cpu_gather_kernel<T, int64_t>(
4645
x, axis, index, *out, dev_ctx);
4746
}

paddle/phi/kernels/funcs/math_function.h

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ limitations under the License. */
1717
#include <memory>
1818
#include <vector>
1919

20-
#include "paddle/fluid/framework/convert_utils.h"
2120
#include "paddle/fluid/framework/operator.h"
2221
#include "paddle/fluid/framework/tensor.h"
2322
#include "paddle/fluid/framework/tensor_util.h"

0 commit comments

Comments
 (0)