11#pragma once
22
3+ #include < algorithm>
4+ #include < bit>
35#include < cassert>
4- #include < cstdint>
5- #include < cstdio>
6+ #include < execution>
67#include < limits>
8+ #include < numeric>
79#include < stdexcept>
810#include < vector>
11+
912#include < tsl/robin_map.h>
1013
11- #include " libdivide.h"
14+ #include < libdivide.h>
1215
1316namespace simple_ans
1417{
@@ -20,7 +23,7 @@ struct EncodedData
2023};
2124
2225// Helper function to verify if a number is a power of 2
23- inline bool is_power_of_2 (uint32_t x)
26+ constexpr bool is_power_of_2 (uint32_t x)
2427{
2528 return x && !(x & (x - 1 ));
2629}
@@ -113,29 +116,17 @@ EncodedData ans_encode_t(const T* signal,
113116 " Value range of T must fit in int64_t for table lookup" );
114117
115118 // Calculate L and verify it's a power of 2
116- uint32_t index_size = 0 ;
117- for (size_t i = 0 ; i < num_symbols; ++i)
118- {
119- index_size += symbol_counts[i];
120- }
119+ const uint32_t index_size = std::reduce (symbol_counts, symbol_counts + num_symbols, 0u );
121120 if (!is_power_of_2 (index_size))
122121 {
123122 throw std::invalid_argument (" L must be a power of 2" );
124123 }
125124
126- int PRECISION_BITS = 0 ;
127- while ((1U << PRECISION_BITS) < index_size)
128- {
129- PRECISION_BITS++;
130- }
125+ const auto PRECISION_BITS = std::bit_width (index_size - 1 );
131126
132127 // Pre-compute cumulative sums
133128 std::vector<uint32_t > C (num_symbols);
134- C[0 ] = 0 ;
135- for (size_t i = 1 ; i < num_symbols; ++i)
136- {
137- C[i] = C[i - 1 ] + symbol_counts[i - 1 ];
138- }
129+ std::exclusive_scan (symbol_counts, symbol_counts + num_symbols, C.begin (), 0u );
139130
140131 // Precompute libdivide dividers for each symbol count
141132 std::vector<libdivide::divider<uint64_t >> fast_dividers (num_symbols);
@@ -157,15 +148,12 @@ EncodedData ans_encode_t(const T* signal,
157148 }
158149
159150 // Map lookups can be a bottleneck, so we use a lookup array if the number of symbols is "small"
160- const bool use_lookup_array = (max_symbol - min_symbol + 1 ) <= lookup_array_threshold;
161- std::vector<size_t > symbol_index_lookup_array;
151+ const auto array_size = max_symbol - min_symbol + 1 ;
152+ const bool use_lookup_array = array_size <= lookup_array_threshold;
153+ std::vector<size_t > symbol_index_lookup_array (0 );
162154 if (use_lookup_array)
163155 {
164- symbol_index_lookup_array.resize (max_symbol - min_symbol + 1 );
165-
166- std::fill (symbol_index_lookup_array.begin (),
167- symbol_index_lookup_array.end (),
168- std::numeric_limits<size_t >::max ());
156+ symbol_index_lookup_array.resize (array_size);
169157
170158 for (size_t i = 0 ; i < num_symbols; ++i)
171159 {
@@ -241,40 +229,25 @@ void ans_decode_t(T* output,
241229 size_t num_symbols)
242230{
243231 // very important that this is signed, because it becomes -1
244- int32_t word_idx = num_words - 1 ;
232+ auto word_idx = static_cast < int32_t >( num_words) - 1 ;
245233 // Calculate index size and verify it's a power of 2
246- uint32_t index_size = 0 ;
247- for (size_t i = 0 ; i < num_symbols; ++i)
248- {
249- index_size += symbol_counts[i];
250- }
234+ const uint32_t index_size = std::reduce (symbol_counts, symbol_counts + num_symbols, 0u );
251235 if (!is_power_of_2 (index_size))
252236 {
253237 throw std::invalid_argument (" L must be a power of 2" );
254238 }
255239
256- int PRECISION_BITS = 0 ;
257- while ((1U << PRECISION_BITS) < index_size)
258- {
259- PRECISION_BITS++;
260- }
240+ const auto PRECISION_BITS = std::bit_width (index_size - 1 );
261241
262242 // Pre-compute cumulative sums
263243 std::vector<uint32_t > C (num_symbols);
264- C[0 ] = 0 ;
265- for (size_t i = 1 ; i < num_symbols; ++i)
266- {
267- C[i] = C[i - 1 ] + symbol_counts[i - 1 ];
268- }
244+ std::exclusive_scan (symbol_counts, symbol_counts + num_symbols, C.begin (), 0u );
269245
270246 // Create symbol lookup table
271247 std::vector<uint32_t > symbol_lookup (index_size);
272- for (size_t s = 0 ; s < num_symbols; ++s)
248+ for (uint32_t s = 0 ; s < num_symbols; ++s)
273249 {
274- for (uint32_t j = 0 ; j < symbol_counts[s]; ++j)
275- {
276- symbol_lookup[C[s] + j] = s;
277- }
250+ std::fill_n (symbol_lookup.begin () + C[s], symbol_counts[s], s);
278251 }
279252
280253 // Decode symbols in reverse order
0 commit comments