Skip to content

Commit ca7fdfe

Browse files
cyyevermalfet
authored andcommitted
[Reland] Use static_assert to detect get_type_index used in device code (pytorch#139966)
pytorch#139173 was reverted due to an internal build break of using get_type_index in device code. This PR is created for ease of importing into META to further investigation. Pull Request resolved: pytorch#139966 Approved by: https://github.com/malfet, https://github.com/huydhn Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
1 parent e474f0d commit ca7fdfe

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

c10/util/TypeIndex.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,14 @@ inline constexpr uint64_t type_index_impl() {
9393

9494
template <typename T>
9595
inline constexpr type_index get_type_index() {
96-
#if !defined(__CUDA_ARCH__)
96+
#if defined(__CUDA_ARCH__)
97+
static_assert(false && sizeof(T), " Don't call me from device code");
98+
#endif
9799
// To enforce that this is really computed at compile time, we pass the
98100
// type index through std::integral_constant.
99101
return type_index{std::integral_constant<
100102
uint64_t,
101103
detail::type_index_impl<std::decay_t<T>>()>::value};
102-
#else
103-
// There's nothing in theory preventing us from running this on device code
104-
// except for nvcc throwing a compiler error if we enable it.
105-
return (abort(), type_index(0));
106-
#endif
107104
}
108105

109106
#if !defined(TORCH_PEDANTIC)

0 commit comments

Comments
 (0)