forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathCompileTimeFunctionPointer.h
57 lines (50 loc) · 1.66 KB
/
CompileTimeFunctionPointer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#pragma once
#include <c10/util/TypeTraits.h>
#include <type_traits>
namespace c10 {
/**
* Represent a function pointer as a C++ type.
* This allows using the function pointer as a type
* in a template and calling it from inside the template
* allows the compiler to inline the call because it
* knows the function pointer at compile time.
*
* Example 1:
* int add(int a, int b) {return a + b;}
* using Add = TORCH_FN_TYPE(add);
* template<class Func> struct Executor {
* int execute(int a, int b) {
* return Func::func_ptr()(a, b);
* }
* };
* Executor<Add> executor;
* EXPECT_EQ(3, executor.execute(1, 2));
*
* Example 2:
* int add(int a, int b) {return a + b;}
* template<class Func> int execute(Func, int a, int b) {
* return Func::func_ptr()(a, b);
* }
* EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
*/
template <class FuncType_, FuncType_* func_ptr_>
struct CompileTimeFunctionPointer final {
static_assert(
guts::is_function_type<FuncType_>::value,
"TORCH_FN can only wrap function types.");
using FuncType = FuncType_;
static constexpr FuncType* func_ptr() {
return func_ptr_;
}
};
template <class T>
struct is_compile_time_function_pointer : std::false_type {};
template <class FuncType, FuncType* func_ptr>
struct is_compile_time_function_pointer<
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
} // namespace c10
#define TORCH_FN_TYPE(func) \
::c10::CompileTimeFunctionPointer< \
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
func>
#define TORCH_FN(func) TORCH_FN_TYPE(func)()