Skip to content

Commit

Permalink
SIMD exp/log
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Sep 26, 2024
1 parent 25e958f commit ed7eec6
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 11 deletions.
101 changes: 90 additions & 11 deletions shared/libebm/compute/avx2_ebm/avx2_32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "Registration.hpp"
#include "Objective.hpp"

#include "math.hpp"
#include "approximate_math.hpp"
#include "compute_wrapper.hpp"

Expand Down Expand Up @@ -94,6 +95,10 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
return Avx2_32_Int(_mm256_add_epi32(m_data, other.m_data));
}

inline Avx2_32_Int operator-(const Avx2_32_Int& other) const noexcept {
return Avx2_32_Int(_mm256_sub_epi32(m_data, other.m_data));
}

inline Avx2_32_Int operator*(const T& other) const noexcept {
return Avx2_32_Int(_mm256_mullo_epi32(m_data, _mm256_set1_epi32(other)));
}
Expand All @@ -106,6 +111,15 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
return Avx2_32_Int(_mm256_and_si256(m_data, other.m_data));
}

inline Avx2_32_Int operator|(const Avx2_32_Int& other) const noexcept {
return Avx2_32_Int(_mm256_or_si256(m_data, other.m_data));
}

friend inline Avx2_32_Int IfThenElse(
const Avx2_32_Int& cmp, const Avx2_32_Int& trueVal, const Avx2_32_Int& falseVal) noexcept {
return Avx2_32_Int(_mm256_blendv_epi8(falseVal.m_data, trueVal.m_data, cmp.m_data));
}

friend inline Avx2_32_Int PermuteForInterleaf(const Avx2_32_Int& val) noexcept {
// this function permutes the values into positions that the Interleaf function expects
// but for any SIMD implementation the positions can be variable as long as they work together
Expand All @@ -124,7 +138,28 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
static_assert(std::is_standard_layout<Avx2_32_Int>::value && std::is_trivially_copyable<Avx2_32_Int>::value,
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");

template<bool bNegateInput = false,
bool bNaNPossible = true,
bool bUnderflowPossible = true,
bool bOverflowPossible = true>
inline Avx2_32_Float Exp(const Avx2_32_Float& val) noexcept;
template<bool bNegateOutput = false,
bool bNaNPossible = true,
bool bNegativePossible = true,
bool bZeroPossible = true,
bool bPositiveInfinityPossible = true>
inline Avx2_32_Float Log(const Avx2_32_Float& val) noexcept;

struct alignas(k_cAlignment) Avx2_32_Float final {
template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
friend Avx2_32_Float Exp(const Avx2_32_Float& val) noexcept;
template<bool bNegateOutput,
bool bNaNPossible,
bool bNegativePossible,
bool bZeroPossible,
bool bPositiveInfinityPossible>
friend Avx2_32_Float Log(const Avx2_32_Float& val) noexcept;

using T = float;
using TPack = __m256;
using TInt = Avx2_32_Int;
Expand All @@ -142,6 +177,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
inline Avx2_32_Float(const double val) noexcept : m_data(_mm256_set1_ps(static_cast<T>(val))) {}
inline Avx2_32_Float(const float val) noexcept : m_data(_mm256_set1_ps(static_cast<T>(val))) {}
inline Avx2_32_Float(const int val) noexcept : m_data(_mm256_set1_ps(static_cast<T>(val))) {}
explicit Avx2_32_Float(const Avx2_32_Int& val) : m_data(_mm256_cvtepi32_ps(val.m_data)) {}

inline Avx2_32_Float operator+() const noexcept { return *this; }

Expand All @@ -150,6 +186,10 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
_mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(m_data), _mm256_set1_epi32(0x80000000))));
}

inline Avx2_32_Float operator~() const noexcept {
return Avx2_32_Float(_mm256_xor_ps(m_data, _mm256_castsi256_ps(_mm256_set1_epi32(-1))));
}

inline Avx2_32_Float operator+(const Avx2_32_Float& other) const noexcept {
return Avx2_32_Float(_mm256_add_ps(m_data, other.m_data));
}
Expand Down Expand Up @@ -218,6 +258,10 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
return Avx2_32_Float(val) / other;
}

friend inline Avx2_32_Float operator<=(const Avx2_32_Float& left, const Avx2_32_Float& right) noexcept {
return Avx2_32_Float(_mm256_cmp_ps(left.m_data, right.m_data, _CMP_LE_OQ));
}

inline static Avx2_32_Float Load(const T* const a) noexcept { return Avx2_32_Float(_mm256_load_ps(a)); }

inline void Store(T* const a) const noexcept { _mm256_store_ps(a, m_data); }
Expand Down Expand Up @@ -484,6 +528,11 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
return Avx2_32_Float(_mm256_blendv_ps(falseVal.m_data, trueVal.m_data, mask));
}

friend inline Avx2_32_Float IfThenElse(
const Avx2_32_Float& cmp, const Avx2_32_Float& trueVal, const Avx2_32_Float& falseVal) noexcept {
return Avx2_32_Float(_mm256_blendv_ps(falseVal.m_data, trueVal.m_data, cmp.m_data));
}

friend inline Avx2_32_Float IfEqual(const Avx2_32_Float& cmp1,
const Avx2_32_Float& cmp2,
const Avx2_32_Float& trueVal,
Expand Down Expand Up @@ -511,6 +560,26 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
return Avx2_32_Float(_mm256_blendv_ps(falseVal.m_data, trueVal.m_data, _mm256_castsi256_ps(mask)));
}

static inline Avx2_32_Int ReinterpretInt(const Avx2_32_Float& val) noexcept {
return Avx2_32_Int(_mm256_castps_si256(val.m_data));
}

static inline Avx2_32_Float ReinterpretFloat(const Avx2_32_Int& val) noexcept {
return Avx2_32_Float(_mm256_castsi256_ps(val.m_data));
}

friend inline Avx2_32_Float Round(const Avx2_32_Float& val) noexcept {
return Avx2_32_Float(_mm256_round_ps(val.m_data, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}

friend inline Avx2_32_Float Mantissa(const Avx2_32_Float& val) noexcept {
return ReinterpretFloat((ReinterpretInt(val) & 0x007FFFFF) | 0x3F000000);
}

friend inline Avx2_32_Int Exponent(const Avx2_32_Float& val) noexcept {
return ((ReinterpretInt(val) << 1) >> 24) - Avx2_32_Int(0x7F);
}

friend inline Avx2_32_Float Abs(const Avx2_32_Float& val) noexcept {
return Avx2_32_Float(_mm256_and_ps(val.m_data, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF))));
}
Expand Down Expand Up @@ -553,14 +622,6 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
return Avx2_32_Float(_mm256_sqrt_ps(val.m_data));
}

friend inline Avx2_32_Float Exp(const Avx2_32_Float& val) noexcept {
return ApplyFunc([](T x) { return std::exp(x); }, val);
}

friend inline Avx2_32_Float Log(const Avx2_32_Float& val) noexcept {
return ApplyFunc([](T x) { return std::log(x); }, val);
}

template<bool bDisableApprox,
bool bNegateInput = false,
bool bNaNPossible = true,
Expand All @@ -571,7 +632,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
static inline Avx2_32_Float ApproxExp(const Avx2_32_Float& val,
const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
UNUSED(addExpSchraudolphTerm);
return Exp(bNegateInput ? -val : val);
return Exp<bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
}

template<bool bDisableApprox,
Expand Down Expand Up @@ -631,8 +692,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
static inline Avx2_32_Float ApproxLog(
const Avx2_32_Float& val, const float addLogSchraudolphTerm = k_logTermLowerBoundInputCloseToOne) noexcept {
UNUSED(addLogSchraudolphTerm);
Avx2_32_Float ret = Log(val);
return bNegateOutput ? -ret : ret;
return Log<bNegateOutput, bNaNPossible, bNegativePossible, bZeroPossible, bPositiveInfinityPossible>(val);
}

template<bool bDisableApprox,
Expand Down Expand Up @@ -723,6 +783,25 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
static_assert(std::is_standard_layout<Avx2_32_Float>::value && std::is_trivially_copyable<Avx2_32_Float>::value,
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");

template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
inline Avx2_32_Float Exp(const Avx2_32_Float& val) noexcept {
return Exp32<Avx2_32_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
}

template<bool bNegateOutput,
bool bNaNPossible,
bool bNegativePossible,
bool bZeroPossible,
bool bPositiveInfinityPossible>
inline Avx2_32_Float Log(const Avx2_32_Float& val) noexcept {
return Log32<Avx2_32_Float,
bNegateOutput,
bNaNPossible,
bNegativePossible,
bZeroPossible,
bPositiveInfinityPossible>(val);
}

INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Avx2_32(
const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
const Objective* const pObjective = static_cast<const Objective*>(pObjectiveWrapper->m_pObjective);
Expand Down
3 changes: 3 additions & 0 deletions shared/libebm/compute/cpu_ebm/cpu_64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "Registration.hpp"
#include "Objective.hpp"

#include "math.hpp"
#include "approximate_math.hpp"
#include "compute_wrapper.hpp"

Expand Down Expand Up @@ -227,6 +228,8 @@ struct Cpu_64_Float final {

friend inline Cpu_64_Float Abs(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::abs(val.m_data)); }

friend inline Cpu_64_Float Round(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::round(val.m_data)); }

friend inline Cpu_64_Float FastApproxReciprocal(const Cpu_64_Float& val) noexcept {
return Cpu_64_Float(T{1.0} / val.m_data);
}
Expand Down
Loading

0 comments on commit ed7eec6

Please sign in to comment.