Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxlib.xla_extension.XlaRuntimeError #50

Open
Xiaojun928 opened this issue Nov 21, 2024 · 0 comments
Open

jaxlib.xla_extension.XlaRuntimeError #50

Xiaojun928 opened this issue Nov 21, 2024 · 0 comments

Comments

@Xiaojun928
Copy link

Xiaojun928 commented Nov 21, 2024

钟博好~

感谢开发出Parafold这一利器!
在安装Parafold过程中暂时没有遇到问题,并且顺利完成了第一步feature。我在尝试运行第二步结构预测时,遇到jax相关的问题,有劳您帮忙给一些建议呀~

GPU配置信息如下:

 NVIDIA-SMI 550.90.12              Driver Version: 550.90.12      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H800 PCIe               Off |   00000000:34:00.0 Off |                    0 |
| N/A   52C    P0             89W /  350W |       1MiB /  81559MiB |      3%      Default |
|                                         |                        |             Disabled

安装方式参考readme中“How to install” 部分,jax 的版本也是遵循readme中提到的0.3.25版本。
另外,我还参考 issue#39 中的建议安装了cuda-nvcc,但类似的问题并未得到解决。

我遇到的报错信息如下:

2024-11-21 11:37:59.402301: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:231] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 9.0
2024-11-21 11:37:59.402323: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:234] Used ptxas at ptxas
2024-11-21 11:37:59.404084: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:628] failed to get PTX kernel "shift_right_logical" from module: CUDA_ERROR_NOT_FOUND: named symbol not found
2024-11-21 11:37:59.404116: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: INTERNAL: Could not find the corresponding function
Traceback (most recent call last):
  File "/home/software/ParallelFold/run_alphafold.py", line 491, in <module>
    app.run(main)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/software/ParallelFold/run_alphafold.py", line 464, in main
    predict_structure(
  File "/home/software/ParallelFold/run_alphafold.py", line 239, in predict_structure
    prediction_result = model_runner.predict(processed_feature_dict,
  File "/home/software/ParallelFold/alphafold/model/model.py", line 167, in predict
    result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/random.py", line 132, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 267, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 580, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 329, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 712, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 592, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 597, in random_seed_impl_base
    return seed(seeds)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 832, in threefry_seed
    lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 515, in shift_right_logical
    return shift_right_logical_p.bind(x, y)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 329, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 712, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 115, in apply_primitive
    return compiled_fun(*args)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 200, in <lambda>
    return lambda *args, **kw: compiled(*args, **kw)[0]
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 895, in _execute_compiled
    out_flat = compiled.execute(in_flat)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Could not find the corresponding function

这似乎是H800与JAX 0.3.25不兼容,请问如果升级JAX可以吗?

多谢!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant