@@ -341,25 +341,30 @@ def test_setitem_array(self):
341
341
arr [:] = orig_arr
342
342
check_arrays (arr , orig_arr )
343
343
344
- def test_narrow_code_storage (self ):
345
- def check_roundtrip (arr ):
346
- assert_equal (
344
+ @staticmethod
345
+ def check_roundtrip (arr ):
346
+ assert_equal (
347
+ arr .as_string_array (),
348
+ LabelArray (
347
349
arr .as_string_array (),
348
- LabelArray (
349
- arr .as_string_array (),
350
- arr .missing_value ,
351
- ).as_string_array (),
350
+ arr .missing_value ,
351
+ ).as_string_array (),
352
+ )
353
+
354
+ @staticmethod
355
+ def create_categories (width , plus_one ):
356
+ length = int (width / 8 ) + plus_one
357
+ return [
358
+ '' .join (cs )
359
+ for cs in take (
360
+ 2 ** width + plus_one ,
361
+ product ([chr (c ) for c in range (256 )], repeat = length ),
352
362
)
363
+ ]
353
364
354
- def create_categories (width , plus_one ):
355
- length = int (width / 8 ) + plus_one
356
- return [
357
- '' .join (cs )
358
- for cs in take (
359
- 2 ** width + plus_one ,
360
- product ([chr (c ) for c in range (256 )], repeat = length ),
361
- )
362
- ]
365
+ def test_narrow_code_storage (self ):
366
+ create_categories = self .create_categories
367
+ check_roundtrip = self .check_roundtrip
363
368
364
369
# uint8
365
370
categories = create_categories (8 , plus_one = False )
@@ -386,11 +391,6 @@ def create_categories(width, plus_one):
386
391
self .assertEqual (arr .itemsize , 2 )
387
392
check_roundtrip (arr )
388
393
389
- # uint16 inference
390
- arr = LabelArray (categories , missing_value = categories [0 ])
391
- self .assertEqual (arr .itemsize , 2 )
392
- check_roundtrip (arr )
393
-
394
394
# fits in uint16
395
395
categories = create_categories (16 , plus_one = False )
396
396
arr = LabelArray (
@@ -422,3 +422,26 @@ def create_categories(width, plus_one):
422
422
423
423
# NOTE: we could do this for 32 and 64; however, no one has enough RAM
424
424
# or time for that.
425
+
426
+ def test_narrow_condense_back_to_valid_size (self ):
427
+ categories = ['a' ] * (2 ** 8 + 1 )
428
+ arr = LabelArray (categories , missing_value = categories [0 ])
429
+ assert_equal (arr .itemsize , 1 )
430
+ self .check_roundtrip (arr )
431
+
432
+ # longer than int16 but still fits when deduped
433
+ categories = self .create_categories (16 , plus_one = False )
434
+ categories .append (categories [0 ])
435
+ arr = LabelArray (categories , missing_value = categories [0 ])
436
+ assert_equal (arr .itemsize , 2 )
437
+ self .check_roundtrip (arr )
438
+
439
+ def manual_narrow_condense_back_to_valid_size_slow (self ):
440
+ """This test is really slow so we don't want it run by default.
441
+ """
442
+ # tests that we don't try to create an 'int24' (which is meaningless)
443
+ categories = self .create_categories (24 , plus_one = False )
444
+ categories .append (categories [0 ])
445
+ arr = LabelArray (categories , missing_value = categories [0 ])
446
+ assert_equal (arr .itemsize , 4 )
447
+ self .check_roundtrip (arr )
0 commit comments