Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 110 additions & 37 deletions .github/workflows/tpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ jobs:
exit "${code:-0}"
fi

run_dev:
run_dev_others:
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }}
runs-on: [linux-x86-ct5lp-224-8tpu]
environment: testing
Expand Down Expand Up @@ -181,6 +181,79 @@ jobs:
--train-mini-batch-size=2 \
--train-micro-batch-size=2 \
--rollout-engine=vanilla
- name: Run tunix SFT integration tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
# Reinstall Tunix with prod dependencies
pip install -e .[prod] --force-reinstall

# Loading tfds requires tensorflow.
pip install tensorflow

export JAX_PLATFORMS=tpu,cpu
./tests/sft/sft_tpu_smoke_test.sh
- name: Run Smoke tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }}
KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }}
run: |
echo "Running Smoke tests..."
python -m pytest tests/smoke_tests/model_creation_test.py -v --tb=short
- name: Run tunix cli tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }}
KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }}
run: |
# Config tests that passed
python -m pytest tests/cli/ -v --tb=short \
--ignore=tests/cli/utils/model_test.py
- name: Run model alignment tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pip install torch
JAX_PLATFORMS=cpu python -m pytest tests/model_alignment/ -v --tb=short
unset JAX_PLATFORMS

run_dev_vllm:
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }}
runs-on: [linux-x86-ct5lp-224-8tpu]
environment: testing
container:
image: vllm/vllm-tpu:nightly-20260406-581c4d4-f6983f0
options: --privileged
env:
CLOUD_TPU_ACCELERATOR: v5e-8
JAX_PLATFORMS: tpu,cpu
steps:
# Cache Hugging Face hub
- name: Cache HF hub
uses: actions/cache@v4
with:
path: ~/.cache/huggingface
key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
restore-keys: |
hf-${{ runner.os }}-

- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Setup Tunix , tpu-inference and dependencies
run: |
echo "Current directory:"
pwd
pip install --upgrade pip setuptools wheel

# Install Tunix with dev and test dependencies without overwriting the vLLM dependencies.
pip install -e .[dev,test]
pip install transformers==4.57.1 --force-reinstall # Issue: https://github.com/google/tunix/pull/795
# tpu-inference/Numba needs NumPy 2.3 or less.
pip install numpy==2.3.5 --force-reinstall
- name: Run vllm tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand All @@ -193,6 +266,42 @@ jobs:
pytest -s "$test" -v --tb=short
done < test_collections.txt

run_dev_sglang:
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }}
runs-on: [linux-x86-ct5lp-224-8tpu]
environment: testing
container:
image: vllm/vllm-tpu:nightly-20260406-581c4d4-f6983f0
options: --privileged
env:
CLOUD_TPU_ACCELERATOR: v5e-8
JAX_PLATFORMS: tpu,cpu
steps:
# Cache Hugging Face hub
- name: Cache HF hub
uses: actions/cache@v4
with:
path: ~/.cache/huggingface
key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
restore-keys: |
hf-${{ runner.os }}-

- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Setup Tunix , tpu-inference and dependencies
run: |
echo "Current directory:"
pwd
pip install --upgrade pip setuptools wheel

# Install Tunix with dev and test dependencies without overwriting the vLLM dependencies.
pip install -e .[dev,test]
pip install transformers==4.57.1 --force-reinstall # Issue: https://github.com/google/tunix/pull/795
# tpu-inference/Numba needs NumPy 2.3 or less.
pip install numpy==2.3.5 --force-reinstall
- name: Run install sglang-jax && test
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand All @@ -218,39 +327,3 @@ jobs:
apt-get update; apt-get install -y less

cd tunix && python tests/generate/sglang_jax_sampler_test.py && python tests/generate/sglang_jax_lora_test.py
- name: Run tunix SFT integration tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
# Reinstall Tunix with prod dependencies
pip install -e .[prod] --force-reinstall

# Loading tfds requires tensorflow.
pip install tensorflow

export JAX_PLATFORMS=tpu,cpu
./tests/sft/sft_tpu_smoke_test.sh
- name: Run Smoke tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }}
KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }}
run: |
echo "Running Smoke tests..."
python -m pytest tests/smoke_tests/model_creation_test.py -v --tb=short
- name: Run tunix cli tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }}
KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }}
run: |
# Config tests that passed
python -m pytest tests/cli/ -v --tb=short \
--ignore=tests/cli/utils/model_test.py
- name: Run model alignment tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pip install torch
JAX_PLATFORMS=cpu python -m pytest tests/model_alignment/ -v --tb=short
unset JAX_PLATFORMS
Loading