Skip to content

Commit 56b2a62

Browse files
authored
Merge pull request #1937 from satra/fix/merge
fix: closes #1920 to revert to original behavior for input names
2 parents 4176856 + 5b44037 commit 56b2a62

File tree

2 files changed

+39
-18
lines changed

2 files changed

+39
-18
lines changed

nipype/interfaces/utility/base.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,29 @@ def _list_outputs(self):
9999
class MergeInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
100100
axis = traits.Enum('vstack', 'hstack', usedefault=True,
101101
desc='direction in which to merge, hstack requires same number of elements in each input')
102-
no_flatten = traits.Bool(False, usedefault=True, desc='append to outlist instead of extending in vstack mode')
102+
no_flatten = traits.Bool(False, usedefault=True,
103+
desc='append to outlist instead of extending in vstack mode')
104+
ravel_inputs = traits.Bool(False, usedefault=True,
105+
desc='ravel inputs when no_flatten is False')
103106

104107

105108
class MergeOutputSpec(TraitedSpec):
106109
out = traits.List(desc='Merged output')
107110

108111

112+
def _ravel(in_val):
113+
if not isinstance(in_val, list):
114+
return in_val
115+
flat_list = []
116+
for val in in_val:
117+
raveled_val = _ravel(val)
118+
if isinstance(raveled_val, list):
119+
flat_list.extend(raveled_val)
120+
else:
121+
flat_list.append(raveled_val)
122+
return flat_list
123+
124+
109125
class Merge(IOBase):
110126
"""Basic interface class to merge inputs into a single list
111127
@@ -123,23 +139,34 @@ class Merge(IOBase):
123139
>>> out.outputs.out
124140
[1, 2, 5, 3]
125141
126-
>>> merge = Merge() # Or Merge(1)
127-
>>> merge.inputs.in_lists = [1, [2, 5], 3]
142+
>>> merge = Merge(1)
143+
>>> merge.inputs.in1 = [1, [2, 5], 3]
144+
>>> out = merge.run()
145+
>>> out.outputs.out
146+
[1, [2, 5], 3]
147+
148+
>>> merge = Merge(1)
149+
>>> merge.inputs.in1 = [1, [2, 5], 3]
150+
>>> merge.inputs.ravel_inputs = True
128151
>>> out = merge.run()
129152
>>> out.outputs.out
130153
[1, 2, 5, 3]
131154
155+
>>> merge = Merge(1)
156+
>>> merge.inputs.in1 = [1, [2, 5], 3]
157+
>>> merge.inputs.no_flatten = True
158+
>>> out = merge.run()
159+
>>> out.outputs.out
160+
[[1, [2, 5], 3]]
132161
"""
133162
input_spec = MergeInputSpec
134163
output_spec = MergeOutputSpec
135164

136-
def __init__(self, numinputs=1, **inputs):
165+
def __init__(self, numinputs=0, **inputs):
137166
super(Merge, self).__init__(**inputs)
138167
self._numinputs = numinputs
139-
if numinputs > 1:
168+
if numinputs >= 1:
140169
input_names = ['in%d' % (i + 1) for i in range(numinputs)]
141-
elif numinputs == 1:
142-
input_names = ['in_lists']
143170
else:
144171
input_names = []
145172
add_traits(self.inputs, input_names)
@@ -150,8 +177,6 @@ def _list_outputs(self):
150177

151178
if self._numinputs < 1:
152179
return outputs
153-
elif self._numinputs == 1:
154-
values = self.inputs.in_lists
155180
else:
156181
getval = lambda idx: getattr(self.inputs, 'in%d' % (idx + 1))
157182
values = [getval(idx) for idx in range(self._numinputs)
@@ -160,7 +185,8 @@ def _list_outputs(self):
160185
if self.inputs.axis == 'vstack':
161186
for value in values:
162187
if isinstance(value, list) and not self.inputs.no_flatten:
163-
out.extend(value)
188+
out.extend(_ravel(value) if self.inputs.ravel_inputs else
189+
value)
164190
else:
165191
out.append(value)
166192
else:

nipype/interfaces/utility/tests/test_base.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,25 +56,20 @@ def test_split(tmpdir, args, expected):
5656
([3], {}, [0, [1, 2], [3, 4, 5]], [0, 1, 2, 3, 4, 5]),
5757
([0], {}, None, None),
5858
([], {}, [], []),
59-
([], {}, [0, [1, 2], [3, 4, 5]], [0, 1, 2, 3, 4, 5]),
59+
([], {}, [0, [1, 2], [3, 4, 5]], [0, [1, 2], [3, 4, 5]]),
6060
([3], {'axis': 'hstack'}, [[0], [1, 2], [3, 4, 5]], [[0, 1, 3]]),
6161
([3], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
6262
[[0, 2, 4], [1, 3, 5]]),
6363
([3], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
6464
[[0, 2, 4], [1, 3, 5]]),
65-
([1], {'axis': 'hstack'}, [[0], [1, 2], [3, 4, 5]], [[0, 1, 3]]),
66-
([1], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
67-
[[0, 2, 4], [1, 3, 5]]),
6865
])
6966
def test_merge(tmpdir, args, kwargs, in_lists, expected):
7067
os.chdir(str(tmpdir))
7168

7269
node = pe.Node(utility.Merge(*args, **kwargs), name='merge')
7370

74-
numinputs = args[0] if args else 1
75-
if numinputs == 1:
76-
node.inputs.in_lists = in_lists
77-
elif numinputs > 1:
71+
numinputs = args[0] if args else 0
72+
if numinputs >= 1:
7873
for i in range(1, numinputs + 1):
7974
setattr(node.inputs, 'in{:d}'.format(i), in_lists[i - 1])
8075

0 commit comments

Comments
 (0)