Skip to content

Commit 153f663

Browse files
Joe Jevnikllllllllll
authored andcommitted
BUG: fix label array code dtype condense
1 parent fcfc06e commit 153f663

File tree

2 files changed

+65
-28
lines changed

2 files changed

+65
-28
lines changed

tests/test_labelarray.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -341,25 +341,30 @@ def test_setitem_array(self):
341341
arr[:] = orig_arr
342342
check_arrays(arr, orig_arr)
343343

344-
def test_narrow_code_storage(self):
345-
def check_roundtrip(arr):
346-
assert_equal(
344+
@staticmethod
345+
def check_roundtrip(arr):
346+
assert_equal(
347+
arr.as_string_array(),
348+
LabelArray(
347349
arr.as_string_array(),
348-
LabelArray(
349-
arr.as_string_array(),
350-
arr.missing_value,
351-
).as_string_array(),
350+
arr.missing_value,
351+
).as_string_array(),
352+
)
353+
354+
@staticmethod
355+
def create_categories(width, plus_one):
356+
length = int(width / 8) + plus_one
357+
return [
358+
''.join(cs)
359+
for cs in take(
360+
2 ** width + plus_one,
361+
product([chr(c) for c in range(256)], repeat=length),
352362
)
363+
]
353364

354-
def create_categories(width, plus_one):
355-
length = int(width / 8) + plus_one
356-
return [
357-
''.join(cs)
358-
for cs in take(
359-
2 ** width + plus_one,
360-
product([chr(c) for c in range(256)], repeat=length),
361-
)
362-
]
365+
def test_narrow_code_storage(self):
366+
create_categories = self.create_categories
367+
check_roundtrip = self.check_roundtrip
363368

364369
# uint8
365370
categories = create_categories(8, plus_one=False)
@@ -386,11 +391,6 @@ def create_categories(width, plus_one):
386391
self.assertEqual(arr.itemsize, 2)
387392
check_roundtrip(arr)
388393

389-
# uint16 inference
390-
arr = LabelArray(categories, missing_value=categories[0])
391-
self.assertEqual(arr.itemsize, 2)
392-
check_roundtrip(arr)
393-
394394
# fits in uint16
395395
categories = create_categories(16, plus_one=False)
396396
arr = LabelArray(
@@ -422,3 +422,26 @@ def create_categories(width, plus_one):
422422

423423
# NOTE: we could do this for 32 and 64; however, no one has enough RAM
424424
# or time for that.
425+
426+
def test_narrow_condense_back_to_valid_size(self):
427+
categories = ['a'] * (2 ** 8 + 1)
428+
arr = LabelArray(categories, missing_value=categories[0])
429+
assert_equal(arr.itemsize, 1)
430+
self.check_roundtrip(arr)
431+
432+
# longer than int16 but still fits when deduped
433+
categories = self.create_categories(16, plus_one=False)
434+
categories.append(categories[0])
435+
arr = LabelArray(categories, missing_value=categories[0])
436+
assert_equal(arr.itemsize, 2)
437+
self.check_roundtrip(arr)
438+
439+
def manual_narrow_condense_back_to_valid_size_slow(self):
440+
"""This test is really slow so we don't want it run by default.
441+
"""
442+
# tests that we don't try to create an 'int24' (which is meaningless)
443+
categories = self.create_categories(24, plus_one=False)
444+
categories.append(categories[0])
445+
arr = LabelArray(categories, missing_value=categories[0])
446+
assert_equal(arr.itemsize, 4)
447+
self.check_roundtrip(arr)

zipline/lib/_factorize.pyx

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Factorization algorithms.
33
"""
4-
from libc.math cimport floor, log
4+
from libc.math cimport log
55
cimport numpy as np
66
import numpy as np
77

@@ -144,6 +144,9 @@ cdef factorize_strings_impl(np.ndarray[object] values,
144144
return codes, categories_array, reverse_categories
145145

146146

147+
cdef list _int_sizes = [1, 1, 2, 4, 4, 8, 8, 8, 8]
148+
149+
147150
cpdef factorize_strings(np.ndarray[object] values,
148151
object missing_value,
149152
int sort):
@@ -209,11 +212,22 @@ cpdef factorize_strings(np.ndarray[object] values,
209212
# unreachable
210213
raise ValueError('nvalues larger than uint64')
211214

212-
if len(categories_array) < 2 ** codes.dtype.itemsize:
213-
# if there are a lot of duplicates in the values we may need to shrink
214-
# the width of the ``codes`` array
215-
codes = codes.astype(unsigned_int_dtype_with_size_in_bytes(
216-
floor(log2(len(categories_array))),
217-
))
215+
length = len(categories_array)
216+
if length < 1:
217+
# lim x -> 0 log2(x) == -infinity so we floor at uint8
218+
narrowest_dtype = np.uint8
219+
else:
220+
# The number of bits required to hold the codes up to ``length`` is
221+
# log2(length). The number of bits per bytes is 8. We cannot have
222+
# fractional bytes so we need to round up. Finally, we can only have
223+
# integers with widths 1, 2, 4, or 8 so so we need to round up to the
224+
# next value by looking up the next largest size in ``_int_sizes``.
225+
narrowest_dtype = unsigned_int_dtype_with_size_in_bytes(
226+
_int_sizes[int(np.ceil(log2(length) / 8))]
227+
)
228+
229+
if codes.dtype != narrowest_dtype:
230+
# condense the codes down to the narrowest dtype possible
231+
codes = codes.astype(narrowest_dtype)
218232

219233
return codes, categories_array, reverse_categories

0 commit comments

Comments
 (0)