Skip to content

TST/PERF: Re-write assert_almost_equal() in cython #4398 #5219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 21, 2013
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
TST/PERF: Re-write assert_almost_equal() in cython #4398
Add a testing.pyx cython file, and port assert_almost_equal() from
python to cython.

On my machine this brings a modest gain to the suite of "not slow" tests
(160s -> 140s), but on assert_almost_equal() heavy tests, like
test_expressions.py, it shows a large improvement (14s -> 4s).
  • Loading branch information
danbirken committed Oct 21, 2013
commit 3b50b52f657a1b406ef335ced9e86adc9d145c46
1 change: 1 addition & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ See :ref:`Internal Refactoring<whatsnew_0130.refactoring>`
compatible. (:issue:`5213`, :issue:`5214`)
- Unity ``dropna`` for Series/DataFrame signature (:issue:`5250`),
tests from :issue:`5234`, courtesy of @rockg
- Rewrite assert_almost_equal() in cython for performance (:issue:`4398`)

.. _release.bug_fixes-0.13.0:

Expand Down
86 changes: 86 additions & 0 deletions pandas/src/testing.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import numpy as np

from pandas import compat
from pandas.core.common import isnull

cdef bint isiterable(obj):
return hasattr(obj, '__iter__')

cdef bint decimal_almost_equal(double desired, double actual, int decimal):
# Code from
# http://docs.scipy.org/doc/numpy/reference/generated
# /numpy.testing.assert_almost_equal.html
return abs(desired - actual) < (0.5 * 10.0 ** -decimal)

cpdef assert_dict_equal(a, b, bint compare_keys=True):
a_keys = frozenset(a.keys())
b_keys = frozenset(b.keys())

if compare_keys:
assert a_keys == b_keys

for k in a_keys:
assert_almost_equal(a[k], b[k])

return True

cpdef assert_almost_equal(a, b, bint check_less_precise=False):
cdef:
int decimal
Py_ssize_t i, na, nb
double fa, fb

if isinstance(a, dict) or isinstance(b, dict):
return assert_dict_equal(a, b)

if isinstance(a, compat.string_types):
assert a == b, "%r != %r" % (a, b)
return True

if isiterable(a):
assert isiterable(b), "First object is iterable, second isn't"
na, nb = len(a), len(b)
assert na == nb, "%s != %s" % (na, nb)
if (isinstance(a, np.ndarray) and
isinstance(b, np.ndarray) and
np.array_equal(a, b)):
return True
else:
for i in xrange(na):
assert_almost_equal(a[i], b[i], check_less_precise)
return True

if isnull(a):
assert isnull(b), "First object is null, second isn't"
return True

if isinstance(a, (bool, float, int, np.float32)):
decimal = 5

# deal with differing dtypes
if check_less_precise:
dtype_a = np.dtype(type(a))
dtype_b = np.dtype(type(b))
if dtype_a.kind == 'f' and dtype_b == 'f':
if dtype_a.itemsize <= 4 and dtype_b.itemsize <= 4:
decimal = 3

if np.isinf(a):
assert np.isinf(b), "First object is inf, second isn't"
else:
fa, fb = a, b

# case for zero
if abs(fa) < 1e-5:
if not decimal_almost_equal(fa, fb, decimal):
assert False, (
'(very low values) expected %.5f but got %.5f' % (b, a)
)
else:
if not decimal_almost_equal(1, fb / fa, decimal):
assert False, 'expected %.5f but got %.5f' % (b, a)

else:
assert a == b, "%s != %s" % (a, b)

return True
73 changes: 7 additions & 66 deletions pandas/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from pandas.tseries.index import DatetimeIndex
from pandas.tseries.period import PeriodIndex

from pandas import _testing

from pandas.io.common import urlopen

Index = index.Index
Expand All @@ -50,6 +52,11 @@
K = 4
_RAISE_NETWORK_ERROR_DEFAULT = False

# NOTE: don't pass an NDFrame or index to this function - may not handle it
# well.
assert_almost_equal = _testing.assert_almost_equal

assert_dict_equal = _testing.assert_dict_equal

def randbool(size=(), p=0.5):
return rand(*size) <= p
Expand Down Expand Up @@ -374,75 +381,9 @@ def assert_attr_equal(attr, left, right):
def isiterable(obj):
return hasattr(obj, '__iter__')


# NOTE: don't pass an NDFrame or index to this function - may not handle it
# well.
def assert_almost_equal(a, b, check_less_precise=False):
if isinstance(a, dict) or isinstance(b, dict):
return assert_dict_equal(a, b)

if isinstance(a, compat.string_types):
assert a == b, "%r != %r" % (a, b)
return True

if isiterable(a):
np.testing.assert_(isiterable(b))
na, nb = len(a), len(b)
assert na == nb, "%s != %s" % (na, nb)
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray) and\
np.array_equal(a, b):
return True
else:
for i in range(na):
assert_almost_equal(a[i], b[i], check_less_precise)
return True

err_msg = lambda a, b: 'expected %.5f but got %.5f' % (b, a)

if isnull(a):
np.testing.assert_(isnull(b))
return

if isinstance(a, (bool, float, int, np.float32)):
decimal = 5

# deal with differing dtypes
if check_less_precise:
dtype_a = np.dtype(type(a))
dtype_b = np.dtype(type(b))
if dtype_a.kind == 'f' and dtype_b == 'f':
if dtype_a.itemsize <= 4 and dtype_b.itemsize <= 4:
decimal = 3

if np.isinf(a):
assert np.isinf(b), err_msg(a, b)

# case for zero
elif abs(a) < 1e-5:
np.testing.assert_almost_equal(
a, b, decimal=decimal, err_msg=err_msg(a, b), verbose=False)
else:
np.testing.assert_almost_equal(
1, a / b, decimal=decimal, err_msg=err_msg(a, b), verbose=False)
else:
assert a == b, "%s != %s" % (a, b)


def is_sorted(seq):
return assert_almost_equal(seq, np.sort(np.array(seq)))


def assert_dict_equal(a, b, compare_keys=True):
a_keys = frozenset(a.keys())
b_keys = frozenset(b.keys())

if compare_keys:
assert(a_keys == b_keys)

for k in a_keys:
assert_almost_equal(a[k], b[k])


def assert_series_equal(left, right, check_dtype=True,
check_index_type=False,
check_series_type=False,
Expand Down
10 changes: 9 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ class CheckSDist(sdist):
'pandas/index.pyx',
'pandas/algos.pyx',
'pandas/parser.pyx',
'pandas/src/sparse.pyx']
'pandas/src/sparse.pyx',
'pandas/src/testing.pyx']

def initialize_options(self):
sdist.initialize_options(self)
Expand Down Expand Up @@ -464,6 +465,13 @@ def pxd(name):

extensions.extend([sparse_ext])

testing_ext = Extension('pandas._testing',
sources=[srcpath('testing', suffix=suffix)],
include_dirs=[],
libraries=libraries)

extensions.extend([testing_ext])

#----------------------------------------------------------------------
# msgpack stuff here

Expand Down