Skip to content

Commit 2b4c975

Browse files
authored
add cuda version check for jit (#1526)
<!-- .github/pull_request_template.md --> ## 📌 Description Add cuda version check for option `-static-global-template-stub`. `-static-global-template-stub` is only enabled for cuda >= 12.8. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> cc: @zhyncs
1 parent 56afe8f commit 2b4c975

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

flashinfer/jit/cpp_ext.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/utils/cpp_extension.py
22

3+
import functools
34
import os
5+
import re
46
import subprocess
57
import sys
68
import sysconfig
9+
from packaging.version import Version
710
from pathlib import Path
811
from typing import List, Optional
912

@@ -19,6 +22,21 @@
1922
from . import env as jit_env
2023

2124

25+
@functools.cache
26+
def _get_cuda_version() -> Version:
27+
if CUDA_HOME is None:
28+
nvcc = "nvcc"
29+
else:
30+
nvcc = os.path.join(CUDA_HOME, "bin/nvcc")
31+
txt = subprocess.check_output([nvcc, "--version"], text=True)
32+
matches = re.findall(r"release (\d+\.\d+),", txt)
33+
if not matches:
34+
raise RuntimeError(
35+
f"Could not parse CUDA version from nvcc --version output: {txt}"
36+
)
37+
return Version(matches[0])
38+
39+
2240
def _get_glibcxx_abi_build_flags() -> List[str]:
2341
glibcxx_abi_cflags = [
2442
"-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
@@ -78,8 +96,13 @@ def generate_ninja_build_for_op(
7896
"$common_cflags",
7997
"--compiler-options=-fPIC",
8098
"--expt-relaxed-constexpr",
81-
"-static-global-template-stub=false",
8299
]
100+
cuda_version = _get_cuda_version()
101+
# enable -static-global-template-stub when cuda version >= 12.8
102+
if cuda_version >= Version("12.8"):
103+
cuda_cflags += [
104+
"-static-global-template-stub=false",
105+
]
83106
cuda_cflags += _get_cuda_arch_flags(extra_cuda_cflags)
84107
if extra_cuda_cflags is not None:
85108
cuda_cflags += extra_cuda_cflags

0 commit comments

Comments
 (0)