Skip to content

Commit 48261d3

Browse files
Ervin TChris Elion
Ervin T
authored and
Chris Elion
committed
[bug-fix] Fix issue with initialize not resetting step count (#3962)
1 parent 39d5394 commit 48261d3

File tree

3 files changed

+6
-0
lines changed

3 files changed

+6
-0
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to
88

99
## [1.0.1-preview] - 2020-05-19
1010
### Bug Fixes
11+
- An issue was fixed where using `--initialize-from` would resume from the past step count. (#3962)
1112
#### com.unity.ml-agents (C#)
1213
#### ml-agents / ml-agents-envs / gym-unity (Python)
1314

ml-agents/mlagents/trainers/policy/tf_policy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None
137137
)
138138
)
139139
if reset_global_steps:
140+
self._set_step(0)
140141
logger.info(
141142
"Starting training from step 0 and saving to {}.".format(
142143
self.model_path

ml-agents/mlagents/trainers/tests/test_nn_policy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def test_load_save(dummy_config, tmp_path):
8686
trainer_params["model_path"] = path1
8787
policy = create_policy_mock(trainer_params)
8888
policy.initialize_or_load()
89+
policy._set_step(2000)
8990
policy.save_model(2000)
9091

9192
assert len(os.listdir(tmp_path)) > 0
@@ -94,6 +95,7 @@ def test_load_save(dummy_config, tmp_path):
9495
policy2 = create_policy_mock(trainer_params, load=True, seed=1)
9596
policy2.initialize_or_load()
9697
_compare_two_policies(policy, policy2)
98+
assert policy2.get_current_step() == 2000
9799

98100
# Try initialize from path 1
99101
trainer_params["model_path"] = path2
@@ -102,6 +104,8 @@ def test_load_save(dummy_config, tmp_path):
102104
policy3.initialize_or_load()
103105

104106
_compare_two_policies(policy2, policy3)
107+
# Assert that the steps are 0.
108+
assert policy3.get_current_step() == 0
105109

106110

107111
def _compare_two_policies(policy1: NNPolicy, policy2: NNPolicy) -> None:

0 commit comments

Comments
 (0)