|
1 | 1 | # Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/utils/cpp_extension.py
|
2 | 2 |
|
| 3 | +import functools |
3 | 4 | import os
|
| 5 | +import re |
4 | 6 | import subprocess
|
5 | 7 | import sys
|
6 | 8 | import sysconfig
|
| 9 | +from packaging.version import Version |
7 | 10 | from pathlib import Path
|
8 | 11 | from typing import List, Optional
|
9 | 12 |
|
|
19 | 22 | from . import env as jit_env
|
20 | 23 |
|
21 | 24 |
|
| 25 | +@functools.cache |
| 26 | +def _get_cuda_version() -> Version: |
| 27 | + if CUDA_HOME is None: |
| 28 | + nvcc = "nvcc" |
| 29 | + else: |
| 30 | + nvcc = os.path.join(CUDA_HOME, "bin/nvcc") |
| 31 | + txt = subprocess.check_output([nvcc, "--version"], text=True) |
| 32 | + matches = re.findall(r"release (\d+\.\d+),", txt) |
| 33 | + if not matches: |
| 34 | + raise RuntimeError( |
| 35 | + f"Could not parse CUDA version from nvcc --version output: {txt}" |
| 36 | + ) |
| 37 | + return Version(matches[0]) |
| 38 | + |
| 39 | + |
22 | 40 | def _get_glibcxx_abi_build_flags() -> List[str]:
|
23 | 41 | glibcxx_abi_cflags = [
|
24 | 42 | "-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
|
@@ -78,8 +96,13 @@ def generate_ninja_build_for_op(
|
78 | 96 | "$common_cflags",
|
79 | 97 | "--compiler-options=-fPIC",
|
80 | 98 | "--expt-relaxed-constexpr",
|
81 |
| - "-static-global-template-stub=false", |
82 | 99 | ]
|
| 100 | + cuda_version = _get_cuda_version() |
| 101 | + # enable -static-global-template-stub when cuda version >= 12.8 |
| 102 | + if cuda_version >= Version("12.8"): |
| 103 | + cuda_cflags += [ |
| 104 | + "-static-global-template-stub=false", |
| 105 | + ] |
83 | 106 | cuda_cflags += _get_cuda_arch_flags(extra_cuda_cflags)
|
84 | 107 | if extra_cuda_cflags is not None:
|
85 | 108 | cuda_cflags += extra_cuda_cflags
|
|
0 commit comments