Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ endif()
# Find pybind11
find_package(pybind11 REQUIRED)

# Add FetchContent module
include(FetchContent)

# libdivide (header-only) via FetchContent
Expand Down
144 changes: 63 additions & 81 deletions simple_ans/cpp/simple_ans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
#include <cstdint>
#include <cstdio>
#include <limits>
#include <numeric>
#include <stdexcept>
#include <vector>
#include <tsl/robin_map.h>

#include "libdivide.h"
#include <libdivide.h>
#include <tsl/robin_map.h>

namespace simple_ans
{
Expand All @@ -20,7 +21,7 @@ struct EncodedData
};

// Helper function to verify if a number is a power of 2
inline bool is_power_of_2(uint32_t x)
constexpr bool is_power_of_2(uint32_t x)
{
return x && !(x & (x - 1));
}
Expand Down Expand Up @@ -62,9 +63,10 @@ constexpr uint64_t MASK_WORD = (1ULL << WORD_BITS) - 1;
template <typename T>
std::tuple<std::vector<T>, std::vector<uint64_t>> unique_with_counts(const T* values, size_t n)
{
// WARNING: This is ONLY a helper function. It doesn't support arrays with a large domain, and will instead fail
// return empty vectors. It is up to the caller to handle this case separately. numpy.unique() is quite fast, with
// improvements to use vectorized sorts (in 2.x, at least), so I didn't bother to implement a more efficient version here.
// WARNING: This is ONLY a helper function. It doesn't support arrays with a large domain, and
// will instead fail return empty vectors. It is up to the caller to handle this case
// separately. numpy.unique() is quite fast, with improvements to use vectorized sorts (in 2.x,
// at least), so I didn't bother to implement a more efficient version here.
std::vector<T> unique_values;
std::vector<uint64_t> counts;
if (!n)
Expand Down Expand Up @@ -113,30 +115,19 @@ EncodedData ans_encode_t(const T* signal,
"Value range of T must fit in int64_t for table lookup");

// Calculate L and verify it's a power of 2
uint32_t index_size = 0;
for (size_t i = 0; i < num_symbols; ++i)
{
index_size += symbol_counts[i];
}
const uint32_t index_size = std::accumulate(symbol_counts, symbol_counts + num_symbols, 0u);

if (!is_power_of_2(index_size))
{
throw std::invalid_argument("L must be a power of 2");
}

int PRECISION_BITS = 0;
while ((1U << PRECISION_BITS) < index_size)
{
PRECISION_BITS++;
}
const auto PRECISION_BITS = std::bit_width(index_size - 1);

// Pre-compute cumulative sums
std::vector<uint32_t> C(num_symbols);
C[0] = 0;
for (size_t i = 1; i < num_symbols; ++i)
{
C[i] = C[i - 1] + symbol_counts[i - 1];
}

std::partial_sum(symbol_counts, symbol_counts + num_symbols - 1, C.begin() + 1);
// Precompute libdivide dividers for each symbol count
std::vector<libdivide::divider<uint64_t>> fast_dividers(num_symbols);
for (size_t i = 0; i < num_symbols; ++i)
Expand All @@ -145,27 +136,24 @@ EncodedData ans_encode_t(const T* signal,
}

// Create symbol index lookup
tsl::robin_map<T, size_t> symbol_index_lookup;
tsl::robin_map<T, size_t> symbol_index_lookup{};
symbol_index_lookup.reserve(num_symbols);
int64_t min_symbol = symbol_values[0];
int64_t max_symbol = symbol_values[0];
auto min_symbol = static_cast<int64_t>(symbol_values[0]);
auto max_symbol = min_symbol;
for (size_t i = 0; i < num_symbols; ++i)
{
const auto symbol_val = static_cast<int64_t>(symbol_values[i]);
symbol_index_lookup[symbol_values[i]] = i;
min_symbol = std::min(min_symbol, static_cast<int64_t>(symbol_values[i]));
max_symbol = std::max(max_symbol, static_cast<int64_t>(symbol_values[i]));
min_symbol = std::min(min_symbol, symbol_val);
max_symbol = std::max(max_symbol, symbol_val);
}

// Map lookups can be a bottleneck, so we use a lookup array if the number of symbols is "small"
const bool use_lookup_array = (max_symbol - min_symbol + 1) <= lookup_array_threshold;
std::vector<size_t> symbol_index_lookup_array;
if (use_lookup_array)
{
symbol_index_lookup_array.resize(max_symbol - min_symbol + 1);

std::fill(symbol_index_lookup_array.begin(),
symbol_index_lookup_array.end(),
std::numeric_limits<size_t>::max());
const size_t array_size = max_symbol - min_symbol + 1;
symbol_index_lookup_array.assign(array_size, std::numeric_limits<size_t>::max());

for (size_t i = 0; i < num_symbols; ++i)
{
Expand All @@ -175,56 +163,65 @@ EncodedData ans_encode_t(const T* signal,

// Initialize state and words
uint64_t state = 0;
std::vector<uint32_t> words; // Use dynamic allocation instead of preallocating
words.reserve(signal_size / 8); // Reserve a reasonable estimate to avoid frequent reallocations
// Use dynamic allocation instead of preallocating
std::vector<uint32_t> words;
// Reserve a reasonable estimate to avoid frequent reallocations
words.reserve(signal_size / 8);

// Encode each symbol
auto SHIFT = STATE_BITS - PRECISION_BITS;

for (size_t i = 0; i < signal_size; ++i)
const auto encode_symbol = [&](size_t s_ind) constexpr noexcept
{
// Cache frequently accessed symbol data to avoid repeated array lookups
const uint32_t F_s = symbol_counts[s_ind];
const uint32_t C_s = C[s_ind];

// Check if we need to normalize
if ((state >> SHIFT) >= F_s)
{
words.emplace_back(state & MASK_WORD);
state >>= WORD_BITS;
}

// Update state using libdivide for faster division
const uint64_t prefix = state / fast_dividers[s_ind];
const uint64_t remainder = state - prefix * F_s;
state = (prefix << PRECISION_BITS) | (C_s + remainder);
};

if (use_lookup_array)
{
// Symbol index lookup
size_t s_ind;
if (use_lookup_array)
for (size_t i = 0; i < signal_size; ++i)
{
const int64_t lookup_ind = signal[i] - min_symbol;
if (lookup_ind < 0 || lookup_ind >= lookup_array_threshold)
if (lookup_ind < 0 || lookup_ind >= lookup_array_threshold) [[unlikely]]
{
throw std::invalid_argument("Signal value not found in symbol_values");
}
s_ind = symbol_index_lookup_array[lookup_ind];
if (s_ind == std::numeric_limits<size_t>::max())
const size_t s_ind = symbol_index_lookup_array[lookup_ind];
if (s_ind == std::numeric_limits<size_t>::max()) [[unlikely]]
{
throw std::invalid_argument("Signal value not found in symbol_values");
}
assert(s_ind == symbol_index_lookup[signal[i]]);

encode_symbol(s_ind);
}
else
}
else
{
for (size_t i = 0; i < signal_size; ++i)
{
auto it = symbol_index_lookup.find(signal[i]);
if (it == symbol_index_lookup.end())
const auto it = symbol_index_lookup.find(signal[i]);
if (it == symbol_index_lookup.end()) [[unlikely]]
{
throw std::invalid_argument("Signal value not found in symbol_values");
}
s_ind = it->second;
}

// Cache frequently accessed symbol data to avoid repeated array lookups
const uint32_t F_s = symbol_counts[s_ind];
const uint32_t C_s = C[s_ind];
const auto& divider = fast_dividers[s_ind];
size_t s_ind = it->second;

// Check if we need to normalize
if ((state >> SHIFT) >= F_s)
{
words.push_back(state & MASK_WORD);
state >>= WORD_BITS;
encode_symbol(s_ind);
}

// Update state using libdivide for faster division
const uint64_t prefix = state / divider;
const uint64_t remainder = state - prefix * F_s;
state = (prefix << PRECISION_BITS) | (C_s + remainder);
}

return {state, std::move(words)};
Expand All @@ -241,40 +238,25 @@ void ans_decode_t(T* output,
size_t num_symbols)
{
// very important that this is signed, because it becomes -1
int32_t word_idx = num_words - 1;
auto word_idx = static_cast<int32_t>(num_words) - 1;
// Calculate index size and verify it's a power of 2
uint32_t index_size = 0;
for (size_t i = 0; i < num_symbols; ++i)
{
index_size += symbol_counts[i];
}
const auto index_size = std::accumulate(symbol_counts, symbol_counts + num_symbols, 0u);
if (!is_power_of_2(index_size))
{
throw std::invalid_argument("L must be a power of 2");
}

int PRECISION_BITS = 0;
while ((1U << PRECISION_BITS) < index_size)
{
PRECISION_BITS++;
}

const auto PRECISION_BITS = std::bit_width(index_size - 1);
// Pre-compute cumulative sums
std::vector<uint32_t> C(num_symbols);
std::partial_sum(symbol_counts, symbol_counts + num_symbols - 1, C.begin() + 1);
C[0] = 0;
for (size_t i = 1; i < num_symbols; ++i)
{
C[i] = C[i - 1] + symbol_counts[i - 1];
}

// Create symbol lookup table
std::vector<uint32_t> symbol_lookup(index_size);
for (size_t s = 0; s < num_symbols; ++s)
{
for (uint32_t j = 0; j < symbol_counts[s]; ++j)
{
symbol_lookup[C[s] + j] = s;
}
std::fill_n(symbol_lookup.begin() + C[s], symbol_counts[s], s);
}

// Decode symbols in reverse order
Expand Down