Skip to content

Commit

Permalink
refactor container core tests
Browse files Browse the repository at this point in the history
  • Loading branch information
toddrjen committed Mar 28, 2014
1 parent 978bba4 commit 6ca3712
Show file tree
Hide file tree
Showing 16 changed files with 2,049 additions and 1,689 deletions.
2 changes: 1 addition & 1 deletion neo/core/epocharray.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __repr__(self):

objs = ['%s@%s for %s' % (label, time, dur) for
label, time, dur in zip(labels, self.times, self.durations)]
return '<EventArray: %s>' % ', '.join(objs)
return '<EpochArray: %s>' % ', '.join(objs)

def merge(self, other):
'''
Expand Down
113 changes: 102 additions & 11 deletions neo/test/coretest/test_analogsignal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import os
import pickle
from pprint import pformat

try:
import unittest2 as unittest
Expand All @@ -18,11 +17,90 @@
import numpy as np
import quantities as pq

try:
from IPython.lib.pretty import pretty
except ImportError as err:
HAVE_IPYTHON = False
else:
HAVE_IPYTHON = True

from neo.core.analogsignal import AnalogSignal, _get_sampling_rate
from neo.core import Segment, RecordingChannel
from neo.test.tools import (assert_arrays_almost_equal, assert_arrays_equal,
assert_neo_object_is_compliant,
assert_same_sub_schema)
from neo.test.generate_datasets import (get_fake_value, get_fake_values,
fake_neo, TEST_ANNOTATIONS)


class Test__generate_datasets(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.annotations = dict([(str(x), TEST_ANNOTATIONS[x]) for x in
range(len(TEST_ANNOTATIONS))])

def test__get_fake_values(self):
self.annotations['seed'] = 0
signal = get_fake_value('signal', pq.Quantity, seed=0, dim=1)
sampling_rate = get_fake_value('sampling_rate', pq.Quantity,
seed=1, dim=0)
t_start = get_fake_value('t_start', pq.Quantity, seed=2, dim=0)
channel_index = get_fake_value('channel_index', int, seed=3)
name = get_fake_value('name', str, seed=4, obj=AnalogSignal)
description = get_fake_value('description', str, seed=5,
obj='AnalogSignal')
file_origin = get_fake_value('file_origin', str)
attrs1 = {'channel_index': channel_index,
'name': name,
'description': description,
'file_origin': file_origin}
attrs2 = attrs1.copy()
attrs2.update(self.annotations)

res11 = get_fake_values(AnalogSignal, annotate=False, seed=0)
res12 = get_fake_values('AnalogSignal', annotate=False, seed=0)
res21 = get_fake_values(AnalogSignal, annotate=True, seed=0)
res22 = get_fake_values('AnalogSignal', annotate=True, seed=0)

assert_arrays_equal(res11.pop('signal'), signal)
assert_arrays_equal(res12.pop('signal'), signal)
assert_arrays_equal(res21.pop('signal'), signal)
assert_arrays_equal(res22.pop('signal'), signal)

assert_arrays_equal(res11.pop('sampling_rate'), sampling_rate)
assert_arrays_equal(res12.pop('sampling_rate'), sampling_rate)
assert_arrays_equal(res21.pop('sampling_rate'), sampling_rate)
assert_arrays_equal(res22.pop('sampling_rate'), sampling_rate)

assert_arrays_equal(res11.pop('t_start'), t_start)
assert_arrays_equal(res12.pop('t_start'), t_start)
assert_arrays_equal(res21.pop('t_start'), t_start)
assert_arrays_equal(res22.pop('t_start'), t_start)

self.assertEqual(res11, attrs1)
self.assertEqual(res12, attrs1)
self.assertEqual(res21, attrs2)
self.assertEqual(res22, attrs2)

def test__fake_neo__cascade(self):
self.annotations['seed'] = None
obj_type = AnalogSignal
cascade = True
res = fake_neo(obj_type=obj_type, cascade=cascade)

self.assertTrue(isinstance(res, AnalogSignal))
assert_neo_object_is_compliant(res)
self.assertEqual(res.annotations, self.annotations)

def test__fake_neo__nocascade(self):
self.annotations['seed'] = None
obj_type = 'AnalogSignal'
cascade = False
res = fake_neo(obj_type=obj_type, cascade=cascade)

self.assertTrue(isinstance(res, AnalogSignal))
assert_neo_object_is_compliant(res)
self.assertEqual(res.annotations, self.annotations)


class TestAnalogSignalConstructor(unittest.TestCase):
Expand Down Expand Up @@ -190,16 +268,6 @@ def test__times_getter(self):
assert_neo_object_is_compliant(signal)
assert_arrays_almost_equal(signal.times, targ, 1e-12*pq.ms)

def test__pprint(self):
for i, signal in enumerate(self.signals):
prepr = pformat(signal)
targ = '<AnalogSignal(%s, [%s, %s], sampling rate: %s)>' % \
(pformat(self.data[i]),
self.t_start[i],
self.t_start[i] + len(self.data[i])/self.rates[i],
self.rates[i])
self.assertEqual(prepr, targ)

def test__duplicate_with_new_array(self):
signal1 = self.signals[1]
signal2 = self.signals[2]
Expand Down Expand Up @@ -256,6 +324,29 @@ def test__children(self):
signal.create_relationship()
assert_neo_object_is_compliant(signal)

def test__repr(self):
for i, signal in enumerate(self.signals):
prepr = repr(signal)
targ = '<AnalogSignal(%s, [%s, %s], sampling rate: %s)>' % \
(repr(self.data[i]),
self.t_start[i],
self.t_start[i] + len(self.data[i])/self.rates[i],
self.rates[i])
self.assertEqual(prepr, targ)

@unittest.skipUnless(HAVE_IPYTHON, "requires IPython")
def test__pretty(self):
for i, signal in enumerate(self.signals):
prepr = pretty(signal)
targ = (('AnalogSignal in %s with %s %s values\n' %
(signal.units, len(signal), signal.dtype)) +
('annotations: %s\n' % signal.annotations) +
('channel index: %s\n' % signal.channel_index) +
('sampling rate: %s\n' % signal.sampling_rate) +
('time: %s to %s' % (signal.t_start, signal.t_stop)))

self.assertEqual(prepr, targ)


class TestAnalogSignalArrayMethods(unittest.TestCase):
def setUp(self):
Expand Down
107 changes: 107 additions & 0 deletions neo/test/coretest/test_analogsignalarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,95 @@
import numpy as np
import quantities as pq

try:
from IPython.lib.pretty import pretty
except ImportError as err:
HAVE_IPYTHON = False
else:
HAVE_IPYTHON = True

from neo.core.analogsignalarray import AnalogSignalArray
from neo.core import AnalogSignal, Segment, RecordingChannelGroup
from neo.test.tools import (assert_arrays_almost_equal, assert_arrays_equal,
assert_neo_object_is_compliant,
assert_same_sub_schema)
from neo.test.generate_datasets import (get_fake_value, get_fake_values,
fake_neo, TEST_ANNOTATIONS)


class Test__generate_datasets(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.annotations = dict([(str(x), TEST_ANNOTATIONS[x]) for x in
range(len(TEST_ANNOTATIONS))])

def test__get_fake_values(self):
self.annotations['seed'] = 0
signal = get_fake_value('signal', pq.Quantity, seed=0, dim=2)
sampling_rate = get_fake_value('sampling_rate', pq.Quantity,
seed=1, dim=0)
t_start = get_fake_value('t_start', pq.Quantity, seed=2, dim=0)
channel_index = get_fake_value('channel_index', np.ndarray, seed=3,
dim=1, dtype='i')
name = get_fake_value('name', str, seed=4, obj=AnalogSignalArray)
description = get_fake_value('description', str, seed=5,
obj='AnalogSignalArray')
file_origin = get_fake_value('file_origin', str)
attrs1 = {'name': name,
'description': description,
'file_origin': file_origin}
attrs2 = attrs1.copy()
attrs2.update(self.annotations)

res11 = get_fake_values(AnalogSignalArray, annotate=False, seed=0)
res12 = get_fake_values('AnalogSignalArray', annotate=False, seed=0)
res21 = get_fake_values(AnalogSignalArray, annotate=True, seed=0)
res22 = get_fake_values('AnalogSignalArray', annotate=True, seed=0)

assert_arrays_equal(res11.pop('signal'), signal)
assert_arrays_equal(res12.pop('signal'), signal)
assert_arrays_equal(res21.pop('signal'), signal)
assert_arrays_equal(res22.pop('signal'), signal)

assert_arrays_equal(res11.pop('sampling_rate'), sampling_rate)
assert_arrays_equal(res12.pop('sampling_rate'), sampling_rate)
assert_arrays_equal(res21.pop('sampling_rate'), sampling_rate)
assert_arrays_equal(res22.pop('sampling_rate'), sampling_rate)

assert_arrays_equal(res11.pop('t_start'), t_start)
assert_arrays_equal(res12.pop('t_start'), t_start)
assert_arrays_equal(res21.pop('t_start'), t_start)
assert_arrays_equal(res22.pop('t_start'), t_start)

assert_arrays_equal(res11.pop('channel_index'), channel_index)
assert_arrays_equal(res12.pop('channel_index'), channel_index)
assert_arrays_equal(res21.pop('channel_index'), channel_index)
assert_arrays_equal(res22.pop('channel_index'), channel_index)

self.assertEqual(res11, attrs1)
self.assertEqual(res12, attrs1)
self.assertEqual(res21, attrs2)
self.assertEqual(res22, attrs2)

def test__fake_neo__cascade(self):
self.annotations['seed'] = None
obj_type = 'AnalogSignalArray'
cascade = True
res = fake_neo(obj_type=obj_type, cascade=cascade)

self.assertTrue(isinstance(res, AnalogSignalArray))
assert_neo_object_is_compliant(res)
self.assertEqual(res.annotations, self.annotations)

def test__fake_neo__nocascade(self):
self.annotations['seed'] = None
obj_type = AnalogSignalArray
cascade = False
res = fake_neo(obj_type=obj_type, cascade=cascade)

self.assertTrue(isinstance(res, AnalogSignalArray))
assert_neo_object_is_compliant(res)
self.assertEqual(res.annotations, self.annotations)


class TestAnalogSignalArrayConstructor(unittest.TestCase):
Expand Down Expand Up @@ -158,6 +242,29 @@ def test__children(self):
signal.create_relationship()
assert_neo_object_is_compliant(signal)

def test__repr(self):
for i, signal in enumerate(self.signals):
prepr = repr(signal)
targ = '<AnalogSignalArray(%s, [%s, %s], sampling rate: %s)>' % \
(repr(self.data[i]),
self.t_start[i],
self.t_start[i] + len(self.data[i])/self.rates[i],
self.rates[i])
self.assertEqual(prepr, targ)

@unittest.skipUnless(HAVE_IPYTHON, "requires IPython")
def test__pretty(self):
for signal in self.signals:
prepr = pretty(signal)
targ = (('AnalogSignalArray in %s with %sx%s %s values\n' %
(signal.units, signal.shape[0], signal.shape[1],
signal.dtype)) +
('channel index: %s\n' % signal.channel_index) +
('sampling rate: %s\n' % signal.sampling_rate) +
('time: %s to %s' % (signal.t_start, signal.t_stop)))

self.assertEqual(prepr, targ)


class TestAnalogSignalArrayArrayMethods(unittest.TestCase):
def setUp(self):
Expand Down
24 changes: 24 additions & 0 deletions neo/test/coretest/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
import numpy as np
import quantities as pq

try:
from IPython.lib.pretty import pretty
except ImportError as err:
HAVE_IPYTHON = False
else:
HAVE_IPYTHON = True

from neo.core.baseneo import (BaseNeo, _check_annotations,
merge_annotation, merge_annotations)
from neo.test.tools import assert_arrays_equal
Expand Down Expand Up @@ -89,19 +96,25 @@ def test_annotate(self):
base = BaseNeo()
base.annotate(test1=1, test2=1)
result1 = {'test1': 1, 'test2': 1}

self.assertDictEqual(result1, base.annotations)

base.annotate(test3=2, test4=3)
result2 = {'test3': 2, 'test4': 3}
result2a = dict(list(result1.items()) + list(result2.items()))

self.assertDictContainsSubset(result1, base.annotations)
self.assertDictContainsSubset(result2, base.annotations)
self.assertDictEqual(result2a, base.annotations)

base.annotate(test1=5, test2=8)
result3 = {'test1': 5, 'test2': 8}
result3a = dict(list(result3.items()) + list(result2.items()))

self.assertDictContainsSubset(result2, base.annotations)
self.assertDictContainsSubset(result3, base.annotations)
self.assertDictEqual(result3a, base.annotations)

self.assertNotEqual(base.annotations['test1'], result1['test1'])
self.assertNotEqual(base.annotations['test2'], result1['test2'])

Expand Down Expand Up @@ -911,5 +924,16 @@ class Foo(object):
self.assertRaises(ValueError, self.base.annotate, data=value)


@unittest.skipUnless(HAVE_IPYTHON, "requires IPython")
class Test_pprint(unittest.TestCase):
def test__pretty(self):
name = 'an object'
description = 'this is a test'
obj = BaseNeo(name=name, description=description)
res = pretty(obj)
targ = "BaseNeo name: '%s' description: '%s'" % (name, description)
self.assertEqual(res, targ)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 6ca3712

Please sign in to comment.