Skip to content

Commit 2d1e211

Browse files
committed
NF: allow truncated last track in trackvis file
From report on mailing list by Carolyn D. Langen - trackvis reader was giving TypeError when trying to read file where last track was shorter than declared in the 'n_points' data. Allow this situation for the last track with False argument to new `strict` kwarg of the trackvis `read` function. This is a slight API break because: * We are now raising a DataError if there are too few streamlines in the file, instead of a HeaderError; * We are raising a DataError if the track is truncated, rather than a TypeError when trying to create the points array.
1 parent 7fe2ebc commit 2d1e211

File tree

3 files changed

+79
-13
lines changed

3 files changed

+79
-13
lines changed

Changelog

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ and Stephan Gerhard (SG).
2424

2525
References like "pr/298" refer to github pull request numbers.
2626

27+
* Upcoming
28+
29+
* Trackvis reader will now allow final streamline to have fewer points that
30+
tne numbe declared in the header, with ``strict=False`` argument to
31+
``read`` function;
32+
* Minor API breakage in trackvis reader. We are now raising a DataError if
33+
there are too few streamlines in the file, instead of a HeaderError. We
34+
are raising a DataError if the track is truncated when ``strict=True``
35+
(the default), rather than a TypeError when trying to create the points
36+
array.
37+
2738
* 2.0.1 (Saturday 27 June 2015)
2839

2940
Contributions from Ben Cipollini, Chris Markiewicz, Alexandre Gramfort,

nibabel/tests/test_trackvis.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,39 @@ def f(pts): # from vx to mm
591591
out_f.seek(0)
592592
tvf2 = tv.TrackvisFile.from_file(out_f, points_space='rasmm')
593593
assert_true(streamlist_equal(fancy_rasmm_streams, tvf2.streamlines))
594+
595+
596+
def test_read_truncated():
597+
# Test behavior when last track contains fewer points than specified
598+
out_f = BytesIO()
599+
xyz0 = np.tile(np.arange(5).reshape(5, 1), (1, 3))
600+
xyz1 = np.tile(np.arange(5).reshape(5, 1) + 10, (1, 3))
601+
streams = [(xyz0, None, None), (xyz1, None, None)]
602+
tv.write(out_f, streams, {})
603+
# Truncate the last stream by one point
604+
value = out_f.getvalue()[:-(3 * 4)]
605+
new_f = BytesIO(value)
606+
# By default, raises a DataError
607+
assert_raises(tv.DataError, tv.read, new_f)
608+
# This corresponds to strict mode
609+
new_f.seek(0)
610+
assert_raises(tv.DataError, tv.read, new_f, strict=True)
611+
# lenient error mode lets this error pass, with truncated track
612+
short_streams = [(xyz0, None, None), (xyz1[:-1], None, None)]
613+
new_f.seek(0)
614+
streams2, hdr = tv.read(new_f, strict=False)
615+
assert_true(streamlist_equal(streams2, short_streams))
616+
# Check that lenient works when number of tracks is 0, where 0 signals to
617+
# the reader to read until the end of the file.
618+
again_hdr = hdr.copy()
619+
assert_equal(again_hdr['n_count'], 2)
620+
again_hdr['n_count'] = 0
621+
again_bytes = again_hdr.tostring() + value[again_hdr.itemsize:]
622+
again_f = BytesIO(again_bytes)
623+
streams2, _ = tv.read(again_f, strict=False)
624+
assert_true(streamlist_equal(streams2, short_streams))
625+
# Set count to one above actual number of tracks, always raise error
626+
again_hdr['n_count'] = 3
627+
again_bytes = again_hdr.tostring() + value[again_hdr.itemsize:]
628+
again_f = BytesIO(again_bytes)
629+
assert_raises(tv.DataError, tv.read, again_f, strict=False)

nibabel/trackvis.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ class DataError(Exception):
9595
"""
9696

9797

98-
def read(fileobj, as_generator=False, points_space=None):
99-
''' Read trackvis file, return streamlines, header
98+
def read(fileobj, as_generator=False, points_space=None, strict=True):
99+
''' Read trackvis file from `fileobj`, return `streamlines`, `header`
100100
101101
Parameters
102102
----------
@@ -116,6 +116,9 @@ def read(fileobj, as_generator=False, points_space=None):
116116
voxel size. If 'rasmm' we'll convert the points to RAS mm space (real
117117
space). For 'rasmm' we check if the affine is set and matches the voxel
118118
sizes and voxel order.
119+
strict : {True, False}, optional
120+
If True, raise error on read for badly-formed file. If False, let pass
121+
files with last track having too few points.
119122
120123
Returns
121124
-------
@@ -192,22 +195,35 @@ def read(fileobj, as_generator=False, points_space=None):
192195
raise HeaderError('Unexpected negative n_count')
193196

194197
def track_gen():
195-
n_streams = 0
196198
# For case where there are no scalars or no properties
197199
scalars = None
198200
ps = None
199-
while True:
201+
n_streams = 0
202+
# stream_count == 0 signals read to end of file
203+
n_streams_required = stream_count if stream_count != 0 else np.inf
204+
end_of_file = False
205+
while not end_of_file and n_streams < n_streams_required:
200206
n_str = fileobj.read(4)
201207
if len(n_str) < 4:
202-
if stream_count:
203-
raise HeaderError(
204-
'Expecting %s points, found only %s' % (
205-
stream_count, n_streams))
206208
break
207209
n_pts = struct.unpack(i_fmt, n_str)[0]
208-
pts_str = fileobj.read(n_pts * pt_size)
210+
# Check if we got as many bytes as we expect for these points
211+
exp_len = n_pts * pt_size
212+
pts_str = fileobj.read(exp_len)
213+
if len(pts_str) != exp_len:
214+
# Short of bytes, should we raise an error or continue?
215+
actual_n_pts = int(len(pts_str) / pt_size)
216+
if actual_n_pts != n_pts:
217+
if strict == True:
218+
raise DataError('Expecting {0} points for stream {1}, '
219+
'found {2}'.format(
220+
n_pts, n_streams, actual_n_pts))
221+
n_pts = actual_n_pts
222+
end_of_file = True
223+
# Cast bytes to points array
209224
pts = np.ndarray(shape=(n_pts, pt_cols), dtype=f4dt,
210225
buffer=pts_str)
226+
# Add properties
211227
if n_p:
212228
ps_str = fileobj.read(ps_size)
213229
ps = np.ndarray(shape=(n_p,), dtype=f4dt, buffer=ps_str)
@@ -220,11 +236,14 @@ def track_gen():
220236
scalars = pts[:, 3:]
221237
yield (xyz, scalars, ps)
222238
n_streams += 1
223-
# deliberately misses case where stream_count is 0
224-
if n_streams == stream_count:
225-
fileobj.close_if_mine()
226-
raise StopIteration
239+
# Always close file if we opened it
227240
fileobj.close_if_mine()
241+
# Raise error if we didn't get as many streams as claimed
242+
if n_streams_required != np.inf and n_streams < n_streams_required:
243+
raise DataError(
244+
'Expecting {0} streamlines, found only {1}'.format(
245+
stream_count, n_streams))
246+
228247
streamlines = track_gen()
229248
if not as_generator:
230249
streamlines = list(streamlines)

0 commit comments

Comments
 (0)