Skip to content

Commit

Permalink
[ntuple] make Real32Quant throws on out of range in Pack/Unpack
Browse files Browse the repository at this point in the history
  • Loading branch information
silverweed committed Sep 24, 2024
1 parent 095babf commit 984bc4a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
46 changes: 36 additions & 10 deletions tree/ntuple/v7/src/RColumnElement.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,9 @@ using Quantized_t = std::uint32_t;
/// The quantized representation will consist of unsigned integers of at most `nQuantBits` (with `nQuantBits <= 8 *
/// sizeof(Quantized_t)`). The unused bits are kept in the LSB of the quantized integers, to allow for easy bit packing
/// of those integers via BitPacking::PackBits().
/// \return The number of values in `src` that were found to be out of range (0 means all values were in range).
template <typename T>
void QuantizeReals(Quantized_t *dst, const T *src, std::size_t count, double min, double max, std::size_t nQuantBits)
int QuantizeReals(Quantized_t *dst, const T *src, std::size_t count, double min, double max, std::size_t nQuantBits)
{
static_assert(std::is_floating_point_v<T>);
static_assert(sizeof(T) <= sizeof(double));
Expand All @@ -784,9 +785,13 @@ void QuantizeReals(Quantized_t *dst, const T *src, std::size_t count, double min
const double scale = quantMax / (max - min);
const std::size_t unusedBits = sizeof(Quantized_t) * 8 - nQuantBits;

int nOutOfRange = 0;

for (std::size_t i = 0; i < count; ++i) {
T elem = src[i];
assert(min <= elem && elem <= max);

nOutOfRange += !(min <= elem && elem <= max);

double e = (elem - min) * scale;
Quantized_t q = static_cast<Quantized_t>(e + 0.5);
ByteSwapIfNecessary(q);
Expand All @@ -798,11 +803,14 @@ void QuantizeReals(Quantized_t *dst, const T *src, std::size_t count, double min
// when bit packing.
dst[i] = q << unusedBits;
}

return nOutOfRange;
}

/// Undoes the transformation performed by QuantizeReals() (assuming the same `count`, `min`, `max` and `nQuantBits`).
/// \return The number of unpacked values that were found to be out of range (0 means all values were in range).
template <typename T>
void UnquantizeReals(T *dst, const Quantized_t *src, std::size_t count, double min, double max, std::size_t nQuantBits)
int UnquantizeReals(T *dst, const Quantized_t *src, std::size_t count, double min, double max, std::size_t nQuantBits)
{
static_assert(std::is_floating_point_v<T>);
static_assert(sizeof(T) <= sizeof(double));
Expand All @@ -813,6 +821,8 @@ void UnquantizeReals(T *dst, const Quantized_t *src, std::size_t count, double m
const double bias = min * quantMax / (max - min);
const std::size_t unusedBits = sizeof(Quantized_t) * 8 - nQuantBits;

int nOutOfRange = 0;

for (std::size_t i = 0; i < count; ++i) {
Quantized_t elem = src[i];
// Undo the LSB-preserving shift performed by QuantizeReals
Expand All @@ -823,8 +833,11 @@ void UnquantizeReals(T *dst, const Quantized_t *src, std::size_t count, double m
double fq = static_cast<double>(elem);
double e = (fq + bias) * scale;
dst[i] = static_cast<T>(e);
assert(min <= dst[i] && dst[i] <= max);

nOutOfRange += !(min <= dst[i] && dst[i] <= max);
}

return nOutOfRange;
}
} // namespace Quantize

Expand Down Expand Up @@ -856,24 +869,37 @@ public:

void Pack(void *dst, const void *src, std::size_t count) const final
{
using namespace ROOT::Experimental;

// TODO(gparolini): see if we can avoid this allocation
auto quantized = std::make_unique<Quantize::Quantized_t[]>(count);
assert(fValueRange);
const auto [min, max] = *fValueRange;
Quantize::QuantizeReals(quantized.get(), reinterpret_cast<const float *>(src), count, min, max, fBitsOnStorage);
ROOT::Experimental::Internal::BitPacking::PackBits(dst, quantized.get(), count, sizeof(Quantize::Quantized_t),
fBitsOnStorage);
const int nOutOfRange = Quantize::QuantizeReals(quantized.get(), reinterpret_cast<const float *>(src), count, min,
max, fBitsOnStorage);
if (nOutOfRange) {
throw RException(R__FAIL(std::to_string(nOutOfRange) +
" values were found of of range for quantization while packing (range is [" +
std::to_string(min) + ", " + std::to_string(max) + "])"));
}
Internal::BitPacking::PackBits(dst, quantized.get(), count, sizeof(Quantize::Quantized_t), fBitsOnStorage);
}

void Unpack(void *dst, const void *src, std::size_t count) const final
{
using namespace ROOT::Experimental;

// TODO(gparolini): see if we can avoid this allocation
auto quantized = std::make_unique<Quantize::Quantized_t[]>(count);
assert(fValueRange);
const auto [min, max] = *fValueRange;
ROOT::Experimental::Internal::BitPacking::UnpackBits(quantized.get(), src, count, sizeof(Quantize::Quantized_t),
fBitsOnStorage);
Quantize::UnquantizeReals(reinterpret_cast<float *>(dst), quantized.get(), count, min, max, fBitsOnStorage);
Internal::BitPacking::UnpackBits(quantized.get(), src, count, sizeof(Quantize::Quantized_t), fBitsOnStorage);
[[maybe_unused]] const int nOutOfRange =
Quantize::UnquantizeReals(reinterpret_cast<float *>(dst), quantized.get(), count, min, max, fBitsOnStorage);
// NOTE: here, differently from Pack(), we don't ever expect to have values out of range, since the quantized
// integers we pass to UnquantizeReals are by construction limited in value to the proper range. In Pack()
// this is not the case, as the user may give us float values that are out of range.
assert(nOutOfRange == 0);
}
};

Expand Down
18 changes: 18 additions & 0 deletions tree/ntuple/v7/test/ntuple_packing.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,24 @@ TEST(Packing, Real32Quant)
EXPECT_NEAR(fout[i], f[i], 0.01f);
}

{
RColumnElement<float, EColumnType::kReal32Quant> element;
element.SetBitsOnStorage(20);
element.SetValueRange(-10.f, 10.f);

float f[5] = { 3.4f, 5.f, -6.f, 10.f, -10.f };
unsigned char out[BitPacking::MinBufSize(std::size(f), 20)];
element.Pack(out, f, std::size(f));
float f2[std::size(f)];
element.Unpack(&f2, out, std::size(f));
for (size_t i = 0; i < std::size(f); ++i)
EXPECT_NEAR(f[i], f2[i], 0.01f);

f[3] = 11.f;
// should throw out of range
EXPECT_THROW(element.Pack(out, f, std::size(f)), RException);
}

{
constexpr auto kBitsOnStorage = 1;
RColumnElement<float, EColumnType::kReal32Quant> element;
Expand Down

0 comments on commit 984bc4a

Please sign in to comment.