Skip to content

Commit 430dcfa

Browse files
zhwesky2010chenwhql
authored andcommitted
[Custom OP]add PD_THROW and PD_CHECK for User Error message (#31253)
* [Custom OP]add PD_THROW and PD_CHECK for User error message * PD_THROW and PD_CHECK, fix comment * fix Windows error message * fix Windows error message * fix CI
1 parent ddb7bae commit 430dcfa

File tree

9 files changed

+341
-47
lines changed

9 files changed

+341
-47
lines changed

paddle/fluid/extension/include/ext_all.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License. */
2626

2727
#include "ext_dispatch.h" // NOLINT
2828
#include "ext_dtype.h" // NOLINT
29+
#include "ext_exception.h" // NOLINT
2930
#include "ext_op_meta_info.h" // NOLINT
3031
#include "ext_place.h" // NOLINT
3132
#include "ext_tensor.h" // NOLINT

paddle/fluid/extension/include/ext_dispatch.h

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include "ext_dtype.h" // NOLINT
17+
#include "ext_dtype.h" // NOLINT
18+
#include "ext_exception.h" // NOLINT
1819

1920
namespace paddle {
2021

@@ -32,19 +33,18 @@ namespace paddle {
3233

3334
///////// Floating Dispatch Marco ///////////
3435

35-
#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
36-
[&] { \
37-
const auto& __dtype__ = TYPE; \
38-
switch (__dtype__) { \
39-
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
40-
__VA_ARGS__) \
41-
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
42-
__VA_ARGS__) \
43-
default: \
44-
throw std::runtime_error("function " #NAME \
45-
" not implemented for data type `" + \
46-
::paddle::ToString(__dtype__) + "`"); \
47-
} \
36+
#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
37+
[&] { \
38+
const auto& __dtype__ = TYPE; \
39+
switch (__dtype__) { \
40+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
41+
__VA_ARGS__) \
42+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
43+
__VA_ARGS__) \
44+
default: \
45+
PD_THROW("function " #NAME " is not implemented for data type `", \
46+
::paddle::ToString(__dtype__), "`"); \
47+
} \
4848
}()
4949

5050
///////// Integral Dispatch Marco ///////////
@@ -63,9 +63,8 @@ namespace paddle {
6363
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
6464
__VA_ARGS__) \
6565
default: \
66-
throw std::runtime_error("function " #NAME \
67-
" not implemented for data type `" + \
68-
::paddle::ToString(__dtype__) + "`"); \
66+
PD_THROW("function " #NAME " is not implemented for data type `" + \
67+
::paddle::ToString(__dtype__) + "`"); \
6968
} \
7069
}()
7170

@@ -89,9 +88,8 @@ namespace paddle {
8988
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
9089
__VA_ARGS__) \
9190
default: \
92-
throw std::runtime_error("function " #NAME \
93-
" not implemented for data type `" + \
94-
::paddle::ToString(__dtype__) + "`"); \
91+
PD_THROW("function " #NAME " is not implemented for data type `" + \
92+
::paddle::ToString(__dtype__) + "`"); \
9593
} \
9694
}()
9795

paddle/fluid/extension/include/ext_dtype.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ limitations under the License. */
1414
#pragma once
1515

1616
#include <cstdint>
17-
#include <stdexcept>
1817
#include <string>
1918

19+
#include "ext_exception.h" // NOLINT
20+
2021
namespace paddle {
2122

2223
enum class DataType {
@@ -50,7 +51,7 @@ inline std::string ToString(DataType dtype) {
5051
case DataType::FLOAT64:
5152
return "double";
5253
default:
53-
throw std::runtime_error("Unsupported paddle enum data type.");
54+
PD_THROW("Unsupported paddle enum data type.");
5455
}
5556
}
5657

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <iostream>
18+
#include <sstream>
19+
#include <string>
20+
21+
namespace paddle {
22+
23+
//////////////// Exception handling and Error Message /////////////////
24+
#if !defined(_WIN32)
25+
#define PD_UNLIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 0))
26+
#define PD_LIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 1))
27+
#else
28+
#define PD_UNLIKELY(expr) (expr)
29+
#define PD_LIKELY(expr) (expr)
30+
#endif
31+
32+
struct PD_Exception : public std::exception {
33+
public:
34+
template <typename... Args>
35+
explicit PD_Exception(const std::string& msg, const char* file, int line,
36+
const char* default_msg) {
37+
std::ostringstream sout;
38+
if (msg.empty()) {
39+
sout << default_msg << "\n [" << file << ":" << line << "]";
40+
} else {
41+
sout << msg << "\n [" << file << ":" << line << "]";
42+
}
43+
err_msg_ = sout.str();
44+
}
45+
46+
const char* what() const noexcept override { return err_msg_.c_str(); }
47+
48+
private:
49+
std::string err_msg_;
50+
};
51+
52+
class ErrorMessage {
53+
public:
54+
template <typename... Args>
55+
explicit ErrorMessage(const Args&... args) {
56+
build_string(args...);
57+
}
58+
59+
void build_string() { oss << ""; }
60+
61+
template <typename T>
62+
void build_string(const T& t) {
63+
oss << t;
64+
}
65+
66+
template <typename T, typename... Args>
67+
void build_string(const T& t, const Args&... args) {
68+
build_string(t);
69+
build_string(args...);
70+
}
71+
72+
std::string to_string() { return oss.str(); }
73+
74+
private:
75+
std::ostringstream oss;
76+
};
77+
78+
#if defined _WIN32
79+
#define HANDLE_THE_ERROR try {
80+
#define END_HANDLE_THE_ERROR \
81+
} \
82+
catch (const std::exception& e) { \
83+
std::cerr << e.what() << std::endl; \
84+
throw e; \
85+
}
86+
#else
87+
#define HANDLE_THE_ERROR
88+
#define END_HANDLE_THE_ERROR
89+
#endif
90+
91+
#define PD_CHECK(COND, ...) \
92+
do { \
93+
if (PD_UNLIKELY(!(COND))) { \
94+
auto __message__ = ::paddle::ErrorMessage(__VA_ARGS__).to_string(); \
95+
throw ::paddle::PD_Exception(__message__, __FILE__, __LINE__, \
96+
"Expected " #COND \
97+
", but it's not satisfied."); \
98+
} \
99+
} while (0)
100+
101+
#define PD_THROW(...) \
102+
do { \
103+
auto __message__ = ::paddle::ErrorMessage(__VA_ARGS__).to_string(); \
104+
throw ::paddle::PD_Exception(__message__, __FILE__, __LINE__, \
105+
"An error occured."); \
106+
} while (0)
107+
108+
} // namespace paddle

paddle/fluid/extension/include/ext_op_meta_info.h

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ limitations under the License. */
2121

2222
#include <boost/any.hpp>
2323

24-
#include "ext_dll_decl.h" // NOLINT
25-
#include "ext_tensor.h" // NOLINT
24+
#include "ext_dll_decl.h" // NOLINT
25+
#include "ext_exception.h" // NOLINT
26+
#include "ext_tensor.h" // NOLINT
2627

2728
/**
2829
* Op Meta Info Related Define.
@@ -47,26 +48,6 @@ using Tensor = paddle::Tensor;
4748
classname& operator=(const classname&) = delete; \
4849
classname& operator=(classname&&) = delete
4950

50-
#if defined _WIN32
51-
#define HANDLE_THE_ERROR try {
52-
#define END_HANDLE_THE_ERROR \
53-
} \
54-
catch (const std::exception& e) { \
55-
std::cerr << e.what() << std::endl; \
56-
throw e; \
57-
}
58-
#else
59-
#define HANDLE_THE_ERROR
60-
#define END_HANDLE_THE_ERROR
61-
#endif
62-
63-
#define PD_THROW(err_msg) \
64-
do { \
65-
HANDLE_THE_ERROR \
66-
throw std::runtime_error(err_msg); \
67-
END_HANDLE_THE_ERROR \
68-
} while (0)
69-
7051
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
7152
struct __test_global_namespace_##uniq_name##__ {}; \
7253
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \

python/paddle/fluid/tests/custom_op/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120)
2323
py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
2424
set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120)
2525

26+
cc_test(test_check_error SRCS test_check_error.cc DEPS gtest)
27+
2628
if(NOT LINUX)
2729
return()
2830
endif()

python/paddle/fluid/tests/custom_op/custom_relu_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
7979
} else if (x.place() == paddle::PlaceType::kGPU) {
8080
return relu_cuda_forward(x);
8181
} else {
82-
throw std::runtime_error("Not implemented.");
82+
PD_THROW("Not implemented.");
8383
}
8484
}
8585

@@ -92,7 +92,7 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
9292
} else if (x.place() == paddle::PlaceType::kGPU) {
9393
return relu_cuda_backward(x, out, grad_out);
9494
} else {
95-
throw std::runtime_error("Not implemented.");
95+
PD_THROW("Not implemented.");
9696
}
9797
}
9898

0 commit comments

Comments
 (0)