Skip to content

Commit e211b5d

Browse files
committed
Alternative API with only a single generic argument for BitTree.
The single-argument that BitTree takes is 1 << NUM_BITS (2 ** NUM_BITS) for the number of bits required in the tree. This is due to restrictions on const generic expressions. The validity of this argument is checked at compile-time with a macro that confirms that the argument P passed is indeed 1 << N for some N using usize::trailing_zeros to calculate floor(log_2(P)). Thus, BitTree<const P: usize> is only valid for any P such that P = 2 ** floor(log_2(P)), where P is the length of the probability array of the BitTree. This maintains the invariant that P = 1 << N.
1 parent 08b6794 commit e211b5d

File tree

3 files changed

+48
-32
lines changed

3 files changed

+48
-32
lines changed

src/decode/lzma.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ pub(crate) struct DecoderState {
167167
pub(crate) lzma_props: LzmaProperties,
168168
unpacked_size: Option<u64>,
169169
literal_probs: Vec2D<u16>,
170-
pos_slot_decoder: [BitTree<6, { 1 << 6 }>; 4],
171-
align_decoder: BitTree<4, { 1 << 4 }>,
170+
pos_slot_decoder: [BitTree<{ 1 << 6 }>; 4],
171+
align_decoder: BitTree<{ 1 << 4 }>,
172172
pos_decoders: [u16; 115],
173173
is_match: [u16; 192], // true = LZ, false = literal
174174
is_rep: [u16; 12],

src/decode/rangecoder.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,42 +152,51 @@ where
152152
}
153153

154154
#[derive(Debug, Clone)]
155-
pub struct BitTree<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> {
155+
pub struct BitTree<const PROBS_ARRAY_LEN: usize> {
156156
probs: [u16; PROBS_ARRAY_LEN],
157157
}
158158

159-
impl<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> BitTree<NUM_BITS, PROBS_ARRAY_LEN> {
159+
impl<const PROBS_ARRAY_LEN: usize> BitTree<PROBS_ARRAY_LEN> {
160160
pub fn new() -> Self {
161-
const_assert!(NUM_BITS: usize, PROBS_ARRAY_LEN: usize => PROBS_ARRAY_LEN == 1 << NUM_BITS);
161+
// The validity of PROBS_ARRAY_LEN is checked at compile-time with a macro
162+
// that confirms that the argument P passed is indeed 1 << N for
163+
// some N using usize::trailing_zeros to calculate floor(log_2(P)).
164+
//
165+
// Thus, BitTree<const P: usize> is only valid for any P such that
166+
// P = 2 ** floor(log_2(P)), where P is the length of the probability array
167+
// of the BitTree. This maintains the invariant that P = 1 << N.
168+
const_assert!(PROBS_ARRAY_LEN: usize => (1 << (PROBS_ARRAY_LEN.trailing_zeros() as usize)) == PROBS_ARRAY_LEN);
162169
BitTree {
163170
probs: [0x400; PROBS_ARRAY_LEN],
164171
}
165172
}
166173

174+
const NUM_BITS: usize = PROBS_ARRAY_LEN.trailing_zeros() as usize;
175+
167176
pub fn parse<R: io::BufRead>(
168177
&mut self,
169178
rangecoder: &mut RangeDecoder<R>,
170179
update: bool,
171180
) -> io::Result<u32> {
172-
rangecoder.parse_bit_tree(NUM_BITS, &mut self.probs, update)
181+
rangecoder.parse_bit_tree(Self::NUM_BITS, &mut self.probs, update)
173182
}
174183

175184
pub fn parse_reverse<R: io::BufRead>(
176185
&mut self,
177186
rangecoder: &mut RangeDecoder<R>,
178187
update: bool,
179188
) -> io::Result<u32> {
180-
rangecoder.parse_reverse_bit_tree(NUM_BITS, &mut self.probs, 0, update)
189+
rangecoder.parse_reverse_bit_tree(Self::NUM_BITS, &mut self.probs, 0, update)
181190
}
182191
}
183192

184193
#[derive(Debug)]
185194
pub struct LenDecoder {
186195
choice: u16,
187196
choice2: u16,
188-
low_coder: [BitTree<3, { 1 << 3 }>; 16],
189-
mid_coder: [BitTree<3, { 1 << 3 }>; 16],
190-
high_coder: BitTree<8, { 1 << 8 }>,
197+
low_coder: [BitTree<{ 1 << 3 }>; 16],
198+
mid_coder: [BitTree<{ 1 << 3 }>; 16],
199+
high_coder: BitTree<{ 1 << 8 }>,
191200
}
192201

193202
impl LenDecoder {

src/encode/rangecoder.rs

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -145,43 +145,52 @@ where
145145

146146
#[cfg(test)]
147147
#[derive(Debug, Clone)]
148-
pub struct BitTree<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> {
148+
pub struct BitTree<const PROBS_ARRAY_LEN: usize> {
149149
probs: [u16; PROBS_ARRAY_LEN],
150150
}
151151

152152
#[cfg(test)]
153-
impl<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> BitTree<NUM_BITS, PROBS_ARRAY_LEN> {
153+
impl<const PROBS_ARRAY_LEN: usize> BitTree<PROBS_ARRAY_LEN> {
154154
pub fn new() -> Self {
155-
const_assert!(NUM_BITS: usize, PROBS_ARRAY_LEN: usize => PROBS_ARRAY_LEN == 1 << NUM_BITS);
155+
// The validity of PROBS_ARRAY_LEN is checked at compile-time with a macro
156+
// that confirms that the argument P passed is indeed 1 << N for
157+
// some N using usize::trailing_zeros to calculate floor(log_2(P)).
158+
//
159+
// Thus, BitTree<const P: usize> is only valid for any P such that
160+
// P = 2 ** floor(log_2(P)), where P is the length of the probability array
161+
// of the BitTree. This maintains the invariant that P = 1 << N.
162+
const_assert!(PROBS_ARRAY_LEN: usize => (1 << (PROBS_ARRAY_LEN.trailing_zeros() as usize)) == PROBS_ARRAY_LEN);
156163
BitTree {
157164
probs: [0x400; PROBS_ARRAY_LEN],
158165
}
159166
}
160167

168+
const NUM_BITS: usize = PROBS_ARRAY_LEN.trailing_zeros() as usize;
169+
161170
pub fn encode<W: io::Write>(
162171
&mut self,
163172
rangecoder: &mut RangeEncoder<W>,
164173
value: u32,
165174
) -> io::Result<()> {
166-
rangecoder.encode_bit_tree(NUM_BITS, self.probs.as_mut_slice(), value)
175+
rangecoder.encode_bit_tree(Self::NUM_BITS, self.probs.as_mut_slice(), value)
167176
}
168177

169178
pub fn encode_reverse<W: io::Write>(
170179
&mut self,
171180
rangecoder: &mut RangeEncoder<W>,
172181
value: u32,
173182
) -> io::Result<()> {
174-
rangecoder.encode_reverse_bit_tree(NUM_BITS, self.probs.as_mut_slice(), 0, value)
183+
rangecoder.encode_reverse_bit_tree(Self::NUM_BITS, self.probs.as_mut_slice(), 0, value)
175184
}
176185
}
177186

178187
#[cfg(test)]
179188
pub struct LenEncoder {
180189
choice: u16,
181190
choice2: u16,
182-
low_coder: [BitTree<3, { 1 << 3 }>; 16],
183-
mid_coder: [BitTree<3, { 1 << 3 }>; 16],
184-
high_coder: BitTree<8, { 1 << 8 }>,
191+
low_coder: [BitTree<{ 1 << 3 }>; 16],
192+
mid_coder: [BitTree<{ 1 << 3 }>; 16],
193+
high_coder: BitTree<{ 1 << 8 }>,
185194
}
186195

187196
#[cfg(test)]
@@ -289,19 +298,19 @@ mod test {
289298
encode_decode(0x400, &[true; 10000]);
290299
}
291300

292-
fn encode_decode_bittree<const NUM_BITS: usize, const PROBS_LEN: usize>(values: &[u32]) {
301+
fn encode_decode_bittree<const PROBS_LEN: usize>(values: &[u32]) {
293302
let mut buf: Vec<u8> = Vec::new();
294303

295304
let mut encoder = RangeEncoder::new(&mut buf);
296-
let mut tree = encode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
305+
let mut tree = encode::rangecoder::BitTree::<PROBS_LEN>::new();
297306
for &v in values {
298307
tree.encode(&mut encoder, v).unwrap();
299308
}
300309
encoder.finish().unwrap();
301310

302311
let mut bufread = BufReader::new(buf.as_slice());
303312
let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
304-
let mut tree = decode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
313+
let mut tree = decode::rangecoder::BitTree::<PROBS_LEN>::new();
305314
for &v in values {
306315
assert_eq!(tree.parse(&mut decoder, true).unwrap(), v);
307316
}
@@ -311,15 +320,15 @@ mod test {
311320
#[test]
312321
fn test_encode_decode_bittree_zeros() {
313322
seq!(NUM_BITS in 0..16 {
314-
encode_decode_bittree::<NUM_BITS, {1 << NUM_BITS}>
323+
encode_decode_bittree::<{1 << NUM_BITS}>
315324
(&[0; 10000]);
316325
});
317326
}
318327

319328
#[test]
320329
fn test_encode_decode_bittree_ones() {
321330
seq!(NUM_BITS in 0..16 {
322-
encode_decode_bittree::<NUM_BITS, {1 << NUM_BITS}>
331+
encode_decode_bittree::<{1 << NUM_BITS}>
323332
(&[(1 << NUM_BITS) - 1; 10000]);
324333
});
325334
}
@@ -329,26 +338,24 @@ mod test {
329338
seq!(NUM_BITS in 0..16 {
330339
let max = 1 << NUM_BITS;
331340
let values: Vec<u32> = (0..max).collect();
332-
encode_decode_bittree::<NUM_BITS, {1 << NUM_BITS}>
341+
encode_decode_bittree::<{1 << NUM_BITS}>
333342
(&values);
334343
});
335344
}
336345

337-
fn encode_decode_reverse_bittree<const NUM_BITS: usize, const PROBS_LEN: usize>(
338-
values: &[u32],
339-
) {
346+
fn encode_decode_reverse_bittree<const PROBS_LEN: usize>(values: &[u32]) {
340347
let mut buf: Vec<u8> = Vec::new();
341348

342349
let mut encoder = RangeEncoder::new(&mut buf);
343-
let mut tree = encode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
350+
let mut tree = encode::rangecoder::BitTree::<PROBS_LEN>::new();
344351
for &v in values {
345352
tree.encode_reverse(&mut encoder, v).unwrap();
346353
}
347354
encoder.finish().unwrap();
348355

349356
let mut bufread = BufReader::new(buf.as_slice());
350357
let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
351-
let mut tree = decode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
358+
let mut tree = decode::rangecoder::BitTree::<PROBS_LEN>::new();
352359
for &v in values {
353360
assert_eq!(tree.parse_reverse(&mut decoder, true).unwrap(), v);
354361
}
@@ -358,15 +365,15 @@ mod test {
358365
#[test]
359366
fn test_encode_decode_reverse_bittree_zeros() {
360367
seq!(NUM_BITS in 0..16 {
361-
encode_decode_reverse_bittree::<NUM_BITS, {1 << NUM_BITS}>
368+
encode_decode_reverse_bittree::<{1 << NUM_BITS}>
362369
(&[0; 10000]);
363370
});
364371
}
365372

366373
#[test]
367374
fn test_encode_decode_reverse_bittree_ones() {
368375
seq!(NUM_BITS in 0..16 {
369-
encode_decode_reverse_bittree::<NUM_BITS, {1 << NUM_BITS}>
376+
encode_decode_reverse_bittree::<{1 << NUM_BITS}>
370377
(&[(1 << NUM_BITS) - 1; 10000]);
371378
});
372379
}
@@ -376,7 +383,7 @@ mod test {
376383
seq!(NUM_BITS in 0..16 {
377384
let max = 1 << NUM_BITS;
378385
let values: Vec<u32> = (0..max).collect();
379-
encode_decode_reverse_bittree::<NUM_BITS, {1 << NUM_BITS}>
386+
encode_decode_reverse_bittree::<{1 << NUM_BITS}>
380387
(&values);
381388
});
382389
}

0 commit comments

Comments
 (0)