Skip to content

Commit

Permalink
Merge pull request tensorflow#6608 from davidzchen/versions
Browse files Browse the repository at this point in the history
Detect and match against full cuda and cudnn versions.
  • Loading branch information
rohan100jain authored Jan 3, 2017
2 parents 55b0159 + 90d3b00 commit bb5f900
Showing 1 changed file with 91 additions and 17 deletions.
108 changes: 91 additions & 17 deletions third_party/gpus/cuda_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,42 @@ def _cudnn_install_basedir(repository_ctx):
return cudnn_install_path


def _matches_version(environ_version, detected_version):
"""Checks whether the user-specified version matches the detected version.
This function performs a weak matching so that if the user specifies only the
major or major and minor versions, the versions are still considered matching
if the version parts match. To illustrate:
environ_version detected_version result
-----------------------------------------
5.1.3 5.1.3 True
5.1 5.1.3 True
5 5.1 True
5.1.3 5.1 False
5.2.3 5.1.3 False
Args:
environ_version: The version specified by the user via environment
variables.
detected_version: The version autodetected from the CUDA installation on
the system.
Returns: True if user-specified version matches detected version and False
otherwise.
"""
environ_version_parts = environ_version.split(".")
detected_version_parts = detected_version.split(".")
if len(detected_version_parts) < len(environ_version_parts):
return False
for i, part in enumerate(detected_version_parts):
if i >= len(environ_version_parts):
break
if part != environ_version_parts[i]:
return False
return True


_NVCC_VERSION_PREFIX = "Cuda compilation tools, release "


Expand Down Expand Up @@ -179,69 +215,107 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
# Parse the CUDA version from the line containing the CUDA version.
prefix_removed = version_line.replace(_NVCC_VERSION_PREFIX, '')
parts = prefix_removed.split(",")
if len(parts) != 2 or len(parts[0]) == 0:
if len(parts) != 2 or len(parts[0]) < 2:
auto_configure_fail(
"Could not parse CUDA version from nvcc --version. Got: %s" %
result.stdout)
version = parts[0].strip()
full_version = parts[1].strip()
if full_version.startswith('V'):
full_version = full_version[1:]

# Check whether TF_CUDA_VERSION was set by the user and fail if it does not
# match the detected version.
environ_version = ""
if _TF_CUDA_VERSION in repository_ctx.os.environ:
environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip()
if environ_version and version != environ_version:
if environ_version and not _matches_version(environ_version, full_version):
auto_configure_fail(
("CUDA version detected from nvcc (%s) does not match " +
"TF_CUDA_VERSION (%s)") % (version, environ_version))
"TF_CUDA_VERSION (%s)") % (full_version, environ_version))

# We only use the version consisting of the major and minor version numbers.
version_parts = full_version.split('.')
if len(version_parts) < 2:
auto_configure_fail("CUDA version detected from nvcc (%s) is incomplete.")
if cpu_value == "Windows":
version = "64_" + version.replace(".", "")
version = "64_%s%s" % (version_parts[0], version_parts[1])
else:
version = "%s.%s" % (version_parts[0], version_parts[1])
return version


_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
_DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR"
_DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL"


def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
"""Detects the version of cuDNN installed on the system.
def _find_cuda_define(repository_ctx, cudnn_install_basedir, define):
"""Returns the value of a #define in cudnn.h
Greps through cudnn.h and returns the value of the specified #define. If the
#define is not found, then raise an error.
Args:
repository_ctx: The repository context.
cpu_value: The name of the host operating system.
cudnn_install_basedir: The cuDNN install directory.
cudnn_install_basedir: The install directory for cuDNN on the system.
define: The #define to search for.
Returns:
A string containing the version of cuDNN.
The value of the #define found in cudnn.h.
"""
# Find cudnn.h and grep for the line defining CUDNN_MAJOR.
cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
cudnn_install_basedir)
if not cudnn_h_path.exists:
auto_configure_fail("Cannot find cudnn.h at %s" % str(cudnn_h_path))
result = repository_ctx.execute([
"grep", "-E", _DEFINE_CUDNN_MAJOR, str(cudnn_h_path)])
result = repository_ctx.execute(["grep", "-E", define, str(cudnn_h_path)])
if result.stderr:
auto_configure_fail("Error reading %s: %s" %
(result.stderr, str(cudnn_h_path)))

# Parse the cuDNN major version from the line defining CUDNN_MAJOR
lines = result.stdout.splitlines()
if len(lines) == 0 or lines[0].find(_DEFINE_CUDNN_MAJOR) == -1:
if len(lines) == 0 or lines[0].find(define) == -1:
auto_configure_fail("Cannot find line containing '%s' in %s" %
(_DEFINE_CUDNN_MAJOR, str(cudnn_h_path)))
version = lines[0].replace(_DEFINE_CUDNN_MAJOR, "").strip()
(define, str(cudnn_h_path)))
return lines[0].replace(define, "").strip()


def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
"""Detects the version of cuDNN installed on the system.
Args:
repository_ctx: The repository context.
cpu_value: The name of the host operating system.
cudnn_install_basedir: The cuDNN install directory.
Returns:
A string containing the version of cuDNN.
"""
major_version = _find_cuda_define(repository_ctx, cudnn_install_basedir,
_DEFINE_CUDNN_MAJOR)
minor_version = _find_cuda_define(repository_ctx, cudnn_install_basedir,
_DEFINE_CUDNN_MINOR)
patch_version = _find_cuda_define(repository_ctx, cudnn_install_basedir,
_DEFINE_CUDNN_PATCHLEVEL)
full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)

# Check whether TF_CUDNN_VERSION was set by the user and fail if it does not
# match the detected version.
environ_version = ""
if _TF_CUDNN_VERSION in repository_ctx.os.environ:
environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
if environ_version and version != environ_version:
if environ_version and not _matches_version(environ_version, full_version):
cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
cudnn_install_basedir)
auto_configure_fail(
("cuDNN version detected from %s (%s) does not match " +
"TF_CUDNN_VERSION (%s)") % (str(cudnn_h_path), version, environ_version))
"TF_CUDNN_VERSION (%s)") %
(str(cudnn_h_path), full_version, environ_version))

# We only use the major version since we use the libcudnn libraries that are
# only versioned with the major version (e.g. libcudnn.so.5).
version = major_version
if cpu_value == "Windows":
version = "64_" + version
return version
Expand Down

0 comments on commit bb5f900

Please sign in to comment.