Skip to content

Commit d0fe3c3

Browse files
committed
infer signal length h for headers without the field. Fixes MIT-LCP#129
1 parent ea29e8f commit d0fe3c3

File tree

5 files changed

+151
-65
lines changed

5 files changed

+151
-65
lines changed

tests/test_record.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import wfdb
77

88

9-
class test_record():
9+
class TestRecord():
1010
"""
1111
Testing read and write of wfdb records, including Physionet
1212
streaming.
@@ -511,7 +511,25 @@ def tearDownClass(cls):
511511
os.remove(file)
512512

513513

514-
class test_download():
514+
class TestSignal():
515+
"""
516+
For lower level signal tests
517+
518+
"""
519+
def test_infer_sig_len(self):
520+
"""
521+
Infer the signal length of a record without the sig_len header
522+
Read two headers. The records should be the same.
523+
"""
524+
525+
record = wfdb.rdrecord('sample-data/100')
526+
record_2 = wfdb.rdrecord('sample-data/100-no-len')
527+
record_2.record_name = '100'
528+
529+
assert record_2.__eq__(record)
530+
531+
532+
class TestDownload():
515533
# Test that we can download records with no "dat" file
516534
# Regression test for https://github.com/MIT-LCP/wfdb-python/issues/118
517535
def test_dl_database_no_dat_file(self):

wfdb/io/_signal.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,14 @@
2121
DAT_FMTS = ALIGNED_FMTS + UNALIGNED_FMTS
2222

2323
# Bytes required to hold each sample (including wasted space) for each
24-
# wfdb formats
24+
# wfdb dat formats
2525
BYTES_PER_SAMPLE = {'8': 1, '16': 2, '24': 3, '32': 4, '61': 2, '80': 1,
2626
'160': 2, '212': 1.5, '310': 4 / 3., '311': 4 / 3.}
2727

28+
# The bit resolution of each wfdb dat format
29+
BIT_RES = {'8': 8, '16': 16, '24': 24, '32': 32, '61': 16, '80': 8,
30+
'160': 16, '212': 12, '310': 10, '311': 10}
31+
2832
# Numpy dtypes used to load dat files of each format.
2933
DATA_LOAD_TYPES = {'8': '<i1', '16': '<i2', '24': '<i3', '32': '<i4',
3034
'61': '>i2', '80': '<u1', '160': '<u2', '212': '<u1',
@@ -860,7 +864,7 @@ def _rd_segment(file_name, dir_name, pb_dir, fmt, n_sig, sig_len, byte_offset,
860864
# Return uniform numpy array
861865
if smooth_frames or sum(samps_per_frame) == n_sig:
862866
# Figure out the largest required dtype for the segment to minimize memory usage
863-
max_dtype = np_dtype(_fmt_res(fmt, max_res=True), discrete=True)
867+
max_dtype = _np_dtype(_fmt_res(fmt, max_res=True), discrete=True)
864868
# Allocate signal array. Minimize dtype
865869
signals = np.zeros([sampto-sampfrom, len(channels)], dtype=max_dtype)
866870

@@ -1560,7 +1564,8 @@ def _wfdb_fmt(bit_res, single_fmt=True):
15601564

15611565
def _fmt_res(fmt, max_res=False):
15621566
"""
1563-
Return the resolution of the WFDB dat format(s).
1567+
Return the resolution of the WFDB dat format(s). Uses the BIT_RES
1568+
dictionary, but accepts lists and other options.
15641569
15651570
Parameters
15661571
----------
@@ -1585,10 +1590,10 @@ def _fmt_res(fmt, max_res=False):
15851590
bit_res = [_fmt_res(f) for f in fmt]
15861591
return bit_res
15871592

1588-
return BYTES_PER_SAMPLE[fmt] * 8
1593+
return BIT_RES[fmt]
15891594

15901595

1591-
def np_dtype(bit_res, discrete):
1596+
def _np_dtype(bit_res, discrete):
15921597
"""
15931598
Given the bit resolution of a signal, return the minimum numpy dtype
15941599
used to store it.
@@ -1787,6 +1792,35 @@ def describe_list_indices(full_list):
17871792
return unique_elements, element_indices
17881793

17891794

1795+
def _infer_sig_len(file_name, fmt, n_sig, dir_name, pb_dir=None):
1796+
"""
1797+
Infer the length of a signal from a dat file.
1798+
1799+
Parameters
1800+
----------
1801+
file_name : str
1802+
Name of the dat file
1803+
fmt : str
1804+
WFDB fmt of the dat file
1805+
n_sig : int
1806+
Number of signals contained in the dat file
1807+
1808+
Notes
1809+
-----
1810+
sig_len * n_sig * bytes_per_sample == file_size
1811+
1812+
"""
1813+
if pb_dir is None:
1814+
file_size = os.path.getsize(os.path.join(dir_name, file_name))
1815+
else:
1816+
url = posixpath.join(db_index_url, pb_dir, file_name)
1817+
file_size = download._remote_file_size(file_name=file_name,
1818+
pb_dir=pb_dir)
1819+
1820+
sig_len = int(file_size / (BYTES_PER_SAMPLE[fmt] * n_sig))
1821+
1822+
return sig_len
1823+
17901824
def downround(x, base):
17911825
"""
17921826
Round <x> down to nearest <base>

wfdb/io/download.py

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,44 @@
66
import requests
77

88

9-
db_index_url = 'http://physionet.org/physiobank/database/'
9+
DB_INDEX_URL = 'http://physionet.org/physiobank/database/'
1010

1111

12+
def _remote_file_size(url=None, file_name=None, pb_dir=None):
13+
"""
14+
Get the remote file size in bytes
15+
16+
Parameters
17+
----------
18+
url : str, optional
19+
The full url of the file. Use this option to explicitly
20+
state the full url.
21+
file_name : str, optional
22+
The base file name. Use this argument along with pb_dir if you
23+
want the full url to be constructed.
24+
pb_dir : str, optional
25+
The base file name. Use this argument along with file_name if
26+
you want the full url to be constructed.
27+
28+
Returns
29+
-------
30+
remote_file_size : int
31+
Size of the file in bytes
32+
33+
"""
34+
35+
# Option to construct the url
36+
if file_name and pb_dir:
37+
url = posixpath.join(DB_INDEX_URL, pb_dir, file_name)
38+
39+
response = requests.head(url, headers={'Accept-Encoding': 'identity'})
40+
# Raise HTTPError if invalid url
41+
response.raise_for_status()
42+
43+
# Supposed size of the file
44+
remote_file_size = int(response.headers['content-length'])
45+
46+
return remote_file_size
1247

1348
def _stream_header(file_name, pb_dir):
1449
"""
@@ -25,14 +60,14 @@ def _stream_header(file_name, pb_dir):
2560
2661
"""
2762
# Full url of header location
28-
url = posixpath.join(db_index_url, pb_dir, file_name)
29-
r = requests.get(url)
63+
url = posixpath.join(DB_INDEX_URL, pb_dir, file_name)
64+
response = requests.get(url)
3065

3166
# Raise HTTPError if invalid url
32-
r.raise_for_status()
67+
response.raise_for_status()
3368

3469
# Get each line as a string
35-
filelines = r.content.decode('iso-8859-1').splitlines()
70+
filelines = response.content.decode('iso-8859-1').splitlines()
3671

3772
# Separate content into header and comment lines
3873
header_lines = []
@@ -82,7 +117,7 @@ def _stream_dat(file_name, pb_dir, byte_count, start_byte, dtype):
82117
"""
83118

84119
# Full url of dat file
85-
url = posixpath.join(db_index_url, pb_dir, file_name)
120+
url = posixpath.join(DB_INDEX_URL, pb_dir, file_name)
86121

87122
# Specify the byte range
88123
end_byte = start_byte + byte_count - 1
@@ -114,7 +149,7 @@ def _stream_annotation(file_name, pb_dir):
114149
115150
"""
116151
# Full url of annotation file
117-
url = posixpath.join(db_index_url, pb_dir, file_name)
152+
url = posixpath.join(DB_INDEX_URL, pb_dir, file_name)
118153

119154
# Get the content
120155
response = requests.get(url)
@@ -136,10 +171,10 @@ def get_dbs():
136171
>>> dbs = get_dbs()
137172
138173
"""
139-
url = posixpath.join(db_index_url, 'DBS')
140-
r = requests.get(url)
174+
url = posixpath.join(DB_INDEX_URL, 'DBS')
175+
response = requests.get(url)
141176

142-
dbs = r.content.decode('ascii').splitlines()
177+
dbs = response.content.decode('ascii').splitlines()
143178
dbs = [re.sub('\t{2,}', '\t', line).split('\t') for line in dbs]
144179

145180
return dbs
@@ -166,7 +201,7 @@ def get_record_list(db_dir, records='all'):
166201
167202
"""
168203
# Full url physiobank database
169-
db_url = posixpath.join(db_index_url, db_dir)
204+
db_url = posixpath.join(DB_INDEX_URL, db_dir)
170205

171206
# Check for a RECORDS file
172207
if records == 'all':
@@ -175,18 +210,18 @@ def get_record_list(db_dir, records='all'):
175210
raise ValueError('The database %s has no WFDB files to download' % db_url)
176211

177212
# Get each line as a string
178-
recordlist = response.content.decode('ascii').splitlines()
213+
record_list = response.content.decode('ascii').splitlines()
179214
# Otherwise the records are input manually
180215
else:
181-
recordlist = records
216+
record_list = records
182217

183-
return recordlist
218+
return record_list
184219

185220

186221
def get_annotators(db_dir, annotators):
187222

188223
# Full url physiobank database
189-
db_url = posixpath.join(db_index_url, db_dir)
224+
db_url = posixpath.join(DB_INDEX_URL, db_dir)
190225

191226
if annotators is not None:
192227
# Check for an ANNOTATORS file
@@ -197,61 +232,61 @@ def get_annotators(db_dir, annotators):
197232
else:
198233
raise ValueError('The database %s has no annotation files to download' % db_url)
199234
# Make sure the input annotators are present in the database
200-
annlist = r.content.decode('ascii').splitlines()
201-
annlist = [a.split('\t')[0] for a in annlist]
235+
ann_list = r.content.decode('ascii').splitlines()
236+
ann_list = [a.split('\t')[0] for a in ann_list]
202237

203238
# Get the annotation file types required
204239
if annotators == 'all':
205240
# all possible ones
206-
annotators = annlist
241+
annotators = ann_list
207242
else:
208243
# In case they didn't input a list
209244
if type(annotators) == str:
210245
annotators = [annotators]
211246
# user input ones. Check validity.
212247
for a in annotators:
213-
if a not in annlist:
248+
if a not in ann_list:
214249
raise ValueError('The database contains no annotators with extension: %s' % a)
215250

216251
return annotators
217252

218253

219-
def make_local_dirs(dl_dir, dlinputs, keep_subdirs):
254+
def make_local_dirs(dl_dir, dl_inputs, keep_subdirs):
220255
"""
221256
Make any required local directories to prepare for downloading
222257
"""
223258

224259
# Make the local download dir if it doesn't exist
225260
if not os.path.isdir(dl_dir):
226261
os.makedirs(dl_dir)
227-
print("Created local base download directory: ", dl_dir)
262+
print('Created local base download directory: %s' % dl_dir)
228263
# Create all required local subdirectories
229264
# This must be out of dl_pb_file to
230265
# avoid clash in multiprocessing
231266
if keep_subdirs:
232-
dldirs = set([os.path.join(dl_dir, d[1]) for d in dlinputs])
233-
for d in dldirs:
267+
dl_dirs = set([os.path.join(dl_dir, d[1]) for d in dl_inputs])
268+
for d in dl_dirs:
234269
if not os.path.isdir(d):
235270
os.makedirs(d)
236271
return
237272

238273

239274
def dl_pb_file(inputs):
240-
# Download a file from physiobank
241-
# The input args are to be unpacked for the use of multiprocessing
275+
"""
276+
Download a file from physiobank.
277+
278+
The input args are to be unpacked for the use of multiprocessing
279+
map, because python2 doesn't have starmap...
280+
281+
"""
242282

243283
basefile, subdir, db, dl_dir, keep_subdirs, overwrite = inputs
244284

245285
# Full url of file
246-
url = posixpath.join(db_index_url, db, subdir, basefile)
247-
248-
# Send a head request
249-
response = requests.head(url, headers={'Accept-Encoding': 'identity'})
250-
# Raise HTTPError if invalid url
251-
response.raise_for_status()
286+
url = posixpath.join(DB_INDEX_URL, db, subdir, basefile)
252287

253288
# Supposed size of the file
254-
remote_file_size = int(response.headers['content-length'])
289+
remote_file_size = _remote_file_size(url)
255290

256291
# Figure out where the file should be locally
257292
if keep_subdirs:
@@ -276,7 +311,7 @@ def dl_pb_file(inputs):
276311
r = requests.get(url, headers=headers, stream=True)
277312
print('headers: ', headers)
278313
print('r content length: ', len(r.content))
279-
with open(local_file, "ba") as writefile:
314+
with open(local_file, 'ba') as writefile:
280315
writefile.write(r.content)
281316
print('Done appending.')
282317
# Local file is larger than it should be. Redownload.
@@ -304,7 +339,7 @@ def dl_full_file(url, save_file_name):
304339
305340
"""
306341
response = requests.get(url)
307-
with open(save_file_name, "wb") as writefile:
342+
with open(save_file_name, 'wb') as writefile:
308343
writefile.write(response.content)
309344

310345
return
@@ -346,22 +381,22 @@ def dl_files(db, dl_dir, files, keep_subdirs=True, overwrite=False):
346381
"""
347382

348383
# Full url physiobank database
349-
db_url = posixpath.join(db_index_url, db)
384+
db_url = posixpath.join(DB_INDEX_URL, db)
350385
# Check if the database is valid
351-
r = requests.get(db_url)
352-
r.raise_for_status()
386+
response = requests.get(db_url)
387+
response.raise_for_status()
353388

354389
# Construct the urls to download
355-
dlinputs = [(os.path.split(file)[1], os.path.split(file)[0], db, dl_dir, keep_subdirs, overwrite) for file in files]
390+
dl_inputs = [(os.path.split(file)[1], os.path.split(file)[0], db, dl_dir, keep_subdirs, overwrite) for file in files]
356391

357392
# Make any required local directories
358-
make_local_dirs(dl_dir, dlinputs, keep_subdirs)
393+
make_local_dirs(dl_dir, dl_inputs, keep_subdirs)
359394

360395
print('Downloading files...')
361396
# Create multiple processes to download files.
362397
# Limit to 2 connections to avoid overloading the server
363398
pool = multiprocessing.Pool(processes=2)
364-
pool.map(dl_pb_file, dlinputs)
399+
pool.map(dl_pb_file, dl_inputs)
365400
print('Finished downloading files')
366401

367402
return

0 commit comments

Comments
 (0)