Skip to content

Commit

Permalink
feat: half float support (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Sep 3, 2022
1 parent 9b68125 commit 39b533a
Show file tree
Hide file tree
Showing 34 changed files with 744 additions and 456 deletions.
3 changes: 3 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
BasedOnStyle: Google
ColumnLimit: 100
BinPackArguments: false
BinPackParameters: false
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
runs-on: ubuntu-latest
needs: [build-sdist]
if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
timeout-minutes: 60
timeout-minutes: 90
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ jobs:
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools
python -m pip install --upgrade pip setuptools wheel
- name: Install TorchOpt
env:
USE_FP16: "OFF"
TORCH_CUDA_ARCH_LIST: "Auto"
run: |
python -m pip install -vvv -e '.[lint]'
python -m pip install torch numpy pybind11
python -m pip install -vvv --no-build-isolation --editable '.[lint]'
- name: pre-commit
run: |
Expand Down
12 changes: 4 additions & 8 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,17 @@ jobs:
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools
- name: Install PyTorch and FuncTorch nightly
run: |
export PIP_EXTRA_INDEX_URL="${PIP_EXTRA_INDEX_URL//whl/whl\/nightly}"
python -m pip install 'torch >= 1.13.0dev' ninja
python -m pip install git+https://github.com/pytorch/functorch.git
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install -r tests/requirements.txt
- name: Install TorchOpt
env:
USE_FP16: "ON"
TORCH_CUDA_ARCH_LIST: "Common"
run: |
export PIP_EXTRA_INDEX_URL="${PIP_EXTRA_INDEX_URL//whl/whl\/nightly}"
python -m pip install -vvv -e .
- name: Test with pytest
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
- id: isort
stages: [commit, push, manual]
- repo: https://github.com/psf/black
rev: 22.6.0
rev: 22.8.0
hooks:
- id: black
stages: [commit, push, manual]
Expand Down
6 changes: 2 additions & 4 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ fail-on=
fail-under=10.0

# Files or directories to be skipped. They should be base names, not paths.
ignore=CVS,.vscode,.history,
examples,
tests
ignore=CVS,.vscode,.history

# Add files or directories matching the regex patterns to the ignore-list. The
# regex matches against paths and can be in Posix or Windows format.
ignore-paths=
ignore-paths=^_C/$,^examples/$,^tests/$

# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths. The default value ignores emacs file
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add half float support for accelerated OPs by [@XuehaiPan](https://github.com/XuehaiPan) in [#67](https://github.com/metaopt/torchopt/pull/67).
- Add MAML example with TorchRL integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#12](https://github.com/metaopt/TorchOpt/pull/12).
- Add optional argument `params` to update function in gradient transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
- Add option `weight_decay` option to optimizers by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
Expand Down
136 changes: 119 additions & 17 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,98 @@ endif()

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -pthread -fPIC -fopenmp")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")

find_package(CUDA)
find_package(Threads REQUIRED) # -pthread
find_package(OpenMP REQUIRED) # -Xpreprocessor -fopenmp
set(CMAKE_POSITION_INDEPENDENT_CODE ON) # -fPIC

if(MSVC)
string(APPEND CMAKE_CXX_FLAGS " /Wall")
string(APPEND CMAKE_CXX_FLAGS_DEBUG " /Zi")
string(APPEND CMAKE_CXX_FLAGS_RELEASE " /O2 /Ob2")
else()
string(APPEND CMAKE_CXX_FLAGS " -Wall")
string(APPEND CMAKE_CXX_FLAGS_DEBUG " -g -Og")
string(APPEND CMAKE_CXX_FLAGS_RELEASE " -O3")
endif()

if(NOT DEFINED USE_FP16 AND NOT "$ENV{USE_FP16}" STREQUAL "")
set(USE_FP16 "$ENV{USE_FP16}")
endif()

if(NOT DEFINED USE_FP16)
set(USE_FP16 OFF)
message(WARNING "FP16 support disabled, compiling without torch.HalfTensor. Suppress this warning with -DUSE_FP16=ON or -DUSE_FP16=OFF.")
elseif(USE_FP16)
message(STATUS "FP16 support enabled, compiling with torch.HalfTensor.")
else()
message(STATUS "FP16 support disabled, compiling without torch.HalfTensor.")
endif()

if(USE_FP16)
add_definitions(-DUSE_FP16)
endif()

if(CUDA_FOUND)
find_package(CUDA)
if(CUDA_FOUND AND NOT WIN32)
message(STATUS "Found CUDA, enabling CUDA support.")
enable_language(CUDA)
add_definitions(-D__CUDA_ENABLED__)
set(CMAKE_CUDA_STANDARD "${CMAKE_CXX_STANDARD}")
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
add_definitions(-D__USE_CUDA__)

string(APPEND CMAKE_CUDA_FLAGS " $ENV{TORCH_NVCC_FLAGS}")

if(NOT DEFINED TORCH_CUDA_ARCH_LIST AND NOT "$ENV{TORCH_CUDA_ARCH_LIST}" STREQUAL "")
set(TORCH_CUDA_ARCH_LIST "$ENV{TORCH_CUDA_ARCH_LIST}")
endif()

cuda_select_nvcc_arch_flags(CUDA_ARCH_FLAGS All)
if(NOT TORCH_CUDA_ARCH_LIST)
set(TORCH_CUDA_ARCH_LIST "Auto")
message(WARNING "Torch CUDA arch list is not set, setting to \"Auto\". Suppress this warning with -DTORCH_CUDA_ARCH_LIST=Common.")
endif()

set(CMAKE_CUDA_ARCHITECTURES OFF)
cuda_select_nvcc_arch_flags(CUDA_ARCH_FLAGS ${TORCH_CUDA_ARCH_LIST})
message(STATUS "TORCH_CUDA_ARCH_LIST: \"${TORCH_CUDA_ARCH_LIST}\"")
message(STATUS "CUDA_ARCH_FLAGS: \"${CUDA_ARCH_FLAGS}\"")
list(APPEND CUDA_NVCC_FLAGS ${CUDA_ARCH_FLAGS})
set(CMAKE_CUDA_STANDARD 14)
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -O3")
else()

list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda")
if(CUDA_HAS_FP16 OR NOT "${CUDA_VERSION}" VERSION_LESS "7.5")
if (USE_FP16)
message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor.")
string(APPEND CMAKE_CUDA_FLAGS " -DCUDA_HAS_FP16=1"
" -D__CUDA_NO_HALF_OPERATORS__"
" -D__CUDA_NO_HALF_CONVERSIONS__"
" -D__CUDA_NO_HALF2_OPERATORS__"
" -D__CUDA_NO_BFLOAT16_CONVERSIONS__")
else()
message(STATUS "Found CUDA with FP16 support, but it is suppressed by the compile options, compiling without torch.cuda.HalfTensor.")
endif()
else()
message(STATUS "Could not find CUDA with FP16 support, compiling without torch.cuda.HalfTensor.")
endif()

foreach(FLAG ${CUDA_NVCC_FLAGS})
string(FIND "${FLAG}" " " flag_space_position)
if(NOT flag_space_position EQUAL -1)
message(FATAL_ERROR "Found spaces in CUDA_NVCC_FLAGS entry '${FLAG}'")
endif()
string(APPEND CMAKE_CUDA_FLAGS " ${FLAG}")
endforeach()
string(STRIP "${CMAKE_CUDA_FLAGS}" CMAKE_CUDA_FLAGS)
message(STATUS "CMAKE_CUDA_FLAGS: \"${CMAKE_CUDA_FLAGS}\"")

if(MSVC)
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} /O2 /Ob2")
else()
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -O3")
endif()
elseif(NOT CUDA_FOUND)
message(STATUS "CUDA not found, build for CPU-only.")
else()
message(STATUS "CUDA found, but build for CPU-only on Windows.")
endif()

function(system)
Expand Down Expand Up @@ -76,13 +152,19 @@ function(system)
endfunction()

if(NOT DEFINED PYTHON_EXECUTABLE)
set(PYTHON_EXECUTABLE python3)
if(WIN32)
set(PYTHON_EXECUTABLE "python.exe")
else()
set(PYTHON_EXECUTABLE "python")
endif()
endif()

system(
STRIP OUTPUT_VARIABLE PYTHON_EXECUTABLE
COMMAND bash -c "type -P '${PYTHON_EXECUTABLE}'"
)
if(UNIX)
system(
STRIP OUTPUT_VARIABLE PYTHON_EXECUTABLE
COMMAND bash -c "type -P '${PYTHON_EXECUTABLE}'"
)
endif()

system(
STRIP OUTPUT_VARIABLE PYTHON_VERSION
Expand All @@ -107,6 +189,12 @@ else()
include_directories(${PYTHON_INCLUDE_DIR})
endif()

system(
STRIP OUTPUT_VARIABLE PYTHON_SITE_PACKAGES
COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig') .get_path('purelib'))"
)
message(STATUS "Detected Python site packages: \"${PYTHON_SITE_PACKAGES}\"")

set(PYBIND11_PYTHON_VERSION "${PYTHON_VERSION}")

if(NOT DEFINED PYBIND11_CMAKE_DIR)
Expand All @@ -130,10 +218,14 @@ if(NOT DEFINED TORCH_INCLUDE_PATH)
STRIP OUTPUT_VARIABLE TORCH_INCLUDE_PATH
COMMAND "${PYTHON_EXECUTABLE}" -c "print('\\\;'.join(__import__('torch.utils.cpp_extension', fromlist=[None]).include_paths()))"
)

if("${TORCH_INCLUDE_PATH}" STREQUAL "")
set(TORCH_INCLUDE_PATH "${PYTHON_SITE_PACKAGES}/torch/include")
endif()
endif()

if("${TORCH_INCLUDE_PATH}" STREQUAL "")
message(FATAL_ERROR "Torch include directory not found")
message(FATAL_ERROR "Torch include directory not found. Got: \"${TORCH_INCLUDE_PATH}\"")
else()
message(STATUS "Detected Torch include directory: \"${TORCH_INCLUDE_PATH}\"")
include_directories(${TORCH_INCLUDE_PATH})
Expand All @@ -145,18 +237,28 @@ if(NOT DEFINED TORCH_LIBRARY_PATH)
STRIP OUTPUT_VARIABLE TORCH_LIBRARY_PATH
COMMAND "${PYTHON_EXECUTABLE}" -c "print('\\\;'.join(__import__('torch.utils.cpp_extension', fromlist=[None]).library_paths()))"
)

if("${TORCH_LIBRARY_PATH}" STREQUAL "")
set(TORCH_LIBRARY_PATH "${PYTHON_SITE_PACKAGES}/torch/lib")
endif()
endif()

if("${TORCH_LIBRARY_PATH}" STREQUAL "")
message(FATAL_ERROR "Torch library directory not found")
message(FATAL_ERROR "Torch library directory not found. Got: \"${TORCH_LIBRARY_PATH}\"")
else()
message(STATUS "Detected Torch library directory: \"${TORCH_LIBRARY_PATH}\"")
endif()

unset(TORCH_LIBRARIES)

foreach(VAR_PATH ${TORCH_LIBRARY_PATH})
list(APPEND TORCH_LIBRARIES "${VAR_PATH}/libtorch_python.so")
if(WIN32)
file(GLOB TORCH_LIBRARY "${VAR_PATH}/*.lib")
else()
file(GLOB TORCH_LIBRARY "${VAR_PATH}/libtorch_python.*")
endif()

list(APPEND TORCH_LIBRARIES "${TORCH_LIBRARY}")
endforeach()

message(STATUS "Detected Torch libraries: \"${TORCH_LIBRARIES}\"")
Expand Down
1 change: 1 addition & 0 deletions CPPLINT.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
linelength=100
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ default: install
install:
$(PYTHON) -m pip install .

install-editable:
$(PYTHON) -m pip install --upgrade pip
$(PYTHON) -m pip install --upgrade setuptools wheel
$(PYTHON) -m pip install torch numpy pybind11
USE_FP16=ON TORCH_CUDA_ARCH_LIST=Auto $(PYTHON) -m pip install -vvv --no-build-isolation --editable .

install-e: install-editable # alias

build:
$(PYTHON) -m pip install --upgrade pip
$(PYTHON) -m pip install --upgrade setuptools wheel build
Expand Down
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,19 @@ cd torchopt
CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml

conda activate torchopt
pip3 install --no-build-isolation --editable .
make install-editable # or run `pip3 install --no-build-isolation --editable .`
```

--------------------------------------------------------------------------------

## Future Plan

- [x] CPU-accelerated optimizer
- [ ] Support general implicit differentiation with functional programing.
- [ ] Support general implicit differentiation with functional programing
- [ ] Support more optimizers such as AdamW, RMSProp
- [ ] Zero order optimization
- [ ] Distributed optimizers
- [ ] Support `complex` data type

## Changelog

Expand Down
2 changes: 1 addition & 1 deletion docs/source/_static/css/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/developer/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ To install TorchOpt in an "editable" mode, run:

.. code-block:: bash
pip3 install --no-build-isolation --editable .
make install-editable # or run `pip3 install --no-build-isolation --editable .`
in the main directory. This installation is removable by:

Expand Down
Loading

0 comments on commit 39b533a

Please sign in to comment.