Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build issues with local CUDA installation #23689

Open
adamjstewart opened this issue Sep 17, 2024 · 58 comments
Open

Build issues with local CUDA installation #23689

adamjstewart opened this issue Sep 17, 2024 · 58 comments
Labels
bug Something isn't working

Comments

@adamjstewart
Copy link
Contributor

Description

When building jaxlib with an externally installed copy of CUDA (something required by all package managers and HPC systems), I see the following error:

gcc: error: unrecognized command-line option '--cuda-path=external/cuda_nvcc'

It's possible I'm passing the wrong flags somewhere. I'm using:

> python3 build/build.py --enable_cuda --cuda_compute_capabilities=8.0 --bazel_options=--repo_env=LOCAL_CUDA_PATH=... --bazel_options=--repo_env=LOCAL_CUDNN_PATH=... --bazel_options=--repo_env=LOCAL_NCCL_PATH=...

(of course, with ... replaced by the actual paths)

System info (python version, jaxlib version, accelerator, etc.)

  • Python: 3.11.9
  • Jaxlib: 0.4.32
  • GCC: 11.4.0

Build log

@adamjstewart adamjstewart added the bug Something isn't working label Sep 17, 2024
@ybaturina
Copy link
Contributor

Hi @adamjstewart GCC compiler is not officially supported by JAX. I recommend using Clang. You can pass the clang path in --clang_path option.

@ybaturina
Copy link
Contributor

If you absolutely need to use GCC, we have an experimental support that can be enabled like this:

--bazel_options=--action_env=CUDA_NVCC="1" --bazel_options=--@local_config_cuda//:cuda_compiler=nvcc

@adamjstewart
Copy link
Contributor Author

I tried adding these flags but I still see the exact same error:

gcc: error: unrecognized command-line option '--cuda-path=external/cuda_nvcc'

@ybaturina
Copy link
Contributor

Would you paste the full stack trace here please? I'd like to make sure that CUDA_NVCC value is recognized by Bazel.

@adamjstewart
Copy link
Contributor Author

Here you go:

@ybaturina
Copy link
Contributor

Hmm, one more suggestion: try this
```--bazel_options=--action_env=TF_NVCC_CLANG="1" --bazel_options=--@local_config_cuda//:cuda_compiler=nvcc`

The reason why your build fails is that GCC is unable to compile CUDA dependencies, it should be done with NVCC compiler.

@adamjstewart
Copy link
Contributor Author

Still the same issue:

gcc: error: unrecognized command-line option ‘--cuda-path=external/cuda_nvcc’

@ybaturina
Copy link
Contributor

This is what I've tried:

python3.10 build/build.py --enable_cuda --use_clang=false --bazel_options=--repo_env=CC="/dt9/usr/bin/gcc" --bazel_options=--repo_env=TF_SYSROOT="/dt9" --bazel_options=--action_env=CUDA_NVCC="1" --bazel_options=--@local_config_cuda//:cuda_compiler=nvcc

The subcommand I got:

SUBCOMMAND: # //jaxlib:cpu_feature_guard.so [action 'Compiling jaxlib/cpu_feature_guard.c', configuration: 988f5a730e2bd9c88c71efcc9c7f0d36ad2ec3c5f71c922aabaf7614ff994b0f, execution platform: @local_execution_config_platform//:platform]
(cd /home/ybaturina/.cache/bazel/_bazel_ybaturina/ead9107e8e47a1c42911a02736d63d03/execroot/__main__ && \
  exec env - \
    CUDA_NVCC=1 \
    PATH=/home/kbuilder/.local/bin:/usr/local/bin/python3.10:/home/ybaturina/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin \
    PWD=/proc/self/cwd \
  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/jaxlib/_objs/cpu_feature_guard.so/cpu_feature_guard.pic.d '-frandom-seed=bazel-out/k8-opt/bin/jaxlib/_objs/cpu_feature_guard.so/cpu_feature_guard.pic.o' '-DBAZEL_CURRENT_REPOSITORY=""' -iquote . -iquote bazel-out/k8-opt/bin -iquote external/python_x86_64-unknown-linux-gnu -iquote bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu -isystem external/python_x86_64-unknown-linux-gnu/include -isystem bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu/include -isystem external/python_x86_64-unknown-linux-gnu/include/python3.10 -isystem bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu/include/python3.10 -isystem external/python_x86_64-unknown-linux-gnu/include/python3.10m -isystem bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu/include/python3.10m -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '-fvisibility=hidden' -Wno-sign-compare -Wno-unknown-warning-option -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' -mavx -fno-strict-aliasing -fexceptions '-fvisibility=hidden' '--sysroot=/dt9' -c jaxlib/cpu_feature_guard.c -o bazel-out/k8-opt/bin/jaxlib/_objs/cpu_feature_guard.so/cpu_feature_guard.pic.o)

I didn't get the --cuda_path option passed to the NVCC compiler.

I assume that something in the environment variables on your machine messes up the subcommand configuration.
Since JAX doesn't support GCC compilation officially, I strongly recommend using clang for the compilation.

@daskol
Copy link
Contributor

daskol commented Oct 22, 2024

There is the --cuda-path issue with GCC for me as well.

Alternatively, I tried to build it with Clang and local CUDA, CUDNN, and NCCL but other issues occure.

In file included from external/xla/xla/tsl/cuda/cudnn_stub.cc:16:
In file included from external/com_google_absl/absl/container/flat_hash_map.h:38:
In file included from external/com_google_absl/absl/algorithm/container.h:43:
In file included from /usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/algorithm:61:
In file included from /usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/bits/stl_algo.h:71:
/usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/cstdlib:79:15: fatal error: 'stdlib.h' file not found
   79 | #include_next <stdlib.h>
      |               ^~~~~~~~~~
1 error generated.

Specifically, I run bazel directly as follows.

build/bazel-6.5.0-linux-x86_64 run --verbose_failures=true \
    --repo_env=LOCAL_CUDA_PATH=/opt/cuda \
    --repo_env=LOCAL_CUDNN_PATH=/usr \
    --repo_env=LOCAL_NCCL_PATH=/usr \
    //jaxlib/tools:build_wheel -- \
    --output_path=$PWD/dist --cpu=x86_64 \
    --jaxlib_git_hash=78ade74d695407306461718a6d73cfed89b4d972

Also, I add the following .bazelrc.user to the repository root.

# .bazelrc.user
build --strategy=Genrule=standalone
build --action_env CLANG_COMPILER_PATH="/usr/bin/clang-18"
build --repo_env CC="/usr/bin/clang-18"
build --repo_env BAZEL_COMPILER="/usr/bin/clang-18"
build --copt=-Wno-error=unused-command-line-argument
build --copt=-Wno-gnu-offsetof-extensions
build --config=avx_posix
build --config=mkl_open_source_only
build --config=cuda
build --config=nvcc_clang
build --action_env=CLANG_CUDA_COMPILER_PATH=/usr/bin/clang-18
build --repo_env HERMETIC_PYTHON_VERSION="3.12"

Dependency versions follow.

$ pacman -Qs '(cuda|cudnn|clang)'
local/clang 18.1.8-4
    C language family frontend for LLVM
local/compiler-rt 18.1.8-1
    Compiler runtime libraries for clang
local/cuda 12.6.2-2
    NVIDIA's GPU programming toolkit
local/cudnn 9.2.1.18-1
    NVIDIA CUDA Deep Neural Network library

@ybaturina
Copy link
Contributor

This looks like a problem with GCC installation.
If you run clang -v, then you'll see smth like this:
Selected GCC installation: /usr/bin/../lib/gcc/x86_64-linux-gnu/14

Looking at the error above, I suggest running this command:
find /usr/bin/../include -name "stdlib.h
If the file is not found in GCC v14, that means you'll need to install missing headers and run sudo apt install g++-14

@daskol
Copy link
Contributor

daskol commented Oct 22, 2024

I reproduce the issue for jaxlib from 0.4.32, 0.4.33, and 0.4.34 with clang-14 and clang-18 (depends on gcc and gcc-libs 14.2.1+r134+gab884fffe3fc-1). Also, cuda package depends on gcc-13 (it's Arch).

$ /usr/lib/llvm14/bin/clang-14 -v
clang version 14.0.6
Target: x86_64-pc-linux-gnu
Thread model: posix
InstalledDir: /usr/lib/llvm14/bin
Found candidate GCC installation: /usr/lib/gcc/x86_64-pc-linux-gnu/13.3.0
Found candidate GCC installation: /usr/lib/gcc/x86_64-pc-linux-gnu/14.2.1
Found candidate GCC installation: /usr/lib64/gcc/x86_64-pc-linux-gnu/13.3.0
Found candidate GCC installation: /usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1
Selected GCC installation: /usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1
Candidate multilib: .;@m64
Candidate multilib: 32;@m32
Selected multilib: .;@m64

Also I have appended -v option to failed command crosstool_wrapper_driver_is_not_gcc. It displays system search which have stdlib.h.

$ (cd ... && .../crosstool_wrapper_driver_is_not_gcc ... -v)
Selected GCC installation: /usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1
 ...
 /usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1
 /usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/x86_64-pc-linux-gnu
 /usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/backward
 /usr/lib/llvm14/lib/clang/14.0.6/include
 /usr/local/include
 /usr/include
End of search list.
external/xla/xla/tsl/cuda/cupti_stub.cc:16:10: fatal error: 'third_party/gpus/cuda/extras/CUPTI/include/cupti.h' file not found
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
         ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1 error generated.
$ ls -l /usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/stdlib.h 
-rw-r--r-- 1 root root 2.3K Sep 10 13:07 /usr/lib64/gcc/x86_64-pc-linux- gnu/14.2.1/../../../../include/c++/14.2.1/stdlib.h

But it is a bit odd that now third_party/gpus/cuda/extras/CUPTI/include/cupti.h is not found. 🤯

@daskol
Copy link
Contributor

daskol commented Oct 22, 2024

There is indeed no directory third_party/gpus. I didn't find third_party/gpus/cuda/extras/CUPTI with the command below. 🤯

find -L bazel-jax-jax-v0.4.34 -name 'CUPTI'

UPD Is it upstream issue (XLA)?

@ybaturina
Copy link
Contributor

ybaturina commented Oct 22, 2024

Would you check if your local CUDA installation has CUPTI headers please? Specifically, the following headers should be present:
https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl#L21-L58

Also please check that the structure of the local CUDA/CUDNN/NCCL dirs is exactly the same as described here.

@daskol
Copy link
Contributor

daskol commented Oct 22, 2024

Sure. I checked and CUPTI is where it should be (i.e. /opt/cuda/extras/CUPTI; CUDA root is /opt/cuda). Also I check include directories for CUPTI and they looks perfect.

-Ibazel-out/k8-opt/bin/external/cuda_cupti/_virtual_includes/headers
-iquote external/cuda_cupti
-isystem bazel-out/k8-opt/bin/external/cuda_cupti/include

However, all these directories are empty. I compare how these directories are looks for other CUDA library (e.g. cufft) and they are not empty. Then I manually symlinked include directory multiple times like

mkdir -p bazel-out/k8-opt/bin/external/cuda_cupti/_virtual_includes/headers/third_party/gpus/cuda/extras/CUPTI
ln -s /opt/cuda/extras/CUPTI/include \
    bazel-out/k8-opt/bin/external/cuda_cupti/_virtual_includes/headers/third_party/gpus/cuda/extras/CUPTI
ln -s /opt/cuda/extras/CUPTI/include \
    bazel-out/k8-opt/bin/external/cuda_cupti/include

and run build ... @xla//xla/tsl/cuda:cupti_stub. Compilation fails.

external/xla/xla/tsl/cuda/BUILD.bazel:240:11: Compiling xla/tsl/cuda/cupti_stub.cc failed: undeclared inclusion(s) in rule '@xla//xla/tsl/cuda:cupti_stub':
this rule is missing dependency declarations for the following files included by 'xla/tsl/cuda/cupti_stub.cc':
  'bazel-out/k8-opt/bin/external/cuda_cupti/_virtual_includes/headers/third_party/gpus/cuda/extras/CUPTI/include/cupti.h'
  'bazel-out/k8-opt/bin/external/cuda_cupti/include/cupti_result.h'
  'bazel-out/k8-opt/bin/external/cuda_cupti/include/cupti_version.h'
  ...

It seems that bazel does not copy and not recreate header library for cupti while it has been done for cufft and others.

Is this trailing slash important? Other BUILD.tpl are without it. https://github.com/openxla/xla/blob/3740d0854106f32a89687484b05fd8947c89ef91/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl#L60

UPD Manual editing of cuda_cupti.BUILD.tpl does not work out. 😔

@ybaturina
Copy link
Contributor

The issue is that /opt/cuda/extras/CUPTI is not an acceptable location (see here).

This is how CUDA folder should look like:

<LOCAL_CUDA_PATH>/
    include/
    bin/
    lib/
    nvvm/

So all headers should be located in <LOCAL_CUDA_PATH>/include, and all libraries should be in <LOCAL_CUDA_PATH>/lib.

Also please note that local CUDA installation is not a recommended approach for building from sources.

@daskol
Copy link
Contributor

daskol commented Oct 22, 2024

I have already tried it. I copied everything from extras/CUPTI to . but it doesn't help. Moreover, include_prefix = "third_party/gpus/cuda/extras/CUPTI/include" in cuda_cupti.BUILD.tpl differs from those in cuda_*.BUILD.tpl.

@ybaturina
Copy link
Contributor

ybaturina commented Oct 22, 2024

include_prefix corresponds to import prefix in the source files, e.g. this one.

As far as I understand, you use the command below:

build/bazel-6.5.0-linux-x86_64 run --verbose_failures=true \
    --repo_env=LOCAL_CUDA_PATH=/opt/cuda \
    --repo_env=LOCAL_CUDNN_PATH=/usr \
    --repo_env=LOCAL_NCCL_PATH=/usr \
    //jaxlib/tools:build_wheel -- \
    --output_path=$PWD/dist --cpu=x86_64 \
    --jaxlib_git_hash=78ade74d695407306461718a6d73cfed89b4d972

Would you confirm that all CUDA headers are located in /opt/cuda/include, and all NCCL/CUDNN headers are in /usr/include? If so, please clean Bazel cache via bazel clean --expunge and run the command again. If it fails, I would appreciate it if you post the full log here.

@daskol
Copy link
Contributor

daskol commented Oct 22, 2024

Would you confirm that all CUDA headers are located in /opt/cuda/include, and all NCCL/CUDNN headers are in /usr/include?

Absolutely.

If so, please clean Bazel cache via bazel clean --expunge and run the command again. If it fails, I would appreciate it if you post the full log here.

Link.

Since target @xla//xla/tsl/cuda:cudnn_stub fails first due to missing <stdlib.h> in this time, I run building @xla//xla/tsl/cuda:cupti_stub that fails too because of "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" (logs).

build/bazel-6.5.0-linux-x86_64 build --verbose_failures=true \
    --repo_env=LOCAL_CUDA_PATH=/opt/cuda \
    --repo_env=LOCAL_CUDNN_PATH=/usr \
    --repo_env=LOCAL_NCCL_PATH=/usr \
    @xla//xla/tsl/cuda:cupti_stub

@ybaturina
Copy link
Contributor

Can you check this folder please?
/home/bershatsky/.cache/bazel/_bazel_bershatsky/3be6d6eea05ac1cf650a152f41829d38/external/cuda_cupti
Does it have symlink include pointing to /opt/cuda/include?

Please don't build @xla//xla/tsl/cuda:cupti_stub, try @cuda_cupti//:headers instead - this is the dependency used in Bazel tests for CUDA.

@daskol
Copy link
Contributor

daskol commented Oct 23, 2024

Does it have symlink include pointing to /opt/cuda/include?

Yes, it have include and others. I checked BUILD in this directory. All headers are commented (link).

Please don't build @xla//xla/tsl/cuda:cupti_stub, try @cuda_cupti//:headers instead - this is the dependency used in Bazel tests for CUDA.

Target @cuda_cupti//:headers has been successfully built. Isn't target @xla//xla/tsl/cuda:cupti_stub a dependency of //jaxlib/tools:build_wheel?

$ build/bazel-6.5.0-linux-x86_64 query \
    --repo_env=LOCAL_CUDA_PATH=/opt/cuda \
    --repo_env=LOCAL_CUDNN_PATH=/usr \
    --repo_env=LOCAL_NCCL_PATH=/usr \
    "deps(kind(rule, deps(//jaxlib/tools:build_wheel)))" | grep cupti
...
@xla//xla/tsl/cuda:cupti_stub
...

@daskol
Copy link
Contributor

daskol commented Oct 23, 2024

I also noticed one important thing: you execute bazel run and bazel build without passing --config=cuda.

I put all auxiliary options to .bazelrc.user. I believe that this is equivalent.

build --strategy=Genrule=standalone
build --action_env CLANG_COMPILER_PATH="/usr/lib/llvm14/bin/clang-14"
build --repo_env CC="/usr/lib/llvm14/bin/clang-14"
build --repo_env BAZEL_COMPILER="/usr/lib/llvm14/bin/clang-14"
build --copt=-Wno-error=unused-command-line-argument
build --copt=-Wno-gnu-offsetof-extensions
build --config=avx_posix
build --config=mkl_open_source_only
build --config=cuda
build --config=cuda_nvcc
build --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm14/bin/clang-14"
build --repo_env HERMETIC_PYTHON_VERSION="3.12"

@ybaturina
Copy link
Contributor

ybaturina commented Oct 23, 2024

The headers are commented out in two cases:

  1. Bazel command didn't receive the instruction to use --config=cuda option.
  2. the CUDA repository rule was unable to find CUPTI libraries in /opt/cuda/lib and assumed that CUPTI redistribution is absent, hence commented out the headers.

bazel query doesn't recognize Bazel options (including those provided in --config=cuda). To find true dependencies, you can use bazel cquery):

Here are my results:

bazel cquery --repo_env=LOCAL_CUDA_PATH="/home/ybaturina/cuda" --repo_env=LOCAL_CUDNN_PATH="/home/ybaturina/cudnn" --repo_env=LOCAL_NCCL_PATH="/home/ybaturina/Downloads/dists/nvidia/nccl" --repo_env=HERMETIC_PYTHON_VERSION=3.10 --config=cuda 'somepath(//jaxlib/tools:build_wheel, @xla//xla/tsl/cuda:cupti_stub)' - returns nothing

bazel cquery --repo_env=LOCAL_CUDA_PATH="/home/ybaturina/cuda" --repo_env=LOCAL_CUDNN_PATH="/home/ybaturina/cudnn" --repo_env=LOCAL_NCCL_PATH="/home/ybaturina/Downloads/dists/nvidia/nccl" --repo_env=HERMETIC_PYTHON_VERSION=3.10 --config=cuda 'somepath(//jaxlib/tools:build_wheel, @cuda_cupti//:headers)' - returns the result below:

INFO: Found 2 targets...
//jaxlib/tools:build_wheel (e356dec)
//jaxlib/cuda:cuda_gpu_support (e356dec)
//jaxlib/mosaic/gpu:mosaic_gpu (e356dec)
//jaxlib/mosaic/gpu:_mosaic_gpu_ext (e356dec)
//jaxlib/mosaic/gpu:_mosaic_gpu_ext.so (e356dec)
//jaxlib/cuda:cuda_vendor (e356dec)
@xla//xla/tsl/cuda:cupti (e356dec)
@cuda_cupti//:cupti (e356dec)
@cuda_cupti//:cupti_shared_library (e356dec)
@cuda_cupti//:headers (e356dec)

@daskol
Copy link
Contributor

daskol commented Oct 23, 2024

It seems that missing header error is caused by #include_next GNU extension and ordering of -isystem search paths in Bazel(?). Actual search path ordering follows.

 ...
#include <...> search starts here:
 ...
 /usr/include/c++/14.2.1
 /usr/include/c++/14.2.1/x86_64-pc-linux-gnu
 /usr/include/c++/14.2.1/backward
 /usr/lib/llvm14/lib/clang/14.0.6/include
 /usr/local/include
End of search list.

And #include_next directive passes inclusion of stdlib.h to the next match which is supposedly stdlib.h in /usr/local. But /usr/local is not in the search list. The list of available stdlib.h in the system.

$ find /usr -iname stdlib.h
/usr/include/bits/stdlib.h
/usr/include/c++/14.2.1/stdlib.h
/usr/include/c++/14.2.1/tr1/stdlib.h
/usr/include/stdlib.h

No idea how to easily fix the issue. Adding --cxxopt=-isystem/usr/include to build options does not help. It seems also that bazel sorts search paths alphabetically.

@actionless
Copy link

actionless commented Nov 5, 2024

when i running the command on pretty much default arch linux box,

    JAXLIB_RELEASE=$pkgver python build/build.py \
        --bazel_startup_options="--output_user_root=$srcdir/bazel"\
        --bazel_options='--action_env=JAXLIB_RELEASE' \
        --enable_cuda \
        --target_cpu_features=release

i don't have a problem with locating stdlib header but have another problem related to it:

ERROR: /home/lie/.cache/pikaur/build/python-jaxlib-cuda/src/jax-jaxlib-v0.4.32/jaxlib/cuda/BUILD:75:13: Compiling jaxlib/gpu/make_batch_pointers.cu.cc failed: (Exit 2): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target //jaxlib/cuda:cuda_make_batch_pointers)
  (cd /home/lie/.cache/pikaur/build/python-jaxlib-cuda/src/bazel/8a8ca1fd42886b1189093a9473f4da62/execroot/__main__ && \
  exec env - \
    CLANG_COMPILER_PATH=/usr/bin/clang-18 \
    CLANG_CUDA_COMPILER_PATH=/usr/bin/clang-18 \
    CUDA_TOOLKIT_PATH=/opt/cuda \
    GCC_HOST_COMPILER_PATH=/usr/bin/gcc-13 \
    JAXLIB_RELEASE=0.4.32 \
    NCCL_INSTALL_PATH=/usr \
    PATH=/usr/local/sbin:/usr/local/bin:/usr/bin \
    PWD=/proc/self/cwd \
    TF_CUDA_COMPUTE_CAPABILITIES=sm_70,sm_75,sm_80,sm_86,sm_89,sm_90,compute_90 \
    TF_CUDA_PATHS=/opt/cuda,/usr/lib,/usr \
    TF_NVCC_CLANG=1 \
  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/jaxlib/cuda/_objs/cuda_make_batch_pointers/make_batch_pointers.cu.pic.d '-frandom-seed=bazel-out/k8-opt/bin/jaxlib/cuda/_objs/cuda_make_batch_pointers/make_batch_pointers.cu.pic.o' '-DEIGEN_MAX_ALIGN_BYTES=64' -DEIGEN_ALLOW_UNALIGNED_SCALARS '-DEIGEN_USE_AVX512_GEMM_KERNELS=0' '-DJAX_GPU_CUDA=1' '-DBAZEL_CURRENT_REPOSITORY=""' -iquote . -iquote bazel-out/k8-opt/bin -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -iquote external/cuda_cudart -iquote bazel-out/k8-opt/bin/external/cuda_cudart -iquote external/cuda_cublas -iquote bazel-out/k8-opt/bin/external/cuda_cublas -iquote external/cuda_cccl -iquote bazel-out/k8-opt/bin/external/cuda_cccl -iquote external/cuda_nvtx -iquote bazel-out/k8-opt/bin/external/cuda_nvtx -iquote external/cuda_nvcc -iquote bazel-out/k8-opt/bin/external/cuda_nvcc -iquote external/cuda_cusolver -iquote bazel-out/k8-opt/bin/external/cuda_cusolver -iquote external/cuda_cufft -iquote bazel-out/k8-opt/bin/external/cuda_cufft -iquote external/cuda_cusparse -iquote bazel-out/k8-opt/bin/external/cuda_cusparse -iquote external/cuda_curand -iquote bazel-out/k8-opt/bin/external/cuda_curand -iquote external/cuda_cupti -iquote bazel-out/k8-opt/bin/external/cuda_cupti -iquote external/cuda_nvml -iquote bazel-out/k8-opt/bin/external/cuda_nvml -iquote external/cuda_nvjitlink -iquote bazel-out/k8-opt/bin/external/cuda_nvjitlink -iquote external/cuda_cudnn -iquote bazel-out/k8-opt/bin/external/cuda_cudnn -iquote external/xla -iquote bazel-out/k8-opt/bin/external/xla -iquote external/tsl -iquote bazel-out/k8-opt/bin/external/tsl -iquote external/eigen_archive -iquote bazel-out/k8-opt/bin/external/eigen_archive -iquote external/ml_dtypes -iquote bazel-out/k8-opt/bin/external/ml_dtypes -iquote external/com_google_absl -iquote bazel-out/k8-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/k8-opt/bin/external/nsync -iquote external/double_conversion -iquote bazel-out/k8-opt/bin/external/double_conversion -iquote external/com_google_protobuf -iquote bazel-out/k8-opt/bin/external/com_google_protobuf -iquote external/zlib -iquote bazel-out/k8-opt/bin/external/zlib -iquote external/local_config_rocm -iquote bazel-out/k8-opt/bin/external/local_config_rocm -iquote external/local_config_tensorrt -iquote bazel-out/k8-opt/bin/external/local_config_tensorrt -iquote external/nccl_archive -iquote bazel-out/k8-opt/bin/external/nccl_archive -Ibazel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers -Ibazel-out/k8-opt/bin/external/cuda_cudart/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cublas/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cccl/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvtx/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvcc/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cusolver/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cufft/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cusparse/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_curand/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cupti/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvml/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvjitlink/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cudnn/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/ml_dtypes/_virtual_includes/float8 -Ibazel-out/k8-opt/bin/external/ml_dtypes/_virtual_includes/intn -Ibazel-out/k8-opt/bin/external/local_config_tensorrt/_virtual_includes/tensorrt_headers -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/nccl_config -isystem external/local_config_cuda/cuda -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/cuda_cudart/include -isystem bazel-out/k8-opt/bin/external/cuda_cudart/include -isystem external/cuda_cublas/include -isystem bazel-out/k8-opt/bin/external/cuda_cublas/include -isystem external/cuda_cccl/include -isystem bazel-out/k8-opt/bin/external/cuda_cccl/include -isystem external/cuda_nvtx/include -isystem bazel-out/k8-opt/bin/external/cuda_nvtx/include -isystem external/cuda_nvcc/include -isystem bazel-out/k8-opt/bin/external/cuda_nvcc/include -isystem external/cuda_cusolver/include -isystem bazel-out/k8-opt/bin/external/cuda_cusolver/include -isystem external/cuda_cufft/include -isystem bazel-out/k8-opt/bin/external/cuda_cufft/include -isystem external/cuda_cusparse/include -isystem bazel-out/k8-opt/bin/external/cuda_cusparse/include -isystem external/cuda_curand/include -isystem bazel-out/k8-opt/bin/external/cuda_curand/include -isystem external/cuda_cupti/include -isystem bazel-out/k8-opt/bin/external/cuda_cupti/include -isystem external/cuda_nvml/include -isystem bazel-out/k8-opt/bin/external/cuda_nvml/include -isystem external/cuda_nvjitlink/include -isystem bazel-out/k8-opt/bin/external/cuda_nvjitlink/include -isystem external/cuda_cudnn/include -isystem bazel-out/k8-opt/bin/external/cuda_cudnn/include -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/eigen_archive/mkl_include -isystem bazel-out/k8-opt/bin/external/eigen_archive/mkl_include -isystem external/ml_dtypes -isystem bazel-out/k8-opt/bin/external/ml_dtypes -isystem external/ml_dtypes/ml_dtypes -isystem bazel-out/k8-opt/bin/external/ml_dtypes/ml_dtypes -isystem external/nsync/public -isystem bazel-out/k8-opt/bin/external/nsync/public -isystem external/com_google_protobuf/src -isystem bazel-out/k8-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/k8-opt/bin/external/zlib -isystem external/local_config_rocm/rocm -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm -isystem external/local_config_rocm/rocm/rocm/include -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include -isystem external/local_config_rocm/rocm/rocm/include/rocrand -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include/rocrand -isystem external/local_config_rocm/rocm/rocm/include/roctracer -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include/roctracer -fmerge-all-constants -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '--cuda-path=external/cuda_nvcc' '-fvisibility=hidden' -Wno-sign-compare -Wno-unknown-warning-option -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' '-Wno-error=unused-command-line-argument' -Wno-gnu-offsetof-extensions -mavx -Wno-gnu-offsetof-extensions -Qunused-arguments '-std=c++17' -x cuda '-DGOOGLE_CUDA=1' '--no-cuda-include-ptx=all' '--cuda-gpu-arch=sm_50' '--cuda-gpu-arch=sm_60' '--cuda-gpu-arch=sm_70' '--cuda-gpu-arch=sm_80' '--cuda-include-ptx=sm_90' '--cuda-gpu-arch=sm_90' '-Xcuda-fatbinary=--compress-all' '-nvcc_options=expt-relaxed-constexpr' -c jaxlib/gpu/make_batch_pointers.cu.cc -o bazel-out/k8-opt/bin/jaxlib/cuda/_objs/cuda_make_batch_pointers/make_batch_pointers.cu.pic.o)
# Configuration: 40116f3bac97303e7dcbac2b0176d8c2300ec77420ba81e4743a9a16e63d74ec
# Execution platform: @local_execution_config_platform//:platform
/home/lie/.cache/pikaur/build/python-jaxlib-cuda/src/bazel/8a8ca1fd42886b1189093a9473f4da62/execroot/__main__/external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc:225: SyntaxWarning: invalid escape sequence '\.'
  re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
/usr/include/bits/stdlib.h(37): error: linkage specification is incompatible with previous "realpath" (declared at line 940 of /usr/include/stdlib.h)
   realpath (const char *__restrict __name, char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __resolved) noexcept (true)
   ^

/usr/include/bits/stdlib.h(72): error: linkage specification is incompatible with previous "ptsname_r" (declared at line 1134 of /usr/include/stdlib.h)
   ptsname_r (int __fd, char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __buflen) noexcept (true)
   ^

/usr/include/bits/stdlib.h(91): error: linkage specification is incompatible with previous "wctomb" (declared at line 1069 of /usr/include/stdlib.h)
   wctomb (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __s, wchar_t __wchar) noexcept (true)
   ^

/usr/include/bits/stdlib.h(129): error: linkage specification is incompatible with previous "mbstowcs" (declared at line 1073 of /usr/include/stdlib.h)
   mbstowcs (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const char *__restrict __src, size_t __len) noexcept (true)
   ^

/usr/include/bits/stdlib.h(159): error: linkage specification is incompatible with previous "wcstombs" (declared at line 1077 of /usr/include/stdlib.h)
   wcstombs (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const wchar_t *__restrict __src, size_t __len) noexcept (true)
   ^

/usr/include/bits/string_fortified.h(77): error: linkage specification is incompatible with previous "strcpy" (declared at line 141 of /usr/include/string.h)
   strcpy (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src) noexcept (true)
   ^

/usr/include/bits/string_fortified.h(86): error: linkage specification is incompatible with previous "stpcpy" (declared at line 491 of /usr/include/string.h)
   stpcpy (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src) noexcept (true)
   ^

/usr/include/bits/string_fortified.h(96): error: linkage specification is incompatible with previous "strncpy" (declared at line 144 of /usr/include/string.h)
   strncpy (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src, size_t __len) noexcept (true)
   ^

/usr/include/bits/string_fortified.h(107): error: linkage specification is incompatible with previous "stpncpy" (declared at line 499 of /usr/include/string.h)
   stpncpy (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__src, size_t __n) noexcept (true)
   ^

/usr/include/bits/string_fortified.h(136): error: linkage specification is incompatible with previous "strcat" (declared at line 149 of /usr/include/string.h)
   strcat (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src) noexcept (true)
   ^

/usr/include/bits/string_fortified.h(145): error: linkage specification is incompatible with previous "strncat" (declared at line 152 of /usr/include/string.h)
   strncat (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src, size_t __len) noexcept (true)
   ^

/usr/include/bits/string_fortified.h(161): error: linkage specification is incompatible with previous "strlcpy" (declared at line 506 of /usr/include/string.h)
   strlcpy (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src, size_t __n) noexcept (true)
   ^

/usr/include/bits/string_fortified.h(179): error: linkage specification is incompatible with previous "strlcat" (declared at line 512 of /usr/include/string.h)
   strlcat (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src, size_t __n) noexcept (true)
   ^

/usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/tuple(2962): error: type name is not allowed
     static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
                                                          ^

/usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/tuple(2962): error: type name is not allowed
     static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
                                                               ^

/usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/tuple(2962): error: identifier "__reference_constructs_from_temporary" is undefined
     static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
                    ^

/usr/include/bits/wchar2.h(24): error: linkage specification is incompatible with previous "wmemcpy" (declared at line 287 of /usr/include/wchar.h)
   wmemcpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s1, const wchar_t *__restrict __s2, size_t __n) noexcept (true)
   ^

/usr/include/bits/wchar2.h(36): error: linkage specification is incompatible with previous "wmemmove" (declared at line 292 of /usr/include/wchar.h)
   wmemmove (wchar_t * const __attribute__ ((__pass_object_size__ (1 > 1))) __s1, const wchar_t *__s2, size_t __n) noexcept (true)
   ^

/usr/include/bits/wchar2.h(49): error: linkage specification is incompatible with previous "wmempcpy" (declared at line 301 of /usr/include/wchar.h)
   wmempcpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s1, const wchar_t *__restrict __s2, size_t __n) noexcept (true)
   ^

/usr/include/bits/wchar2.h(62): error: linkage specification is incompatible with previous "wmemset" (declared at line 296 of /usr/include/wchar.h)
   wmemset (wchar_t * const __attribute__ ((__pass_object_size__ (1 > 1))) __s, wchar_t __c, size_t __n) noexcept (true)
   ^

/usr/include/bits/wchar2.h(74): error: linkage specification is incompatible with previous "wcscpy" (declared at line 98 of /usr/include/wchar.h)
   wcscpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src) noexcept (true)
   ^

/usr/include/bits/wchar2.h(84): error: linkage specification is incompatible with previous "wcpcpy" (declared at line 689 of /usr/include/wchar.h)
   wcpcpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src) noexcept (true)
   ^

/usr/include/bits/wchar2.h(94): error: linkage specification is incompatible with previous "wcsncpy" (declared at line 103 of /usr/include/wchar.h)
   wcsncpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src, size_t __n) noexcept (true)
   ^

/usr/include/bits/wchar2.h(106): error: linkage specification is incompatible with previous "wcpncpy" (declared at line 694 of /usr/include/wchar.h)
   wcpncpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src, size_t __n) noexcept (true)
   ^

/usr/include/bits/wchar2.h(118): error: linkage specification is incompatible with previous "wcscat" (declared at line 121 of /usr/include/wchar.h)
   wcscat (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src) noexcept (true)
   ^

/usr/include/bits/wchar2.h(128): error: linkage specification is incompatible with previous "wcsncat" (declared at line 125 of /usr/include/wchar.h)
   wcsncat (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src, size_t __n) noexcept (true)
   ^

/usr/include/bits/wchar2.h(139): error: linkage specification is incompatible with previous "wcslcpy" (declared at line 109 of /usr/include/wchar.h)
   wcslcpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src, size_t __n) noexcept (true)
   ^

/usr/include/bits/wchar2.h(155): error: linkage specification is incompatible with previous "wcslcat" (declared at line 115 of /usr/include/wchar.h)
   wcslcat (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src, size_t __n) noexcept (true)
   ^

/usr/include/bits/wchar2.h(254): error: linkage specification is incompatible with previous "fgetws" (declared at line 964 of /usr/include/wchar.h)
  fgetws (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s, int __n,
  ^

/usr/include/bits/wchar2.h(272): error: linkage specification is incompatible with previous "fgetws_unlocked" (declared at line 1026 of /usr/include/wchar.h)
  fgetws_unlocked (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s,
  ^

/usr/include/bits/wchar2.h(291): error: linkage specification is incompatible with previous "wcrtomb" (declared at line 326 of /usr/include/wchar.h)
   wcrtomb (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s, wchar_t __wchar, mbstate_t *__restrict __ps) noexcept (true)
   ^

/usr/include/bits/wchar2.h(308): error: linkage specification is incompatible with previous "mbsrtowcs" (declared at line 362 of /usr/include/wchar.h)
   mbsrtowcs (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const char **__restrict __src, size_t __len, mbstate_t *__restrict __ps) noexcept (true)
   ^

/usr/include/bits/wchar2.h(321): error: linkage specification is incompatible with previous "wcsrtombs" (declared at line 368 of /usr/include/wchar.h)
   wcsrtombs (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const wchar_t **__restrict __src, size_t __len, mbstate_t *__restrict __ps) noexcept (true)
   ^

/usr/include/bits/wchar2.h(336): error: linkage specification is incompatible with previous "mbsnrtowcs" (declared at line 376 of /usr/include/wchar.h)
   mbsnrtowcs (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const char **__restrict __src, size_t __nmc, size_t __len, mbstate_t *__restrict __ps) noexcept (true)
   ^

/usr/include/bits/wchar2.h(349): error: linkage specification is incompatible with previous "wcsnrtombs" (declared at line 382 of /usr/include/wchar.h)
   wcsnrtombs (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const wchar_t **__restrict __src, size_t __nwc, size_t __len, mbstate_t *__restrict __ps) noexcept (true)
   ^

/usr/include/bits/unistd.h(26): error: linkage specification is incompatible with previous "read" (declared at line 371 of /usr/include/unistd.h)
  read (int __fd, void * const __attribute__ ((__pass_object_size__ (0))) __buf, size_t __nbytes)
  ^

/usr/include/bits/unistd.h(40): error: linkage specification is incompatible with previous "pread" (declared at line 389 of /usr/include/unistd.h)
  pread (int __fd, void * const __attribute__ ((__pass_object_size__ (0))) __buf,
  ^

/usr/include/bits/unistd.h(66): error: linkage specification is incompatible with previous "pread64" (declared at line 422 of /usr/include/unistd.h)
  pread64 (int __fd, void * const __attribute__ ((__pass_object_size__ (0))) __buf,
  ^

/usr/include/bits/unistd.h(81): error: linkage specification is incompatible with previous "readlink" (declared at line 838 of /usr/include/unistd.h)
   readlink (const char *__restrict __path, char * __restrict const __attribute__ ((__pass_object_size__ (0))) __buf, size_t __len) noexcept (true)
   ^

/usr/include/bits/unistd.h(97): error: linkage specification is incompatible with previous "readlinkat" (declared at line 851 of /usr/include/unistd.h)
   readlinkat (int __fd, const char *__restrict __path, char * __restrict const __attribute__ ((__pass_object_size__ (0))) __buf, size_t __len) noexcept (true)
   ^

/usr/include/bits/unistd.h(111): error: linkage specification is incompatible with previous "getcwd" (declared at line 531 of /usr/include/unistd.h)
   getcwd (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __size) noexcept (true)
   ^

/usr/include/bits/unistd.h(124): error: linkage specification is incompatible with previous "getwd" (declared at line 545 of /usr/include/unistd.h)
   getwd (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf) noexcept (true)
   ^

/usr/include/bits/unistd.h(133): error: linkage specification is incompatible with previous "confstr" (declared at line 644 of /usr/include/unistd.h)
   confstr (int __name, char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __len) noexcept (true)
   ^

/usr/include/bits/unistd.h(146): error: linkage specification is incompatible with previous "getgroups" (declared at line 711 of /usr/include/unistd.h)
   getgroups (int __size, __gid_t * const __attribute__ ((__pass_object_size__ (1 > 1))) __list) noexcept (true)
   ^

/usr/include/bits/unistd.h(160): error: linkage specification is incompatible with previous "ttyname_r" (declared at line 803 of /usr/include/unistd.h)
   ttyname_r (int __fd, char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __buflen) noexcept (true)
   ^

/usr/include/bits/unistd.h(175): error: linkage specification is incompatible with previous "getlogin_r" (declared at line 889 of /usr/include/unistd.h)
  getlogin_r (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __buflen)
  ^

/usr/include/bits/unistd.h(189): error: linkage specification is incompatible with previous "gethostname" (declared at line 911 of /usr/include/unistd.h)
   gethostname (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __buflen) noexcept (true)
   ^

/usr/include/bits/unistd.h(204): error: linkage specification is incompatible with previous "getdomainname" (declared at line 930 of /usr/include/unistd.h)
   getdomainname (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __buflen) noexcept (true)
   ^

/usr/include/bits/stdio2.h(55): error: linkage specification is incompatible with previous "vsprintf" (declared at line 380 of /usr/include/stdio.h)
   vsprintf (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s, const char *__restrict __fmt, __gnuc_va_list __ap) noexcept (true)
   ^

/usr/include/bits/stdio2.h(93): error: linkage specification is incompatible with previous "vsnprintf" (declared at line 389 of /usr/include/stdio.h)
   vsnprintf (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s, size_t __n, const char *__restrict __fmt, __gnuc_va_list __ap) noexcept (true)
   ^

/usr/include/bits/stdio2.h(305): error: linkage specification is incompatible with previous "fgets" (declared at line 654 of /usr/include/stdio.h)
  fgets (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s, int __n,
  ^

/usr/include/bits/stdio2.h(322): error: linkage specification is incompatible with previous "fread" (declared at line 728 of /usr/include/stdio.h)
  fread (void * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __ptr,
  ^

/usr/include/bits/stdio2.h(342): error: linkage specification is incompatible with previous "fgets_unlocked" (declared at line 677 of /usr/include/stdio.h)
  fgets_unlocked (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s,
  ^

/usr/include/bits/stdio2.h(362): error: linkage specification is incompatible with previous "fread_unlocked" (declared at line 756 of /usr/include/stdio.h)
  fread_unlocked (void * __restrict const __attribute__ ((__pass_object_size__ (0))) __ptr,
  ^

54 errors detected in the compilation of "jaxlib/gpu/make_batch_pointers.cu.cc".
Target //jaxlib/tools:build_wheel failed to build

@ybaturina
Copy link
Contributor

Can you try this command please?

python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --clang_path=<absolute clang compiler path> --use_cuda_nvcc=false

Please note that clang_path should be a real path, not a symlink.

We are planning to update instructions how to build JAX from source.

@actionless
Copy link

then i have the same problem as daskol :

clang-18: error: cannot find CUDA installation; provide its path via '--cuda-path', or pass '-nocudainc' to build without CUDA i
ncludes
Target //jaxlib/tools:build_wheel failed to build

if adding:

	#--bazel_options='--repo_env=LOCAL_CUDA_PATH=/opt/cuda' \

message about missing "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"

i've also tried several hacks, like

mkdir -p 'third_party/gpus/'
ln -s /opt/cuda/ 'third_party/gpus/cuda'

and rebuilding after cleaning bazel cache - but still having either one or another of two above error messages

@ybaturina
Copy link
Contributor

May I ask you to describe your use case? Why is it necessary to use the local CUDA path in your scenario?
Also would you attach the full log for running builds script please?

copybara-service bot pushed a commit to google/tsl that referenced this issue Nov 8, 2024
…ath.

`cc.endswith("clang")` ddidn't work for the cases when the clang compiler path is like `/usr/bin/clang-18`.

This change addresses [Github issue](jax-ml/jax#23689).

PiperOrigin-RevId: 693735256
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 8, 2024
…ath.

`cc.endswith("clang")` ddidn't work for the cases when the clang compiler path is like `/usr/bin/clang-18`.

This change addresses [Github issue](jax-ml/jax#23689).

PiperOrigin-RevId: 693735256
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Nov 8, 2024
…ath.

`cc.endswith("clang")` ddidn't work for the cases when the clang compiler path is like `/usr/bin/clang-18`.

This change addresses [Github issue](jax-ml/jax#23689).

PiperOrigin-RevId: 693735256
copybara-service bot pushed a commit to google/tsl that referenced this issue Nov 8, 2024
…ath.

`cc.endswith("clang")` ddidn't work for the cases when the clang compiler path is like `/usr/bin/clang-18`.

This change addresses [Github issue](jax-ml/jax#23689).

PiperOrigin-RevId: 694536448
copybara-service bot pushed a commit to openxla/xla that referenced this issue Nov 8, 2024
…ath.

`cc.endswith("clang")` ddidn't work for the cases when the clang compiler path is like `/usr/bin/clang-18`.

This change addresses [Github issue](jax-ml/jax#23689).

PiperOrigin-RevId: 694536448
@ybaturina
Copy link
Contributor

openxla/xla#19113 is merged now.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Nov 8, 2024
…ath.

`cc.endswith("clang")` ddidn't work for the cases when the clang compiler path is like `/usr/bin/clang-18`.

This change addresses [Github issue](jax-ml/jax#23689).

PiperOrigin-RevId: 694536448
@adamjstewart
Copy link
Contributor Author

Tried upgrading to jaxlib 0.4.37 but now I have a different error even earlier in the build: #25488

@adamjstewart
Copy link
Contributor Author

Okay, on jaxlib 0.4.38 with the patch from #25531 applied, the GCC x86_64 CPU-only build works, but the GCC x86_64 CUDA build still fails. The error message is now different:

ERROR: /root/.cache/bazel/_bazel_root/7ac32083e9ddf32a54f80bfb0375dcee/external/xla/xla/stream_executor/cuda/BUILD:192:11: Compiling xla/stream_executor/cuda/cuda_status.cc failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target @xla//xla/stream_executor/cuda:cuda_status)
...
In file included from external/xla/xla/stream_executor/cuda/cuda_status.cc:16:
external/xla/xla/stream_executor/cuda/cuda_status.h:22:10: fatal error: third_party/gpus/cuda/include/cuda.h: No such file or directory
   22 | #include "third_party/gpus/cuda/include/cuda.h"
      |          ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
Target //jaxlib/tools:build_gpu_kernels_wheel failed to build

I'm not sure why it's hard-coded to search for cuda.h in third_party/gpus/cuda. I'm trying to use an externally installed CUDA. Maybe I'm missing a magic flag? Here is the build log. I could also try manually building XLA instead of relying on the vendored copy if that's the recommended approach.

@ybaturina
Copy link
Contributor

ybaturina commented Dec 20, 2024

Hi @adamjstewart , would you check if you have the symlink include in the following path please?:
/root/.cache/bazel/_bazel_root/7ac32083e9ddf32a54f80bfb0375dcee/external/cuda_cudart/include

It should point to include dir in your local CUDA installation (which should have cuda.h).

Also the content of /root/.cache/bazel/_bazel_root/7ac32083e9ddf32a54f80bfb0375dcee/external/cuda_cudart/BUILD should have cc_headers target which is not commented out (see the template here)

The reason why it searches for third_party/gpus/cuda is that this prefix is defined in include_prefix.
https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl#L128

In the log I can see that cuda_cudart/include is passed to NVCC compiler:

external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/external/xla/xla/stream_executor/cuda/_objs/cuda_status/cuda_status.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/xla/xla/stream_executor/cuda/_objs/cuda_status/cuda_status.pic.o' '-DBAZEL_CURRENT_REPOSITORY="xla"' -iquote external/xla -iquote bazel-out/k8-opt/bin/external/xla -iquote external/com_google_absl -iquote bazel-out/k8-opt/bin/external/com_google_absl -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -iquote external/cuda_cudart -iquote bazel-out/k8-opt/bin/external/cuda_cudart -iquote external/cuda_cublas -iquote bazel-out/k8-opt/bin/external/cuda_cublas -iquote external/cuda_cccl -iquote bazel-out/k8-opt/bin/external/cuda_cccl -iquote external/cuda_nvtx -iquote bazel-out/k8-opt/bin/external/cuda_nvtx -iquote external/cuda_nvcc -iquote bazel-out/k8-opt/bin/external/cuda_nvcc -iquote external/cuda_cusolver -iquote bazel-out/k8-opt/bin/external/cuda_cusolver -iquote external/cuda_cufft -iquote bazel-out/k8-opt/bin/external/cuda_cufft -iquote external/cuda_cusparse -iquote bazel-out/k8-opt/bin/external/cuda_cusparse -iquote external/cuda_curand -iquote bazel-out/k8-opt/bin/external/cuda_curand -iquote external/cuda_cupti -iquote bazel-out/k8-opt/bin/external/cuda_cupti -iquote external/cuda_nvml -iquote bazel-out/k8-opt/bin/external/cuda_nvml -iquote external/cuda_nvjitlink -iquote bazel-out/k8-opt/bin/external/cuda_nvjitlink -iquote external/tsl -iquote bazel-out/k8-opt/bin/external/tsl -Ibazel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers -Ibazel-out/k8-opt/bin/external/cuda_cudart/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cublas/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cccl/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvtx/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvcc/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cusolver/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cufft/virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cusparse/virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_curand/virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cupti/virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvml/virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvjitlink/virtual_includes/headers -isystem external/local_config_cuda/cuda -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/cuda_cudart/include -isystem bazel-out/k8-opt/bin/external/cuda_cudart/include -isystem external/cuda_cublas/include -isystem bazel-out/k8-opt/bin/external/cuda_cublas/include -isystem external/cuda_cccl/include -isystem bazel-out/k8-opt/bin/external/cuda_cccl/include -isystem external/cuda_nvtx/include -isystem bazel-out/k8-opt/bin/external/cuda_nvtx/include -isystem external/cuda_nvcc/include -isystem bazel-out/k8-opt/bin/external/cuda_nvcc/include -isystem external/cuda_cusolver/include -isystem bazel-out/k8-opt/bin/external/cuda_cusolver/include -isystem external/cuda_cufft/include -isystem bazel-out/k8-opt/bin/external/cuda_cufft/include -isystem external/cuda_cusparse/include -isystem bazel-out/k8-opt/bin/external/cuda_cusparse/include -isystem external/cuda_curand/include -isystem bazel-out/k8-opt/bin/external/cuda_curand/include -isystem external/cuda_cupti/include -isystem bazel-out/k8-opt/bin/external/cuda_cupti/include -isystem external/cuda_nvml/include -isystem bazel-out/k8-opt/bin/external/cuda_nvml/include -isystem external/cuda_nvjitlink/include -isystem bazel-out/k8-opt/bin/external/cuda_nvjitlink/include -Wno-builtin-macro-redefined '-D__DATE="redacted"' '-D__TIMESTAMP="redacted"' '-D__TIME="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '-fvisibility=hidden' -Wno-sign-compare -Wno-unknown-warning-option -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' -mavx -mavx '-std=c++17' -c external/xla/xla/stream_executor/cuda/cuda_status.cc -o bazel-out/k8-opt/bin/external/xla/xla/stream_executor/cuda/_objs/cuda_status/cuda_status.pic.o)

@adamjstewart
Copy link
Contributor Author

This is in CI so I can't easily reproduce in that specific environment, let me try building locally and get back to you.

@adamjstewart
Copy link
Contributor Author

Okay, I was able to reproduce the build failure locally. Here is the new build log for posterity.

Yes, the <bazel>/external/cuda_cudart/include symlink exists, and it points to the <cuda>/include directory, which does contain cuda.h.

However, the cc_library headers you speak of is mostly commented out:

cc_library(
    name = "headers",
    #hdrs = glob([
        #...
        #"include/cuda.h",
        #...
    #]),
    include_prefix = "third_party/gpus/cuda/include",
    includes = ["include"],
    strip_include_prefix = "include",
    visibility = ["@local_config_cuda//cuda:__pkg__"],
)

I'm guessing this is the source of the issue. Any idea why this would be commented out?

@ybaturina
Copy link
Contributor

The lines can be commented out when the repository rule can't find the .so files corresponding to CUDA redistributions.
https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl#L182

In particular, if libcudart.so.{major_version} is not found in external/cuda_cudart/lib, some lines in external/cuda_cudart/BUILD will be commented out.

Can you check this path please? Does it exist?
/root/.cache/bazel/_bazel_root/7ac32083e9ddf32a54f80bfb0375dcee/external/cuda_cudart/lib/libcudart.so.12

@adamjstewart
Copy link
Contributor Author

Okay, here is the issue. The external/cuda_cudart/lib symlink exists and points to <cuda>/lib. However, <cuda>/lib does not exist, it should be <cuda>/lib64. This is yet another symlink to <cuda>/targets/x86_64-linux/lib. That directory does contain libcudart.so.*. So we either add symlinks to both lib and lib64, or we check for the presence of both and symlink the correct one.

@ybaturina
Copy link
Contributor

The usage of local installations is not recommended in general. The guidance for the folders structures is provided in this paragraph.

<LOCAL_CUDA_PATH>/
    include/
    bin/
    lib/
    nvvm/

This structure corresponds to the redistributions content that can be downloaded from NVIDIA source.

I suggest adding a symlink lib in your local installation.

@adamjstewart
Copy link
Contributor Author

adamjstewart commented Dec 21, 2024

Ah, my problem is that I'm downloading and installing CUDA using the official runfiles provided by https://developer.nvidia.com/cuda-downloads, which seems to default to lib64 on my system. It would be nice if XLA could also support the runfile layout instead of only the redistribution layout, but I can bring this up with them another day. Confirming whether or not adding this symlink solves my issue and I will get back to you.

@adamjstewart
Copy link
Contributor Author

And we're on to the next issue:

ERROR: /tmp/adam/spack-stage/spack-stage-py-jaxlib-0.4.38-rslu55c6zbasrg6zeizf3p6cs7q3he53/spack-src/jaxlib/cuda/BUILD:75:13: Compiling jaxlib/gpu/make_batch_pointers.cu.cc failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target //jaxlib/cuda:cuda_make_batch_pointers)
...
In file included from ./jaxlib/gpu/make_batch_pointers.h:21,
                 from jaxlib/gpu/make_batch_pointers.cu.cc:16:
./jaxlib/gpu/vendor.h:26:10: fatal error: third_party/gpus/cuda/extras/CUPTI/include/cupti.h: No such file or directory
   26 | #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
      |          ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
[2,803 / 3,725] checking cached actions
Target //jaxlib/tools:build_gpu_kernels_wheel failed to build
INFO: Elapsed time: 74.382s, Critical Path: 15.13s
INFO: 1150 processes: 162 internal, 988 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target

As expected, cuda_cupti/BUILD has the same issue:

cc_library(
    name = "headers",
    #hdrs = glob([
        #...
        #"include/cupti.h",
        #...
    #),
    include_prefix = "third_party/gpus/cuda/extras/CUPTI/include",
    includes = ["include/"],
    strip_include_prefix = "include",
    visibility = ["@local_config_cuda//cuda:__pkg__"],
)

Looks like the exact same issue, there is a extras/CUPTI/lib64 folder but no extras/CUPTI/lib in my installation. Will try symlinking and rebuilding.

@ybaturina
Copy link
Contributor

Please note that libcupti.so.12 should be located in <local_cuda_path>/lib dir.

@adamjstewart
Copy link
Contributor Author

Oh. Well that's not where it lives. Am I supposed to symlink every single file in my CUDA installation to get XLA building, or should I just change XLA to support the default CUDA installation scheme?

@ybaturina
Copy link
Contributor

Hi @adamjstewart . If there are .so files used by repository rules, and they are not located in <local_cuda_path>/lib, then yes, you need to create symlinks for them.
Please note that local path for CUDA redistributions is not recommended, and you need to use standard flow if possible.

@adamjstewart
Copy link
Contributor Author

I'm using the standard flow generated by the official CUDA runfile. I tried adding symlinks in many places but still can't manage to get things to build. May open an issue with XLA if I have time.

@daskol
Copy link
Contributor

daskol commented Dec 24, 2024

FYI I am still having the issue with building jax-cuda-plugin with local CUDA with JAX v0.4.38. An order of system includes and include_next directive causes wrong inclusing of cmath header and there are no CUPTI issues (compiler is Clang v18.1 and standard library is provided by GCC v14.1).

@ybaturina What do you think about renaming the issue in order to accurately reflect building issues with local CUDA across multiple JAX versions?

@adamjstewart adamjstewart changed the title jaxlib 0.4.32, external CUDA, gcc: error: unrecognized command-line option '--cuda-path=external/cuda_nvcc' Build issues with local CUDA installation Dec 24, 2024
@vam-google
Copy link
Contributor

I don't think building jax with gcc 14 standard library works as of now. You may pass --action_env=CCC_OVERRIDE_OPTIONS="^--gcc-install-dir=/usr/lib/gcc/x86_64-linux-gnu/13" to bazel to override gcc selection by clang (in the example above it downgrades it to 13, assuming it is located in /usr/lib/gcc/x86_64-linux-gnu/13 on your machine).

Regarding using local cuda installation for builds: the only reason why it still exists is to support workflow when you need to test two unreleased versions of the components together: unreleased jax and unreleased cuda, i.e. it is mainly for nvidia developers who work on cuda directly. Unless you are actually developing cuda stuff, there should be no reason to ever depend on local installation, it will just make life harder for you, please do not use it unless you have a very very specific reason to.

@daskol
Copy link
Contributor

daskol commented Dec 26, 2024

there should be no reason to ever depend on local installation, it will just make life harder for you, please do not use it unless you have a very very specific reason to.

The issue is that JAX is distributed mainly as Python wheels but it gets in contradiction to Linux distributions like Arch or Nix which strive to vendor all dependencies for better reproducibility and other reasons. From my perspective, it is indeed strange to have multiple CUDA runtimes on one system.

@adamjstewart
Copy link
Contributor Author

there should be no reason to ever depend on local installation, it will just make life harder for you, please do not use it unless you have a very very specific reason to.

The issue is that JAX is distributed mainly as Python wheels but it gets in contradiction to Linux distributions like Arch or Nix which strive to vendor all dependencies for better reproducibility and other reasons. From my perspective, it is indeed strange to have multiple CUDA runtimes on one system.

My "very very specific reason" is that I'm building for a supercomputer where CUDA has already been installed by the system administrators and I don't want to install an incompatible version. I'm also packaging JAX in a package manager (Spack) so that users can install the entire software stack they need without installing multiple incompatible versions of CUDA. Packages that automagically install their own dependencies are a fundamental challenge for any secure system, especially on an air-gapped network where the only source available is what is manually copied to the server.

@vam-google
Copy link
Contributor

vam-google commented Dec 27, 2024

@daskol @adamjstewart

tl;dr; Nothing is being installed on your system during build.

With hermetic cuda nothing is being installed on your system and this is the main point here. It is being pulled in isolated bazel cache during your build together with many other unrelated dependencies (which you have always been pulling during the build with or without cuda), keeping your machine's environment clean and intact. All build dependencies, including cuda are pulled and checked against their sha256 sums from trusted sources (cuda in particular is pulled from official nvidia source).

JAX's build is complex, and cuda dependencies are on the more complex side of it. Having a specific version of cuda on your system and trying to build against it will not make your build more robust (it will have exactly opposite effect), as there are many other "players" in the game: compatibility of JAX itself with a specific cuda version, wiring of cuda headers (compile time) and libraries (linking time) to the rest of the bulid, packaging of the final wheel etc and all the custom bazel logic which makes that possible and makes assumptions about what exactly is in your cuda deps.

especially on an air-gapped network where the only source available is what is manually copied to the server.

There are many other dependencies besides cuda that get downloaded during the build, so in air-gapped case it would either still not work with or withut cuda, or if you already provide many dozens of other dependencies from custom https source, just add a dozen of cuda deps to the already long list of other deps you are already providing.

In other words, non-hermetic cuda dependencies have always been non-idiomatic, nightmare to maintain, non-secure (it actually used to crawl your system to figure out what is where on your machine) extremely fragile enormous build hack which we finally fixed. Please let it go. Controlling your own installation does not let you control the build, as you still rely on all the custom wiring around it which makes assumptions about what it wires together. Nothing is being installed on your machine during build.

@apivovarov
Copy link
Contributor

I don't think building jax with gcc 14 standard library works as of now. You may pass --action_env=CCC_OVERRIDE_OPTIONS="^--gcc-install-dir=/usr/lib/gcc/x86_64-linux-gnu/13" to bazel to override gcc selection by clang (in the example above it downgrades it to 13, assuming it is located in /usr/lib/gcc/x86_64-linux-gnu/13 on your machine).

Regarding using local cuda installation for builds: the only reason why it still exists is to support workflow when you need to test two unreleased versions of the components together: unreleased jax and unreleased cuda, i.e. it is mainly for nvidia developers who work on cuda directly. Unless you are actually developing cuda stuff, there should be no reason to ever depend on local installation, it will just make life harder for you, please do not use it unless you have a very very specific reason to.

Ubuntu 24.04 + Cuda-12-6 depend on gcc-13
I tried to build the latest xla - got similar issue

/usr/lib/gcc/x86_64-linux-gnu/13/../../../../include/c++/13/tuple(2335): error: identifier "__reference_constructs_from_temporary" is undefined
     static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
                    ^

openxla/xla#20915

@apivovarov
Copy link
Contributor

apivovarov commented Dec 31, 2024

We can resolve the __reference_constructs_from_temporary issue by switching the CUDA compiler from nvcc to clang-18.

For example, in XLA, you can add the --cuda_compiler CLANG flag to the configure.py command:

python3 configure.py --backend CUDA --cuda_compiler CLANG

Key Changes in xla_configure.bazelrc
before:

build --config cuda_nvcc

now:

build --config cuda_clang

For more details, see the related discussion:
openxla/xla#20915 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants