Skip to content

Commit a06fa8a

Browse files
authored
Merge pull request #2030 from mgxd/enh/estmodel
fix+enh: support for Bayesian spm.ModelEstimate, add write_residuals for Classical closes #2029 , closes #1786 .
2 parents b76f677 + f6c2195 commit a06fa8a

File tree

2 files changed

+72
-45
lines changed

2 files changed

+72
-45
lines changed

nipype/interfaces/spm/model.py

Lines changed: 63 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..base import (Bunch, traits, TraitedSpec, File, Directory,
3030
OutputMultiPath, InputMultiPath, isdefined)
3131
from .base import (SPMCommand, SPMCommandInputSpec,
32-
scans_for_fnames)
32+
scans_for_fnames, ImageFileSPM)
3333

3434
__docformat__ = 'restructuredtext'
3535
logger = logging.getLogger('interface')
@@ -179,27 +179,38 @@ def _list_outputs(self):
179179

180180
class EstimateModelInputSpec(SPMCommandInputSpec):
181181
spm_mat_file = File(exists=True, field='spmmat',
182-
desc='absolute path to SPM.mat',
183-
copyfile=True,
184-
mandatory=True)
185-
estimation_method = traits.Dict(traits.Enum('Classical', 'Bayesian2',
186-
'Bayesian'),
187-
field='method',
188-
desc=('Classical, Bayesian2, '
189-
'Bayesian (dict)'),
190-
mandatory=True)
191-
flags = traits.Str(desc='optional arguments (opt)')
182+
copyfile=True, mandatory=True,
183+
desc='Absolute path to SPM.mat')
184+
estimation_method = traits.Dict(
185+
traits.Enum('Classical', 'Bayesian2', 'Bayesian'),
186+
field='method', mandatory=True,
187+
desc=('Dictionary of either Classical: 1, Bayesian: 1, '
188+
'or Bayesian2: 1 (dict)'))
189+
write_residuals = traits.Bool(field='write_residuals',
190+
desc="Write individual residual images")
191+
flags = traits.Dict(desc='Additional arguments')
192192

193193

194194
class EstimateModelOutputSpec(TraitedSpec):
195-
mask_image = File(exists=True,
196-
desc='binary mask to constrain estimation')
197-
beta_images = OutputMultiPath(File(exists=True),
198-
desc='design parameter estimates')
199-
residual_image = File(exists=True,
200-
desc='Mean-squared image of the residuals')
201-
RPVimage = File(exists=True, desc='Resels per voxel image')
195+
mask_image = ImageFileSPM(exists=True,
196+
desc='binary mask to constrain estimation')
197+
beta_images = OutputMultiPath(ImageFileSPM(exists=True),
198+
desc='design parameter estimates')
199+
residual_image = ImageFileSPM(exists=True,
200+
desc='Mean-squared image of the residuals')
201+
residual_images = OutputMultiPath(ImageFileSPM(exists=True),
202+
desc="individual residual images (requires `write_residuals`")
203+
RPVimage = ImageFileSPM(exists=True, desc='Resels per voxel image')
202204
spm_mat_file = File(exists=True, desc='Updated SPM mat file')
205+
labels = ImageFileSPM(exists=True, desc="label file")
206+
SDerror = OutputMultiPath(ImageFileSPM(exists=True),
207+
desc="Images of the standard deviation of the error")
208+
ARcoef = OutputMultiPath(ImageFileSPM(exists=True),
209+
desc="Images of the AR coefficient")
210+
Cbetas = OutputMultiPath(ImageFileSPM(exists=True),
211+
desc="Images of the parameter posteriors")
212+
SDbetas = OutputMultiPath(ImageFileSPM(exists=True),
213+
desc="Images of the standard deviation of parameter posteriors")
203214

204215

205216
class EstimateModel(SPMCommand):
@@ -211,6 +222,7 @@ class EstimateModel(SPMCommand):
211222
--------
212223
>>> est = EstimateModel()
213224
>>> est.inputs.spm_mat_file = 'SPM.mat'
225+
>>> est.inputs.estimation_method = {'Classical': 1}
214226
>>> est.run() # doctest: +SKIP
215227
"""
216228
input_spec = EstimateModelInputSpec
@@ -225,7 +237,7 @@ def _format_arg(self, opt, spec, val):
225237
return np.array([str(val)], dtype=object)
226238
if opt == 'estimation_method':
227239
if isinstance(val, (str, bytes)):
228-
return {'%s' % val: 1}
240+
return {'{}'.format(val): 1}
229241
else:
230242
return val
231243
return super(EstimateModel, self)._format_arg(opt, spec, val)
@@ -235,36 +247,43 @@ def _parse_inputs(self):
235247
"""
236248
einputs = super(EstimateModel, self)._parse_inputs(skip=('flags'))
237249
if isdefined(self.inputs.flags):
238-
einputs[0].update(self.inputs.flags)
250+
einputs[0].update({flag: val for (flag, val) in
251+
self.inputs.flags.items()})
239252
return einputs
240253

241254
def _list_outputs(self):
242255
outputs = self._outputs().get()
243-
pth, _ = os.path.split(self.inputs.spm_mat_file)
244-
spm12 = '12' in self.version.split('.')[0]
245-
if spm12:
246-
mask = os.path.join(pth, 'mask.nii')
247-
else:
248-
mask = os.path.join(pth, 'mask.img')
249-
outputs['mask_image'] = mask
256+
pth = os.path.dirname(self.inputs.spm_mat_file)
257+
outtype = 'nii' if '12' in self.version.split('.')[0] else 'img'
250258
spm = sio.loadmat(self.inputs.spm_mat_file, struct_as_record=False)
251-
betas = []
252-
for vbeta in spm['SPM'][0, 0].Vbeta[0]:
253-
betas.append(str(os.path.join(pth, vbeta.fname[0])))
254-
if betas:
255-
outputs['beta_images'] = betas
256-
if spm12:
257-
resms = os.path.join(pth, 'ResMS.nii')
258-
else:
259-
resms = os.path.join(pth, 'ResMS.img')
260-
outputs['residual_image'] = resms
261-
if spm12:
262-
rpv = os.path.join(pth, 'RPV.nii')
263-
else:
264-
rpv = os.path.join(pth, 'RPV.img')
265-
outputs['RPVimage'] = rpv
266-
spm = os.path.join(pth, 'SPM.mat')
267-
outputs['spm_mat_file'] = spm
259+
260+
betas = [vbeta.fname[0] for vbeta in spm['SPM'][0, 0].Vbeta[0]]
261+
if ('Bayesian' in self.inputs.estimation_method.keys() or
262+
'Bayesian2' in self.inputs.estimation_method.keys()):
263+
outputs['labels'] = os.path.join(pth,
264+
'labels.{}'.format(outtype))
265+
outputs['SDerror'] = glob(os.path.join(pth, 'Sess*_SDerror*'))
266+
outputs['ARcoef'] = glob(os.path.join(pth, 'Sess*_AR_*'))
267+
if betas:
268+
outputs['Cbetas'] = [os.path.join(pth, 'C{}'.format(beta))
269+
for beta in betas]
270+
outputs['SDbetas'] = [os.path.join(pth, 'SD{}'.format(beta))
271+
for beta in betas]
272+
273+
if 'Classical' in self.inputs.estimation_method.keys():
274+
outputs['residual_image'] = os.path.join(pth,
275+
'ResMS.{}'.format(outtype))
276+
outputs['RPVimage'] = os.path.join(pth,
277+
'RPV.{}'.format(outtype))
278+
if self.inputs.write_residuals:
279+
outputs['residual_images'] = glob(os.path.join(pth, 'Res_*'))
280+
if betas:
281+
outputs['beta_images'] = [os.path.join(pth, beta)
282+
for beta in betas]
283+
284+
outputs['mask_image'] = os.path.join(pth,
285+
'mask.{}'.format(outtype))
286+
outputs['spm_mat_file'] = os.path.join(pth, 'SPM.mat')
268287
return outputs
269288

270289

nipype/interfaces/spm/tests/test_auto_EstimateModel.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def test_EstimateModel_inputs():
2323
use_v8struct=dict(min_ver='8',
2424
usedefault=True,
2525
),
26+
write_residuals=dict(field='write_residuals',
27+
),
2628
)
2729
inputs = EstimateModel.input_spec()
2830

@@ -32,10 +34,16 @@ def test_EstimateModel_inputs():
3234

3335

3436
def test_EstimateModel_outputs():
35-
output_map = dict(RPVimage=dict(),
37+
output_map = dict(ARcoef=dict(),
38+
Cbetas=dict(),
39+
RPVimage=dict(),
40+
SDbetas=dict(),
41+
SDerror=dict(),
3642
beta_images=dict(),
43+
labels=dict(),
3744
mask_image=dict(),
3845
residual_image=dict(),
46+
residual_images=dict(),
3947
spm_mat_file=dict(),
4048
)
4149
outputs = EstimateModel.output_spec()

0 commit comments

Comments
 (0)