Skip to content
68 changes: 0 additions & 68 deletions paddle/fluid/extension/include/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,6 @@ namespace paddle {
} \
}()

///////// Complex Dispatch Marco ///////////

#define PD_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()

///////// Floating and Integral Dispatch Marco ///////////

#define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \
Expand All @@ -112,57 +95,6 @@ namespace paddle {
} \
}()

///////// Floating and Complex Dispatch Marco ///////////

#define PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()

///////// Floating, Integral and Complex Dispatch Marco ///////////

#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()

// TODO(chenweihang): Add more Marcos in the future if needed

} // namespace paddle
50 changes: 13 additions & 37 deletions paddle/fluid/extension/include/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
#include <cstdint>
#include <stdexcept>
#include <string>

namespace paddle {

using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16;
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;

enum DataType {
enum class DataType {
BOOL,
INT8,
UINT8,
INT16,
INT32,
INT64,
FLOAT16,
BFLOAT16,
FLOAT32,
FLOAT64,
COMPLEX64,
COMPLEX128,
// TODO(JiabinYang) support more data types if needed.
};

Expand All @@ -56,36 +44,24 @@ inline std::string ToString(DataType dtype) {
return "int32_t";
case DataType::INT64:
return "int64_t";
case DataType::FLOAT16:
return "float16";
case DataType::BFLOAT16:
return "bfloat16";
case DataType::FLOAT32:
return "float";
case DataType::FLOAT64:
return "double";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
default:
throw std::runtime_error("Unsupported paddle enum data type.");
}
}

#define PD_FOR_EACH_DATA_TYPE(_) \
_(bool, DataType::BOOL) \
_(int8_t, DataType::INT8) \
_(uint8_t, DataType::UINT8) \
_(int16_t, DataType::INT16) \
_(int, DataType::INT32) \
_(int64_t, DataType::INT64) \
_(float16, DataType::FLOAT16) \
_(bfloat16, DataType::BFLOAT16) \
_(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64) \
_(complex64, DataType::COMPLEX64) \
_(complex128, DataType::COMPLEX128)
#define PD_FOR_EACH_DATA_TYPE(_) \
_(bool, DataType::BOOL) \
_(int8_t, DataType::INT8) \
_(uint8_t, DataType::UINT8) \
_(int16_t, DataType::INT16) \
_(int, DataType::INT32) \
_(int64_t, DataType::INT64) \
_(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64)

template <paddle::DataType T>
struct DataTypeToCPPType;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/extension/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace paddle {
namespace framework {
class CustomTensorUtils;
} // namespace framework

class PD_DLL_DECL Tensor {
public:
/// \brief Construct a Tensor on target Place for CustomOp.
Expand Down
58 changes: 2 additions & 56 deletions paddle/fluid/extension/src/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,17 +159,10 @@ DataType Tensor::type() const {
return DataType::UINT8;
} else if (type == framework::proto::VarType::FP64) {
return DataType::FLOAT64;
} else if (type == framework::proto::VarType::BF16) {
return DataType::BFLOAT16;
} else if (type == framework::proto::VarType::FP16) {
return DataType::FLOAT16;
} else if (type == framework::proto::VarType::COMPLEX64) {
return DataType::COMPLEX64;
} else if (type == framework::proto::VarType::COMPLEX128) {
return DataType::COMPLEX128;
} else if (type == framework::proto::VarType::BOOL) {
return DataType::BOOL;
}
// TODO(JiabinYang) Support more dtype here
return DataType::FLOAT32;
}

Expand Down Expand Up @@ -207,14 +200,6 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const {
return target;
}

template PD_DLL_DECL Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::bfloat16>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex64>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex128>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<float>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Expand All @@ -238,14 +223,6 @@ template PD_DLL_DECL int64_t *Tensor::data<int64_t>() const;
template PD_DLL_DECL int32_t *Tensor::data<int32_t>() const;
template PD_DLL_DECL uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL int8_t *Tensor::data<int8_t>() const;
template PD_DLL_DECL paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const;
template PD_DLL_DECL paddle::platform::bfloat16 *
Tensor::data<paddle::platform::bfloat16>() const;
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::data<paddle::platform::complex128>() const;
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const;
template PD_DLL_DECL int16_t *Tensor::data<int16_t>() const;
template PD_DLL_DECL bool *Tensor::data<bool>() const;

Expand All @@ -255,14 +232,6 @@ template PD_DLL_DECL int64_t *Tensor::mutable_data<int64_t>();
template PD_DLL_DECL int32_t *Tensor::mutable_data<int32_t>();
template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>();
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>();
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>();
template PD_DLL_DECL paddle::platform::bfloat16 *
Tensor::mutable_data<paddle::platform::bfloat16>();
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>();
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>();
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>();
template PD_DLL_DECL bool *Tensor::mutable_data<bool>();

Expand All @@ -277,14 +246,6 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>(
const PlaceType &place);
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(
const PlaceType &place);
template PD_DLL_DECL paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::bfloat16 *
Tensor::mutable_data<paddle::platform::bfloat16>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(
const PlaceType &place);
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place);
Expand Down Expand Up @@ -320,14 +281,6 @@ Tensor Tensor::cast(const DataType &target_type) const {
auto dst_type =
framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type);
switch (src_type) {
case framework::proto::VarType::FP16:
framework::VisitDataType(
dst_type, CastDataType<platform::float16>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::BF16:
framework::VisitDataType(dst_type, CastDataType<platform::bfloat16>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::FP32:
framework::VisitDataType(dst_type,
CastDataType<float>(*tensor, rlt_tensor_, ctx));
Expand Down Expand Up @@ -356,14 +309,7 @@ Tensor Tensor::cast(const DataType &target_type) const {
framework::VisitDataType(
dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::COMPLEX64:
framework::VisitDataType(dst_type, CastDataType<platform::complex64>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::COMPLEX128:
framework::VisitDataType(dst_type, CastDataType<platform::complex128>(
*tensor, rlt_tensor_, ctx));
break;
// TODO(JiabinYang) Support more dtype here
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
Expand Down
48 changes: 1 addition & 47 deletions paddle/fluid/framework/custom_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,14 @@ void TestCast(paddle::DataType data_type) {
t1.reshape(tensor_shape);
t1.template mutable_data<T>();
auto t2 = t1.cast(data_type);
CHECK_EQ(t2.type(), data_type);
CHECK(t2.type() == data_type);
}

void GroupTestCopy() {
VLOG(2) << "Float cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<float>();
VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<double>();
VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::float16>();
VLOG(2) << "BF16 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::bfloat16>();
VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::complex128>();
VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::platform::complex64>();
VLOG(2) << "int cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<int>();
VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu";
Expand All @@ -128,31 +120,17 @@ void GroupTestCast() {
TestCast<int64_t>(paddle::DataType::FLOAT32);
VLOG(2) << "double cast";
TestCast<double>(paddle::DataType::FLOAT32);
VLOG(2) << "bfloat16 cast";
TestCast<paddle::platform::bfloat16>(paddle::DataType::FLOAT32);
VLOG(2) << "float16 cast";
TestCast<paddle::platform::float16>(paddle::DataType::FLOAT32);
VLOG(2) << "bool cast";
TestCast<bool>(paddle::DataType::FLOAT32);
VLOG(2) << "uint8 cast";
TestCast<uint8_t>(paddle::DataType::FLOAT32);
VLOG(2) << "float cast";
TestCast<float>(paddle::DataType::FLOAT32);
VLOG(2) << "complex64 cast";
TestCast<float>(paddle::DataType::FLOAT32);
VLOG(2) << "complex128 cast";
TestCast<float>(paddle::DataType::FLOAT32);
}

void GroupTestDtype() {
CHECK(TestDtype<float>() == paddle::DataType::FLOAT32);
CHECK(TestDtype<double>() == paddle::DataType::FLOAT64);
CHECK(TestDtype<paddle::platform::float16>() == paddle::DataType::FLOAT16);
CHECK(TestDtype<paddle::platform::bfloat16>() == paddle::DataType::BFLOAT16);
CHECK(TestDtype<paddle::platform::complex128>() ==
paddle::DataType::COMPLEX128);
CHECK(TestDtype<paddle::platform::complex64>() ==
paddle::DataType::COMPLEX64);
CHECK(TestDtype<int>() == paddle::DataType::INT32);
CHECK(TestDtype<int64_t>() == paddle::DataType::INT64);
CHECK(TestDtype<int16_t>() == paddle::DataType::INT16);
Expand All @@ -162,24 +140,12 @@ void GroupTestDtype() {

void GroupTestDtypeConvert() {
// enum -> proto
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::COMPLEX128) ==
paddle::framework::proto::VarType::COMPLEX128);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::COMPLEX64) ==
paddle::framework::proto::VarType::COMPLEX64);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT64) ==
paddle::framework::proto::VarType::FP64);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT32) ==
paddle::framework::proto::VarType::FP32);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::FLOAT16) ==
paddle::framework::proto::VarType::FP16);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::BFLOAT16) ==
paddle::framework::proto::VarType::BF16);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::UINT8) ==
paddle::framework::proto::VarType::UINT8);
Expand All @@ -197,24 +163,12 @@ void GroupTestDtypeConvert() {
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::BOOL) == paddle::framework::proto::VarType::BOOL);
// proto -> enum
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::COMPLEX128) ==
paddle::DataType::COMPLEX128);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::COMPLEX64) ==
paddle::DataType::COMPLEX64);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP64) ==
paddle::DataType::FLOAT64);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP32) ==
paddle::DataType::FLOAT32);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP16) ==
paddle::DataType::FLOAT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::BF16) ==
paddle::DataType::BFLOAT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::INT64) ==
paddle::DataType::INT64);
Expand Down
Loading