Skip to content

Commit d5456cc

Browse files
Ervin Tvincentpierre
Ervin T
andcommitted
Add test for settings export (#4164)
* Add test for settings export * Update ml-agents/mlagents/trainers/tests/test_settings.py Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com> Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com>
1 parent 0f56a87 commit d5456cc

File tree

1 file changed

+117
-1
lines changed

1 file changed

+117
-1
lines changed

ml-agents/mlagents/trainers/tests/test_settings.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import attr
22
import pytest
3+
import yaml
34

4-
from typing import Dict
5+
from typing import Dict, List, Optional
56

67
from mlagents.trainers.settings import (
78
RunOptions,
@@ -31,6 +32,32 @@ def check_if_different(testobj1: object, testobj2: object) -> None:
3132
check_if_different(val, attr.asdict(testobj2, recurse=False)[key])
3233

3334

35+
def check_dict_is_at_least(
36+
testdict1: Dict, testdict2: Dict, exceptions: Optional[List[str]] = None
37+
) -> None:
38+
"""
39+
Check if everything present in the 1st dict is the same in the second dict.
40+
Excludes things that the second dict has but is not present in the heirarchy of the
41+
1st dict. Used to compare an underspecified config dict structure (e.g. as
42+
would be provided by a user) with a complete one (e.g. as exported by RunOptions).
43+
"""
44+
for key, val in testdict1.items():
45+
if exceptions is not None and key in exceptions:
46+
continue
47+
assert key in testdict2
48+
if isinstance(val, dict):
49+
check_dict_is_at_least(val, testdict2[key])
50+
elif isinstance(val, list):
51+
assert isinstance(testdict2[key], list)
52+
for _el0, _el1 in zip(val, testdict2[key]):
53+
if isinstance(_el0, dict):
54+
check_dict_is_at_least(_el0, _el1)
55+
else:
56+
assert val == testdict2[key]
57+
else: # If not a dict, don't recurse into it
58+
assert val == testdict2[key]
59+
60+
3461
def test_is_new_instance():
3562
"""
3663
Verify that every instance of RunOptions() and its subclasses
@@ -244,3 +271,92 @@ def test_parameter_randomization_structure():
244271
ParameterRandomizationSettings.structure(
245272
"notadict", Dict[str, ParameterRandomizationSettings]
246273
)
274+
275+
276+
def test_exportable_settings():
277+
"""
278+
Test that structuring and unstructuring a RunOptions object results in the same
279+
configuration representation.
280+
"""
281+
# Try to enable as many features as possible in this test YAML to hit all the
282+
# edge cases. Set as much as possible as non-default values to ensure no flukes.
283+
# TODO: Add back in environment_parameters
284+
test_yaml = """
285+
behaviors:
286+
3DBall:
287+
trainer_type: sac
288+
hyperparameters:
289+
learning_rate: 0.0004
290+
learning_rate_schedule: constant
291+
batch_size: 64
292+
buffer_size: 200000
293+
buffer_init_steps: 100
294+
tau: 0.006
295+
steps_per_update: 10.0
296+
save_replay_buffer: true
297+
init_entcoef: 0.5
298+
reward_signal_steps_per_update: 10.0
299+
network_settings:
300+
normalize: false
301+
hidden_units: 256
302+
num_layers: 3
303+
vis_encode_type: nature_cnn
304+
memory:
305+
memory_size: 1288
306+
sequence_length: 12
307+
reward_signals:
308+
extrinsic:
309+
gamma: 0.999
310+
strength: 1.0
311+
curiosity:
312+
gamma: 0.999
313+
strength: 1.0
314+
keep_checkpoints: 5
315+
max_steps: 500000
316+
time_horizon: 1000
317+
summary_freq: 12000
318+
checkpoint_interval: 1
319+
threaded: true
320+
env_settings:
321+
env_path: test_env_path
322+
env_args:
323+
- test_env_args1
324+
- test_env_args2
325+
base_port: 12345
326+
num_envs: 8
327+
seed: 12345
328+
engine_settings:
329+
width: 12345
330+
height: 12345
331+
quality_level: 12345
332+
time_scale: 12345
333+
target_frame_rate: 12345
334+
capture_frame_rate: 12345
335+
no_graphics: true
336+
checkpoint_settings:
337+
run_id: test_run_id
338+
initialize_from: test_directory
339+
load_model: false
340+
resume: true
341+
force: true
342+
train_model: false
343+
inference: false
344+
debug: true
345+
"""
346+
loaded_yaml = yaml.safe_load(test_yaml)
347+
run_options = RunOptions.from_dict(yaml.safe_load(test_yaml))
348+
dict_export = run_options.as_dict()
349+
check_dict_is_at_least(loaded_yaml, dict_export)
350+
351+
# Re-import and verify has same elements
352+
run_options2 = RunOptions.from_dict(dict_export)
353+
second_export = run_options2.as_dict()
354+
355+
check_dict_is_at_least(
356+
dict_export, second_export, exceptions=["environment_parameters"]
357+
)
358+
# Should be able to use equality instead of back-and-forth once environment_parameters
359+
# is working
360+
check_dict_is_at_least(
361+
second_export, dict_export, exceptions=["environment_parameters"]
362+
)

0 commit comments

Comments
 (0)