diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 96d1672709d20f..86312491b26a14 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -36,3 +36,5 @@ dd3a77bc965adf9fe8ba582ee13bb7f14c9661b0 f70844bec783bfce43c950ccf180dc494e86f2bf # 2023-07-28 Apply UFMT to all non test/torch files e6ec0efaf87703c5f889cfc20b29be455885d58d +# 2023-07-31 [optim][BE] split test file into logical parts: SWA, LR, optim +a53cda1ddc15336dc1ff0ce1eff2a49cdc5f882e diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 4bfec345a744ff..8e4ec07fc0897c 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -f5edcb2088195db71bcd36d0f8f1b6a5e663afd8 +ca5eab87a71f80cd3168630511d02549cc7d2516 diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py index 2cd643d3e94089..e3c528c1a1ee56 100644 --- a/test/optim/test_optim.py +++ b/test/optim/test_optim.py @@ -238,8 +238,6 @@ def fn_base(optimizer, weight, bias): optimizer_c.step(fn_c) self.assertEqual(weight, weight_c) self.assertEqual(bias, bias_c) - # Make sure state dict wasn't modified - self.assertEqual(state_dict, state_dict_c) # Make sure state dict is deterministic with equal but not identical parameters self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict()) # Make sure repeated parameters have identical representation in state dict @@ -301,7 +299,7 @@ def fn_base(optimizer, weight, bias): state_dict_c = deepcopy(optimizer.state_dict()) optimizer_cuda.load_state_dict(state_dict_c) - # Make sure state dict wasn't modified + # Make sure state_dict_c isn't modified by merely calling load_state_dict self.assertEqual(state_dict, state_dict_c) # Make sure that device of state['step'] is still CPU @@ -312,7 +310,7 @@ def fn_base(optimizer, weight, bias): for state in new_state_dict["state"].values(): self.assertEqual(state["step"].device.type, "cpu") - for _i in range(20): + for _ in range(20): optimizer.step(fn) optimizer_cuda.step(fn_cuda) self.assertEqual(weight, weight_cuda) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index b2e524f26f5ec6..9115fcf5c4dc9f 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -712,8 +712,8 @@ def load_state_dict(self, state_dict: StateDict) -> None: state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. """ - # deepcopy, to be consistent with module API - state_dict = deepcopy(state_dict) + # shallow copy, to be consistent with module API + state_dict = state_dict.copy() for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): hook_result = pre_hook(self, state_dict) @@ -722,7 +722,9 @@ def load_state_dict(self, state_dict: StateDict) -> None: # Validate the state_dict groups = self.param_groups - saved_groups = state_dict['param_groups'] + + # Deepcopy as we write into saved_groups later to update state + saved_groups = deepcopy(state_dict['param_groups']) if len(groups) != len(saved_groups): raise ValueError("loaded state dict has a different number of "