Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/extension/include/ext_all.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License. */

#include "ext_dispatch.h" // NOLINT
#include "ext_dtype.h" // NOLINT
#include "ext_exception.h" // NOLINT
#include "ext_op_meta_info.h" // NOLINT
#include "ext_place.h" // NOLINT
#include "ext_tensor.h" // NOLINT
38 changes: 18 additions & 20 deletions paddle/fluid/extension/include/ext_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ limitations under the License. */

#pragma once

#include "ext_dtype.h" // NOLINT
#include "ext_dtype.h" // NOLINT
#include "ext_exception.h" // NOLINT

namespace paddle {

Expand All @@ -32,19 +33,18 @@ namespace paddle {

///////// Floating Dispatch Marco ///////////

#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
::paddle::ToString(__dtype__), "`"); \
} \
}()

///////// Integral Dispatch Marco ///////////
Expand All @@ -63,9 +63,8 @@ namespace paddle {
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
PD_THROW("function " #NAME " is not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()

Expand All @@ -89,9 +88,8 @@ namespace paddle {
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
PD_THROW("function " #NAME " is not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()

Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/extension/include/ext_dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ limitations under the License. */
#pragma once

#include <cstdint>
#include <stdexcept>
#include <string>

#include "ext_exception.h" // NOLINT

namespace paddle {

enum class DataType {
Expand Down Expand Up @@ -50,7 +51,7 @@ inline std::string ToString(DataType dtype) {
case DataType::FLOAT64:
return "double";
default:
throw std::runtime_error("Unsupported paddle enum data type.");
PD_THROW("Unsupported paddle enum data type.");
}
}

Expand Down
108 changes: 108 additions & 0 deletions paddle/fluid/extension/include/ext_exception.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <iostream>
#include <sstream>
#include <string>

namespace paddle {

//////////////// Exception handling and Error Message /////////////////
#if !defined(_WIN32)
#define PD_UNLIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 0))
#define PD_LIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 1))
#else
#define PD_UNLIKELY(expr) (expr)
#define PD_LIKELY(expr) (expr)
#endif

struct PD_Exception : public std::exception {
public:
template <typename... Args>
explicit PD_Exception(const std::string& msg, const char* file, int line,
const char* default_msg) {
std::ostringstream sout;
if (msg.empty()) {
sout << default_msg << "\n [" << file << ":" << line << "]";
} else {
sout << msg << "\n [" << file << ":" << line << "]";
}
err_msg_ = sout.str();
}

const char* what() const noexcept override { return err_msg_.c_str(); }

private:
std::string err_msg_;
};

class ErrorMessage {
public:
template <typename... Args>
explicit ErrorMessage(const Args&... args) {
build_string(args...);
}

void build_string() { oss << ""; }

template <typename T>
void build_string(const T& t) {
oss << t;
}

template <typename T, typename... Args>
void build_string(const T& t, const Args&... args) {
build_string(t);
build_string(args...);
}

std::string to_string() { return oss.str(); }

private:
std::ostringstream oss;
};

#if defined _WIN32
#define HANDLE_THE_ERROR try {
#define END_HANDLE_THE_ERROR \
} \
catch (const std::exception& e) { \
std::cerr << e.what() << std::endl; \
throw e; \
}
#else
#define HANDLE_THE_ERROR
#define END_HANDLE_THE_ERROR
#endif

#define PD_CHECK(COND, ...) \
do { \
if (PD_UNLIKELY(!(COND))) { \
auto __message__ = ::paddle::ErrorMessage(__VA_ARGS__).to_string(); \
throw ::paddle::PD_Exception(__message__, __FILE__, __LINE__, \
"Expected " #COND \
", but it's not satisfied."); \
} \
} while (0)

#define PD_THROW(...) \
do { \
auto __message__ = ::paddle::ErrorMessage(__VA_ARGS__).to_string(); \
throw ::paddle::PD_Exception(__message__, __FILE__, __LINE__, \
"An error occured."); \
} while (0)

} // namespace paddle
25 changes: 3 additions & 22 deletions paddle/fluid/extension/include/ext_op_meta_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ limitations under the License. */

#include <boost/any.hpp>

#include "ext_dll_decl.h" // NOLINT
#include "ext_tensor.h" // NOLINT
#include "ext_dll_decl.h" // NOLINT
#include "ext_exception.h" // NOLINT
#include "ext_tensor.h" // NOLINT

/**
* Op Meta Info Related Define.
Expand All @@ -47,26 +48,6 @@ using Tensor = paddle::Tensor;
classname& operator=(const classname&) = delete; \
classname& operator=(classname&&) = delete

#if defined _WIN32
#define HANDLE_THE_ERROR try {
#define END_HANDLE_THE_ERROR \
} \
catch (const std::exception& e) { \
std::cerr << e.what() << std::endl; \
throw e; \
}
#else
#define HANDLE_THE_ERROR
#define END_HANDLE_THE_ERROR
#endif

#define PD_THROW(err_msg) \
do { \
HANDLE_THE_ERROR \
throw std::runtime_error(err_msg); \
END_HANDLE_THE_ERROR \
} while (0)

#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/custom_op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120)
py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120)

cc_test(test_check_error SRCS test_check_error.cc DEPS gtest)

if(NOT LINUX)
return()
endif()
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/tests/custom_op/custom_relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
} else if (x.place() == paddle::PlaceType::kGPU) {
return relu_cuda_forward(x);
} else {
throw std::runtime_error("Not implemented.");
PD_THROW("Not implemented.");
}
}

Expand All @@ -92,7 +92,7 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
} else if (x.place() == paddle::PlaceType::kGPU) {
return relu_cuda_backward(x, out, grad_out);
} else {
throw std::runtime_error("Not implemented.");
PD_THROW("Not implemented.");
}
}

Expand Down
Loading