Skip to content

Commit

Permalink
chore: rename variables
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Aug 29, 2022
1 parent ecc8361 commit 169ee5b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
20 changes: 10 additions & 10 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def test_sgd(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grad = torch.autograd.grad(loss, params)
updates, optim_state = optim.update(grad, optim_state, params=params, inplace=inplace)
grads = torch.autograd.grad(loss, params)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
Expand Down Expand Up @@ -134,8 +134,8 @@ def test_adam(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grad = torch.autograd.grad(loss, params)
updates, optim_state = optim.update(grad, optim_state, params=params, inplace=inplace)
grads = torch.autograd.grad(loss, params)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
Expand Down Expand Up @@ -193,8 +193,8 @@ def test_adam_accelerated_cpu(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grad = torch.autograd.grad(loss, params)
updates, optim_state = optim.update(grad, optim_state, params=params, inplace=inplace)
grads = torch.autograd.grad(loss, params)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
Expand Down Expand Up @@ -255,8 +255,8 @@ def test_adam_accelerated_cuda(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grad = torch.autograd.grad(loss, params)
updates, optim_state = optim.update(grad, optim_state, params=params, inplace=inplace)
grads = torch.autograd.grad(loss, params)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
Expand Down Expand Up @@ -316,8 +316,8 @@ def test_rmsprop(
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grad = torch.autograd.grad(loss, params)
updates, optim_state = optim.update(grad, optim_state, params=params, inplace=inplace)
grads = torch.autograd.grad(loss, params)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
Expand Down
4 changes: 2 additions & 2 deletions torchopt/_src/optimizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def f(p):
return p.grad

for i, (params, state) in enumerate(zip(self.param_groups, self.state_groups)):
grad = pytree.tree_map(f, params)
updates, new_state = self.impl.update(grad, state, params=params, inplace=True)
grads = pytree.tree_map(f, params)
updates, new_state = self.impl.update(grads, state, params=params, inplace=True)
self.param_groups[i] = apply_updates(params, updates, inplace=True)
self.state_groups[i] = new_state

Expand Down
18 changes: 10 additions & 8 deletions torchopt/_src/optimizer/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,26 @@ def step(self, loss: torch.Tensor):
loss (torch.Tensor): The loss that is used to compute the gradients to the network
parameters.
""" # pylint: disable=line-too-long
# step parameter only
# Step parameter only
for i, (param_container, new_state) in enumerate(
zip(self.param_containers_groups, self.state_groups)
):
flattened_params, container_tree = pytree.tree_flatten(param_container)
flattened_params, container_treedef = pytree.tree_flatten(param_container)
flattened_params = tuple(flattened_params)
grad = torch.autograd.grad(loss, flattened_params, create_graph=True, allow_unused=True)
grads = torch.autograd.grad(
loss, flattened_params, create_graph=True, allow_unused=True
)
updates, new_state = self.impl.update(
grad,
grads,
new_state,
params=flattened_params,
inplace=False,
)
self.state_groups[i] = new_state
new_params = apply_updates(flattened_params, updates, inplace=False)
unflattened_new_params = container_tree.unflatten(new_params)
for container, unflatten_param in zip(param_container, unflattened_new_params):
container.update(unflatten_param)
flattened_new_params = apply_updates(flattened_params, updates, inplace=False)
new_params = pytree.tree_unflatten(container_treedef, flattened_new_params)
for container, new_param in zip(param_container, new_params):
container.update(new_param)

def add_param_group(self, net):
"""Add a param group to the optimizer's :attr:`state_groups`."""
Expand Down

0 comments on commit 169ee5b

Please sign in to comment.