From 644ccd7f1d5a27d5bfdbd43ee67b552d4ba4081c Mon Sep 17 00:00:00 2001 From: zifeitong Date: Thu, 2 Jan 2025 22:46:17 -0800 Subject: [PATCH] ci: Improve compatibility with pytorch 2.5 (#711) - Fix release workflow for pytorch 2.5 wheel by dropping python 3.8. - Relax pytorch version requirements. - Right now, install flashinfer will downgrade existing pytorch versions since it's set the version requirement to exact match. --- .github/workflows/release_wheel.yml | 2 ++ scripts/run-ci-build-wheel.sh | 2 +- setup.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release_wheel.yml b/.github/workflows/release_wheel.yml index 3e33b10d..e646eadc 100644 --- a/.github/workflows/release_wheel.yml +++ b/.github/workflows/release_wheel.yml @@ -33,6 +33,8 @@ jobs: torch: "2.2" - cuda: "12.4" torch: "2.3" + - python: "3.8" # torch 2.5+ drops python 3.8 + torch: "2.5" runs-on: [self-hosted] steps: diff --git a/scripts/run-ci-build-wheel.sh b/scripts/run-ci-build-wheel.sh index 7e048683..91a6550c 100644 --- a/scripts/run-ci-build-wheel.sh +++ b/scripts/run-ci-build-wheel.sh @@ -41,7 +41,7 @@ else fi echo "::group::Install PyTorch" -pip install torch==$FLASHINFER_CI_TORCH_VERSION --index-url "https://download.pytorch.org/whl/cu${CUDA_MAJOR}${CUDA_MINOR}" +pip install torch==${FLASHINFER_CI_TORCH_VERSION}.* --index-url "https://download.pytorch.org/whl/cu${CUDA_MAJOR}${CUDA_MINOR}" echo "::endgroup::" echo "::group::Install build system" diff --git a/setup.py b/setup.py index f41315e7..6ad370ae 100644 --- a/setup.py +++ b/setup.py @@ -135,7 +135,7 @@ def __init__(self, *args, **kwargs) -> None: torch_full_version = Version(torch.__version__) torch_version = f"{torch_full_version.major}.{torch_full_version.minor}" cmdclass["build_ext"] = NinjaBuildExtension - install_requires = [f"torch == {torch_version}"] + install_requires = [f"torch == {torch_version}.*"] aot_build_meta = {} aot_build_meta["cuda_major"] = cuda_version.major