Skip to content

build on Jetson AGX Orin failed #28989

Open
@wjz-kidding

Description

@wjz-kidding

Description

I'm trying to build jaxlib==0.3.25 on jetson agx orin dev kit, ubuntu22.04, gcc 11.3 (because I need it strictly), cuda 12.2, cudnn 8.9
precondition: my cuda 12, libcupti.so.12 has soname 12, but bazel need 12.2 strictly, so I created a libcupti.so.12.2 with patchelf, so it's SONAME is 12.2 (might cause a problem)

running process is like this:

source /path/to/python/bin/activate
git clone jax
cd jax
git checkout 0.3.25

after python ./build/build.py --help, the old document I try to build like this.

python build/build.py \
  --bazel_path /usr/bin/bazel \                         # Path to Bazel binary
  --python_bin_path /home/cuda12onnvme/venv/dalleminipy310/bin/python \  # Python binary path
  --enable_cuda \                                       # Enable CUDA support
  --noenable_mkl_dnn \                                  # Disable MKL-DNN
  --noenable_rocm \                                     # Disable ROCm
  --noenable_tpu \                                      # Disable TPU
  --noenable_remote_tpu \                               # Disable remote TPU
  --noenable_nccl \                                     # Disable NCCL
  --cuda_path=/usr/local/cuda/ \                        # CUDA installation path
  --cudnn_path=/usr/lib/aarch64-linux-gnu \             # cuDNN installation path
  --cuda_version=12 \                                   # CUDA version
  --cudnn_version=8 \                                   # cuDNN version
  --cuda_compute_capabilities=8.7 \                     # CUDA compute capabilities
  --output_path=./jaxlib_dist \                         # Output directory
  > build_log.txt 2>&1                                  # Redirect output to log file

but the error look like this:

[3,028 / 4,634] Compiling tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc; 39s local ... (12 actions, 11 running)
ERROR: /home/cuda12onnvme/.cache/bazel/_bazel_cuda12onnvme/430fe01fe46ae5863ffdb4c3cc5d2f3d/external/org_tensorflow/tensorflow/tsl/cuda/BUILD:75:11: Compiling tensorflow/tsl/cuda/cudart_stub.cc failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command 
  (cd /home/cuda12onnvme/.cache/bazel/_bazel_cuda12onnvme/430fe01fe46ae5863ffdb4c3cc5d2f3d/execroot/__main__ && \
  exec env - \
    CUDA_TOOLKIT_PATH=/usr/local/cuda/ \
    CUDNN_INSTALL_PATH=/usr/lib/aarch64-linux-gnu \
    LD_LIBRARY_PATH=/usr/local/cuda-12.2/lib64:/usr/local/cuda-12.2/lib64: \
    PATH=/home/cuda12onnvme/.cache/bazelisk/downloads/sha256/a590a28608772e779efc0c29bb678cd2a150deb27a9f8c557cc1d2b131a779ef/bin:/home/cuda12onnvme/venv/dalleminipy310/bin:/home/cuda12onnvme/bin:/usr/local/cuda-12.2/bin:/home/cuda12onnvme/.vscode-server/cli/servers/Stable-f1a4fb101478ce6ec82fe9627c43efbf9e98c813/server/bin/remote-cli:/usr/local/cuda-12.2/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/home/cuda12onnvme/.vscode-server/extensions/ms-python.debugpy-2025.8.0-linux-arm64/bundled/scripts/noConfigScripts \
    PWD=/proc/self/cwd \
    TF_CUDA_COMPUTE_CAPABILITIES=8.7 \
    TF_CUDA_PATHS=/usr/local/cuda/,/usr/lib/aarch64-linux-gnu \
    TF_CUDA_VERSION=12 \
    TF_CUDNN_VERSION=8 \
  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/aarch64-opt/bin/external/org_tensorflow/tensorflow/tsl/cuda/_objs/cudart_stub/cudart_stub.pic.d '-frandom-seed=bazel-out/aarch64-opt/bin/external/org_tensorflow/tensorflow/tsl/cuda/_objs/cudart_stub/cudart_stub.pic.o' -DEIGEN_MPL2_ONLY '-DEIGEN_MAX_ALIGN_BYTES=64' -DHAVE_SYS_UIO_H -DTF_USE_SNAPPY -iquote external/org_tensorflow -iquote bazel-out/aarch64-opt/bin/external/org_tensorflow -iquote external/local_config_cuda -iquote bazel-out/aarch64-opt/bin/external/local_config_cuda -iquote external/eigen_archive -iquote bazel-out/aarch64-opt/bin/external/eigen_archive -iquote external/com_google_absl -iquote bazel-out/aarch64-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/aarch64-opt/bin/external/nsync -iquote external/snappy -iquote bazel-out/aarch64-opt/bin/external/snappy -iquote external/double_conversion -iquote bazel-out/aarch64-opt/bin/external/double_conversion -iquote external/com_google_protobuf -iquote bazel-out/aarch64-opt/bin/external/com_google_protobuf -iquote external/zlib -iquote bazel-out/aarch64-opt/bin/external/zlib -iquote external/com_googlesource_code_re2 -iquote bazel-out/aarch64-opt/bin/external/com_googlesource_code_re2 -iquote external/local_config_rocm -iquote bazel-out/aarch64-opt/bin/external/local_config_rocm -iquote external/local_config_tensorrt -iquote bazel-out/aarch64-opt/bin/external/local_config_tensorrt -Ibazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual -Ibazel-out/aarch64-opt/bin/external/local_config_tensorrt/_virtual_includes/tensorrt_headers -isystem external/local_config_cuda/cuda -isystem bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda -isystem external/local_config_cuda/cuda/cuda/include -isystem bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/cuda/include -isystem external/eigen_archive -isystem bazel-out/aarch64-opt/bin/external/eigen_archive -isystem external/nsync/public -isystem bazel-out/aarch64-opt/bin/external/nsync/public -isystem external/com_google_protobuf/src -isystem bazel-out/aarch64-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/aarch64-opt/bin/external/zlib -isystem external/local_config_rocm/rocm -isystem bazel-out/aarch64-opt/bin/external/local_config_rocm/rocm -isystem external/local_config_rocm/rocm/rocm/include -isystem bazel-out/aarch64-opt/bin/external/local_config_rocm/rocm/rocm/include -isystem external/local_config_rocm/rocm/rocm/include/rocrand -isystem bazel-out/aarch64-opt/bin/external/local_config_rocm/rocm/rocm/include/rocrand -isystem external/local_config_rocm/rocm/rocm/include/roctracer -isystem bazel-out/aarch64-opt/bin/external/local_config_rocm/rocm/rocm/include/roctracer -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -fno-canonical-system-headers -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '-fvisibility=hidden' -Wno-sign-compare -Wno-unknown-warning-option -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' '-std=c++17' -c external/org_tensorflow/tensorflow/tsl/cuda/cudart_stub.cc -o bazel-out/aarch64-opt/bin/external/org_tensorflow/tensorflow/tsl/cuda/_objs/cudart_stub/cudart_stub.pic.o)
# Configuration: 41f72ef3e92e6f87aabb202fd6e8b2733d72f3ee8b1f76bfc17cff4c48f3c241
# Execution platform: @local_execution_config_platform//:platform
In file included from external/org_tensorflow/tensorflow/tsl/cuda/cudart_stub.cc:19:
bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual/third_party/gpus/cuda/include/cuda_runtime_api.h:166:62: error: conflicting declaration of C function ‘cudaError_t cudaStreamGetCaptureInfo_v2(cudaStream_t, cudaStreamCaptureStatus*, long long unsigned int*)’
  166 |     #define cudaStreamGetCaptureInfo       __CUDART_API_PTSZ(cudaStreamGetCaptureInfo_v2)
      |                                                              ^~~~~~~~~~~~~~~~~~~~~~~~~~~
bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual/third_party/gpus/cuda/include/cuda_runtime_api.h:160:36: note: in definition of macro ‘__CUDART_API_PTSZ’
  160 |     #define __CUDART_API_PTSZ(api) api
      |                                    ^~~
external/org_tensorflow/tensorflow/tsl/cuda/cuda_runtime_11_2.inc:487:39: note: in expansion of macro ‘cudaStreamGetCaptureInfo’
  487 | extern __host__ cudaError_t CUDARTAPI cudaStreamGetCaptureInfo(
      |                                       ^~~~~~~~~~~~~~~~~~~~~~~~
bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual/third_party/gpus/cuda/include/cuda_runtime_api.h:166:62: note: previous declaration ‘cudaError_t cudaStreamGetCaptureInfo_v2(cudaStream_t, cudaStreamCaptureStatus*, long long unsigned int*, CUgraph_st**, CUgraphNode_st* const**, size_t*)’
  166 |     #define cudaStreamGetCaptureInfo       __CUDART_API_PTSZ(cudaStreamGetCaptureInfo_v2)
      |                                                              ^~~~~~~~~~~~~~~~~~~~~~~~~~~
bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual/third_party/gpus/cuda/include/cuda_runtime_api.h:160:36: note: in definition of macro ‘__CUDART_API_PTSZ’
  160 |     #define __CUDART_API_PTSZ(api) api
      |                                    ^~~
bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual/third_party/gpus/cuda/include/cuda_runtime_api.h:3107:39: note: in expansion of macro ‘cudaStreamGetCaptureInfo’
 3107 | extern __host__ cudaError_t CUDARTAPI cudaStreamGetCaptureInfo(cudaStream_t stream, enum cudaStreamCaptureStatus *captureStatus_out, unsigned long long *id_out __dv(0), cudaGraph_t *graph_out __dv(0), const cudaGraphNode_t **dependencies_out __dv(0), size_t *numDependencies_out __dv(0));
      |                                       ^~~~~~~~~~~~~~~~~~~~~~~~
In file included from external/org_tensorflow/tensorflow/tsl/cuda/cudart_stub.cc:64:
external/org_tensorflow/tensorflow/tsl/cuda/cuda_runtime_11_2.inc:2073:39: error: conflicting declaration of C function ‘cudaError_t cudaGraphInstantiate(CUgraphExec_st**, cudaGraph_t, CUgraphNode_st**, char*, size_t)’
 2073 | extern __host__ cudaError_t CUDARTAPI cudaGraphInstantiate(
      |                                       ^~~~~~~~~~~~~~~~~~~~
In file included from external/org_tensorflow/tensorflow/tsl/cuda/cudart_stub.cc:19:
bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual/third_party/gpus/cuda/include/cuda_runtime_api.h:11650:39: note: previous declaration ‘cudaError_t cudaGraphInstantiate(CUgraphExec_st**, cudaGraph_t, long long unsigned int)’
11650 | extern __host__ cudaError_t CUDARTAPI cudaGraphInstantiate(cudaGraphExec_t *pGraphExec, cudaGraph_t graph, unsigned long long flags __dv(0));
      |                                       ^~~~~~~~~~~~~~~~~~~~
In file included from external/org_tensorflow/tensorflow/tsl/cuda/cudart_stub.cc:64:
external/org_tensorflow/tensorflow/tsl/cuda/cuda_runtime_11_2.inc:2127:1: error: conflicting declaration of C function ‘cudaError_t cudaGraphExecUpdate(cudaGraphExec_t, cudaGraph_t, CUgraphNode_st**, cudaGraphExecUpdateResult*)’
 2127 | cudaGraphExecUpdate(cudaGraphExec_t hGraphExec, cudaGraph_t hGraph,
      | ^~~~~~~~~~~~~~~~~~~
In file included from external/org_tensorflow/tensorflow/tsl/cuda/cudart_stub.cc:19:
bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual/third_party/gpus/cuda/include/cuda_runtime_api.h:12620:39: note: previous declaration ‘cudaError_t cudaGraphExecUpdate(cudaGraphExec_t, cudaGraph_t, cudaGraphExecUpdateResultInfo*)’
12620 | extern __host__ cudaError_t CUDARTAPI cudaGraphExecUpdate(cudaGraphExec_t hGraphExec, cudaGraph_t hGraph, cudaGraphExecUpdateResultInfo *resultInfo);
      |                                       ^~~~~~~~~~~~~~~~~~~
In file included from external/org_tensorflow/tensorflow/tsl/cuda/cudart_stub.cc:64:
external/org_tensorflow/tensorflow/tsl/cuda/cuda_runtime_11_2.inc:2250:39: error: conflicting declaration of C function ‘cudaError_t cudaGetDriverEntryPoint(const char*, void**, long long unsigned int)’
 2250 | extern __host__ cudaError_t CUDARTAPI cudaGetDriverEntryPoint(
      |                                       ^~~~~~~~~~~~~~~~~~~~~~~
In file included from external/org_tensorflow/tensorflow/tsl/cuda/cudart_stub.cc:19:
bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual/third_party/gpus/cuda/include/cuda_runtime_api.h:13077:39: note: previous declaration ‘cudaError_t cudaGetDriverEntryPoint(const char*, void**, long long unsigned int, cudaDriverEntryPointQueryResult*)’
13077 | extern __host__ cudaError_t CUDARTAPI cudaGetDriverEntryPoint(const char *symbol, void **funcPtr, unsigned long long flags, enum cudaDriverEntryPointQueryResult *driverStatus = NULL);
      |                                       ^~~~~~~~~~~~~~~~~~~~~~~
cc1plus: note: unrecognized command-line option ‘-Wno-unknown-warning-option’ may have been intended to silence earlier diagnostics
Target //build:build_wheel failed to build
INFO: Elapsed time: 1542.624s, Critical Path: 247.76s
INFO: 3275 processes: 575 internal, 2700 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully

     _   _  __  __
    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\


Bazel binary path: /usr/bin/bazel
Bazel version: 5.1.1
Python binary path: /home/cuda12onnvme/venv/dalleminipy310/bin/python
Python version: 3.10
NumPy version: 1.26.4
MKL-DNN enabled: no
Target CPU: aarch64
Target CPU features: release
CUDA enabled: yes
CUDA toolkit path: /usr/local/cuda/
CUDNN library path: /usr/lib/aarch64-linux-gnu
CUDA compute capabilities: 8.7
CUDA version: 12
CUDNN version: 8
NCCL enabled: no
TPU enabled: no
Remote TPU enabled: no
ROCm enabled: no
Plugin device enabled: no

Building XLA and installing it in the jaxlib source tree...
/usr/bin/bazel run --verbose_failures=true :build_wheel -- --output_path=/home/cuda12onnvme/dalle-mini/jax/jaxlib_dist --cpu=aarch64
b''
Traceback (most recent call last):
  File "/home/cuda12onnvme/dalle-mini/jax/build/build.py", line 567, in <module>
    main()
  File "/home/cuda12onnvme/dalle-mini/jax/build/build.py", line 562, in main
    shell(command)
  File "/home/cuda12onnvme/dalle-mini/jax/build/build.py", line 53, in shell
    output = subprocess.check_output(cmd)
  File "/usr/lib/python3.10/subprocess.py", line 421, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/usr/lib/python3.10/subprocess.py", line 526, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/usr/bin/bazel', 'run', '--verbose_failures=true', ':build_wheel', '--', '--output_path=/home/cuda12onnvme/dalle-mini/jax/jaxlib_dist', '--cpu=aarch64']' returned non-zero exit status 1.

What's more, the similar build from source process worked on ubuntu 20.04 (L4T35) and cuda 11.4, when I tries to build jaxlib==0.3.25 on cuda 11.4 it worked. even without bazelisk installing bazel.

source /path/to/python/bin/activate
git clone jax
cd jax
git checkout 0.3.25

python build/build.py --python_bin_path /home/usr_name/venv/dellaPy38/bin/python --enable_cuda --noenable_mkl_dnn --noenable_rocm --noenable_tpu --noenable_remote_tpu --noenable_nccl --cuda_path=/usr/local/cuda/ --cudnn_path=/usr/lib/aarch64-linux-gnu --cuda_version=11.4 --cudnn_version=8 --cuda_compute_capabilities=8.7 --output_path=./jaxlib_dist > build_log.txt 2>&1

unfortunately, this jetson linux image use gcc 8, and I need gcc >=10, so it doesn't work.

System info (python version, jaxlib version, accelerator, etc.)

python==3.10
jaxlib==0.3.25
gcc 11.3
jetson agx orin, a aarch64 device with cuda and gpu.
the compute compatibility==8.7, sm_87
jetpack 6.0(cuda 12.2) with L4T36.3 (linux image for jetson ubuntu22.04),

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions