Skip to content

Commit

Permalink
BUG: fix label array code dtype condense
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Jevnik authored and llllllllll committed Mar 9, 2017
1 parent fcfc06e commit 153f663
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 28 deletions.
65 changes: 44 additions & 21 deletions tests/test_labelarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,25 +341,30 @@ def test_setitem_array(self):
arr[:] = orig_arr
check_arrays(arr, orig_arr)

def test_narrow_code_storage(self):
def check_roundtrip(arr):
assert_equal(
@staticmethod
def check_roundtrip(arr):
assert_equal(
arr.as_string_array(),
LabelArray(
arr.as_string_array(),
LabelArray(
arr.as_string_array(),
arr.missing_value,
).as_string_array(),
arr.missing_value,
).as_string_array(),
)

@staticmethod
def create_categories(width, plus_one):
length = int(width / 8) + plus_one
return [
''.join(cs)
for cs in take(
2 ** width + plus_one,
product([chr(c) for c in range(256)], repeat=length),
)
]

def create_categories(width, plus_one):
length = int(width / 8) + plus_one
return [
''.join(cs)
for cs in take(
2 ** width + plus_one,
product([chr(c) for c in range(256)], repeat=length),
)
]
def test_narrow_code_storage(self):
create_categories = self.create_categories
check_roundtrip = self.check_roundtrip

# uint8
categories = create_categories(8, plus_one=False)
Expand All @@ -386,11 +391,6 @@ def create_categories(width, plus_one):
self.assertEqual(arr.itemsize, 2)
check_roundtrip(arr)

# uint16 inference
arr = LabelArray(categories, missing_value=categories[0])
self.assertEqual(arr.itemsize, 2)
check_roundtrip(arr)

# fits in uint16
categories = create_categories(16, plus_one=False)
arr = LabelArray(
Expand Down Expand Up @@ -422,3 +422,26 @@ def create_categories(width, plus_one):

# NOTE: we could do this for 32 and 64; however, no one has enough RAM
# or time for that.

def test_narrow_condense_back_to_valid_size(self):
categories = ['a'] * (2 ** 8 + 1)
arr = LabelArray(categories, missing_value=categories[0])
assert_equal(arr.itemsize, 1)
self.check_roundtrip(arr)

# longer than int16 but still fits when deduped
categories = self.create_categories(16, plus_one=False)
categories.append(categories[0])
arr = LabelArray(categories, missing_value=categories[0])
assert_equal(arr.itemsize, 2)
self.check_roundtrip(arr)

def manual_narrow_condense_back_to_valid_size_slow(self):
"""This test is really slow so we don't want it run by default.
"""
# tests that we don't try to create an 'int24' (which is meaningless)
categories = self.create_categories(24, plus_one=False)
categories.append(categories[0])
arr = LabelArray(categories, missing_value=categories[0])
assert_equal(arr.itemsize, 4)
self.check_roundtrip(arr)
28 changes: 21 additions & 7 deletions zipline/lib/_factorize.pyx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Factorization algorithms.
"""
from libc.math cimport floor, log
from libc.math cimport log
cimport numpy as np
import numpy as np

Expand Down Expand Up @@ -144,6 +144,9 @@ cdef factorize_strings_impl(np.ndarray[object] values,
return codes, categories_array, reverse_categories


cdef list _int_sizes = [1, 1, 2, 4, 4, 8, 8, 8, 8]


cpdef factorize_strings(np.ndarray[object] values,
object missing_value,
int sort):
Expand Down Expand Up @@ -209,11 +212,22 @@ cpdef factorize_strings(np.ndarray[object] values,
# unreachable
raise ValueError('nvalues larger than uint64')

if len(categories_array) < 2 ** codes.dtype.itemsize:
# if there are a lot of duplicates in the values we may need to shrink
# the width of the ``codes`` array
codes = codes.astype(unsigned_int_dtype_with_size_in_bytes(
floor(log2(len(categories_array))),
))
length = len(categories_array)
if length < 1:
# lim x -> 0 log2(x) == -infinity so we floor at uint8
narrowest_dtype = np.uint8
else:
# The number of bits required to hold the codes up to ``length`` is
# log2(length). The number of bits per bytes is 8. We cannot have
# fractional bytes so we need to round up. Finally, we can only have
# integers with widths 1, 2, 4, or 8 so so we need to round up to the
# next value by looking up the next largest size in ``_int_sizes``.
narrowest_dtype = unsigned_int_dtype_with_size_in_bytes(
_int_sizes[int(np.ceil(log2(length) / 8))]
)

if codes.dtype != narrowest_dtype:
# condense the codes down to the narrowest dtype possible
codes = codes.astype(narrowest_dtype)

return codes, categories_array, reverse_categories

0 comments on commit 153f663

Please sign in to comment.