Skip to content

Commit 356b028

Browse files
committed
Merge pull request #1234 from oesteban/enh/DivisionWarningsTSNR
[ENH] Remove warnings in tSNR calculation
2 parents 267e42a + 7fe9d3c commit 356b028

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

nipype/algorithms/misc.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,14 @@ class TSNRInputSpec(BaseInterfaceInputSpec):
263263
in_file = InputMultiPath(File(exists=True), mandatory=True,
264264
desc='realigned 4D file or a list of 3D files')
265265
regress_poly = traits.Range(low=1, desc='Remove polynomials')
266+
tsnr_file = File('tsnr.nii.gz', usedefault=True, hash_files=False,
267+
desc='output tSNR file')
268+
mean_file = File('mean.nii.gz', usedefault=True, hash_files=False,
269+
desc='output mean file')
270+
stddev_file = File('stdev.nii.gz', usedefault=True, hash_files=False,
271+
desc='output tSNR file')
272+
detrended_file = File('detrend.nii.gz', usedefault=True, hash_files=False,
273+
desc='input file after detrending')
266274

267275

268276
class TSNROutputSpec(TraitedSpec):
@@ -288,24 +296,18 @@ class TSNR(BaseInterface):
288296
input_spec = TSNRInputSpec
289297
output_spec = TSNROutputSpec
290298

291-
def _gen_output_file_name(self, suffix=None):
292-
_, base, ext = split_filename(self.inputs.in_file[0])
293-
if suffix in ['mean', 'stddev']:
294-
return os.path.abspath(base + "_tsnr_" + suffix + ext)
295-
elif suffix in ['detrended']:
296-
return os.path.abspath(base + "_" + suffix + ext)
297-
else:
298-
return os.path.abspath(base + "_tsnr" + ext)
299-
300299
def _run_interface(self, runtime):
301300
img = nb.load(self.inputs.in_file[0])
302301
header = img.header.copy()
303302
vollist = [nb.load(filename) for filename in self.inputs.in_file]
304303
data = np.concatenate([vol.get_data().reshape(
305-
vol.shape[:3] + (-1,)) for vol in vollist], axis=3)
304+
vol.get_shape()[:3] + (-1,)) for vol in vollist], axis=3)
305+
data = data.nan_to_num()
306+
306307
if data.dtype.kind == 'i':
307308
header.set_data_dtype(np.float32)
308309
data = data.astype(np.float32)
310+
309311
if isdefined(self.inputs.regress_poly):
310312
timepoints = img.shape[-1]
311313
X = np.ones((timepoints, 1))
@@ -318,26 +320,28 @@ def _run_interface(self, runtime):
318320
betas[1:, :, :, :], 0, 3)),
319321
0, 4)
320322
data = data - datahat
321-
img = nb.Nifti1Image(data, img.affine, header)
322-
nb.save(img, self._gen_output_file_name('detrended'))
323+
img = nb.Nifti1Image(data, img.get_affine(), header)
324+
nb.save(img, op.abspath(self.inputs.detrended_file))
325+
323326
meanimg = np.mean(data, axis=3)
324327
stddevimg = np.std(data, axis=3)
325-
tsnr = meanimg / stddevimg
326-
img = nb.Nifti1Image(tsnr, img.affine, header)
327-
nb.save(img, self._gen_output_file_name())
328-
img = nb.Nifti1Image(meanimg, img.affine, header)
329-
nb.save(img, self._gen_output_file_name('mean'))
330-
img = nb.Nifti1Image(stddevimg, img.affine, header)
331-
nb.save(img, self._gen_output_file_name('stddev'))
328+
tsnr = np.zeros_like(meanimg)
329+
tsnr[stddevimg > 1.e-3] = meanimg[stddevimg > 1.e-3] / stddevimg[stddevimg > 1.e-3]
330+
img = nb.Nifti1Image(tsnr, img.get_affine(), header)
331+
nb.save(img, op.abspath(self.inputs.tsnr_file))
332+
img = nb.Nifti1Image(meanimg, img.get_affine(), header)
333+
nb.save(img, op.abspath(self.inputs.mean_file))
334+
img = nb.Nifti1Image(stddevimg, img.get_affine(), header)
335+
nb.save(img, op.abspath(self.inputs.stddev_file))
332336
return runtime
333337

334338
def _list_outputs(self):
335339
outputs = self._outputs().get()
336-
outputs['tsnr_file'] = self._gen_output_file_name()
337-
outputs['mean_file'] = self._gen_output_file_name('mean')
338-
outputs['stddev_file'] = self._gen_output_file_name('stddev')
340+
for k in ['tsnr_file', 'mean_file', 'stddev_file']:
341+
outputs[k] = op.abspath(getattr(self.inputs, k))
342+
339343
if isdefined(self.inputs.regress_poly):
340-
outputs['detrended_file'] = self._gen_output_file_name('detrended')
344+
outputs['detrended_file'] = op.abspath(self.inputs.detrended_file)
341345
return outputs
342346

343347

0 commit comments

Comments
 (0)