Skip to content

ENH: Allow chained name_source #938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 15, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Next release
============

* ENH: Inputs with name_source can be now chained in cascade (https://github.com/nipy/nipype/pull/938)
* ENH: Improve JSON interfaces: default settings when reading and consistent output creation
when writing (https://github.com/nipy/nipype/pull/1047)
* FIX: AddCSVRow problems when using infields (https://github.com/nipy/nipype/pull/1028)
Expand Down
3 changes: 3 additions & 0 deletions doc/devel/interface_specs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ CommandLine
``name_source``
Indicates the list of input fields from which the value of the current File
output variable will be drawn. This input field must be the name of a File.
Chaining is allowed, meaning that an input field can point to another as
``name_source``, which also points as ``name_source`` to a third field.
In this situation, the templates for substitutions are also accumulated.

``name_template``
By default a ``%s_generated`` template is used to create the output
Expand Down
58 changes: 42 additions & 16 deletions nipype/interfaces/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@

__docformat__ = 'restructuredtext'

class NipypeInterfaceError(Exception):
def __init__(self, value):
self.value = value
def __str__(self):
return repr(self.value)

def _lock_files():
tmpdir = '/tmp'
pattern = '.X*-lock'
Expand Down Expand Up @@ -1510,9 +1516,13 @@ def _format_arg(self, name, trait_spec, value):
# Append options using format string.
return argstr % value

def _filename_from_source(self, name):
def _filename_from_source(self, name, chain=None):
if chain is None:
chain = []

trait_spec = self.inputs.trait(name)
retval = getattr(self.inputs, name)

if not isdefined(retval) or "%s" in retval:
if not trait_spec.name_source:
return retval
Expand All @@ -1522,26 +1532,42 @@ def _filename_from_source(self, name):
name_template = trait_spec.name_template
if not name_template:
name_template = "%s_generated"
if isinstance(trait_spec.name_source, list):
for ns in trait_spec.name_source:
if isdefined(getattr(self.inputs, ns)):
name_source = ns
break

ns = trait_spec.name_source
while isinstance(ns, list):
if len(ns) > 1:
iflogger.warn('Only one name_source per trait is allowed')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is only one allowed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, we have a somewhat difficult decision here.

Former implementation iterated the name_source list until it found the first defined name_source.
Current implementation may accept undefined name_sources, but they must have a name_source as well.

In both implementations only one name_source is actually used, right?.

ns = ns[0]

if not isinstance(ns, six.string_types):
raise ValueError(('name_source of \'%s\' trait sould be an '
'input trait name') % name)

if isdefined(getattr(self.inputs, ns)):
name_source = ns
source = getattr(self.inputs, name_source)
while isinstance(source, list):
source = source[0]

# special treatment for files
try:
_, base, _ = split_filename(source)
except AttributeError:
base = source
else:
name_source = trait_spec.name_source
source = getattr(self.inputs, name_source)
while isinstance(source, list):
source = source[0]
#special treatment for files
try:
_, base, _ = split_filename(source)
except AttributeError:
base = source
if name in chain:
raise NipypeInterfaceError('Mutually pointing name_sources')

chain.append(name)
base = self._filename_from_source(ns, chain)

chain = None
retval = name_template % base
_, _, ext = split_filename(retval)
if trait_spec.keep_extension and ext:
return retval
return self._overload_extension(retval, name)

return retval

def _gen_filename(self, name):
Expand All @@ -1557,7 +1583,7 @@ def _list_outputs(self):
outputs = self.output_spec().get()
for name, trait_spec in traits.iteritems():
out_name = name
if trait_spec.output_name != None:
if trait_spec.output_name is not None:
out_name = trait_spec.output_name
outputs[out_name] = \
os.path.abspath(self._filename_from_source(name))
Expand Down
102 changes: 101 additions & 1 deletion nipype/interfaces/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,13 @@ class DeprecationSpec3(nib.TraitedSpec):
yield assert_equal, spec_instance.foo, Undefined
yield assert_equal, spec_instance.bar, 1


def test_namesource():
tmp_infile = setup_file()
tmpd, nme, ext = split_filename(tmp_infile)
pwd = os.getcwd()
os.chdir(tmpd)

class spec2(nib.CommandLineInputSpec):
moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s",
position=2)
Expand All @@ -196,6 +198,104 @@ class TestName(nib.CommandLine):
os.chdir(pwd)
teardown_file(tmpd)


def test_chained_namesource():
tmp_infile = setup_file()
tmpd, nme, ext = split_filename(tmp_infile)
pwd = os.getcwd()
os.chdir(tmpd)

class spec2(nib.CommandLineInputSpec):
doo = nib.File(exists=True, argstr="%s", position=1)
moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s",
position=2, name_template='%s_mootpl')
poo = nib.File(name_source=['moo'], hash_files=False,
argstr="%s", position=3)

class TestName(nib.CommandLine):
_cmd = "mycommand"
input_spec = spec2

testobj = TestName()
testobj.inputs.doo = tmp_infile
res = testobj.cmdline
yield assert_true, '%s' % tmp_infile in res
yield assert_true, '%s_mootpl ' % nme in res
yield assert_true, '%s_mootpl_generated' % nme in res

os.chdir(pwd)
teardown_file(tmpd)


def test_cycle_namesource1():
tmp_infile = setup_file()
tmpd, nme, ext = split_filename(tmp_infile)
pwd = os.getcwd()
os.chdir(tmpd)

class spec3(nib.CommandLineInputSpec):
moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s",
position=1, name_template='%s_mootpl')
poo = nib.File(name_source=['moo'], hash_files=False,
argstr="%s", position=2)
doo = nib.File(name_source=['poo'], hash_files=False,
argstr="%s", position=3)

class TestCycle(nib.CommandLine):
_cmd = "mycommand"
input_spec = spec3

# Check that an exception is raised
to0 = TestCycle()
not_raised = True
try:
to0.cmdline
except nib.NipypeInterfaceError:
not_raised = False
yield assert_false, not_raised

os.chdir(pwd)
teardown_file(tmpd)

def test_cycle_namesource2():
tmp_infile = setup_file()
tmpd, nme, ext = split_filename(tmp_infile)
pwd = os.getcwd()
os.chdir(tmpd)


class spec3(nib.CommandLineInputSpec):
moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s",
position=1, name_template='%s_mootpl')
poo = nib.File(name_source=['moo'], hash_files=False,
argstr="%s", position=2)
doo = nib.File(name_source=['poo'], hash_files=False,
argstr="%s", position=3)

class TestCycle(nib.CommandLine):
_cmd = "mycommand"
input_spec = spec3

# Check that loop can be broken by setting one of the inputs
to1 = TestCycle()
to1.inputs.poo = tmp_infile

not_raised = True
try:
res = to1.cmdline
except nib.NipypeInterfaceError:
not_raised = False
print res

yield assert_true, not_raised
yield assert_true, '%s' % tmp_infile in res
yield assert_true, '%s_generated' % nme in res
yield assert_true, '%s_generated_mootpl' % nme in res

os.chdir(pwd)
teardown_file(tmpd)


def checknose():
"""check version of nose for known incompatability"""
mod = __import__('nose')
Expand Down Expand Up @@ -536,4 +636,4 @@ def test_global_CommandLine_output():
res = ci.run()
yield assert_equal, res.runtime.stdout, ''
os.chdir(pwd)
teardown_file(tmpd)
teardown_file(tmpd)