Skip to content

Commit 618479a

Browse files
committed
More fixes
1 parent aa7b8c6 commit 618479a

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

aten/src/ATen/native/cuda/BinaryMulDivKernel.cu

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ static inline __host__ __device__ typename std::enable_if<!std::is_same<scalar_t
6666
floor_(scalar_t a) {
6767
return std::floor(a);
6868
}
69+
template <typename scalar_t>
70+
static inline __host__ __device__ typename std::enable_if<std::is_same<scalar_t, float>::value, scalar_t>::type
71+
trunc_(scalar_t a) {
72+
return std::truncf(a);
73+
}
74+
template <typename scalar_t>
75+
static inline __host__ __device__ typename std::enable_if<!std::is_same<scalar_t, float>::value, scalar_t>::type
76+
trunc_(scalar_t a) {
77+
return std::trunc(a);
78+
}
6979
template <typename scalar_t1, typename scalar_t2>
7080
static inline __host__ __device__ typename std::enable_if<std::is_same<scalar_t1, float>::value && std::is_same<scalar_t2, float>::value, scalar_t1>::type
7181
copysign_(scalar_t1 a, scalar_t2 b) {
@@ -80,6 +90,7 @@ static inline __host__ __device__ typename std::enable_if<!std::is_same<scalar_t
8090
#else
8191
#define ceil_ std::ceil
8292
#define floor_ std::floor
93+
#define trunc_ std::trunc
8394
#define copysign_ std::copysign
8495
#endif
8596

@@ -121,13 +132,13 @@ void div_trunc_kernel_cuda(TensorIteratorBase& iter) {
121132
auto inv_b = accscalar_t(1.0) / iter.scalar_value<accscalar_t>(2);
122133
iter.remove_operand(2);
123134
gpu_kernel(iter, [inv_b] GPU_LAMBDA (scalar_t a) -> scalar_t {
124-
return std::trunc(a * inv_b);
135+
return trunc_(a * inv_b);
125136
});
126137
});
127138
} else {
128139
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() {
129140
gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
130-
return std::trunc(a / b);
141+
return trunc_(a / b);
131142
});
132143
});
133144
}

c10/util/BFloat16-math.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@ inline c10::BFloat16 log(c10::BFloat16 a) { return std::log(float(a));}
1717
inline c10::BFloat16 log10(c10::BFloat16 a) { return std::log10(float(a));}
1818
inline c10::BFloat16 log1p(c10::BFloat16 a) { return std::log1p(float(a));}
1919
inline c10::BFloat16 log2(c10::BFloat16 a) { return std::log2(float(a));}
20-
inline c10::BFloat16 ceil(c10::BFloat16 a) { return std::ceil(float(a));}
2120
inline c10::BFloat16 cos(c10::BFloat16 a) { return std::cos(float(a));}
22-
inline c10::BFloat16 floor(c10::BFloat16 a) { return std::floor(float(a));}
2321
inline c10::BFloat16 nearbyint(c10::BFloat16 a) { return std::nearbyint(float(a));}
2422
inline c10::BFloat16 sin(c10::BFloat16 a) { return std::sin(float(a));}
2523
inline c10::BFloat16 tan(c10::BFloat16 a) { return std::tan(float(a));}
2624
inline c10::BFloat16 tanh(c10::BFloat16 a) { return std::tanh(float(a));}
27-
inline c10::BFloat16 trunc(c10::BFloat16 a) { return std::trunc(float(a));}
2825
inline c10::BFloat16 lgamma(c10::BFloat16 a) { return std::lgamma(float(a));}
2926
inline c10::BFloat16 sqrt(c10::BFloat16 a) { return std::sqrt(float(a));}
3027
inline c10::BFloat16 rsqrt(c10::BFloat16 a) { return 1.0 / std::sqrt(float(a));}
@@ -36,6 +33,15 @@ inline c10::BFloat16 pow(c10::BFloat16 a, double b) { return std::pow(float(a),
3633
#else
3734
inline c10::BFloat16 pow(c10::BFloat16 a, double b) { return std::pow(float(a), b);}
3835
#endif
36+
#if defined(_MSC_VER) && _MSC_VER >= 1928 && defined(__CUDACC__)
37+
inline c10::BFloat16 ceil(c10::BFloat16 a) { return std::ceilf(float(a));}
38+
inline c10::BFloat16 floor(c10::BFloat16 a) { return std::floorf(float(a));}
39+
inline c10::BFloat16 trunc(c10::BFloat16 a) { return std::truncf(float(a));}
40+
#else
41+
inline c10::BFloat16 ceil(c10::BFloat16 a) { return std::ceil(float(a));}
42+
inline c10::BFloat16 floor(c10::BFloat16 a) { return std::floor(float(a));}
43+
inline c10::BFloat16 trunc(c10::BFloat16 a) { return std::trunc(float(a));}
44+
#endif
3945
inline c10::BFloat16 pow(c10::BFloat16 a, c10::BFloat16 b) { return std::pow(float(a), float(b));}
4046
inline c10::BFloat16 fmod(c10::BFloat16 a, c10::BFloat16 b) { return std::fmod(float(a), float(b));}
4147

0 commit comments

Comments
 (0)