-
Notifications
You must be signed in to change notification settings - Fork 19.6k
WIP: Support NCHW (for conv2d). #8021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
32b0c6f
to
2a12f9d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
keras/backend/tensorflow_backend.py
Outdated
|
||
|
||
def get_current_tf_device(): | ||
"""Return device string of current graph context that's explicitly, otherwise returns `None`.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix docstring typos
keras/backend/tensorflow_backend.py
Outdated
self.device = device | ||
|
||
|
||
def get_current_tf_device(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should make this function private, as well as is_current_explicit_device
, and get_available_gpus
.
keras/backend/tensorflow_backend.py
Outdated
|
||
def is_current_explicit_device(device_type): | ||
""" | ||
Check if the current device is explicitly set on the device type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please put the docstring description on the first line
keras/backend/tensorflow_backend.py
Outdated
return [x.name for x in LOCAL_DEVICES if x.device_type == 'GPU'] | ||
|
||
|
||
def has_nchw_support(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds too specific. Why not something like "_running_on_gpu()" or something to that extent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm.. What it checks for is if the current scope is "not explicit on CPU, and has GPUs available". It's called like this to anticipate any TF 1.x's support that does not need the roundtrip so we can test for the tf version in the function.
Could we have some tests where we mock being on a gpu machine so that some tests run using NHWC and some run using NCWH? We could test |
2a12f9d
to
ecbae72
Compare
Sure @Dref360 , that does require ignoring tests if people/travis does not have a GPU though. Do we already have tests in Keras that skip tests for environments where there's no GPU? |
keras/backend/tensorflow_backend.py
Outdated
return op.device | ||
|
||
|
||
def is_current_explicit_device(device_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make this method private.
keras/backend/tensorflow_backend.py
Outdated
return (device is not None and device.device_type == device_type.upper()) | ||
|
||
|
||
def get_available_gpus(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make this method private (unless there is a rationale for making it part of the public API).
keras/backend/tensorflow_backend.py
Outdated
return [x.name for x in LOCAL_DEVICES if x.device_type == 'GPU'] | ||
|
||
|
||
def has_nchw_support(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make this method private.
keras/backend/tensorflow_backend.py
Outdated
@@ -44,6 +45,13 @@ | |||
# Change its value via `manual_variable_initialization(value)`. | |||
_MANUAL_VAR_INIT = False | |||
|
|||
# This map is for converting the keras data format string to that of TF. | |||
DATA_FORMAT_MAP = {'channels_first': 'NCHW', 'channels_last': 'NHWC'} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make this global variable private.
keras/backend/tensorflow_backend.py
Outdated
|
||
# This list queries the devices. | ||
# We assume our devices don't change during our lifetime. | ||
LOCAL_DEVICES = device_lib.list_local_devices() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make this global variable private.
ecbae72
to
ee13a1b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Please fix the docstrings and PEP8 issue reported in CI: https://travis-ci.org/fchollet/keras/builds/282490074 |
ee13a1b
to
0a1063b
Compare
keras/backend/tensorflow_backend.py
Outdated
|
||
# Returns | ||
bool: if the current device scope is explicitly set on the device type. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per the failing docstring test, this docstring needs a Raises
section mentioning the ValueError: https://travis-ci.org/fchollet/keras/jobs/282558708
0a1063b
to
87df398
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, many thanks!
OK! We're not done yet; we need to patch other ops too, most importantly (biggest bottleneck) is the add_bias's tranpose on NCHW. I'll get to that. |
This reverts commit 80dcdcd.
This reverts commit 80dcdcd.
…-outputs * master: (68 commits) Change default value of shuffle parameter of Sequential.fit_generator() from True to False. (keras-team#8075) Fix off-by-one bug in predict/evaluate progress bar (keras-team#8071) Revert "Faster sequence" (keras-team#8060) Support NCHW for conv2d. (keras-team#8021) Change compute_accuracy() argument order and names (keras-team#8049) Replace literal constant 10 with variable num_classes in example/ (keras-team#8041) Faster sequence (keras-team#8039) Improve RNN docs. Enable accuracy reporting during training in examples/mnist_siamese_graph.py (keras-team#7997) Bug fix: Models with shared layers shouldn't be considered Sequential like (keras-team#8025) Add 'subtract' merge layer documentation (keras-team#8038) Update inference in seq2seq script to be more efficient Remove lstm_benchmark from examples/README.md (keras-team#8024) Add shuffle to the Model API (keras-team#8023) Add seq2seq example script. fix travis failure (keras-team#8014) Improve TF backend's Switch function (keras-team#7958) Added support for dynamic noise_shape in Dropout (keras-team#7999) Make on_epoch_end optional (keras-team#8007) Incremental tests speed ups. ...
Tensorflow (at least versions 1.3 and below) does not support
NCHW
computations of many ops when run on the CPU. That creates compatibility issues. These are originally mitigated in Keras, by doing a "NHWC roundtrip" and any op requested to be ran in "channels_first
" (NCHW
) would actually be transposed to- and from-NHWC
.NCHW
often has performance benefits; for example, cuDNN is often fastest when the largest dimension is last (favouringNCHW
for images).The current implementation probes if NCHW is available by checking:
NCHW
(channels_first
) is requested at allNCHW
is supported by checking (1) if a GPU is available at all and (2) if it is not explicitly set on CPU.If the op is set explicitly on the CPU, or there is no GPU device, and if NCHW is requested, we will proceed with the transpose roundtrip as we always did in Keras.
This PR only patches the
conv2d
op. There are more (and more significant) ops to be patched: any op that uses_postprocess_conv2d_output
or_postprocess_conv3d_output
adbias_add
too. The latter is actually Keras's bottleneck when it comes to NCHW. Lets save those for distinct PRs.