Skip to content

[CUDA][build] error: identifier "__reference_constructs_from_temporary" is undefined #20915

Open
@apivovarov

Description

TL;DR Solution

We can resolve the __reference_constructs_from_temporary issue by switching the CUDA compiler from nvcc to clang-18.

For example, in XLA, you can add the --cuda_compiler CLANG flag to the configure.py command:

python3 configure.py --backend CUDA --cuda_compiler CLANG

OS: Ubuntu 24.04
compiler: clang-18
cuda-12-6
branch: main 79ada3d Dec 27

To reproduce

python3 configure.py --backend CUDA

bazel build //xla/service/gpu/kernels:topk_kernel_gpu

Error:

ERROR: /home/ubuntu/workspace/xla/xla/service/gpu/kernels/BUILD:173:19: Compiling xla/service/gpu/kernels/topk_kernel_float.cu.cc failed: (Exit 2): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target //xla/service/gpu/kernels:topk_kernel_gpu_cuda) external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/xla/service/gpu/kernels/_objs/topk_kernel_gpu_cuda/topk_kernel_float.cu.pic.d ... (remaining 206 arguments skipped)
/home/ubuntu/.cache/bazel/_bazel_ubuntu/0176aac02e77df836cc7203737c26784/execroot/xla/external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc:227: SyntaxWarning: invalid escape sequence '\.'
  re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
/usr/lib/gcc/x86_64-linux-gnu/13/../../../../include/c++/13/tuple(2335): error: type name is not allowed
     static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
                                                          ^

/usr/lib/gcc/x86_64-linux-gnu/13/../../../../include/c++/13/tuple(2335): error: type name is not allowed
     static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
                                                               ^

/usr/lib/gcc/x86_64-linux-gnu/13/../../../../include/c++/13/tuple(2335): error: identifier "__reference_constructs_from_temporary" is undefined
     static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
                    ^

3 errors detected in the compilation of "xla/service/gpu/kernels/topk_kernel_float.cu.cc".
Target //xla/service/gpu/kernels:topk_kernel_gpu_cuda failed to build
SUBCOMMAND: # //xla/service/gpu/kernels:topk_kernel_gpu_cuda [action 'Compiling xla/service/gpu/kernels/topk_kernel_bfloat16.cu.cc', configuration: a23d58a7decd97872ba4fb28d7f36f33537d5e63c34655e2e999fb4c17de5668, execution platform: @local_execution_config_platform//:platform]
(cd /home/ubuntu/.cache/bazel/_bazel_ubuntu/0176aac02e77df836cc7203737c26784/execroot/xla && \
  exec env - \
    CLANG_COMPILER_PATH=/usr/lib/llvm-18/bin/clang \
    CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang \
    PATH=/home/ubuntu/.cache/bazelisk/downloads/sha256/a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307/bin:/home/ubuntu/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/usr/local/cuda/bin \
    PWD=/proc/self/cwd \
    PYTHON_BIN_PATH=/usr/bin/python3 \
    TF2_BEHAVIOR=1 \
    TF_NVCC_CLANG=1 \
  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/xla/service/gpu/kernels/_objs/topk_kernel_gpu_cuda/topk_kernel_bfloat16.cu.pic.d '-frandom-seed=bazel-out/k8-opt/bin/xla/service/gpu/kernels/_objs/topk_kernel_gpu_cuda/topk_kernel_bfloat16.cu.pic.o' '-DEIGEN_MAX_ALIGN_BYTES=64' -DEIGEN_ALLOW_UNALIGNED_SCALARS '-DEIGEN_USE_AVX512_GEMM_KERNELS=0' '-DGOOGLE_CUDA=1' '-DBAZEL_CURRENT_REPOSITORY=""' -iquote . -iquote bazel-out/k8-opt/bin -iquote external/com_google_absl -iquote bazel-out/k8-opt/bin/external/com_google_absl -iquote external/eigen_archive -iquote bazel-out/k8-opt/bin/external/eigen_archive -iquote external/tsl -iquote bazel-out/k8-opt/bin/external/tsl -iquote external/ml_dtypes -iquote bazel-out/k8-opt/bin/external/ml_dtypes -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -iquote external/cuda_cudart -iquote bazel-out/k8-opt/bin/external/cuda_cudart -iquote external/cuda_cublas -iquote bazel-out/k8-opt/bin/external/cuda_cublas -iquote external/cuda_cccl -iquote bazel-out/k8-opt/bin/external/cuda_cccl -iquote external/cuda_nvtx -iquote bazel-out/k8-opt/bin/external/cuda_nvtx -iquote external/cuda_nvcc -iquote bazel-out/k8-opt/bin/external/cuda_nvcc -iquote external/cuda_cusolver -iquote bazel-out/k8-opt/bin/external/cuda_cusolver -iquote external/cuda_cufft -iquote bazel-out/k8-opt/bin/external/cuda_cufft -iquote external/cuda_cusparse -iquote bazel-out/k8-opt/bin/external/cuda_cusparse -iquote external/cuda_curand -iquote bazel-out/k8-opt/bin/external/cuda_curand -iquote external/cuda_cupti -iquote bazel-out/k8-opt/bin/external/cuda_cupti -iquote external/cuda_nvml -iquote bazel-out/k8-opt/bin/external/cuda_nvml -iquote external/cuda_nvjitlink -iquote bazel-out/k8-opt/bin/external/cuda_nvjitlink -Ibazel-out/k8-opt/bin/external/ml_dtypes/_virtual_includes/float8 -Ibazel-out/k8-opt/bin/external/ml_dtypes/_virtual_includes/intn -Ibazel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers -Ibazel-out/k8-opt/bin/external/cuda_cudart/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cublas/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cccl/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvtx/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvcc/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cusolver/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cufft/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cusparse/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_curand/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cupti/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvml/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvjitlink/_virtual_includes/headers -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/eigen_archive/mkl_include -isystem bazel-out/k8-opt/bin/external/eigen_archive/mkl_include -isystem external/ml_dtypes -isystem bazel-out/k8-opt/bin/external/ml_dtypes -isystem external/ml_dtypes/ml_dtypes -isystem bazel-out/k8-opt/bin/external/ml_dtypes/ml_dtypes -isystem external/local_config_cuda/cuda -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/cuda_cudart/include -isystem bazel-out/k8-opt/bin/external/cuda_cudart/include -isystem external/cuda_cublas/include -isystem bazel-out/k8-opt/bin/external/cuda_cublas/include -isystem external/cuda_cccl/include -isystem bazel-out/k8-opt/bin/external/cuda_cccl/include -isystem external/cuda_nvtx/include -isystem bazel-out/k8-opt/bin/external/cuda_nvtx/include -isystem external/cuda_nvcc/include -isystem bazel-out/k8-opt/bin/external/cuda_nvcc/include -isystem external/cuda_cusolver/include -isystem bazel-out/k8-opt/bin/external/cuda_cusolver/include -isystem external/cuda_cufft/include -isystem bazel-out/k8-opt/bin/external/cuda_cufft/include -isystem external/cuda_cusparse/include -isystem bazel-out/k8-opt/bin/external/cuda_cusparse/include -isystem external/cuda_curand/include -isystem bazel-out/k8-opt/bin/external/cuda_curand/include -isystem external/cuda_cupti/include -isystem bazel-out/k8-opt/bin/external/cuda_cupti/include -isystem external/cuda_nvml/include -isystem bazel-out/k8-opt/bin/external/cuda_nvml/include -isystem external/cuda_nvjitlink/include -isystem bazel-out/k8-opt/bin/external/cuda_nvjitlink/include -fmerge-all-constants -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '--cuda-path=external/cuda_nvcc' -Wno-all -Wno-extra -Wno-deprecated -Wno-deprecated-declarations -Wno-ignored-attributes -Wno-array-bounds -Wunused-result '-Werror=unused-result' -Wswitch '-Werror=switch' -DAUTOLOAD_DYNAMIC_KERNELS -Wno-sign-compare '-Wno-error=unused-command-line-argument' -Wno-gnu-offsetof-extensions '-std=c++17' -x cuda '-DGOOGLE_CUDA=1' '--no-cuda-include-ptx=all' '--cuda-include-ptx=sm_80' '--cuda-gpu-arch=sm_80' '-Xcuda-fatbinary=--compress-all' '-nvcc_options=expt-relaxed-constexpr' -c xla/service/gpu/kernels/topk_kernel_bfloat16.cu.cc -o bazel-out/k8-opt/bin/xla/service/gpu/kernels/_objs/topk_kernel_gpu_cuda/topk_kernel_bfloat16.cu.pic.o)
# Configuration: a23d58a7decd97872ba4fb28d7f36f33537d5e63c34655e2e999fb4c17de5668
# Execution platform: @local_execution_config_platform//:platform
ERROR: /home/ubuntu/workspace/xla/xla/service/gpu/kernels/BUILD:173:19: Compiling xla/service/gpu/kernels/topk_kernel_float.cu.cc failed: (Exit 2): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target //xla/service/gpu/kernels:topk_kernel_gpu_cuda) external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/xla/service/gpu/kernels/_objs/topk_kernel_gpu_cuda/topk_kernel_float.cu.d ... (remaining 206 arguments skipped)
/home/ubuntu/.cache/bazel/_bazel_ubuntu/0176aac02e77df836cc7203737c26784/execroot/xla/external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc:227: SyntaxWarning: invalid escape sequence '\.'
  re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
/usr/lib/gcc/x86_64-linux-gnu/13/../../../../include/c++/13/tuple(2335): error: type name is not allowed
     static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
                                                          ^

/usr/lib/gcc/x86_64-linux-gnu/13/../../../../include/c++/13/tuple(2335): error: type name is not allowed
     static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
                                                               ^

/usr/lib/gcc/x86_64-linux-gnu/13/../../../../include/c++/13/tuple(2335): error: identifier "__reference_constructs_from_temporary" is undefined
     static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
                    ^

3 errors detected in the compilation of "xla/service/gpu/kernels/topk_kernel_float.cu.cc".

workaround:

edit

sudo vi /usr/lib/gcc/x86_64-linux-gnu/13/../../../../include/c++/13/tuple

comment out line 2335

#if __has_builtin(__reference_constructs_from_temporary)
      if constexpr (__n == 1)
        {
          using _Elt = decltype(std::get<0>(std::declval<_Tuple>()));
          // Disabled for XLA compatibility static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
        }
#endif

xla_configure.bazelrc

build --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-18/bin/clang
build --repo_env CC=/usr/lib/llvm-18/bin/clang
build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang
build --config cuda_nvcc
build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=8.0
build --config nonccl
build --action_env PYTHON_BIN_PATH=/usr/bin/python3
build --python_path /usr/bin/python3
test --test_env LD_LIBRARY_PATH
test --test_size_filters small,medium
build --copt -Wno-sign-compare
build --copt -Wno-error=unused-command-line-argument
build --copt -Wno-gnu-offsetof-extensions
build --build_tag_filters -no_oss,-rocm-only,-sycl-only
build --test_tag_filters -no_oss,-rocm-only,-sycl-only
test --build_tag_filters -no_oss,-rocm-only,-sycl-only
test --test_tag_filters -no_oss,-rocm-only,-sycl-only

Related issues:

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions