Skip to content

Commit 423d092

Browse files
committed
correct arg passing
1 parent aecf630 commit 423d092

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

python/tvm/topi/reduction.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def min(data, axis=None, keepdims=False):
167167
return cpp.min(data, axis, keepdims)
168168

169169

170-
def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False):
170+
def argmax(data, axis=None, keepdims=False, select_last_index=False):
171171
"""Returns the indices of the maximum values along an axis.
172172
173173
Parameters
@@ -185,14 +185,18 @@ def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal
185185
with size one.
186186
With this option, the result will broadcast correctly against the input array.
187187
188+
select_last_index: bool
189+
Whether to select the last index if the maximum element appears multiple times, else
190+
select the first index.
191+
188192
Returns
189193
-------
190194
ret : tvm.te.Tensor
191195
"""
192-
return cpp.argmax(data, axis, keepdims, exclude=exclude, select_last_index=select_last_index)
196+
return cpp.argmax(data, axis, keepdims, select_last_index)
193197

194198

195-
def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False):
199+
def argmin(data, axis=None, keepdims=False, select_last_index=False):
196200
"""Returns the indices of the minimum values along an axis.
197201
198202
Parameters
@@ -210,11 +214,15 @@ def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal
210214
with size one.
211215
With this option, the result will broadcast correctly against the input array.
212216
217+
select_last_index: bool
218+
Whether to select the last index if the minimum element appears multiple times, else
219+
select the first index.
220+
213221
Returns
214222
-------
215223
ret : tvm.te.Tensor
216224
"""
217-
return cpp.argmin(data, axis, keepdims, exclude, select_last_index)
225+
return cpp.argmin(data, axis, keepdims, select_last_index)
218226

219227

220228
def prod(data, axis=None, keepdims=False):

0 commit comments

Comments
 (0)