From 169ee5b1167f6172906d723d0899625ce1c8dd8b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 29 Aug 2022 17:23:23 +0800 Subject: [PATCH] chore: rename variables --- tests/test_alias.py | 20 ++++++++++---------- torchopt/_src/optimizer/base.py | 4 ++-- torchopt/_src/optimizer/meta/base.py | 18 ++++++++++-------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/test_alias.py b/tests/test_alias.py index e0ebe0f4d..b75202a5f 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -76,8 +76,8 @@ def test_sgd( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grad = torch.autograd.grad(loss, params) - updates, optim_state = optim.update(grad, optim_state, params=params, inplace=inplace) + grads = torch.autograd.grad(loss, params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) optim_ref.zero_grad() @@ -134,8 +134,8 @@ def test_adam( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grad = torch.autograd.grad(loss, params) - updates, optim_state = optim.update(grad, optim_state, params=params, inplace=inplace) + grads = torch.autograd.grad(loss, params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) optim_ref.zero_grad() @@ -193,8 +193,8 @@ def test_adam_accelerated_cpu( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grad = torch.autograd.grad(loss, params) - updates, optim_state = optim.update(grad, optim_state, params=params, inplace=inplace) + grads = torch.autograd.grad(loss, params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) optim_ref.zero_grad() @@ -255,8 +255,8 @@ def test_adam_accelerated_cuda( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grad = torch.autograd.grad(loss, params) - updates, optim_state = optim.update(grad, optim_state, params=params, inplace=inplace) + grads = torch.autograd.grad(loss, params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) optim_ref.zero_grad() @@ -316,8 +316,8 @@ def test_rmsprop( loss = F.cross_entropy(pred, ys) loss_ref = F.cross_entropy(pred_ref, ys) - grad = torch.autograd.grad(loss, params) - updates, optim_state = optim.update(grad, optim_state, params=params, inplace=inplace) + grads = torch.autograd.grad(loss, params) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) params = torchopt.apply_updates(params, updates, inplace=inplace) optim_ref.zero_grad() diff --git a/torchopt/_src/optimizer/base.py b/torchopt/_src/optimizer/base.py index f71dbd953..99e18b366 100644 --- a/torchopt/_src/optimizer/base.py +++ b/torchopt/_src/optimizer/base.py @@ -103,8 +103,8 @@ def f(p): return p.grad for i, (params, state) in enumerate(zip(self.param_groups, self.state_groups)): - grad = pytree.tree_map(f, params) - updates, new_state = self.impl.update(grad, state, params=params, inplace=True) + grads = pytree.tree_map(f, params) + updates, new_state = self.impl.update(grads, state, params=params, inplace=True) self.param_groups[i] = apply_updates(params, updates, inplace=True) self.state_groups[i] = new_state diff --git a/torchopt/_src/optimizer/meta/base.py b/torchopt/_src/optimizer/meta/base.py index 395fa17eb..1acbd1b8d 100644 --- a/torchopt/_src/optimizer/meta/base.py +++ b/torchopt/_src/optimizer/meta/base.py @@ -53,24 +53,26 @@ def step(self, loss: torch.Tensor): loss (torch.Tensor): The loss that is used to compute the gradients to the network parameters. """ # pylint: disable=line-too-long - # step parameter only + # Step parameter only for i, (param_container, new_state) in enumerate( zip(self.param_containers_groups, self.state_groups) ): - flattened_params, container_tree = pytree.tree_flatten(param_container) + flattened_params, container_treedef = pytree.tree_flatten(param_container) flattened_params = tuple(flattened_params) - grad = torch.autograd.grad(loss, flattened_params, create_graph=True, allow_unused=True) + grads = torch.autograd.grad( + loss, flattened_params, create_graph=True, allow_unused=True + ) updates, new_state = self.impl.update( - grad, + grads, new_state, params=flattened_params, inplace=False, ) self.state_groups[i] = new_state - new_params = apply_updates(flattened_params, updates, inplace=False) - unflattened_new_params = container_tree.unflatten(new_params) - for container, unflatten_param in zip(param_container, unflattened_new_params): - container.update(unflatten_param) + flattened_new_params = apply_updates(flattened_params, updates, inplace=False) + new_params = pytree.tree_unflatten(container_treedef, flattened_new_params) + for container, new_param in zip(param_container, new_params): + container.update(new_param) def add_param_group(self, net): """Add a param group to the optimizer's :attr:`state_groups`."""