Skip to content

Commit f1725c1

Browse files
committed
add save fn, add test #1578
1 parent 43bd933 commit f1725c1

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

nipype/interfaces/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,13 +1149,23 @@ def version(self):
11491149
return self._version
11501150

11511151
def load_inputs_from_json(self, json_file):
1152+
"""
1153+
A convenient way to load pre-set inputs from a JSON file.
1154+
"""
1155+
11521156
with open(json_file) as fhandle:
11531157
inputs_dict = json.load(fhandle)
11541158

11551159
for key, val in list(inputs_dict.items()):
11561160
if not isdefined(getattr(self.inputs, key, Undefined)):
11571161
setattr(self.inputs, key, val)
11581162

1163+
def save_inputs_to_json(self, json_file):
1164+
"""
1165+
A convenient way to save current inputs to a JSON file.
1166+
"""
1167+
with open(json_file, 'w') as fhandle:
1168+
json.dump(self.inputs.get(), fhandle, indent=4)
11591169

11601170

11611171
class Stream(object):

nipype/interfaces/tests/test_base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,31 @@ def _run_interface(self, runtime):
455455
nib.BaseInterface.input_spec = None
456456
yield assert_raises, Exception, nib.BaseInterface
457457

458+
def test_BaseInterface_load_save_inputs():
459+
tmp_dir = tempfile.mkdtemp()
460+
tmp_json = os.path.join(tmp_dir, 'settings.json')
461+
462+
463+
class InputSpec(nib.TraitedSpec):
464+
input1 = nib.traits.Int()
465+
input2 = nib.traits.Float()
466+
input3 = nib.traits.Bool()
467+
input4 = nib.traits.Str()
468+
469+
class DerivedInterface(nib.BaseInterface):
470+
input_spec = InputSpec
471+
472+
def __init__(self, **inputs):
473+
super(DerivedInterface, self).__init__(**inputs)
474+
475+
inputs_dict = {'input1': 12, 'input2': 3.4, 'input3': True,
476+
'input4': 'some string'}
477+
bif = DerivedInterface(**inputs_dict)
478+
bif.save_inputs_to_json(tmp_json)
479+
bif2 = DerivedInterface()
480+
bif2.load_inputs_from_json(tmp_json)
481+
yield assert_equal, inputs_dict, bif2.inputs.get()
482+
458483

459484
def assert_not_raises(fn, *args, **kwargs):
460485
fn(*args, **kwargs)

0 commit comments

Comments
 (0)