Skip to content

Commit

Permalink
Merge pull request NVIDIA#266 from mani-ananth/master
Browse files Browse the repository at this point in the history
Fixes for public issue NVIDIA#265
  • Loading branch information
hwu36 authored May 19, 2021
2 parents b68113f + da2f110 commit 9cb7d63
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 68 deletions.
11 changes: 7 additions & 4 deletions include/cutlass/epilogue/thread/linear_combination_clamp.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,14 @@ class LinearCombinationClamp {
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X

/// Clamping constant value
ElementCompute const kClamp =
ElementCompute((1U << (sizeof_bits<ElementOutput>::value - 1)) - 1);
ElementCompute const kClampMax =
ElementCompute(platform::numeric_limits<ElementOutput>::max());

intermediate = max_accumulator(intermediate, -kClamp - ElementCompute(1));
intermediate = min_accumulator(intermediate, kClamp);
ElementCompute const kClampMin =
ElementCompute(platform::numeric_limits<ElementOutput>::lowest());

intermediate = max_accumulator(intermediate, kClampMin);
intermediate = min_accumulator(intermediate, kClampMax);

// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
Expand Down
9 changes: 5 additions & 4 deletions include/cutlass/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ enum
#include <cuda_fp16.h>

#include "cutlass/cutlass.h"
#include "cutlass/platform/platform.h"

///////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -545,9 +546,9 @@ half_t copysign(half_t const& a, half_t const& b) {
//
///////////////////////////////////////////////////////////////////////////////////////////////////

namespace std {
namespace cutlass {
namespace platform {

#if !defined(__CUDACC_RTC__)
/// Numeric limits
template <>
struct numeric_limits<cutlass::half_t> {
Expand Down Expand Up @@ -593,9 +594,9 @@ struct numeric_limits<cutlass::half_t> {
/// Returns smallest finite value
static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); }
};
#endif

} // namespace std
} // namespace platform
} // namespace cutlass

///////////////////////////////////////////////////////////////////////////////////////////////////
//
Expand Down
21 changes: 21 additions & 0 deletions include/cutlass/integer_subbyte.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,25 @@ struct sizeof_bits<uint4b_t> {

///////////////////////////////////////////////////////////////////////////////////////////////////

namespace platform {

template <>
struct numeric_limits<cutlass::int4b_t> {
CUTLASS_HOST_DEVICE
static cutlass::int4b_t const lowest() noexcept { return -8;}
CUTLASS_HOST_DEVICE
static cutlass::int4b_t const max() noexcept { return 7;}
};

template <>
struct numeric_limits<cutlass::uint4b_t> {
CUTLASS_HOST_DEVICE
static cutlass::uint4b_t const lowest() noexcept { return 0;}
CUTLASS_HOST_DEVICE
static cutlass::uint4b_t const max() noexcept { return 15;}
};

///////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace platform
} // namespace cutlass
67 changes: 7 additions & 60 deletions include/cutlass/numeric_conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,69 +483,16 @@ struct NumericConverterClamp {
using result_type = T;
using source_type = S;

static_assert((platform::is_same<result_type, int32_t>::value ||
platform::is_same<result_type, int8_t>::value ||
platform::is_same<result_type, cutlass::int4b_t>::value),
"Clamp is only needed for integer types");

CUTLASS_HOST_DEVICE
static result_type convert(source_type const & s) {
NumericConverter<result_type, source_type> convert_op;
result_type const kClamp_max =
(0x1U << (sizeof_bits<result_type>::value - 1)) - 1;
result_type const kClamp_min = -kClamp_max - 1;
bool is_int_min = !(s > kClamp_min);
bool is_int_max = !(s < kClamp_max);
return is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s));
}

CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return convert(s);
}
};

/// Partial specialization for clamping from a single-precision float.
template <
typename T
>
struct NumericConverterClamp<T, float> {

using result_type = T;
using source_type = float;

static_assert((platform::is_same<result_type, int32_t>::value ||
platform::is_same<result_type, int16_t>::value ||
platform::is_same<result_type, uint16_t>::value ||
platform::is_same<result_type, int8_t>::value ||
platform::is_same<result_type, uint8_t>::value ||
platform::is_same<result_type, cutlass::int4b_t>::value ||
platform::is_same<result_type, cutlass::uint4b_t>::value),
"Clamp is only needed for integer types");

CUTLASS_HOST_DEVICE
static result_type convert(source_type const & s) {

NumericConverter<result_type, double> convert_op;
double kClamp_max, kClamp_min;

if (platform::is_same<result_type, int32_t>::value ||
platform::is_same<result_type, int16_t>::value ||
platform::is_same<result_type, int8_t>::value ||
platform::is_same<result_type, cutlass::int4b_t>::value) {
kClamp_max = double((1LLU << (sizeof_bits<result_type>::value - 1)) - 1);
kClamp_min = -kClamp_max - 1;
} else {
kClamp_max = double((1LLU << (sizeof_bits<result_type>::value)) - 1);
kClamp_min = 0;
}

double source = s;

source = fmax(source, kClamp_min);
source = fmin(source, kClamp_max);

return convert_op(source);
result_type const kClamp_max = platform::numeric_limits<result_type>::max();
result_type const kClamp_min = platform::numeric_limits<result_type>::lowest();
if (s < (source_type)kClamp_min)
return kClamp_min;
if (s > (source_type)kClamp_max)
return kClamp_max;
return convert_op(s);
}

CUTLASS_HOST_DEVICE
Expand Down
53 changes: 53 additions & 0 deletions include/cutlass/platform/platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -783,5 +783,58 @@ void swap(unique_ptr<T, Deleter>& lhs, unique_ptr<T, Deleter>& rhs) noexcept {
}
#endif

/// std::numeric_limits
template <class T>
struct numeric_limits;

template <>
struct numeric_limits<int32_t> {
CUTLASS_HOST_DEVICE
static constexpr int32_t lowest() noexcept { return -2147483647 - 1;}
CUTLASS_HOST_DEVICE
static constexpr int32_t max() noexcept { return 2147483647;}
};

template <>
struct numeric_limits<int16_t> {
CUTLASS_HOST_DEVICE
static constexpr int16_t lowest() noexcept { return -32768;}
CUTLASS_HOST_DEVICE
static constexpr int16_t max() noexcept { return 32767;}
};

template <>
struct numeric_limits<int8_t> {
CUTLASS_HOST_DEVICE
static constexpr int8_t lowest() noexcept { return -128;}
CUTLASS_HOST_DEVICE
static constexpr int8_t max() noexcept { return 127;}
};


template <>
struct numeric_limits<uint32_t> {
CUTLASS_HOST_DEVICE
static constexpr uint32_t lowest() noexcept { return 0;}
CUTLASS_HOST_DEVICE
static constexpr uint32_t max() noexcept { return 4294967295;}
};

template <>
struct numeric_limits<uint16_t> {
CUTLASS_HOST_DEVICE
static constexpr uint16_t lowest() noexcept { return 0;}
CUTLASS_HOST_DEVICE
static constexpr uint16_t max() noexcept { return 65535;}
};

template <>
struct numeric_limits<uint8_t> {
CUTLASS_HOST_DEVICE
static constexpr uint8_t lowest() noexcept { return 0;}
CUTLASS_HOST_DEVICE
static constexpr uint8_t max() noexcept { return 255;}
};

} // namespace platform
} // namespace cutlass

0 comments on commit 9cb7d63

Please sign in to comment.