Skip to content

Conversation

TimZaman
Copy link
Contributor

Tensorflow (at least versions 1.3 and below) does not support NCHW computations of many ops when run on the CPU. That creates compatibility issues. These are originally mitigated in Keras, by doing a "NHWC roundtrip" and any op requested to be ran in "channels_first" (NCHW) would actually be transposed to- and from- NHWC.

NCHW often has performance benefits; for example, cuDNN is often fastest when the largest dimension is last (favouring NCHW for images).

The current implementation probes if NCHW is available by checking:

  1. Whether NCHW (channels_first) is requested at all
  2. If NCHW is supported by checking (1) if a GPU is available at all and (2) if it is not explicitly set on CPU.

If the op is set explicitly on the CPU, or there is no GPU device, and if NCHW is requested, we will proceed with the transpose roundtrip as we always did in Keras.

This PR only patches the conv2d op. There are more (and more significant) ops to be patched: any op that uses _postprocess_conv2d_output or _postprocess_conv3d_output ad bias_add too. The latter is actually Keras's bottleneck when it comes to NCHW. Lets save those for distinct PRs.

@TimZaman TimZaman force-pushed the tzaman/nchw-conv2d branch 3 times, most recently from 32b0c6f to 2a12f9d Compare September 29, 2017 17:44
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!



def get_current_tf_device():
"""Return device string of current graph context that's explicitly, otherwise returns `None`."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix docstring typos

self.device = device


def get_current_tf_device():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make this function private, as well as is_current_explicit_device, and get_available_gpus.


def is_current_explicit_device(device_type):
"""
Check if the current device is explicitly set on the device type.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put the docstring description on the first line

return [x.name for x in LOCAL_DEVICES if x.device_type == 'GPU']


def has_nchw_support():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds too specific. Why not something like "_running_on_gpu()" or something to that extent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. What it checks for is if the current scope is "not explicit on CPU, and has GPUs available". It's called like this to anticipate any TF 1.x's support that does not need the roundtrip so we can test for the tf version in the function.

@Dref360
Copy link
Contributor

Dref360 commented Oct 1, 2017

Could we have some tests where we mock being on a gpu machine so that some tests run using NHWC and some run using NCWH? We could test with tf.device('/cpu:0'): etc.

@TimZaman TimZaman force-pushed the tzaman/nchw-conv2d branch from 2a12f9d to ecbae72 Compare October 1, 2017 18:38
@TimZaman
Copy link
Contributor Author

TimZaman commented Oct 1, 2017

Sure @Dref360 , that does require ignoring tests if people/travis does not have a GPU though. Do we already have tests in Keras that skip tests for environments where there's no GPU?

return op.device


def is_current_explicit_device(device_type):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make this method private.

return (device is not None and device.device_type == device_type.upper())


def get_available_gpus():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make this method private (unless there is a rationale for making it part of the public API).

return [x.name for x in LOCAL_DEVICES if x.device_type == 'GPU']


def has_nchw_support():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make this method private.

@@ -44,6 +45,13 @@
# Change its value via `manual_variable_initialization(value)`.
_MANUAL_VAR_INIT = False

# This map is for converting the keras data format string to that of TF.
DATA_FORMAT_MAP = {'channels_first': 'NCHW', 'channels_last': 'NHWC'}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make this global variable private.


# This list queries the devices.
# We assume our devices don't change during our lifetime.
LOCAL_DEVICES = device_lib.list_local_devices()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make this global variable private.

@TimZaman TimZaman force-pushed the tzaman/nchw-conv2d branch from ecbae72 to ee13a1b Compare October 2, 2017 23:17
fchollet
fchollet previously approved these changes Oct 3, 2017
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fchollet
Copy link
Collaborator

fchollet commented Oct 3, 2017

Please fix the docstrings and PEP8 issue reported in CI: https://travis-ci.org/fchollet/keras/builds/282490074


# Returns
bool: if the current device scope is explicitly set on the device type.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the failing docstring test, this docstring needs a Raises section mentioning the ValueError: https://travis-ci.org/fchollet/keras/jobs/282558708

@TimZaman TimZaman force-pushed the tzaman/nchw-conv2d branch from 0a1063b to 87df398 Compare October 4, 2017 03:36
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, many thanks!

@fchollet fchollet merged commit 80dcdcd into keras-team:master Oct 4, 2017
@TimZaman
Copy link
Contributor Author

TimZaman commented Oct 4, 2017

OK! We're not done yet; we need to patch other ops too, most importantly (biggest bottleneck) is the add_bias's tranpose on NCHW. I'll get to that.

Dref360 pushed a commit to Dref360/keras that referenced this pull request Oct 5, 2017
bdwyer2 added a commit to bdwyer2/keras that referenced this pull request Oct 5, 2017
dschwertfeger added a commit to dschwertfeger/keras that referenced this pull request Oct 8, 2017
…-outputs

* master: (68 commits)
  Change default value of shuffle parameter of Sequential.fit_generator() from True to False. (keras-team#8075)
  Fix off-by-one bug in predict/evaluate progress bar (keras-team#8071)
  Revert "Faster sequence" (keras-team#8060)
  Support NCHW for conv2d. (keras-team#8021)
  Change compute_accuracy() argument order and names (keras-team#8049)
  Replace literal constant 10 with variable num_classes in example/ (keras-team#8041)
  Faster sequence (keras-team#8039)
  Improve RNN docs.
  Enable accuracy reporting during training in examples/mnist_siamese_graph.py (keras-team#7997)
  Bug fix: Models with shared layers shouldn't be considered Sequential like (keras-team#8025)
  Add 'subtract' merge layer documentation (keras-team#8038)
  Update inference in seq2seq script to be more efficient
  Remove lstm_benchmark from examples/README.md (keras-team#8024)
  Add shuffle to the Model API (keras-team#8023)
  Add seq2seq example script.
  fix travis failure (keras-team#8014)
  Improve TF backend's Switch function (keras-team#7958)
  Added support for dynamic noise_shape in Dropout (keras-team#7999)
  Make on_epoch_end optional (keras-team#8007)
  Incremental tests speed ups.
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants