@@ -263,6 +263,14 @@ class TSNRInputSpec(BaseInterfaceInputSpec):
263
263
in_file = InputMultiPath (File (exists = True ), mandatory = True ,
264
264
desc = 'realigned 4D file or a list of 3D files' )
265
265
regress_poly = traits .Range (low = 1 , desc = 'Remove polynomials' )
266
+ tsnr_file = File ('tsnr.nii.gz' , usedefault = True , hash_files = False ,
267
+ desc = 'output tSNR file' )
268
+ mean_file = File ('mean.nii.gz' , usedefault = True , hash_files = False ,
269
+ desc = 'output mean file' )
270
+ stddev_file = File ('stdev.nii.gz' , usedefault = True , hash_files = False ,
271
+ desc = 'output tSNR file' )
272
+ detrended_file = File ('detrend.nii.gz' , usedefault = True , hash_files = False ,
273
+ desc = 'input file after detrending' )
266
274
267
275
268
276
class TSNROutputSpec (TraitedSpec ):
@@ -288,24 +296,18 @@ class TSNR(BaseInterface):
288
296
input_spec = TSNRInputSpec
289
297
output_spec = TSNROutputSpec
290
298
291
- def _gen_output_file_name (self , suffix = None ):
292
- _ , base , ext = split_filename (self .inputs .in_file [0 ])
293
- if suffix in ['mean' , 'stddev' ]:
294
- return os .path .abspath (base + "_tsnr_" + suffix + ext )
295
- elif suffix in ['detrended' ]:
296
- return os .path .abspath (base + "_" + suffix + ext )
297
- else :
298
- return os .path .abspath (base + "_tsnr" + ext )
299
-
300
299
def _run_interface (self , runtime ):
301
300
img = nb .load (self .inputs .in_file [0 ])
302
301
header = img .header .copy ()
303
302
vollist = [nb .load (filename ) for filename in self .inputs .in_file ]
304
303
data = np .concatenate ([vol .get_data ().reshape (
305
- vol .shape [:3 ] + (- 1 ,)) for vol in vollist ], axis = 3 )
304
+ vol .get_shape ()[:3 ] + (- 1 ,)) for vol in vollist ], axis = 3 )
305
+ data = data .nan_to_num ()
306
+
306
307
if data .dtype .kind == 'i' :
307
308
header .set_data_dtype (np .float32 )
308
309
data = data .astype (np .float32 )
310
+
309
311
if isdefined (self .inputs .regress_poly ):
310
312
timepoints = img .shape [- 1 ]
311
313
X = np .ones ((timepoints , 1 ))
@@ -318,26 +320,28 @@ def _run_interface(self, runtime):
318
320
betas [1 :, :, :, :], 0 , 3 )),
319
321
0 , 4 )
320
322
data = data - datahat
321
- img = nb .Nifti1Image (data , img .affine , header )
322
- nb .save (img , self ._gen_output_file_name ('detrended' ))
323
+ img = nb .Nifti1Image (data , img .get_affine (), header )
324
+ nb .save (img , op .abspath (self .inputs .detrended_file ))
325
+
323
326
meanimg = np .mean (data , axis = 3 )
324
327
stddevimg = np .std (data , axis = 3 )
325
- tsnr = meanimg / stddevimg
326
- img = nb .Nifti1Image (tsnr , img .affine , header )
327
- nb .save (img , self ._gen_output_file_name ())
328
- img = nb .Nifti1Image (meanimg , img .affine , header )
329
- nb .save (img , self ._gen_output_file_name ('mean' ))
330
- img = nb .Nifti1Image (stddevimg , img .affine , header )
331
- nb .save (img , self ._gen_output_file_name ('stddev' ))
328
+ tsnr = np .zeros_like (meanimg )
329
+ tsnr [stddevimg > 1.e-3 ] = meanimg [stddevimg > 1.e-3 ] / stddevimg [stddevimg > 1.e-3 ]
330
+ img = nb .Nifti1Image (tsnr , img .get_affine (), header )
331
+ nb .save (img , op .abspath (self .inputs .tsnr_file ))
332
+ img = nb .Nifti1Image (meanimg , img .get_affine (), header )
333
+ nb .save (img , op .abspath (self .inputs .mean_file ))
334
+ img = nb .Nifti1Image (stddevimg , img .get_affine (), header )
335
+ nb .save (img , op .abspath (self .inputs .stddev_file ))
332
336
return runtime
333
337
334
338
def _list_outputs (self ):
335
339
outputs = self ._outputs ().get ()
336
- outputs ['tsnr_file' ] = self . _gen_output_file_name ()
337
- outputs ['mean_file' ] = self . _gen_output_file_name ( 'mean' )
338
- outputs [ 'stddev_file' ] = self . _gen_output_file_name ( 'stddev' )
340
+ for k in ['tsnr_file' , 'mean_file' , 'stddev_file' ]:
341
+ outputs [k ] = op . abspath ( getattr ( self . inputs , k ) )
342
+
339
343
if isdefined (self .inputs .regress_poly ):
340
- outputs ['detrended_file' ] = self . _gen_output_file_name ( 'detrended' )
344
+ outputs ['detrended_file' ] = op . abspath ( self . inputs . detrended_file )
341
345
return outputs
342
346
343
347
0 commit comments