Skip to content

Commit

Permalink
deps: bump optree to 0.3.0 (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Oct 26, 2022
1 parent aaec9a9 commit 73983d5
Show file tree
Hide file tree
Showing 32 changed files with 232 additions and 142 deletions.
83 changes: 74 additions & 9 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,16 @@ jobs:
path: dist/*.tar.gz
if-no-files-found: error

build-wheels:
build-wheels-py37:
name: Build wheels for Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
needs: [build-sdist]
if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
timeout-minutes: 90
strategy:
matrix:
python-version: ["3.7"] # sync with requires-python in pyproject.toml
fail-fast: false
timeout-minutes: 30
steps:
- name: Checkout
uses: actions/checkout@v3
Expand All @@ -94,24 +99,77 @@ jobs:
fetch-depth: 0

- name: Set up Python
id: py
uses: actions/setup-python@v4
with:
python-version: "3.7 - 3.10" # sync with requires-python in pyproject.toml
python-version: ${{ matrix.python-version }}
update-environment: true

- name: Set __release__
if: |
startsWith(github.ref, 'refs/tags/') ||
(github.event_name == 'workflow_dispatch' && github.event.inputs.task == 'build-and-publish')
run: |
python .github/workflows/set_release.py
run: python .github/workflows/set_release.py

- name: Print version
run: python setup.py --version

- name: Set CIBW_BUILD
run: python .github/workflows/set_cibw_build.py

- name: Build wheels
uses: pypa/cibuildwheel@v2.11.1
env:
CIBW_BUILD: ${{ env.CIBW_BUILD }}
with:
package-dir: .
output-dir: wheelhouse
config-file: "{package}/pyproject.toml"

- uses: actions/upload-artifact@v3
with:
name: wheels-py37
path: wheelhouse/*.whl
if-no-files-found: error

build-wheels:
name: Build wheels for Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
needs: [build-sdist, build-wheels-py37]
if: github.repository == 'metaopt/torchopt' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"] # sync with requires-python in pyproject.toml
fail-fast: false
timeout-minutes: 30
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: "recursive"
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
update-environment: true

- name: Set __release__
if: |
startsWith(github.ref, 'refs/tags/') ||
(github.event_name == 'workflow_dispatch' && github.event.inputs.task == 'build-and-publish')
run: python .github/workflows/set_release.py

- name: Print version
run: python setup.py --version

- name: Set CIBW_BUILD
run: python .github/workflows/set_cibw_build.py

- name: Build wheels
uses: pypa/cibuildwheel@v2.11.1
env:
CIBW_BUILD: ${{ env.CIBW_BUILD }}
with:
package-dir: .
output-dir: wheelhouse
Expand All @@ -125,7 +183,7 @@ jobs:

publish:
runs-on: ubuntu-latest
needs: [build-sdist, build-wheels]
needs: [build-sdist, build-wheels-py37, build-wheels]
if: |
github.repository == 'metaopt/torchopt' && github.event_name != 'pull_request' &&
(github.event_name != 'workflow_dispatch' || github.event.inputs.task == 'build-and-publish') &&
Expand Down Expand Up @@ -171,6 +229,14 @@ jobs:
name: sdist
path: dist

- name: Download built wheels
uses: actions/download-artifact@v3
with:
# unpacks default artifact into dist/
# if `name: artifact` is omitted, the action will create extra parent dir
name: wheels-py37
path: dist

- name: Download built wheels
uses: actions/download-artifact@v3
with:
Expand All @@ -180,8 +246,7 @@ jobs:
path: dist

- name: List distributions
run:
ls -lh dist/*
run: ls -lh dist/*

- name: Publish to TestPyPI
if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch'
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/set_cibw_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os
import sys


# pylint: disable-next=consider-using-f-string
CIBW_BUILD = 'CIBW_BUILD=*cp%d%d-*manylinux*' % sys.version_info[:2]

print(CIBW_BUILD)
with open(os.getenv('GITHUB_ENV'), mode='at', encoding='UTF-8') as file:
print(CIBW_BUILD, file=file)
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ jobs:
make pytest
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV }}
token: ${{ secrets.CODECOV_TOKEN }}
file: ./tests/coverage.xml
flags: unittests
name: codecov-umbrella
Expand Down
12 changes: 9 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CXX_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.h" -o -name "*.
COMMIT_HASH = $(shell git log -1 --format=%h)
PATH := $(HOME)/go/bin:$(PATH)
PYTHON ?= $(shell command -v python3 || command -v python)
CLANG_FORMAT ?= $(shell command -v clang-format-14 || command -v clang-format)

.PHONY: default
default: install
Expand All @@ -24,6 +25,9 @@ install-editable:

install-e: install-editable # alias

uninstall:
$(PYTHON) -m pip uninstall -y $(PROJECT_NAME)

build:
$(PYTHON) -m pip install --upgrade pip
$(PYTHON) -m pip install --upgrade setuptools wheel build
Expand Down Expand Up @@ -79,7 +83,9 @@ cpplint-install:
$(call check_pip_install,cpplint)

clang-format-install:
command -v clang-format || sudo apt-get install -y clang-format
command -v clang-format-14 || command -v clang-format || \
sudo apt-get install -y clang-format-14 || \
sudo apt-get install -y clang-format

clang-tidy-install:
command -v clang-tidy || sudo apt-get install -y clang-tidy
Expand Down Expand Up @@ -125,7 +131,7 @@ cpplint: cpplint-install
$(PYTHON) -m cpplint $(CXX_FILES)

clang-format: clang-format-install
clang-format --style=file -i $(CXX_FILES) -n --Werror
$(CLANG_FORMAT) --style=file -i $(CXX_FILES) -n --Werror

# Documentation

Expand Down Expand Up @@ -153,7 +159,7 @@ lint: flake8 py-format mypy pylint clang-format cpplint docstyle spelling
format: py-format-install clang-format-install addlicense-install
$(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES)
$(PYTHON) -m black $(PYTHON_FILES) tutorials
clang-format -style=file -i $(CXX_FILES)
$(CLANG_FORMAT) -style=file -i $(CXX_FILES)
addlicense -c $(COPYRIGHT) -l apache -y 2022 $(SOURCE_FOLDERS)

clean-py:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/developer/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ in the main directory. This installation is removable by:

.. code-block:: bash
pip3 uninstall torchopt
make uninstall
Lint Check
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ classifiers = [
dependencies = [
"torch >= 1.12", # see also build-system.requires and project.requires-python
"functorch >= 0.2",
"optree >= 0.2.0",
"optree >= 0.3.0",
"numpy",
"graphviz",
"typing-extensions",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Sync with project.dependencies
torch >= 1.12
functorch >= 0.2
optree >= 0.2.0
optree >= 0.3.0
numpy
graphviz
typing-extensions
27 changes: 24 additions & 3 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def parametrize(**argvalues) -> pytest.mark.parametrize:
argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments))))
first_product = argvalues[0]
argvalues.extend((dtype,) + first_product[1:] for dtype in dtypes[1:])
else:
argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments))))

ids = tuple(
'-'.join(f'{arg}({val})' for arg, val in zip(arguments, values)) for values in argvalues
Expand All @@ -77,14 +79,33 @@ def seed_everything(seed: int) -> None:
pass


class MyLinear(nn.Module):
def __init__(
self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None
) -> None:
super().__init__()
self.linear = nn.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
self.unused_module = nn.Linear(1, 1, bias=False)
self.unused_parameter = nn.Parameter(torch.zeros(1, 1), requires_grad=True)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)


@torch.no_grad()
def get_models(
device: Optional[Union[str, torch.device]] = None, dtype: torch.dtype = torch.float32
) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]:
seed_everything(seed=42)

model_base = nn.Sequential(
nn.Linear(
MyLinear(
in_features=MODEL_NUM_INPUTS,
out_features=MODEL_HIDDEN_SIZE,
bias=True,
Expand Down Expand Up @@ -178,8 +199,8 @@ def assert_all_close(
from torch.testing._comparison import get_tolerances

rtol, atol = get_tolerances(actual, expected, rtol=rtol, atol=atol)
rtol *= 4 * NUM_UPDATES
atol *= 4 * NUM_UPDATES
rtol *= 5 * NUM_UPDATES
atol *= 5 * NUM_UPDATES

torch.testing.assert_close(
actual,
Expand Down
12 changes: 6 additions & 6 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_sgd(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params)
grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

Expand Down Expand Up @@ -134,7 +134,7 @@ def test_adam(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params)
grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

Expand Down Expand Up @@ -192,7 +192,7 @@ def test_adamw(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params)
grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

Expand Down Expand Up @@ -251,7 +251,7 @@ def test_adam_accelerated_cpu(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params)
grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

Expand Down Expand Up @@ -313,7 +313,7 @@ def test_adam_accelerated_cuda(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params)
grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

Expand Down Expand Up @@ -374,7 +374,7 @@ def test_rmsprop(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params)
grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_lr_linear_schedule(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params)
grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

Expand Down
16 changes: 7 additions & 9 deletions torchopt/alias/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,14 @@ def update_fn(updates, state, *, params=None, inplace=True):
if inplace:

def f(g, p):
if g is not None:
if g.requires_grad:
return g.add_(p, alpha=weight_decay)
return g.add_(p.data, alpha=weight_decay)
return None
if g.requires_grad:
return g.add_(p, alpha=weight_decay)
return g.add_(p.data, alpha=weight_decay)

else:

def f(g, p):
return g.add(p, alpha=weight_decay) if g is not None else None
return g.add(p, alpha=weight_decay)

updates = tree_map_flat(f, updates, params)
return updates, state
Expand All @@ -66,12 +64,12 @@ def update_fn(updates, state, *, params=None, inplace=True):
if inplace:

def f(g):
return g.neg_() if g is not None else None
return g.neg_()

else:

def f(g):
return g.neg() if g is not None else None
return g.neg()

updates = tree_map_flat(f, updates)
return updates, state
Expand All @@ -96,7 +94,7 @@ def f(g, p):
else:

def f(g, p):
return g.neg().add_(p, alpha=weight_decay) if g is not None else None
return g.neg().add_(p, alpha=weight_decay)

updates = tree_map_flat(f, updates, params)
return updates, state
Expand Down
Loading

0 comments on commit 73983d5

Please sign in to comment.