Skip to content

Commit

Permalink
[CI] Add triton wheels build workflow (pytorch#87234)
Browse files Browse the repository at this point in the history
Also, add `torchtriton` and `jinja2` as extra `dynamo` dependency to PyTorch wheels,

Version packages as first 10 characters of pinned repo hash and make `torch[dynamo]` wheel depend on the exact version it was build against.

TODO: Automate uploading to nightly wheels storage
Pull Request resolved: pytorch#87234
Approved by: https://github.com/msaroufim
  • Loading branch information
malfet authored and pytorchmergebot committed Oct 19, 2022
1 parent c413a32 commit dfe3fc0
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 0 deletions.
51 changes: 51 additions & 0 deletions .github/scripts/build_triton_wheel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
from subprocess import check_call
from pathlib import Path
from tempfile import TemporaryDirectory
import sys
import shutil
SCRIPT_DIR = Path(__file__).parent

def read_triton_pin() -> str:
with open(SCRIPT_DIR.parent / "ci_commit_pins" / "triton.txt") as f:
return f.read().strip()


def check_and_replace(inp: str, src: str, dst: str) -> str:
""" Checks that `src` can be found in `input` and replaces it with `dst` """
if src not in inp:
raise RuntimeError(f"Can't find ${src} in the input")
return inp.replace(src, dst)


def patch_setup_py(path: Path, *, version: str = "2.0.0", name: str = "triton") -> None:
with open(path) as f:
orig = f.read()
# Replace name
orig = check_and_replace(orig, "name=\"triton\",", f"name=\"{name}\",")
# Replace version
orig = check_and_replace(orig, "version=\"2.0.0\",", f"version=\"{version}\",")
with open(path, "w") as f:
f.write(orig)


def build_triton(commit_hash: str) -> Path:
with TemporaryDirectory() as tmpdir:
triton_basedir = Path(tmpdir) / "triton"
triton_pythondir = triton_basedir / "python"
check_call(["git", "clone", "https://github.com/openai/triton"], cwd=tmpdir)
check_call(["git", "checkout", commit_hash], cwd=triton_basedir)
patch_setup_py(triton_pythondir / "setup.py", name="torchtriton", version=f"2.0.0+{commit_hash[:10]}")
check_call([sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir)
whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0]
shutil.copy(whl_path, Path.cwd())
return Path.cwd() / whl_path.name


def main() -> None:
pin = read_triton_pin()
build_triton(pin)


if __name__ == "__main__":
main()
100 changes: 100 additions & 0 deletions .github/workflows/build-triton-wheel.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
name: Build Triton wheels

on:
push:
branches:
main
paths:
- .github/workflows/build-triton-wheel.yml
- .github/scripts/build_triton_wheel.py
- .github/ci_commit_pins/triton.txt
pull_request:
paths:
- .github/workflows/build-triton-wheel.yml
- .github/scripts/build_triton_wheel.py
- .github/ci_commit_pins/triton.txt

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true

jobs:
build-wheel:
runs-on: [self-hosted, linux.2xlarge]
strategy:
fail-fast: false
matrix:
py_vers: [ "3.7", "3.8", "3.9", "3.10", "3.11" ]
timeout-minutes: 40
env:
DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.6
PY_VERS: ${{ matrix.py_vers }}
steps:
- name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@master
with:
submodules: false

- name: Setup Linux
uses: ./.github/actions/setup-linux

- name: Setup SSH (Click me for login details)
uses: pytorch/test-infra/.github/actions/setup-ssh@main
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}

- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ env.DOCKER_IMAGE }}

- name: Build Triton wheel
run: |
set -x
mkdir -p "${RUNNER_TEMP}/artifacts/"
container_name=$(docker run \
--tty \
--detach \
-v "${GITHUB_WORKSPACE}:/pytorch" \
-v "${RUNNER_TEMP}/artifacts:/artifacts" \
-w /artifacts/ \
"${DOCKER_IMAGE}" \
)
# Determine python executable for given version
case $PY_VERS in
3.7)
PYTHON_EXECUTABLE=/opt/python/cp37-cp37m/bin/python
;;
3.8)
PYTHON_EXECUTABLE=/opt/python/cp38-cp38/bin/python
;;
3.9)
PYTHON_EXECUTABLE=/opt/python/cp39-cp39/bin/python
;;
3.10)
PYTHON_EXECUTABLE=/opt/python/cp310-cp310/bin/python
;;
3.11)
PYTHON_EXECUTABLE=/opt/python/cp311-cp311/bin/python
;;
*)
echo "Unsupported python version ${PY_VERS}"
exit 1
;;
esac
docker exec -t "${container_name}" yum install -y zlib-devel
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" /pytorch/.github/scripts/build_triton_wheel.py
docker exec -t "${container_name}" chown -R 1000.1000 /artifacts
- uses: actions/upload-artifact@v3
with:
name: "pytorch-triton-${{ matrix.py_vers }}"
if-no-files-found: error
path:
${{ runner.temp }}/artifacts/*

- name: Teardown Linux
uses: pytorch/test-infra/.github/actions/teardown-linux@main
if: always()
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,12 @@ def main():
extras_require = {
'opt-einsum': ['opt-einsum>=3.3']
}
if platform.system() == 'Linux':
triton_pin_file = os.path.join(cwd, ".github", "ci_commit_pins", "triton.txt")
if os.path.exists(triton_pin_file):
with open(triton_pin_file) as f:
triton_pin = f.read().strip()
extras_require['dynamo'] = ['torchtriton==2.0.0+' + triton_pin[:10], 'jinja2']

# Parse the command line and check the arguments before we proceed with
# building deps and setup. We need to set values so `--help` works.
Expand Down

0 comments on commit dfe3fc0

Please sign in to comment.