forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CI] Add triton wheels build workflow (pytorch#87234)
Also, add `torchtriton` and `jinja2` as extra `dynamo` dependency to PyTorch wheels, Version packages as first 10 characters of pinned repo hash and make `torch[dynamo]` wheel depend on the exact version it was build against. TODO: Automate uploading to nightly wheels storage Pull Request resolved: pytorch#87234 Approved by: https://github.com/msaroufim
- Loading branch information
1 parent
c413a32
commit dfe3fc0
Showing
3 changed files
with
157 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#!/usr/bin/env python3 | ||
from subprocess import check_call | ||
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
import sys | ||
import shutil | ||
SCRIPT_DIR = Path(__file__).parent | ||
|
||
def read_triton_pin() -> str: | ||
with open(SCRIPT_DIR.parent / "ci_commit_pins" / "triton.txt") as f: | ||
return f.read().strip() | ||
|
||
|
||
def check_and_replace(inp: str, src: str, dst: str) -> str: | ||
""" Checks that `src` can be found in `input` and replaces it with `dst` """ | ||
if src not in inp: | ||
raise RuntimeError(f"Can't find ${src} in the input") | ||
return inp.replace(src, dst) | ||
|
||
|
||
def patch_setup_py(path: Path, *, version: str = "2.0.0", name: str = "triton") -> None: | ||
with open(path) as f: | ||
orig = f.read() | ||
# Replace name | ||
orig = check_and_replace(orig, "name=\"triton\",", f"name=\"{name}\",") | ||
# Replace version | ||
orig = check_and_replace(orig, "version=\"2.0.0\",", f"version=\"{version}\",") | ||
with open(path, "w") as f: | ||
f.write(orig) | ||
|
||
|
||
def build_triton(commit_hash: str) -> Path: | ||
with TemporaryDirectory() as tmpdir: | ||
triton_basedir = Path(tmpdir) / "triton" | ||
triton_pythondir = triton_basedir / "python" | ||
check_call(["git", "clone", "https://github.com/openai/triton"], cwd=tmpdir) | ||
check_call(["git", "checkout", commit_hash], cwd=triton_basedir) | ||
patch_setup_py(triton_pythondir / "setup.py", name="torchtriton", version=f"2.0.0+{commit_hash[:10]}") | ||
check_call([sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir) | ||
whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0] | ||
shutil.copy(whl_path, Path.cwd()) | ||
return Path.cwd() / whl_path.name | ||
|
||
|
||
def main() -> None: | ||
pin = read_triton_pin() | ||
build_triton(pin) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
name: Build Triton wheels | ||
|
||
on: | ||
push: | ||
branches: | ||
main | ||
paths: | ||
- .github/workflows/build-triton-wheel.yml | ||
- .github/scripts/build_triton_wheel.py | ||
- .github/ci_commit_pins/triton.txt | ||
pull_request: | ||
paths: | ||
- .github/workflows/build-triton-wheel.yml | ||
- .github/scripts/build_triton_wheel.py | ||
- .github/ci_commit_pins/triton.txt | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
build-wheel: | ||
runs-on: [self-hosted, linux.2xlarge] | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
py_vers: [ "3.7", "3.8", "3.9", "3.10", "3.11" ] | ||
timeout-minutes: 40 | ||
env: | ||
DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.6 | ||
PY_VERS: ${{ matrix.py_vers }} | ||
steps: | ||
- name: Checkout PyTorch | ||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@master | ||
with: | ||
submodules: false | ||
|
||
- name: Setup Linux | ||
uses: ./.github/actions/setup-linux | ||
|
||
- name: Setup SSH (Click me for login details) | ||
uses: pytorch/test-infra/.github/actions/setup-ssh@main | ||
with: | ||
github-secret: ${{ secrets.GITHUB_TOKEN }} | ||
|
||
- name: Pull Docker image | ||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main | ||
with: | ||
docker-image: ${{ env.DOCKER_IMAGE }} | ||
|
||
- name: Build Triton wheel | ||
run: | | ||
set -x | ||
mkdir -p "${RUNNER_TEMP}/artifacts/" | ||
container_name=$(docker run \ | ||
--tty \ | ||
--detach \ | ||
-v "${GITHUB_WORKSPACE}:/pytorch" \ | ||
-v "${RUNNER_TEMP}/artifacts:/artifacts" \ | ||
-w /artifacts/ \ | ||
"${DOCKER_IMAGE}" \ | ||
) | ||
# Determine python executable for given version | ||
case $PY_VERS in | ||
3.7) | ||
PYTHON_EXECUTABLE=/opt/python/cp37-cp37m/bin/python | ||
;; | ||
3.8) | ||
PYTHON_EXECUTABLE=/opt/python/cp38-cp38/bin/python | ||
;; | ||
3.9) | ||
PYTHON_EXECUTABLE=/opt/python/cp39-cp39/bin/python | ||
;; | ||
3.10) | ||
PYTHON_EXECUTABLE=/opt/python/cp310-cp310/bin/python | ||
;; | ||
3.11) | ||
PYTHON_EXECUTABLE=/opt/python/cp311-cp311/bin/python | ||
;; | ||
*) | ||
echo "Unsupported python version ${PY_VERS}" | ||
exit 1 | ||
;; | ||
esac | ||
docker exec -t "${container_name}" yum install -y zlib-devel | ||
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" /pytorch/.github/scripts/build_triton_wheel.py | ||
docker exec -t "${container_name}" chown -R 1000.1000 /artifacts | ||
- uses: actions/upload-artifact@v3 | ||
with: | ||
name: "pytorch-triton-${{ matrix.py_vers }}" | ||
if-no-files-found: error | ||
path: | ||
${{ runner.temp }}/artifacts/* | ||
|
||
- name: Teardown Linux | ||
uses: pytorch/test-infra/.github/actions/teardown-linux@main | ||
if: always() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters