Skip to content

Commit 053a780

Browse files
committed
Fix channel reordering
1 parent c67345d commit 053a780

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

src/dhn_med_py/med_session.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -722,9 +722,8 @@ def read_by_time(self, start_time, end_time, channels=None):
722722

723723
# Make sure the data is in requested channels order
724724
if type(channels) is list:
725-
channel_indices = [channel_names.index(x) for x in channels]
726-
reordered_pos = sorted(range(len(channel_indices)), key=lambda x: channel_indices[x])
727-
data = [data[i] for i in reordered_pos]
725+
channel_indices = [channel_names.index(ch) for ch in channels]
726+
data = [data[i] for i in channel_indices]
728727

729728
self.set_channel_active(channel_names, False)
730729
self.set_channel_active(curr_active_channels, True)
@@ -801,9 +800,8 @@ def read_by_index(self, start_idx, end_idx, channels=None):
801800

802801
# Make sure the data is in requested channels order
803802
if type(channels) is list:
804-
channel_indices = [channel_names.index(x) for x in channels]
805-
reordered_pos = sorted(range(len(channel_indices)), key=lambda x: channel_indices[x])
806-
data = [data[i] for i in reordered_pos]
803+
channel_indices = [channel_names.index(ch) for ch in channels]
804+
data = [data[i] for i in channel_indices]
807805

808806
self.set_channel_active(channel_names, False)
809807
self.set_channel_active(curr_active_channels, True)

test/dhn_med_py_test.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020

2121
# Standard library imports
22-
import unittest
22+
import unittest, random
2323

2424
# Third party imports
2525
import numpy as np
@@ -261,12 +261,16 @@ def test_read_session(self):
261261
assert len(data) == len(channel_names)
262262
assert len(data[ref_index]) == int(10 * chan_fs)
263263

264-
# Read by index reverted order channels specified
265-
rev_data = ms.read_by_index(0, int(10*chan_fs), channel_names[::-1])
264+
# Read by index mixed order channels specified
265+
mixed_order = channel_names.copy()
266+
random.shuffle(mixed_order)
267+
268+
rev_data = ms.read_by_index(0, int(10*chan_fs), mixed_order)
266269

267270
assert len(rev_data) == len(channel_names)
268-
assert len(rev_data[-1]) == len(data[0])
269-
np.testing.assert_array_equal(rev_data[-1], data[0])
271+
for i, ch in enumerate(mixed_order):
272+
ch_idx = channel_names.index(ch)
273+
np.testing.assert_array_equal(rev_data[i], data[ch_idx])
270274

271275
# Read by index - no channels specified
272276
data = ms.read_by_index(0, int(10*chan_fs), None)
@@ -302,12 +306,16 @@ def test_read_session(self):
302306
assert len(data) == len(channel_names)
303307
assert len(data[ref_index]) == int(10 * ref_fs) + 5 # TODO: why +5?
304308

305-
# Read by time reverted order channels specified
306-
rev_data = ms.read_by_time(start_time, start_time + 10 * 1000000, channel_names[::-1])
309+
# Read by time mixed order channels specified
310+
mixed_order = channel_names.copy()
311+
random.shuffle(mixed_order)
312+
313+
rev_data = ms.read_by_time(start_time, start_time + 10 * 1000000, mixed_order)
307314

308315
assert len(rev_data) == len(channel_names)
309-
assert len(rev_data[-1]) == len(data[0])
310-
np.testing.assert_array_equal(rev_data[-1], data[0])
316+
for i, ch in enumerate(mixed_order):
317+
ch_idx = channel_names.index(ch)
318+
np.testing.assert_array_equal(rev_data[i], data[ch_idx])
311319

312320
# Read by time - no channels specified
313321
data = ms.read_by_time(start_time, start_time + 10 * 1000000)

0 commit comments

Comments
 (0)