Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Examples of parallel training on one GPU using functorch.vmap with torchopt #32

Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Add question/help/support issue template [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#43](https://github.com/metaopt/TorchOpt/pull/43).
- Add parallel training on one GPU using functorch.vmap example [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#32](https://github.com/metaopt/TorchOpt/pull/32).


### Changed

Expand Down
207 changes: 207 additions & 0 deletions examples/FuncTorch/parallel_train_torchopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import argparse
import math
from collections import namedtuple
from typing import Any, NamedTuple

import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchopt


def make_spirals(n_samples, noise_std=0.0, rotations=1.0, device='cpu'):
ts = torch.linspace(0, 1, n_samples, device=device)
rs = ts**0.5
thetas = rs * rotations * 2 * math.pi
signs = torch.randint(0, 2, (n_samples,), device=device) * 2 - 1
labels = (signs > 0).to(torch.long).to(device)

xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=device) * noise_std
ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=device) * noise_std
points = torch.stack([xs, ys], dim=1)
return points, labels


class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes

self.fc1 = nn.Linear(2, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)

def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x


class ParallelTrainFunctorchOriginal:
def __init__(self, loss_fn, lr, device):
self.device = device
self.loss_fn = loss_fn
self.lr = lr
self.func_model, _ = functorch.make_functional(MLPClassifier().to(self.device))

def init_fn(self, num_models):
models = [MLPClassifier().to(self.device) for _ in range(num_models)]
_, batched_weights, _ = functorch.combine_state_for_ensemble(models)
return batched_weights

def train_step_fn(self, weights, batch, targets):
def compute_loss(weights, batch, targets):
output = self.func_model(weights, batch)
loss = self.loss_fn(output, targets)
return loss

grad_weights, loss = functorch.grad_and_value(compute_loss)(weights, batch, targets)
# NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
# so we are going to re-implement SGD here.
new_weights = []
with torch.no_grad():
for grad_weight, weight in zip(grad_weights, weights):
new_weights.append(weight - grad_weight * self.lr)

return loss, new_weights

def test_train_step_fn(self, weights, points, labels):
for i in range(2000):
loss, weights = self.train_step_fn(weights, points, labels)
if i % 100 == 0:
print(loss)

def test_parallel_train_step_fn(self, num_models):
parallel_train_step_fn = functorch.vmap(self.train_step_fn, in_dims=(0, None, None))
batched_weights = self.init_fn(num_models=num_models)
for i in range(2000):
loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels)
if i % 200 == 0:
print(loss)


class ParallelTrainFunctorchTorchOpt:
def __init__(self, loss_fn, optimizer, device):
self.device = device
self.loss_fn = loss_fn
self.optimizer = optimizer
self.func_model, _ = functorch.make_functional(MLPClassifier().to(self.device))

def init_fn(self, model_idx):
_, weights = functorch.make_functional(MLPClassifier().to(self.device))
opt_state = self.optimizer.init(weights)
return weights, opt_state

def train_step_fn(self, training_state, batch, targets):
weights, opt_state = training_state

def compute_loss(weights, batch, targets):
output = self.func_model(weights, batch)
loss = self.loss_fn(output, targets)
return loss

grads, loss = functorch.grad_and_value(compute_loss)(weights, batch, targets)
# functional optimizer API is here now
updates, new_opt_state = optimizer.update(grads, opt_state, inplace=False)
new_weights = torchopt.apply_updates(weights, updates, inplace=False)
return loss, (new_weights, new_opt_state)

def test_train_step_fn(self, weights, opt_state, points, labels):
for i in range(2000):
loss, (weights, opt_state) = self.train_step_fn((weights, opt_state), points, labels)
if i % 100 == 0:
print(loss)

def test_parallel_train_step_fn(self, num_models):
parallel_init_fn = functorch.vmap(self.init_fn, randomness='same')
parallel_train_step_fn = functorch.vmap(self.train_step_fn, in_dims=(0, None, None))
weights, opt_state = parallel_init_fn(torch.ones(num_models, 1))
for i in range(2000):
loss, (weights, opt_states) = parallel_train_step_fn(
(weights, opt_state), points, labels
)
if i % 200 == 0:
print(loss)


if __name__ == '__main__':
# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a
# tutorial on Model Ensembling with JAX by Will Whitney.
#
# The original code comes with the following citation:
# @misc{Whitney2021Parallelizing,
# author = {William F. Whitney},
# title = { {Parallelizing neural networks on one GPU with JAX} },
# year = {2021},
# url = {http://willwhitney.com/parallel-training-jax.html},
# }

# GOAL: Demonstrate that it is possible to use eager-mode vmap

parser = argparse.ArgumentParser(description='Functorch Ensembled Models with TorchOpt')
parser.add_argument(
'--device',
type=str,
default='cpu',
help="CPU or GPU ID for this process (default: 'cpu')",
)
args = parser.parse_args()

DEVICE = args.device
# Step 1: Make some spirals
points, labels = make_spirals(100, noise_std=0.05)
# Step 2: Define two-layer MLP and loss function
loss_fn = nn.NLLLoss()
# Step 3: Make the model functional(!!) and define a training function.
func_model, weights = functorch.make_functional(MLPClassifier().to(DEVICE))

# original functorch implementation
functorch_original = ParallelTrainFunctorchOriginal(loss_fn=loss_fn, lr=0.2, device=DEVICE)
# Step 4: Let's verify this actually trains.
# We should see the loss decrease.
functorch_original.test_train_step_fn(weights, points, labels)
# Step 6: Now, can we try multiple models at the same time?
# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
# on decreasing
functorch_original.test_parallel_train_step_fn(num_models=2)

# functorch + torchopt implementation
optimizer = torchopt.adam(lr=0.2)
opt_state = optimizer.init(weights)
functorch_original = ParallelTrainFunctorchTorchOpt(
loss_fn=loss_fn, optimizer=optimizer, device=DEVICE
)
# Step 4: Let's verify this actually trains.
# We should see the loss decrease.
functorch_original.test_train_step_fn(weights, opt_state, points, labels)
# Step 6: Now, can we try multiple models at the same time?
# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
# on decreasing
functorch_original.test_parallel_train_step_fn(num_models=2)

# Step 7: Now, the flaw with step 6 is that we were training on the same exact
# data. This can lead to all of the models in the ensemble overfitting in the
# same way. The solution that http://willwhitney.com/parallel-training-jax.html
# applies is to randomly subset the data in a way that the models do not recieve
# exactly the same data in each training step!
# Because the goal of this doc is to show that we can use eager-mode vmap to
# achieve similar things as JAX, the rest of this is left as an exercise to the reader.
35 changes: 27 additions & 8 deletions torchopt/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,23 @@


ScaleState = base.EmptyState
INT32_MAX = torch.iinfo(torch.int32).max


def inc_count(updates, count: Tuple[int]) -> Tuple[int]:
"""Increments int counter by one."""
def inc_count(updates, count: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
"""Increments int counter by one.

Returns:
A counter incremeted by one, or max_int if the maximum precision is reached.
"""
one = torch.ones(1, dtype=torch.int32, device=count[0].device)

def f(c, g):
return c + 1 if g is not None else c
return (
c + (1 - torch.div(c, INT32_MAX, rounding_mode='trunc')) * one
if g is not None
else c
)

return pytree.tree_map(f, count, updates)

Expand Down Expand Up @@ -87,7 +97,7 @@ def f(g):
class ScaleByScheduleState(NamedTuple):
"""Maintains count for scale scheduling."""

count: Tuple[int, ...] # type: ignore
count: Tuple[torch.Tensor, ...] # type: ignore


def scale_by_schedule(step_size_fn: Schedule) -> base.GradientTransformation:
Expand All @@ -103,7 +113,10 @@ def scale_by_schedule(step_size_fn: Schedule) -> base.GradientTransformation:
"""

def init_fn(params):
return ScaleByScheduleState(count=tuple(0 for _ in range(len(params))))
zero = pytree.tree_map( # Count init
lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params
)
return ScaleByScheduleState(count=tuple(zero))

def update_fn(updates, state, inplace=True):
step_size = step_size_fn(state.count)
Expand Down Expand Up @@ -149,7 +162,7 @@ def f(g, t):
class ScaleByAdamState(NamedTuple):
"""State for the Adam algorithm."""

count: Tuple[int, ...] # type: ignore
count: Tuple[torch.Tensor, ...] # type: ignore
mu: base.Updates
nu: base.Updates

Expand Down Expand Up @@ -199,13 +212,16 @@ def scale_by_adam(
"""

def init_fn(params):
zero = pytree.tree_map( # Count init
lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params
)
mu = pytree.tree_map( # First moment
lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params
)
nu = pytree.tree_map( # Second moment
lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params
)
return ScaleByAdamState(count=tuple(0 for _ in range(len(mu))), mu=tuple(mu), nu=tuple(nu))
return ScaleByAdamState(count=tuple(zero), mu=tuple(mu), nu=tuple(nu))

def update_fn(updates, state, inplace=True):
mu = _update_moment(updates, state.mu, b1, 1, inplace)
Expand Down Expand Up @@ -262,13 +278,16 @@ def scale_by_accelerated_adam(
from torchopt._src.accelerated_op import AdamOp # pylint: disable=import-outside-toplevel

def init_fn(params):
zero = pytree.tree_map( # Count init
lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params
)
mu = pytree.tree_map( # First moment
lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params
)
nu = pytree.tree_map( # Second moment
lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params
)
return ScaleByAdamState(count=tuple(0 for _ in range(len(params))), mu=mu, nu=nu)
return ScaleByAdamState(count=tuple(zero), mu=mu, nu=nu)

def update_fn(updates, state, inplace=True):
count_inc = inc_count(updates, state.count)
Expand Down