Skip to content

feat: add Python implementation of accelerated OP #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ env:
TEST_TORCH_SPECS: "cpu cu116"

jobs:
build-sdist:
build:
name: Build sdist and pure-Python wheel
runs-on: ubuntu-latest
if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
timeout-minutes: 10
timeout-minutes: 60
steps:
- name: Checkout
uses: actions/checkout@v3
Expand All @@ -55,7 +56,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.7 - 3.11" # sync with requires-python in pyproject.toml
python-version: "3.7 - 3.10" # sync with requires-python in pyproject.toml
update-environment: true

- name: Set __release__
Expand All @@ -71,20 +72,34 @@ jobs:
- name: Install dependencies
run: python -m pip install --upgrade pip setuptools wheel build

- name: Build sdist
run: python -m build --sdist
- name: Build sdist and pure-Python wheel
run: python -m build
env:
TORCHOPT_NO_EXTENSIONS: "true"

- name: Upload artifact
uses: actions/upload-artifact@v3
with:
name: sdist
path: dist/*.tar.gz
name: build
path: dist/*
if-no-files-found: error

- name: Install dependencies
run: |
python -m pip install -r tests/requirements.txt

- name: Install TorchOpt
run: |
python -m pip install -vvv dist/*.whl

- name: Test with pytest
run: |
make pytest

build-wheels-py37:
name: Build wheels for Python ${{ matrix.python-version }}
name: Build wheels for Python ${{ matrix.python-version }} on ubuntu-latest
runs-on: ubuntu-latest
needs: [build-sdist]
needs: [build]
if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
strategy:
matrix:
Expand Down Expand Up @@ -132,9 +147,9 @@ jobs:
if-no-files-found: error

build-wheels:
name: Build wheels for Python ${{ matrix.python-version }}
name: Build wheels for Python ${{ matrix.python-version }} on ubuntu-latest
runs-on: ubuntu-latest
needs: [build-sdist, build-wheels-py37]
needs: [build, build-wheels-py37]
if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
strategy:
matrix:
Expand Down Expand Up @@ -183,7 +198,7 @@ jobs:

publish:
runs-on: ubuntu-latest
needs: [build-sdist, build-wheels-py37, build-wheels]
needs: [build, build-wheels-py37, build-wheels]
if: |
github.repository == 'metaopt/torchopt' && github.event_name != 'pull_request' &&
(github.event_name != 'workflow_dispatch' || github.event.inputs.task == 'build-and-publish') &&
Expand Down Expand Up @@ -226,7 +241,7 @@ jobs:
with:
# unpacks default artifact into dist/
# if `name: artifact` is omitted, the action will create extra parent dir
name: sdist
name: build
path: dist

- name: Download built wheels
Expand All @@ -250,7 +265,7 @@ jobs:

- name: Publish to TestPyPI
if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch'
uses: pypa/gh-action-pypi-publish@v1.5.0
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.TESTPYPI_UPLOAD_TOKEN }}
Expand All @@ -261,7 +276,7 @@ jobs:

- name: Publish to PyPI
if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch'
uses: pypa/gh-action-pypi-publish@v1.5.0
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_UPLOAD_TOKEN }}
Expand Down
40 changes: 40 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ concurrency:

jobs:
test:
name: Test with CXX/CUDA extensions on ubuntu-latest
runs-on: ubuntu-latest
timeout-minutes: 60
steps:
Expand Down Expand Up @@ -88,3 +89,42 @@ jobs:
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false

test-pure-python:
name: Test for pure-Python on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
matrix:
os: [ubuntu-latest, macos-latest] # jaxlib is not available on Windows
fail-fast: false
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: "recursive"
fetch-depth: 1

- name: Set up Python 3.7
uses: actions/setup-python@v4
with:
python-version: "3.7" # the lowest version we support (sync with requires-python in pyproject.toml)
update-environment: true

- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel

- name: Install dependencies
run: |
python -m pip install -r tests/requirements.txt

- name: Install TorchOpt
run: |
python -m pip install -vvv -e .
env:
TORCHOPT_NO_EXTENSIONS: "true"

- name: Test with pytest
run: |
make pytest
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
##### Project Specific #####
third-party/

##### Python.gitignore #####
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -73,7 +76,6 @@ instance/

# Sphinx documentation
docs/_build/
docs/build/
docs/source/_build/

# PyBuilder
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add Python implementation of accelerated OP and pure-Python wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#67](https://github.com/metaopt/torchopt/pull/67).
- Add `nan_to_num` hook and gradient transformation by [@XuehaiPan](https://github.com/XuehaiPan) in [#119](https://github.com/metaopt/torchopt/pull/119).
- Add matrix inversion linear solver with neumann series approximation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#98](https://github.com/metaopt/torchopt/pull/98).
- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105).
Expand Down
38 changes: 30 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.
# ==============================================================================

cmake_minimum_required(VERSION 3.8)
cmake_minimum_required(VERSION 3.11) # for FetchContent
project(torchopt LANGUAGES CXX)

include(FetchContent)
set(PYBIND11_VERSION v2.10.1)

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
Expand All @@ -26,6 +29,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_package(Threads REQUIRED) # -pthread
find_package(OpenMP REQUIRED) # -Xpreprocessor -fopenmp
set(CMAKE_POSITION_INDEPENDENT_CODE ON) # -fPIC
set(CMAKE_CXX_VISIBILITY_PRESET hidden) # -fvisibility=hidden

if(MSVC)
string(APPEND CMAKE_CXX_FLAGS " /Wall")
Expand Down Expand Up @@ -186,7 +190,7 @@ if("${PYTHON_INCLUDE_DIR}" STREQUAL "")
message(FATAL_ERROR "Python include directory not found")
else()
message(STATUS "Detected Python include directory: \"${PYTHON_INCLUDE_DIR}\"")
include_directories(${PYTHON_INCLUDE_DIR})
include_directories("${PYTHON_INCLUDE_DIR}")
endif()

system(
Expand All @@ -195,6 +199,7 @@ system(
)
message(STATUS "Detected Python site packages: \"${PYTHON_SITE_PACKAGES}\"")

# Include pybind11
set(PYBIND11_PYTHON_VERSION "${PYTHON_VERSION}")

if(NOT DEFINED PYBIND11_CMAKE_DIR)
Expand All @@ -206,14 +211,27 @@ if(NOT DEFINED PYBIND11_CMAKE_DIR)
endif()

if("${PYBIND11_CMAKE_DIR}" STREQUAL "")
message(FATAL_ERROR "Pybind11 CMake directory not found")
FetchContent_Declare(
pybind11
GIT_REPOSITORY https://github.com/pybind/pybind11.git
GIT_TAG "${PYBIND11_VERSION}"
GIT_SHALLOW TRUE
SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/pybind11"
BINARY_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/build"
STAMP_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/stamp"
)
FetchContent_GetProperties(pybind11)
if(NOT pybind11_POPULATED)
message(STATUS "Populating Git repository pybind11@${PYBIND11_VERSION} to third-party/pybind11...")
FetchContent_MakeAvailable(pybind11)
endif()
else()
message(STATUS "Detected Pybind11 CMake directory: \"${PYBIND11_CMAKE_DIR}\"")
find_package(pybind11 CONFIG PATHS "${PYBIND11_CMAKE_DIR}")
endif()

if(NOT DEFINED TORCH_INCLUDE_PATH)
message(STATUS "Auto detecting PyTorch include directory...")
message(STATUS "Auto detecting Torch include directory...")
system(
STRIP OUTPUT_VARIABLE TORCH_INCLUDE_PATH
COMMAND "${PYTHON_EXECUTABLE}" -c "print('\\\;'.join(__import__('torch.utils.cpp_extension', fromlist=[None]).include_paths()))"
Expand All @@ -232,7 +250,7 @@ else()
endif()

if(NOT DEFINED TORCH_LIBRARY_PATH)
message(STATUS "Auto detecting PyTorch library directory...")
message(STATUS "Auto detecting Torch library directory...")
system(
STRIP OUTPUT_VARIABLE TORCH_LIBRARY_PATH
COMMAND "${PYTHON_EXECUTABLE}" -c "print('\\\;'.join(__import__('torch.utils.cpp_extension', fromlist=[None]).library_paths()))"
Expand All @@ -251,19 +269,23 @@ endif()

unset(TORCH_LIBRARIES)

foreach(VAR_PATH ${TORCH_LIBRARY_PATH})
file(GLOB TORCH_LIBRARY "${VAR_PATH}/*")
message(STATUS "Detected Torch libraries: \"${TORCH_LIBRARY}\"")
endforeach()

foreach(VAR_PATH ${TORCH_LIBRARY_PATH})
if(WIN32)
file(GLOB TORCH_LIBRARY "${VAR_PATH}/*.lib")
else()
file(GLOB TORCH_LIBRARY "${VAR_PATH}/libtorch_python.*")
endif()

list(APPEND TORCH_LIBRARIES "${TORCH_LIBRARY}")
endforeach()

message(STATUS "Detected Torch libraries: \"${TORCH_LIBRARIES}\"")
message(STATUS "Detected Torch Python libraries: \"${TORCH_LIBRARIES}\"")

add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)

include_directories(${CMAKE_SOURCE_DIR})
include_directories("${CMAKE_SOURCE_DIR}")
add_subdirectory(src)
6 changes: 3 additions & 3 deletions conda-recipe-minimal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ dependencies:
- nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7

# Build toolchain
- cmake >= 3.4
- cmake >= 3.11
- make
- cxx-compiler
- gxx = 10
- nvidia/label/cuda-11.7.1::cuda-nvcc
- nvidia/label/cuda-11.7.1::cuda-cudart-dev
- pybind11
- pybind11 >= 2.10.1

# Misc
- optree >= 0.4.0
- optree >= 0.4.1
- typing-extensions >= 4.0.0
- numpy
- python-graphviz
6 changes: 3 additions & 3 deletions conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ dependencies:
- nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7

# Build toolchain
- cmake >= 3.4
- cmake >= 3.11
- make
- cxx-compiler
- gxx = 10
- nvidia/label/cuda-11.7.1::cuda-nvcc
- nvidia/label/cuda-11.7.1::cuda-cudart-dev
- patchelf >= 0.14
- pybind11
- pybind11 >= 2.10.1

# Misc
- optree >= 0.4.0
- optree >= 0.4.1
- typing-extensions >= 4.0.0
- numpy
- matplotlib-base
Expand Down
6 changes: 3 additions & 3 deletions docs/conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ dependencies:
- sphinxcontrib-katex # for documentation

# Build toolchain
- cmake >= 3.4
- cmake >= 3.11
- make
- cxx-compiler
- gxx = 10
- nvidia/label/cuda-11.7.1::cuda-nvcc
- nvidia/label/cuda-11.7.1::cuda-cudart-dev
- pybind11
- pybind11 >= 2.10.1

# Misc
- optree >= 0.4.0
- optree >= 0.4.1
- typing-extensions >= 4.0.0
- numpy
- matplotlib-base
Expand Down
Loading