Skip to content

Commit

Permalink
AVX512 exp/log
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Sep 26, 2024
1 parent 1419d9c commit db714e3
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 12 deletions.
92 changes: 81 additions & 11 deletions shared/libebm/compute/avx512f_ebm/avx512f_32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "Registration.hpp"
#include "Objective.hpp"

#include "math.hpp"
#include "approximate_math.hpp"
#include "compute_wrapper.hpp"

Expand Down Expand Up @@ -102,6 +103,10 @@ struct alignas(k_cAlignment) Avx512f_32_Int final {
return Avx512f_32_Int(_mm512_add_epi32(m_data, other.m_data));
}

inline Avx512f_32_Int operator-(const Avx512f_32_Int& other) const noexcept {
return Avx512f_32_Int(_mm512_sub_epi32(m_data, other.m_data));
}

inline Avx512f_32_Int operator*(const T& other) const noexcept {
return Avx512f_32_Int(_mm512_mullo_epi32(m_data, _mm512_set1_epi32(other)));
}
Expand All @@ -118,6 +123,16 @@ struct alignas(k_cAlignment) Avx512f_32_Int final {
return Avx512f_32_Int(_mm512_and_si512(m_data, other.m_data));
}

inline Avx512f_32_Int operator|(const Avx512f_32_Int& other) const noexcept {
return Avx512f_32_Int(_mm512_or_si512(m_data, other.m_data));
}

friend inline Avx512f_32_Int IfThenElse(
const __mmask16& cmp, const Avx512f_32_Int& trueVal, const Avx512f_32_Int& falseVal) noexcept {
return Avx512f_32_Int(_mm512_castps_si512(
_mm512_mask_blend_ps(cmp, _mm512_castsi512_ps(falseVal.m_data), _mm512_castsi512_ps(trueVal.m_data))));
}

friend inline Avx512f_32_Int PermuteForInterleaf(const Avx512f_32_Int& val) noexcept {
// this function permutes the values into positions that the Interleaf function expects
// but for any SIMD implementation the positions can be variable as long as they work together
Expand All @@ -137,7 +152,28 @@ struct alignas(k_cAlignment) Avx512f_32_Int final {
static_assert(std::is_standard_layout<Avx512f_32_Int>::value && std::is_trivially_copyable<Avx512f_32_Int>::value,
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");

template<bool bNegateInput = false,
bool bNaNPossible = true,
bool bUnderflowPossible = true,
bool bOverflowPossible = true>
inline Avx512f_32_Float Exp(const Avx512f_32_Float& val) noexcept;
template<bool bNegateOutput = false,
bool bNaNPossible = true,
bool bNegativePossible = true,
bool bZeroPossible = true,
bool bPositiveInfinityPossible = true>
inline Avx512f_32_Float Log(const Avx512f_32_Float& val) noexcept;

struct alignas(k_cAlignment) Avx512f_32_Float final {
template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
friend Avx512f_32_Float Exp(const Avx512f_32_Float& val) noexcept;
template<bool bNegateOutput,
bool bNaNPossible,
bool bNegativePossible,
bool bZeroPossible,
bool bPositiveInfinityPossible>
friend Avx512f_32_Float Log(const Avx512f_32_Float& val) noexcept;

using T = float;
using TPack = __m512;
using TInt = Avx512f_32_Int;
Expand All @@ -155,6 +191,7 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
inline Avx512f_32_Float(const double val) noexcept : m_data(_mm512_set1_ps(static_cast<T>(val))) {}
inline Avx512f_32_Float(const float val) noexcept : m_data(_mm512_set1_ps(static_cast<T>(val))) {}
inline Avx512f_32_Float(const int val) noexcept : m_data(_mm512_set1_ps(static_cast<T>(val))) {}
explicit Avx512f_32_Float(const Avx512f_32_Int& val) : m_data(_mm512_cvtepi32_ps(val.m_data)) {}

inline Avx512f_32_Float operator+() const noexcept { return *this; }

Expand Down Expand Up @@ -231,6 +268,10 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
return Avx512f_32_Float(val) / other;
}

friend inline __mmask16 operator<=(const Avx512f_32_Float& left, const Avx512f_32_Float& right) noexcept {
return _mm512_cmp_ps_mask(left.m_data, right.m_data, _CMP_LE_OQ);
}

inline static Avx512f_32_Float Load(const T* const a) noexcept { return Avx512f_32_Float(_mm512_load_ps(a)); }

inline void Store(T* const a) const noexcept { _mm512_store_ps(a, m_data); }
Expand Down Expand Up @@ -545,6 +586,11 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
return Avx512f_32_Float(_mm512_mask_blend_ps(mask, falseVal.m_data, trueVal.m_data));
}

friend inline Avx512f_32_Float IfThenElse(
const __mmask16& cmp, const Avx512f_32_Float& trueVal, const Avx512f_32_Float& falseVal) noexcept {
return Avx512f_32_Float(_mm512_mask_blend_ps(cmp, falseVal.m_data, trueVal.m_data));
}

friend inline Avx512f_32_Float IfEqual(const Avx512f_32_Float& cmp1,
const Avx512f_32_Float& cmp2,
const Avx512f_32_Float& trueVal,
Expand Down Expand Up @@ -572,6 +618,20 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
return Avx512f_32_Float(_mm512_mask_blend_ps(mask, falseVal.m_data, trueVal.m_data));
}

static inline __mmask16 ReinterpretInt(const __mmask16& val) noexcept { return val; }

static inline Avx512f_32_Int ReinterpretInt(const Avx512f_32_Float& val) noexcept {
return Avx512f_32_Int(_mm512_castps_si512(val.m_data));
}

static inline Avx512f_32_Float ReinterpretFloat(const Avx512f_32_Int& val) noexcept {
return Avx512f_32_Float(_mm512_castsi512_ps(val.m_data));
}

friend inline Avx512f_32_Float Round(const Avx512f_32_Float& val) noexcept {
return Avx512f_32_Float(_mm512_roundscale_ps(val.m_data, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}

friend inline Avx512f_32_Float Abs(const Avx512f_32_Float& val) noexcept {
return Avx512f_32_Float(
_mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(val.m_data), _mm512_set1_epi32(0x7FFFFFFF))));
Expand Down Expand Up @@ -609,14 +669,6 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
return Avx512f_32_Float(_mm512_sqrt_ps(val.m_data));
}

friend inline Avx512f_32_Float Exp(const Avx512f_32_Float& val) noexcept {
return ApplyFunc([](T x) { return std::exp(x); }, val);
}

friend inline Avx512f_32_Float Log(const Avx512f_32_Float& val) noexcept {
return ApplyFunc([](T x) { return std::log(x); }, val);
}

template<bool bDisableApprox,
bool bNegateInput = false,
bool bNaNPossible = true,
Expand All @@ -627,7 +679,7 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
static inline Avx512f_32_Float ApproxExp(const Avx512f_32_Float& val,
const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
UNUSED(addExpSchraudolphTerm);
return Exp(bNegateInput ? -val : val);
return Exp<bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
}

template<bool bDisableApprox,
Expand Down Expand Up @@ -687,8 +739,7 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
static inline Avx512f_32_Float ApproxLog(
const Avx512f_32_Float& val, const float addLogSchraudolphTerm = k_logTermLowerBoundInputCloseToOne) noexcept {
UNUSED(addLogSchraudolphTerm);
Avx512f_32_Float ret = Log(val);
return bNegateOutput ? -ret : ret;
return Log<bNegateOutput, bNaNPossible, bNegativePossible, bZeroPossible, bPositiveInfinityPossible>(val);
}

template<bool bDisableApprox,
Expand Down Expand Up @@ -772,6 +823,25 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
static_assert(std::is_standard_layout<Avx512f_32_Float>::value && std::is_trivially_copyable<Avx512f_32_Float>::value,
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");

template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
inline Avx512f_32_Float Exp(const Avx512f_32_Float& val) noexcept {
return Exp32<Avx512f_32_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
}

template<bool bNegateOutput,
bool bNaNPossible,
bool bNegativePossible,
bool bZeroPossible,
bool bPositiveInfinityPossible>
inline Avx512f_32_Float Log(const Avx512f_32_Float& val) noexcept {
return Log32<Avx512f_32_Float,
bNegateOutput,
bNaNPossible,
bNegativePossible,
bZeroPossible,
bPositiveInfinityPossible>(val);
}

INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Avx512f_32(
const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
const Objective* const pObjective = static_cast<const Objective*>(pObjectiveWrapper->m_pObjective);
Expand Down
1 change: 0 additions & 1 deletion shared/libebm/compute/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ static INLINE_ALWAYS TFloat Exp32(const TFloat val) {

ret = (ret + TFloat{1}) * rounded2;

// TODO: handling overflow/underflow possible faster see vectormath version2 code
if(bOverflowPossible) {
if(bNegateInput) {
ret = IfLess(val,
Expand Down

0 comments on commit db714e3

Please sign in to comment.