Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: align argument names with PyTorch #65

Merged
merged 33 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5e4676e
refactor: refactor optimizer bases
XuehaiPan Aug 25, 2022
20d5377
refactor: align Adam options with PyTorch
XuehaiPan Aug 25, 2022
bbc3bdd
refactor: align RMSProp options with PyTorch
XuehaiPan Aug 25, 2022
a45111a
refactor: align SGD options with PyTorch
XuehaiPan Aug 25, 2022
0c30794
feat(alias): check value range
XuehaiPan Aug 25, 2022
4a985ea
feat: add `params` to `update_fn`'s signature
XuehaiPan Aug 25, 2022
a3cb948
feat: add weight decay
XuehaiPan Aug 26, 2022
cbf1a52
test: add weight decay tests
XuehaiPan Aug 26, 2022
a0be5d9
lint: pass lint
XuehaiPan Aug 26, 2022
71a4a63
docs: update docstring
XuehaiPan Aug 26, 2022
2ab1da0
chore: update type hints
XuehaiPan Aug 26, 2022
a7d7643
fix: fix grad tracing for weight decay
XuehaiPan Aug 26, 2022
63d77cc
test: reorganize tests
XuehaiPan Aug 26, 2022
b31e245
chore: add RMSprop aliases for PyTorch compatibility
XuehaiPan Aug 26, 2022
a929c51
test: add module buffers
XuehaiPan Aug 26, 2022
f082d6c
test: update test parameters
XuehaiPan Aug 26, 2022
1c8421d
chore: update .gitignore
XuehaiPan Aug 26, 2022
2723c01
test: update test parameters
XuehaiPan Aug 27, 2022
1ec3d0f
refactor: refactor transform
XuehaiPan Aug 27, 2022
e8bd609
refactor: chain
XuehaiPan Aug 27, 2022
353b628
refactor: identity
XuehaiPan Aug 27, 2022
bee81bb
feat: add with_flattened_tree
XuehaiPan Aug 27, 2022
a8b6dc0
test: update test parameters
XuehaiPan Aug 27, 2022
7d1d20a
feat: add dampening
XuehaiPan Aug 27, 2022
58dcc56
docs: update docstring
XuehaiPan Aug 28, 2022
acde3fd
lint: fix mypy
XuehaiPan Aug 28, 2022
b1e1521
fix: fix grad tracing for initial value
XuehaiPan Aug 29, 2022
6331cdc
test: update test parameters
XuehaiPan Aug 29, 2022
1b39ad5
docs: update docstrings
XuehaiPan Aug 29, 2022
9976f96
chore: rename variables
XuehaiPan Aug 29, 2022
4c726f2
test: update test parameters
XuehaiPan Aug 29, 2022
3eee243
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Aug 29, 2022
1fded5d
test: test with pre-release of PyTorch
XuehaiPan Aug 29, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ jobs:
- run: |
CUDA_VERSION="${{steps.cuda-toolkit.outputs.cuda}}"
echo "CUDA_VERSION=${CUDA_VERSION}" >> "${GITHUB_ENV}"
TORCH_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')"
echo "TORCH_INDEX_URL=${TORCH_INDEX_URL}" >> "${GITHUB_ENV}"
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')"
echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" >> "${GITHUB_ENV}"

echo "Installed CUDA version is: ${CUDA_VERSION}"
echo "CUDA install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}"
nvcc -V
echo "Torch index URL: ${TORCH_INDEX_URL}"
echo "Torch index URL: ${PIP_EXTRA_INDEX_URL}"

- name: Upgrade pip
run: |
Expand Down Expand Up @@ -92,8 +92,7 @@ jobs:

- name: Install dependencies
run: |
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
-r docs/requirements.txt
python -m pip install -r docs/requirements.txt

- name: docstyle
run: |
Expand Down
16 changes: 11 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,31 @@ jobs:
- run: |
CUDA_VERSION="${{steps.cuda-toolkit.outputs.cuda}}"
echo "CUDA_VERSION=${CUDA_VERSION}" >> "${GITHUB_ENV}"
TORCH_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')"
echo "TORCH_INDEX_URL=${TORCH_INDEX_URL}" >> "${GITHUB_ENV}"
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')"
echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" >> "${GITHUB_ENV}"

echo "Installed CUDA version is: ${CUDA_VERSION}"
echo "CUDA install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}"
nvcc -V
echo "Torch index URL: ${TORCH_INDEX_URL}"
echo "Torch index URL: ${PIP_EXTRA_INDEX_URL}"

- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools

- name: Install PyTorch and FuncTorch nightly
run: |
export PIP_EXTRA_INDEX_URL="${PIP_EXTRA_INDEX_URL//whl/whl\/nightly}"
python -m pip install 'torch >= 1.13.0dev' ninja
python -m pip install git+https://github.com/pytorch/functorch.git

- name: Install dependencies
run: |
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
-r tests/requirements.txt
python -m pip install -r tests/requirements.txt

- name: Install TorchOpt
run: |
export PIP_EXTRA_INDEX_URL="${PIP_EXTRA_INDEX_URL//whl/whl\/nightly}"
python -m pip install -vvv -e .

- name: Test with pytest
Expand Down
10 changes: 5 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,12 @@ fabric.properties

##### Vim.gitignore #####
# Swap
[._]*.s[a-v][a-z]
.*.s[a-v][a-z]
!*.svg # comment out if you don't need vector files
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]
.*.sw[a-p]
.s[a-rt-v][a-z]
.ss[a-gi-z]
.sw[a-p]

# Session
Session.vim
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add optional argument `params` to update function in gradient transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
- Add option `weight_decay` option to optimizers by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
- Add option `maximize` option to optimizers by [@XuehaiPan](https://github.com/XuehaiPan) in [#64](https://github.com/metaopt/torchopt/pull/64).
- Refactor tests using `pytest.mark.parametrize` and enabling parallel testing by [@XuehaiPan](https://github.com/XuehaiPan) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#55](https://github.com/metaopt/torchopt/pull/55).
- Add maml-omniglot few-shot classification example using functorch.vmap by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#39](https://github.com/metaopt/torchopt/pull/39).
Expand All @@ -21,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Align argument names with PyTorch by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
- Replace JAX PyTrees with OpTree by [@XuehaiPan](https://github.com/XuehaiPan) in [#62](https://github.com/metaopt/torchopt/pull/62).
- Update image link in README to support PyPI rendering by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#56](https://github.com/metaopt/torchopt/pull/56).

Expand Down
9 changes: 4 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ RUN echo "export PS1='[\[\e[1;33m\]\u\[\e[0m\]:\[\e[1;35m\]\w\[\e[0m\]]\$ '" >>

# Setup virtual environment
RUN /usr/bin/python3.9 -m venv --upgrade-deps ~/venv && rm -rf ~/.pip/cache
RUN TORCH_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')" && \
echo "export TORCH_INDEX_URL='${TORCH_INDEX_URL}'" >> ~/venv/bin/activate && \
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')" && \
echo "export PIP_EXTRA_INDEX_URL='${PIP_EXTRA_INDEX_URL}'" >> ~/venv/bin/activate && \
echo "source /home/torchopt/venv/bin/activate" >> ~/.bashrc

# Install dependencies
WORKDIR /home/torchopt/torchopt
COPY --chown=torchopt requirements.txt requirements.txt
RUN source ~/venv/bin/activate && \
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" -r requirements.txt && \
python -m pip install -r requirements.txt && \
rm -rf ~/.pip/cache ~/.cache/pip

####################################################################################################
Expand All @@ -63,8 +63,7 @@ RUN go install github.com/google/addlicense@latest
COPY --chown=torchopt tests/requirements.txt tests/requirements.txt
COPY --chown=torchopt tutorials/requirements.txt tutorials/requirements.txt
RUN source ~/venv/bin/activate && \
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
-r tests/requirements.txt -r tutorials/requirements.txt && \
python -m pip install -r tests/requirements.txt -r tutorials/requirements.txt && \
rm -rf ~/.pip/cache ~/.cache/pip

####################################################################################################
Expand Down
23 changes: 20 additions & 3 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from torch.utils import data


BATCH_SIZE = 4
NUM_UPDATES = 3
BATCH_SIZE = 64
NUM_UPDATES = 5

MODEL_NUM_INPUTS = 28 * 28 # MNIST
MODEL_NUM_CLASSES = 10
Expand Down Expand Up @@ -82,13 +82,23 @@ def get_models(
bias=True,
dtype=dtype,
),
nn.BatchNorm1d(
num_features=MODEL_HIDDEN_SIZE,
track_running_stats=True,
dtype=dtype,
),
nn.ReLU(),
nn.Linear(
in_features=MODEL_HIDDEN_SIZE,
out_features=MODEL_HIDDEN_SIZE,
bias=True,
dtype=dtype,
),
nn.BatchNorm1d(
num_features=MODEL_HIDDEN_SIZE,
track_running_stats=True,
dtype=dtype,
),
nn.ReLU(),
nn.Linear(
in_features=MODEL_HIDDEN_SIZE,
Expand All @@ -99,7 +109,7 @@ def get_models(
nn.Softmax(dim=-1),
)
for name, param in model_base.named_parameters(recurse=True):
if name.endswith('weight'):
if name.endswith('weight') and param.ndim >= 2:
nn.init.orthogonal_(param)
if name.endswith('bias'):
param.data.normal_(0, 0.1)
Expand Down Expand Up @@ -160,6 +170,13 @@ def assert_all_close(
actual = actual - base
expected = expected - base

if rtol is None or atol is None:
from torch.testing._comparison import get_tolerances

rtol, atol = get_tolerances(actual, expected, rtol=rtol, atol=atol)
rtol *= 4 * NUM_UPDATES
atol *= 4 * NUM_UPDATES

torch.testing.assert_close(
actual,
expected,
Expand Down
Loading