Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TF backend's Switch function #7958

Merged
merged 19 commits into from
Sep 29, 2017
Merged

Conversation

farizrahman4u
Copy link
Contributor

@farizrahman4u farizrahman4u commented Sep 22, 2017

As of now, K.switch supports only scalar condition on TF backend. This PR makes TF backend's switch function behave more like Theano's switch; with element-wise selection and broadcasting when required.

  • If rank of condition == 0 (scalar), call tf.cond like before
  • If rank of condition == 1, call tf.where
  • If rank of condition > 1, make sure shape of condition is same as that of the expressions by applying expand_dims and tile ops; then call tf.where

@farizrahman4u
Copy link
Contributor Author

@abhaikollara Tests :)

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks good to me. Please add unit tests for the new cases.

if callable(else_expression):
else_expression = else_expression()
expr_ndim = ndim(then_expression)
assert cond_ndim <= expr_ndim, 'Rank of condition should be less'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise a ValueError instead of an assert (and add it to the docstring).

else_expression = else_expression()
expr_ndim = ndim(then_expression)
assert cond_ndim <= expr_ndim, 'Rank of condition should be less'
' than or equal to rank of then and else expressions.'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message should mention the rank of the condition that was passed and the rank of the then/else expressions

@farizrahman4u
Copy link
Contributor Author

Seems need to make changes to CNTK backend too.

@fchollet
Copy link
Member

The CNTK issue seems straightforward: def tile(x, n): needs to be fixed to support integer n.

@farizrahman4u
Copy link
Contributor Author

Seems there is another bug in cntk's tile.

@farizrahman4u
Copy link
Contributor Author

Theano was broadcasting the wrong way: [(4,), (4, 3), (4, 3)] -> [(1, 4), (4, 3), (4, 3)] instead of [(4, 1), (4, 3), (4, 3)]

@farizrahman4u
Copy link
Contributor Author

@fchollet Done (finally).

fchollet
fchollet previously approved these changes Sep 28, 2017
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks. Note we are currently have an issue with Travis CI so I don't expect the tests to pass.

ndim_cond = ndim(condition)
ndim_expr = ndim(then_expression)
if ndim_cond > ndim_expr:
raise ValueError('Rank of condition should be less'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduce indent (should be 4 spaces)

ndim_expr = ndim(then_expression)
if ndim_cond > ndim_expr:
raise ValueError('Rank of condition should be less'
' than or equal to rank of then and'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap then and else with ` to make the sentence easier to parse.

if cond_ndim > expr_ndim:
raise ValueError('Rank of condition should be less'
' than or equal to rank of then and'
' else expressions. ndim(condition)=' +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap then and else with ` to make the sentence easier to parse.

Copy link
Contributor Author

@farizrahman4u farizrahman4u Sep 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

~ is used for actual stuff in code right? So shouldn't it be then_expression and else_expression? Or simply then and else expressions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, better to use then_expression etc.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fchollet fchollet merged commit 1bbfdb6 into keras-team:master Sep 29, 2017
@farizrahman4u farizrahman4u deleted the k_switch branch September 29, 2017 09:02
ozabluda pushed a commit to ozabluda/keras that referenced this pull request Oct 2, 2017
* improve k.switch

* update doc

* handle case when ndim is None

* pep8

* use tf.reshape instead of tf.expand_dims

* unit tests + better error msg

* remove faulty cases

* add broadcasting to cntk backend

* typo

* typo

* update ds

* update ds

* cntk doesnt support symbolic tile

* cntk doesnt support symbolic tile

* allow int n for cntk tile

* int->tuple

* bug fix cntk tile

* fix broadcasting in theano backend

* formatting fixes
dschwertfeger added a commit to dschwertfeger/keras that referenced this pull request Oct 8, 2017
…-outputs

* master: (68 commits)
  Change default value of shuffle parameter of Sequential.fit_generator() from True to False. (keras-team#8075)
  Fix off-by-one bug in predict/evaluate progress bar (keras-team#8071)
  Revert "Faster sequence" (keras-team#8060)
  Support NCHW for conv2d. (keras-team#8021)
  Change compute_accuracy() argument order and names (keras-team#8049)
  Replace literal constant 10 with variable num_classes in example/ (keras-team#8041)
  Faster sequence (keras-team#8039)
  Improve RNN docs.
  Enable accuracy reporting during training in examples/mnist_siamese_graph.py (keras-team#7997)
  Bug fix: Models with shared layers shouldn't be considered Sequential like (keras-team#8025)
  Add 'subtract' merge layer documentation (keras-team#8038)
  Update inference in seq2seq script to be more efficient
  Remove lstm_benchmark from examples/README.md (keras-team#8024)
  Add shuffle to the Model API (keras-team#8023)
  Add seq2seq example script.
  fix travis failure (keras-team#8014)
  Improve TF backend's Switch function (keras-team#7958)
  Added support for dynamic noise_shape in Dropout (keras-team#7999)
  Make on_epoch_end optional (keras-team#8007)
  Incremental tests speed ups.
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants