-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve TF backend's Switch function #7958
Conversation
@abhaikollara Tests :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks good to me. Please add unit tests for the new cases.
keras/backend/tensorflow_backend.py
Outdated
if callable(else_expression): | ||
else_expression = else_expression() | ||
expr_ndim = ndim(then_expression) | ||
assert cond_ndim <= expr_ndim, 'Rank of condition should be less' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Raise a ValueError
instead of an assert (and add it to the docstring).
keras/backend/tensorflow_backend.py
Outdated
else_expression = else_expression() | ||
expr_ndim = ndim(then_expression) | ||
assert cond_ndim <= expr_ndim, 'Rank of condition should be less' | ||
' than or equal to rank of then and else expressions.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message should mention the rank of the condition that was passed and the rank of the then/else expressions
Seems need to make changes to CNTK backend too. |
The CNTK issue seems straightforward: |
Seems there is another bug in cntk's tile. |
Theano was broadcasting the wrong way: |
@fchollet Done (finally). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks. Note we are currently have an issue with Travis CI so I don't expect the tests to pass.
keras/backend/cntk_backend.py
Outdated
ndim_cond = ndim(condition) | ||
ndim_expr = ndim(then_expression) | ||
if ndim_cond > ndim_expr: | ||
raise ValueError('Rank of condition should be less' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reduce indent (should be 4 spaces)
keras/backend/cntk_backend.py
Outdated
ndim_expr = ndim(then_expression) | ||
if ndim_cond > ndim_expr: | ||
raise ValueError('Rank of condition should be less' | ||
' than or equal to rank of then and' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrap then
and else
with ` to make the sentence easier to parse.
keras/backend/tensorflow_backend.py
Outdated
if cond_ndim > expr_ndim: | ||
raise ValueError('Rank of condition should be less' | ||
' than or equal to rank of then and' | ||
' else expressions. ndim(condition)=' + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrap then
and else
with ` to make the sentence easier to parse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
~ is used for actual stuff in code right? So shouldn't it be then_expression
and else_expression
? Or simply then
and else
expressions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, better to use then_expression
etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
* improve k.switch * update doc * handle case when ndim is None * pep8 * use tf.reshape instead of tf.expand_dims * unit tests + better error msg * remove faulty cases * add broadcasting to cntk backend * typo * typo * update ds * update ds * cntk doesnt support symbolic tile * cntk doesnt support symbolic tile * allow int n for cntk tile * int->tuple * bug fix cntk tile * fix broadcasting in theano backend * formatting fixes
…-outputs * master: (68 commits) Change default value of shuffle parameter of Sequential.fit_generator() from True to False. (keras-team#8075) Fix off-by-one bug in predict/evaluate progress bar (keras-team#8071) Revert "Faster sequence" (keras-team#8060) Support NCHW for conv2d. (keras-team#8021) Change compute_accuracy() argument order and names (keras-team#8049) Replace literal constant 10 with variable num_classes in example/ (keras-team#8041) Faster sequence (keras-team#8039) Improve RNN docs. Enable accuracy reporting during training in examples/mnist_siamese_graph.py (keras-team#7997) Bug fix: Models with shared layers shouldn't be considered Sequential like (keras-team#8025) Add 'subtract' merge layer documentation (keras-team#8038) Update inference in seq2seq script to be more efficient Remove lstm_benchmark from examples/README.md (keras-team#8024) Add shuffle to the Model API (keras-team#8023) Add seq2seq example script. fix travis failure (keras-team#8014) Improve TF backend's Switch function (keras-team#7958) Added support for dynamic noise_shape in Dropout (keras-team#7999) Make on_epoch_end optional (keras-team#8007) Incremental tests speed ups. ...
As of now, K.switch supports only scalar condition on TF backend. This PR makes TF backend's switch function behave more like Theano's switch; with element-wise selection and broadcasting when required.
condition
== 0 (scalar), calltf.cond
like beforecondition
== 1, calltf.where
condition
> 1, make sure shape ofcondition
is same as that of the expressions by applyingexpand_dims
andtile
ops; then calltf.where