Skip to content

Commit a5aace5

Browse files
author
Ervin T
authored
[bug-fix] Initialize-from being incorrectly loaded as "None" rather than None (#4175)
1 parent 5e60954 commit a5aace5

File tree

4 files changed

+120
-3
lines changed

4 files changed

+120
-3
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ list of tags was empty, or a tag in the list was invalid (unknown, null, or
3030
empty string). (#4155)
3131

3232
#### ml-agents / ml-agents-envs / gym-unity (Python)
33+
- Fixed an error when setting `initialize_from` in the trainer confiiguration YAML to
34+
`null`. (#4175)
3335

3436
## [1.1.0-preview] - 2020-06-10
3537
### Major Changes

ml-agents/mlagents/trainers/learn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
6969
write_path = os.path.join(base_path, checkpoint_settings.run_id)
7070
maybe_init_path = (
7171
os.path.join(base_path, checkpoint_settings.initialize_from)
72-
if checkpoint_settings.initialize_from
72+
if checkpoint_settings.initialize_from is not None
7373
else None
7474
)
7575
run_logs_dir = os.path.join(write_path, "run_logs")

ml-agents/mlagents/trainers/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ class MeasureType:
429429
@attr.s(auto_attribs=True)
430430
class CheckpointSettings:
431431
run_id: str = parser.get_default("run_id")
432-
initialize_from: str = parser.get_default("initialize_from")
432+
initialize_from: Optional[str] = parser.get_default("initialize_from")
433433
load_model: bool = parser.get_default("load_model")
434434
resume: bool = parser.get_default("resume")
435435
force: bool = parser.get_default("force")

ml-agents/mlagents/trainers/tests/test_settings.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import attr
22
import pytest
3+
import yaml
34

4-
from typing import Dict
5+
from typing import Dict, List, Optional
56

67
from mlagents.trainers.settings import (
78
RunOptions,
@@ -31,6 +32,32 @@ def check_if_different(testobj1: object, testobj2: object) -> None:
3132
check_if_different(val, attr.asdict(testobj2, recurse=False)[key])
3233

3334

35+
def check_dict_is_at_least(
36+
testdict1: Dict, testdict2: Dict, exceptions: Optional[List[str]] = None
37+
) -> None:
38+
"""
39+
Check if everything present in the 1st dict is the same in the second dict.
40+
Excludes things that the second dict has but is not present in the heirarchy of the
41+
1st dict. Used to compare an underspecified config dict structure (e.g. as
42+
would be provided by a user) with a complete one (e.g. as exported by RunOptions).
43+
"""
44+
for key, val in testdict1.items():
45+
if exceptions is not None and key in exceptions:
46+
continue
47+
assert key in testdict2
48+
if isinstance(val, dict):
49+
check_dict_is_at_least(val, testdict2[key])
50+
elif isinstance(val, list):
51+
assert isinstance(testdict2[key], list)
52+
for _el0, _el1 in zip(val, testdict2[key]):
53+
if isinstance(_el0, dict):
54+
check_dict_is_at_least(_el0, _el1)
55+
else:
56+
assert val == testdict2[key]
57+
else: # If not a dict, don't recurse into it
58+
assert val == testdict2[key]
59+
60+
3461
def test_is_new_instance():
3562
"""
3663
Verify that every instance of RunOptions() and its subclasses
@@ -244,3 +271,91 @@ def test_parameter_randomization_structure():
244271
ParameterRandomizationSettings.structure(
245272
"notadict", Dict[str, ParameterRandomizationSettings]
246273
)
274+
275+
276+
@pytest.mark.parametrize("use_defaults", [True, False])
277+
def test_exportable_settings(use_defaults):
278+
"""
279+
Test that structuring and unstructuring a RunOptions object results in the same
280+
configuration representation.
281+
"""
282+
# Try to enable as many features as possible in this test YAML to hit all the
283+
# edge cases. Set as much as possible as non-default values to ensure no flukes.
284+
test_yaml = """
285+
behaviors:
286+
3DBall:
287+
trainer_type: sac
288+
hyperparameters:
289+
learning_rate: 0.0004
290+
learning_rate_schedule: constant
291+
batch_size: 64
292+
buffer_size: 200000
293+
buffer_init_steps: 100
294+
tau: 0.006
295+
steps_per_update: 10.0
296+
save_replay_buffer: true
297+
init_entcoef: 0.5
298+
reward_signal_steps_per_update: 10.0
299+
network_settings:
300+
normalize: false
301+
hidden_units: 256
302+
num_layers: 3
303+
vis_encode_type: nature_cnn
304+
memory:
305+
memory_size: 1288
306+
sequence_length: 12
307+
reward_signals:
308+
extrinsic:
309+
gamma: 0.999
310+
strength: 1.0
311+
curiosity:
312+
gamma: 0.999
313+
strength: 1.0
314+
keep_checkpoints: 5
315+
max_steps: 500000
316+
time_horizon: 1000
317+
summary_freq: 12000
318+
checkpoint_interval: 1
319+
threaded: true
320+
env_settings:
321+
env_path: test_env_path
322+
env_args:
323+
- test_env_args1
324+
- test_env_args2
325+
base_port: 12345
326+
num_envs: 8
327+
seed: 12345
328+
engine_settings:
329+
width: 12345
330+
height: 12345
331+
quality_level: 12345
332+
time_scale: 12345
333+
target_frame_rate: 12345
334+
capture_frame_rate: 12345
335+
no_graphics: true
336+
checkpoint_settings:
337+
run_id: test_run_id
338+
initialize_from: test_directory
339+
load_model: false
340+
resume: true
341+
force: true
342+
train_model: false
343+
inference: false
344+
debug: true
345+
"""
346+
if not use_defaults:
347+
loaded_yaml = yaml.safe_load(test_yaml)
348+
run_options = RunOptions.from_dict(yaml.safe_load(test_yaml))
349+
else:
350+
run_options = RunOptions()
351+
dict_export = run_options.as_dict()
352+
353+
if not use_defaults: # Don't need to check if no yaml
354+
check_dict_is_at_least(loaded_yaml, dict_export)
355+
356+
# Re-import and verify has same elements
357+
run_options2 = RunOptions.from_dict(dict_export)
358+
second_export = run_options2.as_dict()
359+
360+
# Check that the two exports are the same
361+
assert dict_export == second_export

0 commit comments

Comments
 (0)