Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Control flow operators do not work in imperative (NDArray) API #12825

Closed
samskalicky opened this issue Oct 15, 2018 · 4 comments
Closed

Control flow operators do not work in imperative (NDArray) API #12825

samskalicky opened this issue Oct 15, 2018 · 4 comments
Labels

Comments

@samskalicky
Copy link
Contributor

Description

The example for cond operator does not work for NDArray API from the website ([https://mxnet.incubator.apache.org/api/python/ndarray/contrib.html#mxnet.ndarray.contrib.cond])

a, b = mx.nd.array([1]), mx.nd.array([2])
pred = a * b < 5
then_func = lambda a, b: (a + 5) * (b + 5)
else_func = lambda a, b: (a - 5) * (b - 5)
outputs = mx.nd.contrib.cond(pred, then_func, else_func)
outputs[0]

When running this with v1.3.x here is the output:

>>> a, b = mx.nd.array([1]), mx.nd.array([2])
>>> pred = a * b < 5
>>> then_func = lambda a, b: (a + 5) * (b + 5)
>>> else_func = lambda a, b: (a - 5) * (b - 5)
>>> outputs = mx.nd.contrib.cond(pred, then_func, else_func)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python2.7/site-packages/mxnet/ndarray/contrib.py", line 460, in cond
    return then_func()
TypeError: <lambda>() takes exactly 2 arguments (0 given)

The symbolic API example does work for the cond operator. Is it supposed to work for for the NDArray API? If so, then this is a bug. If not, then the documentation needs to be fixed since its not supported in the NDArray API.

@samskalicky
Copy link
Contributor Author

@zheng-da What is your take on this? Is this supposed to work, or does the documentation need to be removed for cond (and other control flow operators) in the NDArray API?

@piyushghai
Copy link
Contributor

@mxnet-label-bot [Question, NDArray]

@wkcn
Copy link
Member

wkcn commented Oct 24, 2018

It seems the example is wrong.

The following example works.

import mxnet as mx
a, b = mx.nd.array([1]), mx.nd.array([2])
pred = a * b < 5
then_func = lambda : (a + 5) * (b + 5)
else_func = lambda : (a - 5) * (b - 5)
outputs = mx.nd.contrib.cond(pred, then_func, else_func)
print (outputs[0])

@ThomasDelteil
Copy link
Contributor

@samskalicky can you please close the issue? @sandeep-krishnamurthy can you please close the issue the PR has been merged

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

6 participants