Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bridging for clang _Float16 type. #7201

Merged
merged 6 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/Float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
#include <cstdint>
#include <string>

// clang had _Float16 added as a reserved name in clang 8, but
// doesn't actually support it on most platforms until clang 15.
// Ideally there would be a better way to detect if the type
// is supported, even in a compiler independent fashion, but
// coming up with one has proven elusive.
#if !defined(__clang__) || (__clang_major__ >= 15)
#if defined(__is_identifier)
#if !__is_identifier(_Float16)
#define HALIDE_CPP_COMPILER_HAS_FLOAT16
#endif
#endif
#endif

namespace Halide {

/** Class that provides a type that implements half precision
Expand Down Expand Up @@ -38,6 +51,13 @@ struct float16_t {
* positive zero.*/
float16_t() = default;

#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
/** Construct a float16_t from compiler's built-in _Float16 type. */
explicit float16_t(_Float16 value) {
data = *(uint16_t *)&value;
}
#endif

/// @}

// Use explicit to avoid accidently raising the precision
Expand All @@ -48,6 +68,13 @@ struct float16_t {
/** Cast to int */
explicit operator int() const;

#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
/** Cast to compiler's built-in _Float16 type. */
explicit operator _Float16() const {
return *(const _Float16 *)&data;
}
#endif

/** Get a new float16_t that represents a special value */
// @{
static float16_t make_zero();
Expand Down
11 changes: 11 additions & 0 deletions test/correctness/float16_t.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,17 @@ int main(int argc, char **argv) {
}
}

#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16
{
float16_t f(1.0f16);
_Float16 f2 = (_Float16)f;
if (f2 != 1.0f16) {
printf("Roundtrip of 16-bit float via _Float16 failed.\n");
return -1;
}
}
#endif

printf("Success!\n");
return 0;
}