Skip to content

Commit

Permalink
Merge pull request tensorflow#6047 from meteorcloudy/fix_cuda_configure
Browse files Browse the repository at this point in the history
Fix cuda version detect on Windows
  • Loading branch information
gunan authored Dec 3, 2016
2 parents d5d654c + f7ff20e commit 48fb73a
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions third_party/gpus/cuda_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,16 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
auto_configure_fail(
"CUDA version detected from nvcc (%s) does not match " +
"TF_CUDA_VERSION (%s)" % (version, environ_version))

if cpu_value == "Windows":
version = "64_" + version.replace(".", "")
return version


_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"


def _cudnn_version(repository_ctx, cudnn_install_basedir):
def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
"""Detects the version of cuDNN installed on the system.
Args:
Expand Down Expand Up @@ -236,8 +239,11 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir):
environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
if environ_version and version != environ_version:
auto_configure_fail(
"cuDNN version detected from %s (%s) does not match " +
"TF_CUDNN_VERSION (%s)" % (str(cudnn_h_path), version, environ_version))
("cuDNN version detected from %s (%s) does not match " +
"TF_CUDNN_VERSION (%s)") % (str(cudnn_h_path), version, environ_version))

if cpu_value == "Windows":
version = "64_" + version
return version


Expand Down Expand Up @@ -501,7 +507,7 @@ def _get_cuda_config(repository_ctx):
cuda_toolkit_path = _cuda_toolkit_path(repository_ctx)
cuda_version = _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value)
cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
cudnn_version = _cudnn_version(repository_ctx, cudnn_install_basedir)
cudnn_version = _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value)
return struct(
cuda_toolkit_path = cuda_toolkit_path,
cudnn_install_basedir = cudnn_install_basedir,
Expand Down

0 comments on commit 48fb73a

Please sign in to comment.