@@ -17,14 +17,11 @@ inline c10::BFloat16 log(c10::BFloat16 a) { return std::log(float(a));}
17
17
inline c10::BFloat16 log10 (c10::BFloat16 a) { return std::log10 (float (a));}
18
18
inline c10::BFloat16 log1p (c10::BFloat16 a) { return std::log1p (float (a));}
19
19
inline c10::BFloat16 log2 (c10::BFloat16 a) { return std::log2 (float (a));}
20
- inline c10::BFloat16 ceil (c10::BFloat16 a) { return std::ceil (float (a));}
21
20
inline c10::BFloat16 cos (c10::BFloat16 a) { return std::cos (float (a));}
22
- inline c10::BFloat16 floor (c10::BFloat16 a) { return std::floor (float (a));}
23
21
inline c10::BFloat16 nearbyint (c10::BFloat16 a) { return std::nearbyint (float (a));}
24
22
inline c10::BFloat16 sin (c10::BFloat16 a) { return std::sin (float (a));}
25
23
inline c10::BFloat16 tan (c10::BFloat16 a) { return std::tan (float (a));}
26
24
inline c10::BFloat16 tanh (c10::BFloat16 a) { return std::tanh (float (a));}
27
- inline c10::BFloat16 trunc (c10::BFloat16 a) { return std::trunc (float (a));}
28
25
inline c10::BFloat16 lgamma (c10::BFloat16 a) { return std::lgamma (float (a));}
29
26
inline c10::BFloat16 sqrt (c10::BFloat16 a) { return std::sqrt (float (a));}
30
27
inline c10::BFloat16 rsqrt (c10::BFloat16 a) { return 1.0 / std::sqrt (float (a));}
@@ -36,6 +33,15 @@ inline c10::BFloat16 pow(c10::BFloat16 a, double b) { return std::pow(float(a),
36
33
#else
37
34
inline c10::BFloat16 pow (c10::BFloat16 a, double b) { return std::pow (float (a), b);}
38
35
#endif
36
+ #if defined(_MSC_VER) && _MSC_VER >= 1928 && defined(__CUDACC__)
37
+ inline c10::BFloat16 ceil (c10::BFloat16 a) { return std::ceilf (float (a));}
38
+ inline c10::BFloat16 floor (c10::BFloat16 a) { return std::floorf (float (a));}
39
+ inline c10::BFloat16 trunc (c10::BFloat16 a) { return std::truncf (float (a));}
40
+ #else
41
+ inline c10::BFloat16 ceil (c10::BFloat16 a) { return std::ceil (float (a));}
42
+ inline c10::BFloat16 floor (c10::BFloat16 a) { return std::floor (float (a));}
43
+ inline c10::BFloat16 trunc (c10::BFloat16 a) { return std::trunc (float (a));}
44
+ #endif
39
45
inline c10::BFloat16 pow (c10::BFloat16 a, c10::BFloat16 b) { return std::pow (float (a), float (b));}
40
46
inline c10::BFloat16 fmod (c10::BFloat16 a, c10::BFloat16 b) { return std::fmod (float (a), float (b));}
41
47
0 commit comments