Skip to content

Commit af33a32

Browse files
committed
Improving the use of stdlib where possible and const-correctness for good measure
1 parent 6252263 commit af33a32

File tree

2 files changed

+25
-47
lines changed

2 files changed

+25
-47
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,10 @@ if(ipo_supported AND (CMAKE_BUILD_TYPE STREQUAL "Release"))
8484
set_target_properties(_simple_ans PROPERTIES INTERPROCEDURAL_OPTIMIZATION TRUE)
8585
endif()
8686

87+
set_target_properties(_simple_ans PROPERTIES
88+
CXX_VISIBILITY_PRESET hidden
89+
VISIBILITY_INLINES_HIDDEN ON
90+
)
91+
8792
# Install rules
8893
install(TARGETS _simple_ans LIBRARY DESTINATION simple_ans)

simple_ans/cpp/simple_ans.hpp

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
#pragma once
22

3+
#include <algorithm>
4+
#include <bit>
35
#include <cassert>
4-
#include <cstdint>
5-
#include <cstdio>
6+
#include <execution>
67
#include <limits>
8+
#include <numeric>
79
#include <stdexcept>
810
#include <vector>
11+
912
#include <tsl/robin_map.h>
1013

11-
#include "libdivide.h"
14+
#include <libdivide.h>
1215

1316
namespace simple_ans
1417
{
@@ -20,7 +23,7 @@ struct EncodedData
2023
};
2124

2225
// Helper function to verify if a number is a power of 2
23-
inline bool is_power_of_2(uint32_t x)
26+
constexpr bool is_power_of_2(uint32_t x)
2427
{
2528
return x && !(x & (x - 1));
2629
}
@@ -113,29 +116,17 @@ EncodedData ans_encode_t(const T* signal,
113116
"Value range of T must fit in int64_t for table lookup");
114117

115118
// Calculate L and verify it's a power of 2
116-
uint32_t index_size = 0;
117-
for (size_t i = 0; i < num_symbols; ++i)
118-
{
119-
index_size += symbol_counts[i];
120-
}
119+
const uint32_t index_size = std::reduce(symbol_counts, symbol_counts + num_symbols, 0u);
121120
if (!is_power_of_2(index_size))
122121
{
123122
throw std::invalid_argument("L must be a power of 2");
124123
}
125124

126-
int PRECISION_BITS = 0;
127-
while ((1U << PRECISION_BITS) < index_size)
128-
{
129-
PRECISION_BITS++;
130-
}
125+
const auto PRECISION_BITS = std::bit_width(index_size - 1);
131126

132127
// Pre-compute cumulative sums
133128
std::vector<uint32_t> C(num_symbols);
134-
C[0] = 0;
135-
for (size_t i = 1; i < num_symbols; ++i)
136-
{
137-
C[i] = C[i - 1] + symbol_counts[i - 1];
138-
}
129+
std::exclusive_scan(symbol_counts, symbol_counts + num_symbols, C.begin(), 0u);
139130

140131
// Precompute libdivide dividers for each symbol count
141132
std::vector<libdivide::divider<uint64_t>> fast_dividers(num_symbols);
@@ -157,15 +148,12 @@ EncodedData ans_encode_t(const T* signal,
157148
}
158149

159150
// Map lookups can be a bottleneck, so we use a lookup array if the number of symbols is "small"
160-
const bool use_lookup_array = (max_symbol - min_symbol + 1) <= lookup_array_threshold;
161-
std::vector<size_t> symbol_index_lookup_array;
151+
const auto array_size = max_symbol - min_symbol + 1;
152+
const bool use_lookup_array = array_size <= lookup_array_threshold;
153+
std::vector<size_t> symbol_index_lookup_array(0);
162154
if (use_lookup_array)
163155
{
164-
symbol_index_lookup_array.resize(max_symbol - min_symbol + 1);
165-
166-
std::fill(symbol_index_lookup_array.begin(),
167-
symbol_index_lookup_array.end(),
168-
std::numeric_limits<size_t>::max());
156+
symbol_index_lookup_array.resize(array_size);
169157

170158
for (size_t i = 0; i < num_symbols; ++i)
171159
{
@@ -241,40 +229,25 @@ void ans_decode_t(T* output,
241229
size_t num_symbols)
242230
{
243231
// very important that this is signed, because it becomes -1
244-
int32_t word_idx = num_words - 1;
232+
auto word_idx = static_cast<int32_t>(num_words) - 1;
245233
// Calculate index size and verify it's a power of 2
246-
uint32_t index_size = 0;
247-
for (size_t i = 0; i < num_symbols; ++i)
248-
{
249-
index_size += symbol_counts[i];
250-
}
234+
const uint32_t index_size = std::reduce(symbol_counts, symbol_counts + num_symbols, 0u);
251235
if (!is_power_of_2(index_size))
252236
{
253237
throw std::invalid_argument("L must be a power of 2");
254238
}
255239

256-
int PRECISION_BITS = 0;
257-
while ((1U << PRECISION_BITS) < index_size)
258-
{
259-
PRECISION_BITS++;
260-
}
240+
const auto PRECISION_BITS = std::bit_width(index_size - 1);
261241

262242
// Pre-compute cumulative sums
263243
std::vector<uint32_t> C(num_symbols);
264-
C[0] = 0;
265-
for (size_t i = 1; i < num_symbols; ++i)
266-
{
267-
C[i] = C[i - 1] + symbol_counts[i - 1];
268-
}
244+
std::exclusive_scan(symbol_counts, symbol_counts + num_symbols, C.begin(), 0u);
269245

270246
// Create symbol lookup table
271247
std::vector<uint32_t> symbol_lookup(index_size);
272-
for (size_t s = 0; s < num_symbols; ++s)
248+
for (uint32_t s = 0; s < num_symbols; ++s)
273249
{
274-
for (uint32_t j = 0; j < symbol_counts[s]; ++j)
275-
{
276-
symbol_lookup[C[s] + j] = s;
277-
}
250+
std::fill_n(symbol_lookup.begin() + C[s], symbol_counts[s], s);
278251
}
279252

280253
// Decode symbols in reverse order

0 commit comments

Comments
 (0)