@@ -145,43 +145,52 @@ where
145
145
146
146
#[ cfg( test) ]
147
147
#[ derive( Debug , Clone ) ]
148
- pub struct BitTree < const NUM_BITS : usize , const PROBS_ARRAY_LEN : usize > {
148
+ pub struct BitTree < const PROBS_ARRAY_LEN : usize > {
149
149
probs : [ u16 ; PROBS_ARRAY_LEN ] ,
150
150
}
151
151
152
152
#[ cfg( test) ]
153
- impl < const NUM_BITS : usize , const PROBS_ARRAY_LEN : usize > BitTree < NUM_BITS , PROBS_ARRAY_LEN > {
153
+ impl < const PROBS_ARRAY_LEN : usize > BitTree < PROBS_ARRAY_LEN > {
154
154
pub fn new ( ) -> Self {
155
- const_assert ! ( NUM_BITS : usize , PROBS_ARRAY_LEN : usize => PROBS_ARRAY_LEN == 1 << NUM_BITS ) ;
155
+ // The validity of PROBS_ARRAY_LEN is checked at compile-time with a macro
156
+ // that confirms that the argument P passed is indeed 1 << N for
157
+ // some N using usize::trailing_zeros to calculate floor(log_2(P)).
158
+ //
159
+ // Thus, BitTree<const P: usize> is only valid for any P such that
160
+ // P = 2 ** floor(log_2(P)), where P is the length of the probability array
161
+ // of the BitTree. This maintains the invariant that P = 1 << N.
162
+ const_assert ! ( PROBS_ARRAY_LEN : usize => ( 1 << ( PROBS_ARRAY_LEN . trailing_zeros( ) as usize ) ) == PROBS_ARRAY_LEN ) ;
156
163
BitTree {
157
164
probs : [ 0x400 ; PROBS_ARRAY_LEN ] ,
158
165
}
159
166
}
160
167
168
+ const NUM_BITS : usize = PROBS_ARRAY_LEN . trailing_zeros ( ) as usize ;
169
+
161
170
pub fn encode < W : io:: Write > (
162
171
& mut self ,
163
172
rangecoder : & mut RangeEncoder < W > ,
164
173
value : u32 ,
165
174
) -> io:: Result < ( ) > {
166
- rangecoder. encode_bit_tree ( NUM_BITS , self . probs . as_mut_slice ( ) , value)
175
+ rangecoder. encode_bit_tree ( Self :: NUM_BITS , self . probs . as_mut_slice ( ) , value)
167
176
}
168
177
169
178
pub fn encode_reverse < W : io:: Write > (
170
179
& mut self ,
171
180
rangecoder : & mut RangeEncoder < W > ,
172
181
value : u32 ,
173
182
) -> io:: Result < ( ) > {
174
- rangecoder. encode_reverse_bit_tree ( NUM_BITS , self . probs . as_mut_slice ( ) , 0 , value)
183
+ rangecoder. encode_reverse_bit_tree ( Self :: NUM_BITS , self . probs . as_mut_slice ( ) , 0 , value)
175
184
}
176
185
}
177
186
178
187
#[ cfg( test) ]
179
188
pub struct LenEncoder {
180
189
choice : u16 ,
181
190
choice2 : u16 ,
182
- low_coder : [ BitTree < 3 , { 1 << 3 } > ; 16 ] ,
183
- mid_coder : [ BitTree < 3 , { 1 << 3 } > ; 16 ] ,
184
- high_coder : BitTree < 8 , { 1 << 8 } > ,
191
+ low_coder : [ BitTree < { 1 << 3 } > ; 16 ] ,
192
+ mid_coder : [ BitTree < { 1 << 3 } > ; 16 ] ,
193
+ high_coder : BitTree < { 1 << 8 } > ,
185
194
}
186
195
187
196
#[ cfg( test) ]
@@ -289,19 +298,19 @@ mod test {
289
298
encode_decode ( 0x400 , & [ true ; 10000 ] ) ;
290
299
}
291
300
292
- fn encode_decode_bittree < const NUM_BITS : usize , const PROBS_LEN : usize > ( values : & [ u32 ] ) {
301
+ fn encode_decode_bittree < const PROBS_LEN : usize > ( values : & [ u32 ] ) {
293
302
let mut buf: Vec < u8 > = Vec :: new ( ) ;
294
303
295
304
let mut encoder = RangeEncoder :: new ( & mut buf) ;
296
- let mut tree = encode:: rangecoder:: BitTree :: < NUM_BITS , PROBS_LEN > :: new ( ) ;
305
+ let mut tree = encode:: rangecoder:: BitTree :: < PROBS_LEN > :: new ( ) ;
297
306
for & v in values {
298
307
tree. encode ( & mut encoder, v) . unwrap ( ) ;
299
308
}
300
309
encoder. finish ( ) . unwrap ( ) ;
301
310
302
311
let mut bufread = BufReader :: new ( buf. as_slice ( ) ) ;
303
312
let mut decoder = RangeDecoder :: new ( & mut bufread) . unwrap ( ) ;
304
- let mut tree = decode:: rangecoder:: BitTree :: < NUM_BITS , PROBS_LEN > :: new ( ) ;
313
+ let mut tree = decode:: rangecoder:: BitTree :: < PROBS_LEN > :: new ( ) ;
305
314
for & v in values {
306
315
assert_eq ! ( tree. parse( & mut decoder, true ) . unwrap( ) , v) ;
307
316
}
@@ -311,15 +320,15 @@ mod test {
311
320
#[ test]
312
321
fn test_encode_decode_bittree_zeros ( ) {
313
322
seq ! ( NUM_BITS in 0 ..16 {
314
- encode_decode_bittree:: <NUM_BITS , { 1 << NUM_BITS } >
323
+ encode_decode_bittree:: <{ 1 << NUM_BITS } >
315
324
( & [ 0 ; 10000 ] ) ;
316
325
} ) ;
317
326
}
318
327
319
328
#[ test]
320
329
fn test_encode_decode_bittree_ones ( ) {
321
330
seq ! ( NUM_BITS in 0 ..16 {
322
- encode_decode_bittree:: <NUM_BITS , { 1 << NUM_BITS } >
331
+ encode_decode_bittree:: <{ 1 << NUM_BITS } >
323
332
( & [ ( 1 << NUM_BITS ) - 1 ; 10000 ] ) ;
324
333
} ) ;
325
334
}
@@ -329,26 +338,24 @@ mod test {
329
338
seq ! ( NUM_BITS in 0 ..16 {
330
339
let max = 1 << NUM_BITS ;
331
340
let values: Vec <u32 > = ( 0 ..max) . collect( ) ;
332
- encode_decode_bittree:: <NUM_BITS , { 1 << NUM_BITS } >
341
+ encode_decode_bittree:: <{ 1 << NUM_BITS } >
333
342
( & values) ;
334
343
} ) ;
335
344
}
336
345
337
- fn encode_decode_reverse_bittree < const NUM_BITS : usize , const PROBS_LEN : usize > (
338
- values : & [ u32 ] ,
339
- ) {
346
+ fn encode_decode_reverse_bittree < const PROBS_LEN : usize > ( values : & [ u32 ] ) {
340
347
let mut buf: Vec < u8 > = Vec :: new ( ) ;
341
348
342
349
let mut encoder = RangeEncoder :: new ( & mut buf) ;
343
- let mut tree = encode:: rangecoder:: BitTree :: < NUM_BITS , PROBS_LEN > :: new ( ) ;
350
+ let mut tree = encode:: rangecoder:: BitTree :: < PROBS_LEN > :: new ( ) ;
344
351
for & v in values {
345
352
tree. encode_reverse ( & mut encoder, v) . unwrap ( ) ;
346
353
}
347
354
encoder. finish ( ) . unwrap ( ) ;
348
355
349
356
let mut bufread = BufReader :: new ( buf. as_slice ( ) ) ;
350
357
let mut decoder = RangeDecoder :: new ( & mut bufread) . unwrap ( ) ;
351
- let mut tree = decode:: rangecoder:: BitTree :: < NUM_BITS , PROBS_LEN > :: new ( ) ;
358
+ let mut tree = decode:: rangecoder:: BitTree :: < PROBS_LEN > :: new ( ) ;
352
359
for & v in values {
353
360
assert_eq ! ( tree. parse_reverse( & mut decoder, true ) . unwrap( ) , v) ;
354
361
}
@@ -358,15 +365,15 @@ mod test {
358
365
#[ test]
359
366
fn test_encode_decode_reverse_bittree_zeros ( ) {
360
367
seq ! ( NUM_BITS in 0 ..16 {
361
- encode_decode_reverse_bittree:: <NUM_BITS , { 1 << NUM_BITS } >
368
+ encode_decode_reverse_bittree:: <{ 1 << NUM_BITS } >
362
369
( & [ 0 ; 10000 ] ) ;
363
370
} ) ;
364
371
}
365
372
366
373
#[ test]
367
374
fn test_encode_decode_reverse_bittree_ones ( ) {
368
375
seq ! ( NUM_BITS in 0 ..16 {
369
- encode_decode_reverse_bittree:: <NUM_BITS , { 1 << NUM_BITS } >
376
+ encode_decode_reverse_bittree:: <{ 1 << NUM_BITS } >
370
377
( & [ ( 1 << NUM_BITS ) - 1 ; 10000 ] ) ;
371
378
} ) ;
372
379
}
@@ -376,7 +383,7 @@ mod test {
376
383
seq ! ( NUM_BITS in 0 ..16 {
377
384
let max = 1 << NUM_BITS ;
378
385
let values: Vec <u32 > = ( 0 ..max) . collect( ) ;
379
- encode_decode_reverse_bittree:: <NUM_BITS , { 1 << NUM_BITS } >
386
+ encode_decode_reverse_bittree:: <{ 1 << NUM_BITS } >
380
387
( & values) ;
381
388
} ) ;
382
389
}
0 commit comments