Skip to content

Commit

Permalink
Add some missing _Float16 support (#8174)
Browse files Browse the repository at this point in the history
(Changes extracted from #8169, which may or may not land in its current form)

Some missing support for _Float16 that will likely be handy:
- Allow _Float16 to be detected for Clang 15 (since my local XCode Clang 15 definitely supports it)
- Expr(_Float16)
- HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(_Float16);
- Add _Float16 to the convert matrix in halide_image_io.h
  • Loading branch information
steven-johnson authored Apr 4, 2024
1 parent a4158c0 commit 3b8a532
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ struct Expr : public Internal::IRHandle {
Expr(bfloat16_t x)
: IRHandle(Internal::FloatImm::make(BFloat(16), (double)x)) {
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
explicit Expr(_Float16 x)
: IRHandle(Internal::FloatImm::make(Float(16), (double)x)) {
}
#endif
Expr(float x)
: IRHandle(Internal::FloatImm::make(Float(32), x)) {
}
Expand Down
3 changes: 3 additions & 0 deletions src/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(Halide::float16_t);
HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(Halide::bfloat16_t);
HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(halide_task_t);
HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(halide_loop_task_t);
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(_Float16);
#endif
HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(float);
HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(double);
HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_buffer_t);
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/HalideRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ extern "C" {
// Ideally there would be a better way to detect if the type
// is supported, even in a compiler independent fashion, but
// coming up with one has proven elusive.
#if defined(__clang__) && (__clang_major__ >= 16) && !defined(__EMSCRIPTEN__) && !defined(__i386__)
#if defined(__clang__) && (__clang_major__ >= 15) && !defined(__EMSCRIPTEN__) && !defined(__i386__)
#if defined(__is_identifier)
#if !__is_identifier(_Float16)
#define HALIDE_CPP_COMPILER_HAS_FLOAT16
Expand Down
118 changes: 118 additions & 0 deletions tools/halide_image_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ template<>
inline bool convert(const int64_t &in) {
return in != 0;
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline bool convert(const _Float16 &in) {
return (float)in != 0;
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline bool convert(const float &in) {
return in != 0;
Expand Down Expand Up @@ -165,6 +171,12 @@ template<>
inline uint8_t convert(const int64_t &in) {
return convert<uint8_t, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint8_t convert(const _Float16 &in) {
return (uint8_t)std::lround((float)in * 255.0f);
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint8_t convert(const float &in) {
return (uint8_t)std::lround(in * 255.0f);
Expand Down Expand Up @@ -211,6 +223,12 @@ template<>
inline uint16_t convert(const int64_t &in) {
return convert<uint16_t, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint16_t convert(const _Float16 &in) {
return (uint16_t)std::lround((float)in * 65535.0f);
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint16_t convert(const float &in) {
return (uint16_t)std::lround(in * 65535.0f);
Expand Down Expand Up @@ -257,6 +275,12 @@ template<>
inline uint32_t convert(const int64_t &in) {
return convert<uint32_t, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint32_t convert(const _Float16 &in) {
return (uint32_t)std::llround((float)in * 4294967295.0);
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint32_t convert(const float &in) {
return (uint32_t)std::llround(in * 4294967295.0);
Expand Down Expand Up @@ -303,6 +327,12 @@ template<>
inline uint64_t convert(const int64_t &in) {
return convert<uint64_t, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint64_t convert(const _Float16 &in) {
return convert<uint64_t, uint32_t>((uint32_t)std::llround((float)in * 4294967295.0));
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline uint64_t convert(const float &in) {
return convert<uint64_t, uint32_t>((uint32_t)std::llround(in * 4294967295.0));
Expand Down Expand Up @@ -349,6 +379,12 @@ template<>
inline int8_t convert(const int64_t &in) {
return convert<uint8_t, int64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int8_t convert(const _Float16 &in) {
return convert<uint8_t, float>((float)in);
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int8_t convert(const float &in) {
return convert<uint8_t, float>(in);
Expand Down Expand Up @@ -395,6 +431,12 @@ template<>
inline int16_t convert(const int64_t &in) {
return convert<uint16_t, int64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int16_t convert(const _Float16 &in) {
return convert<uint16_t, float>((float)in);
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int16_t convert(const float &in) {
return convert<uint16_t, float>(in);
Expand Down Expand Up @@ -441,6 +483,12 @@ template<>
inline int32_t convert(const int64_t &in) {
return convert<uint32_t, int64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int32_t convert(const _Float16 &in) {
return convert<uint32_t, float>((float)in);
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int32_t convert(const float &in) {
return convert<uint32_t, float>(in);
Expand Down Expand Up @@ -487,6 +535,12 @@ template<>
inline int64_t convert(const int64_t &in) {
return convert<uint64_t, int64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int64_t convert(const _Float16 &in) {
return convert<uint64_t, float>((float)in);
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline int64_t convert(const float &in) {
return convert<uint64_t, float>(in);
Expand All @@ -496,6 +550,58 @@ inline int64_t convert(const double &in) {
return convert<uint64_t, double>(in);
}

#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
// Convert to f16
template<>
inline _Float16 convert(const bool &in) {
return in;
}
template<>
inline _Float16 convert(const uint8_t &in) {
return (_Float16)(in / 255.0f);
}
template<>
inline _Float16 convert(const uint16_t &in) {
return (_Float16)(in / 65535.0f);
}
template<>
inline _Float16 convert(const uint32_t &in) {
return (_Float16)(in / 4294967295.0);
}
template<>
inline _Float16 convert(const uint64_t &in) {
return convert<_Float16, uint32_t>(uint32_t(in >> 32));
}
template<>
inline _Float16 convert(const int8_t &in) {
return convert<_Float16, uint8_t>(in);
}
template<>
inline _Float16 convert(const int16_t &in) {
return convert<_Float16, uint16_t>(in);
}
template<>
inline _Float16 convert(const int32_t &in) {
return convert<_Float16, uint64_t>(in);
}
template<>
inline _Float16 convert(const int64_t &in) {
return convert<_Float16, uint64_t>(in);
}
template<>
inline _Float16 convert(const _Float16 &in) {
return in;
}
template<>
inline _Float16 convert(const float &in) {
return (_Float16)in;
}
template<>
inline _Float16 convert(const double &in) {
return (_Float16)in;
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16

// Convert to f32
template<>
inline float convert(const bool &in) {
Expand Down Expand Up @@ -533,6 +639,12 @@ template<>
inline float convert(const int64_t &in) {
return convert<float, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline float convert(const _Float16 &in) {
return (float)in;
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline float convert(const float &in) {
return in;
Expand Down Expand Up @@ -579,6 +691,12 @@ template<>
inline double convert(const int64_t &in) {
return convert<double, uint64_t>(in);
}
#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline double convert(const _Float16 &in) {
return (double)in;
}
#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16
template<>
inline double convert(const float &in) {
return (double)in;
Expand Down

0 comments on commit 3b8a532

Please sign in to comment.