Skip to content

Commit 4a3c4f0

Browse files
Ervin Tvincentpierre
Ervin T
andauthored
Add test for settings export (#4164)
* Add test for settings export * Update ml-agents/mlagents/trainers/tests/test_settings.py Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com> Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com>
1 parent ae70965 commit 4a3c4f0

File tree

1 file changed

+117
-1
lines changed

1 file changed

+117
-1
lines changed

ml-agents/mlagents/trainers/tests/test_settings.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import attr
22
import pytest
3+
import yaml
34

4-
from typing import Dict
5+
from typing import Dict, List, Optional
56

67
from mlagents.trainers.settings import (
78
RunOptions,
@@ -32,6 +33,32 @@ def check_if_different(testobj1: object, testobj2: object) -> None:
3233
check_if_different(val, attr.asdict(testobj2, recurse=False)[key])
3334

3435

36+
def check_dict_is_at_least(
37+
testdict1: Dict, testdict2: Dict, exceptions: Optional[List[str]] = None
38+
) -> None:
39+
"""
40+
Check if everything present in the 1st dict is the same in the second dict.
41+
Excludes things that the second dict has but is not present in the heirarchy of the
42+
1st dict. Used to compare an underspecified config dict structure (e.g. as
43+
would be provided by a user) with a complete one (e.g. as exported by RunOptions).
44+
"""
45+
for key, val in testdict1.items():
46+
if exceptions is not None and key in exceptions:
47+
continue
48+
assert key in testdict2
49+
if isinstance(val, dict):
50+
check_dict_is_at_least(val, testdict2[key])
51+
elif isinstance(val, list):
52+
assert isinstance(testdict2[key], list)
53+
for _el0, _el1 in zip(val, testdict2[key]):
54+
if isinstance(_el0, dict):
55+
check_dict_is_at_least(_el0, _el1)
56+
else:
57+
assert val == testdict2[key]
58+
else: # If not a dict, don't recurse into it
59+
assert val == testdict2[key]
60+
61+
3562
def test_is_new_instance():
3663
"""
3764
Verify that every instance of RunOptions() and its subclasses
@@ -289,3 +316,92 @@ def test_env_parameter_structure():
289316
EnvironmentParameterSettings.structure(
290317
invalid_curriculum_dict, Dict[str, EnvironmentParameterSettings]
291318
)
319+
320+
321+
def test_exportable_settings():
322+
"""
323+
Test that structuring and unstructuring a RunOptions object results in the same
324+
configuration representation.
325+
"""
326+
# Try to enable as many features as possible in this test YAML to hit all the
327+
# edge cases. Set as much as possible as non-default values to ensure no flukes.
328+
# TODO: Add back in environment_parameters
329+
test_yaml = """
330+
behaviors:
331+
3DBall:
332+
trainer_type: sac
333+
hyperparameters:
334+
learning_rate: 0.0004
335+
learning_rate_schedule: constant
336+
batch_size: 64
337+
buffer_size: 200000
338+
buffer_init_steps: 100
339+
tau: 0.006
340+
steps_per_update: 10.0
341+
save_replay_buffer: true
342+
init_entcoef: 0.5
343+
reward_signal_steps_per_update: 10.0
344+
network_settings:
345+
normalize: false
346+
hidden_units: 256
347+
num_layers: 3
348+
vis_encode_type: nature_cnn
349+
memory:
350+
memory_size: 1288
351+
sequence_length: 12
352+
reward_signals:
353+
extrinsic:
354+
gamma: 0.999
355+
strength: 1.0
356+
curiosity:
357+
gamma: 0.999
358+
strength: 1.0
359+
keep_checkpoints: 5
360+
max_steps: 500000
361+
time_horizon: 1000
362+
summary_freq: 12000
363+
checkpoint_interval: 1
364+
threaded: true
365+
env_settings:
366+
env_path: test_env_path
367+
env_args:
368+
- test_env_args1
369+
- test_env_args2
370+
base_port: 12345
371+
num_envs: 8
372+
seed: 12345
373+
engine_settings:
374+
width: 12345
375+
height: 12345
376+
quality_level: 12345
377+
time_scale: 12345
378+
target_frame_rate: 12345
379+
capture_frame_rate: 12345
380+
no_graphics: true
381+
checkpoint_settings:
382+
run_id: test_run_id
383+
initialize_from: test_directory
384+
load_model: false
385+
resume: true
386+
force: true
387+
train_model: false
388+
inference: false
389+
debug: true
390+
"""
391+
loaded_yaml = yaml.safe_load(test_yaml)
392+
run_options = RunOptions.from_dict(yaml.safe_load(test_yaml))
393+
dict_export = run_options.as_dict()
394+
check_dict_is_at_least(loaded_yaml, dict_export)
395+
396+
# Re-import and verify has same elements
397+
run_options2 = RunOptions.from_dict(dict_export)
398+
second_export = run_options2.as_dict()
399+
400+
check_dict_is_at_least(
401+
dict_export, second_export, exceptions=["environment_parameters"]
402+
)
403+
# Should be able to use equality instead of back-and-forth once environment_parameters
404+
# is working
405+
check_dict_is_at_least(
406+
second_export, dict_export, exceptions=["environment_parameters"]
407+
)

0 commit comments

Comments
 (0)