Skip to content

Commit 7e7e940

Browse files
[PTen]Migrate proto::VarType outside of Pten (PaddlePaddle#39411)
* #1 migrate dist-related type()-> dtype() * move datatype function from pten -> fluid/framework * change type() in imperative into convert(dtype()) * modify xx_tensor->type into xx_tensor->dtype * change the set_type interface and the caller * modify xx_tensor.type into xx_tensor.dtype * fix mutable_data(place, dtype()) * change caller of mutable_data in pten and distributed * change the caller of mutable_data in fluid/framework * change the caller of mutable_data in imperative directory * mutable_data: inference * update the call of mutable_data * transfer MakePenScalarArray MakePtenScalar ResetHolderWithType * pass the compile. the next step is remove VarType in Pten * fix all and remove VarType from pten. success in linux. Next task is other platform * fix conflict with develop * fix compiled error * Fix reset conversion * fix conflict * fix compiled problem * fix typo * Fix << in tensor_utils.cc * fix type->dtype * fix unittest * fix tensor init constructor * fix DataTypeSize for BFloat16 * fix code style * fix npu compiled error * fix npu * compile npu sucessfully * fix conflict * fix conflict Co-authored-by: xiongkun <xiongkun03@baidu.com>
1 parent 9c2cee1 commit 7e7e940

File tree

352 files changed

+2175
-1445
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

352 files changed

+2175
-1445
lines changed

paddle/fluid/distributed/fleet_executor/dist_model.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ bool DistModel::FetchResults(std::vector<DistModelTensor> *output_data,
562562
framework::FetchType &fetch_var =
563563
framework::GetFetchVariable(*scope, "fetch", idx);
564564
auto &fetch = BOOST_GET(framework::LoDTensor, fetch_var);
565-
auto type = fetch.type();
565+
auto type = framework::TransToProtoVarType(fetch.dtype());
566566
auto output = &(output_data->at(i));
567567
output->name = idx_to_fetches_[idx];
568568
bool rst = false;

paddle/fluid/distributed/ps/service/brpc_utils.cc

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
16+
1617
#include <arpa/inet.h>
1718
#include <netdb.h>
19+
20+
#include "paddle/fluid/framework/convert_utils.h"
1821
#include "paddle/fluid/platform/enforce.h"
1922

2023
namespace paddle {
@@ -98,25 +101,29 @@ void SerializeLodTensor(framework::Variable* var,
98101
}
99102
}
100103
}
101-
var_msg->set_data_type(static_cast<VarMsg::Type>(tensor->type()));
104+
var_msg->set_data_type(static_cast<VarMsg::Type>(
105+
framework::TransToProtoVarType(tensor->dtype())));
102106
for (auto& dim : framework::vectorize(tensor->dims())) {
103107
var_msg->add_dims(dim);
104108
}
105109
// IO Buffer
106110
if (platform::is_cpu_place(tensor->place())) {
107-
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
111+
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
108112
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
109113
iobuf->append(reinterpret_cast<const char*>(tensor->data()), data_len);
110114
} else {
111115
#ifdef PADDLE_WITH_CUDA
112-
char* temp_ptr = new char[tensor->numel() *
113-
framework::SizeOfType(tensor->type())]; // NOLINT
116+
char* temp_ptr =
117+
new char[tensor->numel() *
118+
framework::DataTypeSize(tensor->dtype())]; // NOLINT
114119
auto stream =
115120
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
116121
memory::Copy(
117122
platform::CPUPlace(), temp_ptr, tensor->place(), tensor->data(),
118-
tensor->numel() * framework::SizeOfType(tensor->type()), stream);
119-
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
123+
tensor->numel() * framework::SizeOfType(
124+
framework::TransToProtoVarType(tensor->dtype())),
125+
stream);
126+
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
120127
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
121128
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
122129
delete[] temp_ptr;
@@ -139,25 +146,29 @@ void SerializeSelectedRows(framework::Variable* var,
139146
var_data->resize(rows->size() * sizeof(int64_t));
140147
char* data_ptr = const_cast<char*>(var_data->data());
141148
memcpy(data_ptr, &((*rows)[0]), rows->size() * sizeof(int64_t));
142-
var_msg->set_data_type(static_cast<VarMsg::Type>(tensor->type()));
149+
var_msg->set_data_type(static_cast<VarMsg::Type>(
150+
framework::TransToProtoVarType(tensor->dtype())));
143151
for (auto& dim : framework::vectorize(tensor->dims())) {
144152
var_msg->add_dims(dim);
145153
}
146154
// IO Buffer
147155
if (platform::is_cpu_place(tensor->place())) {
148-
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
156+
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
149157
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
150158
iobuf->append(reinterpret_cast<const char*>(tensor->data()), data_len);
151159
} else {
152160
#ifdef PADDLE_WITH_CUDA
153-
char* temp_ptr = new char[tensor->numel() *
154-
framework::SizeOfType(tensor->type())]; // NOLINT
161+
char* temp_ptr =
162+
new char[tensor->numel() *
163+
framework::DataTypeSize(tensor->dtype())]; // NOLINT
155164
auto stream =
156165
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
157166
memory::Copy(
158167
platform::CPUPlace(), temp_ptr, tensor->place(), tensor->data(),
159-
tensor->numel() * framework::SizeOfType(tensor->type()), stream);
160-
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
168+
tensor->numel() * framework::SizeOfType(
169+
framework::TransToProtoVarType(tensor->dtype())),
170+
stream);
171+
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
161172
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
162173
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
163174
delete[] temp_ptr;
@@ -225,8 +236,9 @@ void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
225236
}
226237
tensor->set_lod(lod);
227238

228-
void* tensor_data =
229-
tensor->mutable_data(place, VarMessageToVarType(msg.data_type()));
239+
void* tensor_data = tensor->mutable_data(
240+
place,
241+
framework::TransToPtenDataType(VarMessageToVarType(msg.data_type())));
230242

231243
// IO Buffer
232244
if (platform::is_cpu_place(place)) {
@@ -236,15 +248,16 @@ void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
236248
} else if (platform::is_gpu_place(place)) {
237249
#ifdef PADDLE_WITH_CUDA
238250
unsigned long data_len; // NOLINT
239-
char* temp_ptr = new char[tensor->numel() *
240-
framework::SizeOfType(tensor->type())]; // NOLINT
241-
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
242-
io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len); // NOLINT
251+
char* temp_ptr =
252+
new char[tensor->numel() *
253+
framework::DataTypeSize(tensor->dtype())]; // NOLINT
254+
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
255+
io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len); // NOLINT
243256
auto stream =
244257
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
245258
memory::Copy(
246259
place, tensor_data, platform::CPUPlace(), (void*)temp_ptr, // NOLINT
247-
tensor->numel() * framework::SizeOfType(tensor->type()), stream);
260+
tensor->numel() * framework::DataTypeSize(tensor->dtype()), stream);
248261
delete[] temp_ptr;
249262
#endif
250263
}
@@ -266,24 +279,26 @@ void DeserializeSelectedRows(
266279
vec_dim.push_back(x);
267280
}
268281
tensor->Resize(framework::make_ddim(vec_dim));
269-
void* tensor_data =
270-
tensor->mutable_data(place, VarMessageToVarType(msg.data_type()));
282+
void* tensor_data = tensor->mutable_data(
283+
place,
284+
framework::TransToPtenDataType(VarMessageToVarType(msg.data_type())));
271285
// IO Buffer
272286
if (platform::is_cpu_place(place)) {
273287
unsigned long data_len; // NOLINT
274288
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
275289
io_buffer_itr.copy_and_forward(tensor_data, data_len);
276290
} else if (platform::is_gpu_place(place)) {
277291
#ifdef PADDLE_WITH_CUDA
278-
char* temp_ptr = new char[tensor->numel() *
279-
framework::SizeOfType(tensor->type())]; // NOLINT
280-
unsigned long data_len; // NOLINT
281-
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
292+
char* temp_ptr =
293+
new char[tensor->numel() *
294+
framework::DataTypeSize(tensor->dtype())]; // NOLINT
295+
unsigned long data_len; // NOLINT
296+
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
282297
io_buffer_itr.copy_and_forward(temp_ptr, data_len);
283298
auto stream =
284299
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
285300
memory::Copy(place, tensor_data, platform::CPUPlace(), temp_ptr,
286-
tensor->numel() * framework::SizeOfType(tensor->type()),
301+
tensor->numel() * framework::DataTypeSize(tensor->dtype()),
287302
stream);
288303
delete[] temp_ptr;
289304
#endif

paddle/fluid/distributed/ps/service/heter_client.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/distributed/ps/service/heter_client.h"
16+
17+
#include "paddle/fluid/framework/convert_utils.h"
1618
#include "paddle/fluid/platform/profiler.h"
1719
#include "paddle/fluid/string/split.h"
1820

@@ -39,13 +41,13 @@ int GetMicroId(const platform::DeviceContext& ctx,
3941
} else {
4042
#ifdef PADDLE_WITH_CUDA
4143
std::vector<char> temp;
42-
temp.resize(tensor->numel() * framework::SizeOfType(tensor->type()));
44+
temp.resize(tensor->numel() * framework::DataTypeSize(tensor->dtype()));
4345
char* temp_ptr = temp.data();
4446
auto stream =
4547
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
4648
memory::Copy(
4749
platform::CPUPlace(), temp_ptr, tensor->place(), tensor->data(),
48-
tensor->numel() * framework::SizeOfType(tensor->type()), stream);
50+
tensor->numel() * framework::DataTypeSize(tensor->dtype()), stream);
4951
float* temp_ptr_float = reinterpret_cast<float*>(temp_ptr);
5052
micro_id = static_cast<int>(temp_ptr_float[0]);
5153
#endif

paddle/fluid/eager/grad_tensor_holder.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/eager/grad_tensor_holder.h"
1616
#include "paddle/fluid/imperative/gradient_accumulator.h"
1717

18+
#include "paddle/fluid/framework/convert_utils.h"
1819
#include "paddle/fluid/framework/var_type.h"
1920
#include "paddle/pten/kernels/funcs/math_function.h"
2021

paddle/fluid/framework/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,4 +452,10 @@ endif()
452452

453453
cc_test(scope_guard_test SRCS scope_guard_test.cc)
454454
cc_test(pten_utils_test SRCS pten_utils_test.cc DEPS pten_utils)
455+
456+
if(WITH_GPU OR WITH_ROCM)
457+
cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
458+
else()
459+
cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place)
460+
endif()
455461
cc_test(custom_kernel_test SRCS custom_kernel_test.cc DEPS custom_kernel pten_tensor)
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include "paddle/fluid/framework/convert_utils.h"
15+
// See Note [ Why still include the fluid headers? ]
16+
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
21+
paddle::experimental::DataType TransToPtenDataType(
22+
const paddle::framework::proto::VarType::Type& dtype) {
23+
// Set the order of case branches according to the frequency with
24+
// the data type is used
25+
switch (dtype) {
26+
case paddle::framework::proto::VarType::FP32:
27+
return DataType::FLOAT32;
28+
case paddle::framework::proto::VarType::FP64:
29+
return DataType::FLOAT64;
30+
case paddle::framework::proto::VarType::INT64:
31+
return DataType::INT64;
32+
case paddle::framework::proto::VarType::INT32:
33+
return DataType::INT32;
34+
case paddle::framework::proto::VarType::INT8:
35+
return DataType::INT8;
36+
case paddle::framework::proto::VarType::UINT8:
37+
return DataType::UINT8;
38+
case paddle::framework::proto::VarType::INT16:
39+
return DataType::INT16;
40+
case paddle::framework::proto::VarType::COMPLEX64:
41+
return DataType::COMPLEX64;
42+
case paddle::framework::proto::VarType::COMPLEX128:
43+
return DataType::COMPLEX128;
44+
case paddle::framework::proto::VarType::FP16:
45+
return DataType::FLOAT16;
46+
case paddle::framework::proto::VarType::BF16:
47+
return DataType::BFLOAT16;
48+
case paddle::framework::proto::VarType::BOOL:
49+
return DataType::BOOL;
50+
default:
51+
return DataType::UNDEFINED;
52+
}
53+
}
54+
55+
paddle::framework::proto::VarType::Type TransToProtoVarType(
56+
const paddle::experimental::DataType& dtype) {
57+
// Set the order of case branches according to the frequency with
58+
// the data type is used
59+
switch (dtype) {
60+
case DataType::FLOAT32:
61+
return paddle::framework::proto::VarType::FP32;
62+
case DataType::FLOAT64:
63+
return paddle::framework::proto::VarType::FP64;
64+
case DataType::INT64:
65+
return paddle::framework::proto::VarType::INT64;
66+
case DataType::INT32:
67+
return paddle::framework::proto::VarType::INT32;
68+
case DataType::INT8:
69+
return paddle::framework::proto::VarType::INT8;
70+
case DataType::UINT8:
71+
return paddle::framework::proto::VarType::UINT8;
72+
case DataType::INT16:
73+
return paddle::framework::proto::VarType::INT16;
74+
case DataType::COMPLEX64:
75+
return paddle::framework::proto::VarType::COMPLEX64;
76+
case DataType::COMPLEX128:
77+
return paddle::framework::proto::VarType::COMPLEX128;
78+
case DataType::FLOAT16:
79+
return paddle::framework::proto::VarType::FP16;
80+
case DataType::BFLOAT16:
81+
return paddle::framework::proto::VarType::BF16;
82+
case DataType::BOOL:
83+
return paddle::framework::proto::VarType::BOOL;
84+
default:
85+
PADDLE_THROW(paddle::platform::errors::Unimplemented(
86+
"Unsupported data type `%s` when casting it into "
87+
"paddle data type.",
88+
dtype));
89+
}
90+
}
91+
92+
size_t DataTypeSize(DataType dtype) {
93+
switch (dtype) {
94+
case DataType::UNDEFINED:
95+
return 0;
96+
case DataType::BOOL:
97+
return sizeof(bool);
98+
case DataType::INT8:
99+
return sizeof(int8_t);
100+
case DataType::UINT8:
101+
return sizeof(uint8_t);
102+
case DataType::INT16:
103+
return sizeof(int16_t);
104+
case DataType::INT32:
105+
return sizeof(int);
106+
case DataType::INT64:
107+
return sizeof(int64_t);
108+
case DataType::BFLOAT16:
109+
return sizeof(paddle::platform::bfloat16);
110+
case DataType::FLOAT16:
111+
return sizeof(paddle::platform::float16);
112+
case DataType::FLOAT32:
113+
return sizeof(float);
114+
case DataType::FLOAT64:
115+
return sizeof(double);
116+
case DataType::COMPLEX64:
117+
return sizeof(paddle::platform::complex<float>);
118+
case DataType::COMPLEX128:
119+
return sizeof(paddle::platform::complex<double>);
120+
default:
121+
return 0;
122+
}
123+
}
124+
125+
DataType String2DataType(const std::string& str) {
126+
if (str == "bool") {
127+
return DataType::BOOL;
128+
} else if (str == "float16") {
129+
return DataType::FLOAT16;
130+
} else if (str == "float32") {
131+
return DataType::FLOAT32;
132+
} else if (str == "float64") {
133+
return DataType::FLOAT64;
134+
} else if (str == "int8") {
135+
return DataType::INT8;
136+
} else if (str == "int16") {
137+
return DataType::INT16;
138+
} else if (str == "int32") {
139+
return DataType::INT32;
140+
} else if (str == "int64") {
141+
return DataType::INT64;
142+
} else if (str == "uint8") {
143+
return DataType::UINT8;
144+
} else if (str == "complex64") {
145+
return DataType::COMPLEX64;
146+
} else if (str == "complex128") {
147+
return DataType::COMPLEX128;
148+
} else {
149+
return DataType::UNDEFINED;
150+
}
151+
}
152+
153+
std::string DataType2String(DataType dtype) {
154+
switch (dtype) {
155+
case DataType::BOOL:
156+
return "bool";
157+
case DataType::INT8:
158+
return "int8";
159+
case DataType::UINT8:
160+
return "uint8";
161+
case DataType::INT16:
162+
return "int16";
163+
case DataType::INT32:
164+
return "int32";
165+
case DataType::INT64:
166+
return "int64";
167+
case DataType::FLOAT16:
168+
return "float16";
169+
case DataType::FLOAT32:
170+
return "float32";
171+
case DataType::FLOAT64:
172+
return "float64";
173+
case DataType::COMPLEX64:
174+
return "complex64";
175+
case DataType::COMPLEX128:
176+
return "complex128";
177+
default:
178+
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
179+
"Unknow pten::DataType, the int value = %d.",
180+
static_cast<int>(dtype)));
181+
return "";
182+
}
183+
}
184+
} // namespace framework
185+
} // namespace paddle

0 commit comments

Comments
 (0)