Skip to content

Commit

Permalink
refactor: reorganize code and add full type hint (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Sep 25, 2022
1 parent bc7a1c5 commit af6d24c
Show file tree
Hide file tree
Showing 55 changed files with 1,227 additions and 564 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.7 - 3.10"
python-version: "3.7 - 3.10" # sync with requires-python in pyproject.toml
update-environment: true

- name: Set __release__
Expand Down Expand Up @@ -95,7 +95,7 @@ jobs:
id: py
uses: actions/setup-python@v4
with:
python-version: "3.7 - 3.10"
python-version: "3.7 - 3.10" # sync with requires-python in pyproject.toml
update-environment: true

- name: Set __release__
Expand All @@ -107,7 +107,7 @@ jobs:
run: python setup.py --version

- name: Build wheels
uses: pypa/cibuildwheel@v2.10.0
uses: pypa/cibuildwheel@v2.10.2
with:
package-dir: .
output-dir: wheelhouse
Expand Down Expand Up @@ -138,7 +138,7 @@ jobs:
uses: actions/setup-python@v4
if: startsWith(github.ref, 'refs/tags/')
with:
python-version: "3.7 - 3.10"
python-version: "3.7 - 3.10" # sync with requires-python in pyproject.toml
update-environment: true

- name: Check consistency between the package version and release tag
Expand Down
9 changes: 7 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ jobs:
submodules: "recursive"
fetch-depth: 1

- name: Set up Python 3.7 # the lowest version we support
- name: Set up Python 3.7
uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.7" # the lowest version we support (sync with requires-python in pyproject.toml)
update-environment: true

- name: Setup CUDA Toolkit
Expand Down Expand Up @@ -90,6 +90,11 @@ jobs:
run: |
make addlicense
- name: Install dev version of mypy
run: |
python -m pip install git+https://github.com/python/mypy.git
python -m pip install types-setuptools
- name: mypy
run: |
make mypy
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ jobs:
submodules: "recursive"
fetch-depth: 1

- name: Set up Python 3.7 # the lowest version we support
- name: Set up Python 3.7
uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.7" # the lowest version we support (sync with requires-python in pyproject.toml)
update-environment: true

- name: Setup CUDA Toolkit
Expand Down
5 changes: 1 addition & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
##### Project specific #####
!torchopt/_src/
!torchopt/_lib/

##### Python.gitignore #####
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -31,6 +27,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
*.whl

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
5 changes: 3 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ persistent=yes

# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
py-version=3.7
py-version=3.7 # the lowest version we support (sync with requires-python in pyproject.toml)

# Discover python modules and packages in the file system subtree.
recursive=no
Expand Down Expand Up @@ -266,7 +266,8 @@ good-names=i,
t,
lr,
mu,
nu
nu,
x

# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
Expand Down
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add full type hints by [@XuehaiPan](https://github.com/XuehaiPan) in [#92](https://github.com/metaopt/torchopt/pull/92).
- Add API documentation and tutorial for implicit gradients by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan) in [#73](https://github.com/metaopt/torchopt/pull/73).
- Add wrapper class for functional optimizers and examples of `functorch` integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#6](https://github.com/metaopt/torchopt/pull/6).
- Implicit differentiation support by [@JieRen98](https://github.com/JieRen98) and [@waterhorse1](https://github.com/waterhorse1) and [@XuehaiPan](https://github.com/XuehaiPan) in [#41](https://github.com/metaopt/torchopt/pull/41).

### Changed


- Refactor code organization by [@XuehaiPan](https://github.com/XuehaiPan) in [#92](https://github.com/metaopt/torchopt/pull/92).

### Fixed

Expand Down
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ if(NOT DEFINED PYTHON_INCLUDE_DIR)
message(STATUS "Auto detecting Python include directory...")
system(
STRIP OUTPUT_VARIABLE PYTHON_INCLUDE_DIR
COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('include'))"
COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('platinclude'))"
)
endif()

Expand All @@ -191,7 +191,7 @@ endif()

system(
STRIP OUTPUT_VARIABLE PYTHON_SITE_PACKAGES
COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig') .get_path('purelib'))"
COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('purelib'))"
)
message(STATUS "Detected Python site packages: \"${PYTHON_SITE_PACKAGES}\"")

Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
recursive-include torchopt *.pyi
recursive-include torchopt *.typed
include LICENSE

# Include source files in sdist
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ flake8: flake8-install
$(PYTHON) -m flake8 $(PYTHON_FILES) --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics

py-format: py-format-install
$(PYTHON) -m isort --project torchopt --check $(PYTHON_FILES) && \
$(PYTHON) -m isort --project $(PROJECT_NAME) --check $(PYTHON_FILES) && \
$(PYTHON) -m black --check $(PYTHON_FILES)

mypy: mypy-install
Expand Down Expand Up @@ -143,10 +143,10 @@ clean-docs:

# Utility functions

lint: flake8 py-format mypy clang-format cpplint docstyle spelling
lint: flake8 py-format mypy pylint clang-format cpplint docstyle spelling

format: py-format-install clang-format-install addlicense-install
$(PYTHON) -m isort --project torchopt $(PYTHON_FILES)
$(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES)
$(PYTHON) -m black $(PYTHON_FILES)
clang-format -style=file -i $(CXX_FILES)
addlicense -c $(COPYRIGHT) -l apache -y 2022 $(SOURCE_FOLDERS)
Expand Down
6 changes: 3 additions & 3 deletions conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ dependencies:
- pip

# Learning
- pytorch::pytorch >= 1.12
- pytorch::pytorch >= 1.12 # sync with project.dependencies
- pytorch::torchvision
- pytorch::pytorch-mutex = *=*cuda*
- pip:
- functorch >= 0.2
- functorch >= 0.2 # sync with project.dependencies
- torchviz
- sphinxcontrib-katex # for documentation
- jax # for tutorials
Expand Down Expand Up @@ -64,7 +64,7 @@ dependencies:
- myst-nb
- ipykernel
- pandoc
- docutils = 0.16
- docutils

# Testing
- pytest
Expand Down
6 changes: 3 additions & 3 deletions docs/conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ dependencies:
- pip

# Learning
- pytorch::pytorch >= 1.12
- pytorch::pytorch >= 1.12 # sync with project.dependencies
- pytorch::pytorch-mutex = *=*cpu*
- pip:
- functorch >= 0.2
- functorch >= 0.2 # sync with project.dependencies
- torchviz
- sphinxcontrib-katex # for documentation
- tensorboard
Expand Down Expand Up @@ -68,4 +68,4 @@ dependencies:
- myst-nb
- ipykernel
- pandoc
- docutils = 0.16
- docutils
5 changes: 3 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
--extra-index-url https://download.pytorch.org/whl/cpu
# Sync with project.dependencies
torch >= 1.12
functorch >= 0.2

--requirement ../requirements.txt

sphinx >= 5.0
sphinx >= 5.0, < 5.2.0a0
sphinx-autoapi
sphinx-autobuild
sphinx-copybutton
Expand All @@ -16,5 +17,5 @@ IPython
ipykernel
pandoc
myst_nb
docutils == 0.16
docutils
matplotlib
14 changes: 7 additions & 7 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Differentiable Meta-RMSProp Optimizer
Implicit differentiation
========================

.. currentmodule:: torchopt._src.implicit_diff
.. currentmodule:: torchopt.implicit_diff

.. autosummary::

Expand All @@ -150,7 +150,7 @@ Custom solvers
Linear system solving
=====================

.. currentmodule:: torchopt._src.linear_solve
.. currentmodule:: torchopt.linear_solve

.. autosummary::

Expand All @@ -168,7 +168,7 @@ Indirect solvers
Optimizer Hooks
===============

.. currentmodule:: torchopt._src.hook
.. currentmodule:: torchopt.hook

.. autosummary::

Expand All @@ -186,7 +186,7 @@ Hook
Gradient Transformation
=======================

.. currentmodule:: torchopt._src.clip
.. currentmodule:: torchopt.clip

.. autosummary::

Expand All @@ -200,7 +200,7 @@ Transforms
Optimizer Schedules
===================

.. currentmodule:: torchopt._src.schedule
.. currentmodule:: torchopt.schedule

.. autosummary::

Expand Down Expand Up @@ -231,7 +231,7 @@ Apply Updates
Combining Optimizers
====================

.. currentmodule:: torchopt._src.combine
.. currentmodule:: torchopt.combine

.. autosummary::

Expand Down Expand Up @@ -273,7 +273,7 @@ Stop Gradient
Visualizing Gradient Flow
=========================

.. currentmodule:: torchopt._src.visual
.. currentmodule:: torchopt.visual

.. autosummary::

Expand Down
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,4 @@ optimality
argnums
matvec
Hermitian
deepcopy
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ seaborn
torchviz
torchrl
pillow
setproctitle
24 changes: 16 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Package ######################################################################

[build-system]
# Sync with project.dependencies
requires = ["setuptools", "torch >= 1.12", "numpy", "pybind11"]
build-backend = "setuptools.build_meta"

[project]
name = "torchopt"
description = "A Jax-style optimizer for PyTorch."
readme = "README.md"
requires-python = ">= 3.7"
# Change this if wheels for `torch` is available
# Search "requires-python" and update all corresponding items
requires-python = ">= 3.7, < 3.11.0a0"
authors = [
{ name = "TorchOpt Contributors" },
{ name = "Jie Ren", email = "jieren9806@gmail.com" },
Expand All @@ -29,6 +32,7 @@ keywords = [
classifiers = [
"Development Status :: 4 - Beta",
"License :: OSI Approved :: Apache Software License",
# Sync with requires-python
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
Expand All @@ -44,9 +48,9 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"torch >= 1.12",
"torch >= 1.12", # see also build-system.requires and project.requires-python
"functorch >= 0.2",
"optree",
"optree >= 0.2.0",
"numpy",
"graphviz",
"typing-extensions",
Expand Down Expand Up @@ -88,7 +92,7 @@ include = ["torchopt", "torchopt.*"]
# Wheel builder ################################################################
# Reference: https://cibuildwheel.readthedocs.io
[tool.cibuildwheel]
archs = ["x86_64"]
archs = ["auto64"]
build = "*manylinux*"
skip = "pp*"
build-frontend = "pip"
Expand Down Expand Up @@ -163,6 +167,7 @@ repair-wheel-command = """
safe = true
line-length = 100
skip-string-normalization = true
# Sync with requires-python
target-version = ["py37", "py38", "py39", "py310"]

[tool.isort]
Expand All @@ -174,16 +179,19 @@ lines_after_imports = 2
multi_line_output = 3

[tool.mypy]
# Sync with requires-python
python_version = 3.7
pretty = true
show_error_codes = true
show_error_context = true
show_traceback = true
allow_redefinition = true
enable_recursive_aliases = true
check_untyped_defs = true
disallow_incomplete_defs = false
disallow_untyped_defs = false
ignore_missing_imports = true
no_implicit_optional = true
pretty = true
show_error_codes = true
show_error_context = true
show_traceback = true
strict_equality = true
strict_optional = true
warn_no_return = true
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Sync with project.dependencies
torch >= 1.12
functorch >= 0.2
optree
optree >= 0.2.0
numpy
graphviz
typing-extensions
Loading

0 comments on commit af6d24c

Please sign in to comment.