Skip to content

Commit f6c2195

Browse files
committed
enh: support for bayesian estimation outputs
1 parent 936cf91 commit f6c2195

File tree

2 files changed

+52
-34
lines changed

2 files changed

+52
-34
lines changed

nipype/interfaces/spm/model.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..base import (Bunch, traits, TraitedSpec, File, Directory,
3030
OutputMultiPath, InputMultiPath, isdefined)
3131
from .base import (SPMCommand, SPMCommandInputSpec,
32-
scans_for_fnames)
32+
scans_for_fnames, ImageFileSPM)
3333

3434
__docformat__ = 'restructuredtext'
3535
logger = logging.getLogger('interface')
@@ -192,16 +192,25 @@ class EstimateModelInputSpec(SPMCommandInputSpec):
192192

193193

194194
class EstimateModelOutputSpec(TraitedSpec):
195-
mask_image = File(exists=True,
195+
mask_image = ImageFileSPM(exists=True,
196196
desc='binary mask to constrain estimation')
197-
beta_images = OutputMultiPath(File(exists=True),
197+
beta_images = OutputMultiPath(ImageFileSPM(exists=True),
198198
desc='design parameter estimates')
199-
residual_image = File(exists=True,
199+
residual_image = ImageFileSPM(exists=True,
200200
desc='Mean-squared image of the residuals')
201-
residual_images = OutputMultiPath(File(exists=True),
201+
residual_images = OutputMultiPath(ImageFileSPM(exists=True),
202202
desc="individual residual images (requires `write_residuals`")
203-
RPVimage = File(exists=True, desc='Resels per voxel image')
203+
RPVimage = ImageFileSPM(exists=True, desc='Resels per voxel image')
204204
spm_mat_file = File(exists=True, desc='Updated SPM mat file')
205+
labels = ImageFileSPM(exists=True, desc="label file")
206+
SDerror = OutputMultiPath(ImageFileSPM(exists=True),
207+
desc="Images of the standard deviation of the error")
208+
ARcoef = OutputMultiPath(ImageFileSPM(exists=True),
209+
desc="Images of the AR coefficient")
210+
Cbetas = OutputMultiPath(ImageFileSPM(exists=True),
211+
desc="Images of the parameter posteriors")
212+
SDbetas = OutputMultiPath(ImageFileSPM(exists=True),
213+
desc="Images of the standard deviation of parameter posteriors")
205214

206215

207216
class EstimateModel(SPMCommand):
@@ -213,6 +222,7 @@ class EstimateModel(SPMCommand):
213222
--------
214223
>>> est = EstimateModel()
215224
>>> est.inputs.spm_mat_file = 'SPM.mat'
225+
>>> est.inputs.estimation_method = {'Classical': 1}
216226
>>> est.run() # doctest: +SKIP
217227
"""
218228
input_spec = EstimateModelInputSpec
@@ -243,34 +253,37 @@ def _parse_inputs(self):
243253

244254
def _list_outputs(self):
245255
outputs = self._outputs().get()
246-
pth, _ = os.path.split(self.inputs.spm_mat_file)
247-
spm12 = '12' in self.version.split('.')[0]
248-
if spm12:
249-
mask = os.path.join(pth, 'mask.nii')
250-
else:
251-
mask = os.path.join(pth, 'mask.img')
252-
outputs['mask_image'] = mask
256+
pth = os.path.dirname(self.inputs.spm_mat_file)
257+
outtype = 'nii' if '12' in self.version.split('.')[0] else 'img'
253258
spm = sio.loadmat(self.inputs.spm_mat_file, struct_as_record=False)
254-
betas = []
255-
for vbeta in spm['SPM'][0, 0].Vbeta[0]:
256-
betas.append(str(os.path.join(pth, vbeta.fname[0])))
257-
if betas:
258-
outputs['beta_images'] = betas
259-
if spm12:
260-
resms = os.path.join(pth, 'ResMS.nii')
261-
else:
262-
resms = os.path.join(pth, 'ResMS.img')
263-
outputs['residual_image'] = resms
264-
if spm12:
265-
rpv = os.path.join(pth, 'RPV.nii')
266-
else:
267-
rpv = os.path.join(pth, 'RPV.img')
268-
if self.inputs.write_residuals:
269-
outres = [x for x in glob(os.path.join(pth, 'Res_*'))]
270-
outputs['residual_images'] = outres
271-
outputs['RPVimage'] = rpv
272-
spm = os.path.join(pth, 'SPM.mat')
273-
outputs['spm_mat_file'] = spm
259+
260+
betas = [vbeta.fname[0] for vbeta in spm['SPM'][0, 0].Vbeta[0]]
261+
if ('Bayesian' in self.inputs.estimation_method.keys() or
262+
'Bayesian2' in self.inputs.estimation_method.keys()):
263+
outputs['labels'] = os.path.join(pth,
264+
'labels.{}'.format(outtype))
265+
outputs['SDerror'] = glob(os.path.join(pth, 'Sess*_SDerror*'))
266+
outputs['ARcoef'] = glob(os.path.join(pth, 'Sess*_AR_*'))
267+
if betas:
268+
outputs['Cbetas'] = [os.path.join(pth, 'C{}'.format(beta))
269+
for beta in betas]
270+
outputs['SDbetas'] = [os.path.join(pth, 'SD{}'.format(beta))
271+
for beta in betas]
272+
273+
if 'Classical' in self.inputs.estimation_method.keys():
274+
outputs['residual_image'] = os.path.join(pth,
275+
'ResMS.{}'.format(outtype))
276+
outputs['RPVimage'] = os.path.join(pth,
277+
'RPV.{}'.format(outtype))
278+
if self.inputs.write_residuals:
279+
outputs['residual_images'] = glob(os.path.join(pth, 'Res_*'))
280+
if betas:
281+
outputs['beta_images'] = [os.path.join(pth, beta)
282+
for beta in betas]
283+
284+
outputs['mask_image'] = os.path.join(pth,
285+
'mask.{}'.format(outtype))
286+
outputs['spm_mat_file'] = os.path.join(pth, 'SPM.mat')
274287
return outputs
275288

276289

nipype/interfaces/spm/tests/test_auto_EstimateModel.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,13 @@ def test_EstimateModel_inputs():
3434

3535

3636
def test_EstimateModel_outputs():
37-
output_map = dict(RPVimage=dict(),
37+
output_map = dict(ARcoef=dict(),
38+
Cbetas=dict(),
39+
RPVimage=dict(),
40+
SDbetas=dict(),
41+
SDerror=dict(),
3842
beta_images=dict(),
43+
labels=dict(),
3944
mask_image=dict(),
4045
residual_image=dict(),
4146
residual_images=dict(),

0 commit comments

Comments
 (0)