|
1 | 1 | import attr
|
2 | 2 | import pytest
|
| 3 | +import yaml |
3 | 4 |
|
4 |
| -from typing import Dict |
| 5 | +from typing import Dict, List, Optional |
5 | 6 |
|
6 | 7 | from mlagents.trainers.settings import (
|
7 | 8 | RunOptions,
|
@@ -31,6 +32,32 @@ def check_if_different(testobj1: object, testobj2: object) -> None:
|
31 | 32 | check_if_different(val, attr.asdict(testobj2, recurse=False)[key])
|
32 | 33 |
|
33 | 34 |
|
| 35 | +def check_dict_is_at_least( |
| 36 | + testdict1: Dict, testdict2: Dict, exceptions: Optional[List[str]] = None |
| 37 | +) -> None: |
| 38 | + """ |
| 39 | + Check if everything present in the 1st dict is the same in the second dict. |
| 40 | + Excludes things that the second dict has but is not present in the heirarchy of the |
| 41 | + 1st dict. Used to compare an underspecified config dict structure (e.g. as |
| 42 | + would be provided by a user) with a complete one (e.g. as exported by RunOptions). |
| 43 | + """ |
| 44 | + for key, val in testdict1.items(): |
| 45 | + if exceptions is not None and key in exceptions: |
| 46 | + continue |
| 47 | + assert key in testdict2 |
| 48 | + if isinstance(val, dict): |
| 49 | + check_dict_is_at_least(val, testdict2[key]) |
| 50 | + elif isinstance(val, list): |
| 51 | + assert isinstance(testdict2[key], list) |
| 52 | + for _el0, _el1 in zip(val, testdict2[key]): |
| 53 | + if isinstance(_el0, dict): |
| 54 | + check_dict_is_at_least(_el0, _el1) |
| 55 | + else: |
| 56 | + assert val == testdict2[key] |
| 57 | + else: # If not a dict, don't recurse into it |
| 58 | + assert val == testdict2[key] |
| 59 | + |
| 60 | + |
34 | 61 | def test_is_new_instance():
|
35 | 62 | """
|
36 | 63 | Verify that every instance of RunOptions() and its subclasses
|
@@ -244,3 +271,92 @@ def test_parameter_randomization_structure():
|
244 | 271 | ParameterRandomizationSettings.structure(
|
245 | 272 | "notadict", Dict[str, ParameterRandomizationSettings]
|
246 | 273 | )
|
| 274 | + |
| 275 | + |
| 276 | +def test_exportable_settings(): |
| 277 | + """ |
| 278 | + Test that structuring and unstructuring a RunOptions object results in the same |
| 279 | + configuration representation. |
| 280 | + """ |
| 281 | + # Try to enable as many features as possible in this test YAML to hit all the |
| 282 | + # edge cases. Set as much as possible as non-default values to ensure no flukes. |
| 283 | + # TODO: Add back in environment_parameters |
| 284 | + test_yaml = """ |
| 285 | + behaviors: |
| 286 | + 3DBall: |
| 287 | + trainer_type: sac |
| 288 | + hyperparameters: |
| 289 | + learning_rate: 0.0004 |
| 290 | + learning_rate_schedule: constant |
| 291 | + batch_size: 64 |
| 292 | + buffer_size: 200000 |
| 293 | + buffer_init_steps: 100 |
| 294 | + tau: 0.006 |
| 295 | + steps_per_update: 10.0 |
| 296 | + save_replay_buffer: true |
| 297 | + init_entcoef: 0.5 |
| 298 | + reward_signal_steps_per_update: 10.0 |
| 299 | + network_settings: |
| 300 | + normalize: false |
| 301 | + hidden_units: 256 |
| 302 | + num_layers: 3 |
| 303 | + vis_encode_type: nature_cnn |
| 304 | + memory: |
| 305 | + memory_size: 1288 |
| 306 | + sequence_length: 12 |
| 307 | + reward_signals: |
| 308 | + extrinsic: |
| 309 | + gamma: 0.999 |
| 310 | + strength: 1.0 |
| 311 | + curiosity: |
| 312 | + gamma: 0.999 |
| 313 | + strength: 1.0 |
| 314 | + keep_checkpoints: 5 |
| 315 | + max_steps: 500000 |
| 316 | + time_horizon: 1000 |
| 317 | + summary_freq: 12000 |
| 318 | + checkpoint_interval: 1 |
| 319 | + threaded: true |
| 320 | + env_settings: |
| 321 | + env_path: test_env_path |
| 322 | + env_args: |
| 323 | + - test_env_args1 |
| 324 | + - test_env_args2 |
| 325 | + base_port: 12345 |
| 326 | + num_envs: 8 |
| 327 | + seed: 12345 |
| 328 | + engine_settings: |
| 329 | + width: 12345 |
| 330 | + height: 12345 |
| 331 | + quality_level: 12345 |
| 332 | + time_scale: 12345 |
| 333 | + target_frame_rate: 12345 |
| 334 | + capture_frame_rate: 12345 |
| 335 | + no_graphics: true |
| 336 | + checkpoint_settings: |
| 337 | + run_id: test_run_id |
| 338 | + initialize_from: test_directory |
| 339 | + load_model: false |
| 340 | + resume: true |
| 341 | + force: true |
| 342 | + train_model: false |
| 343 | + inference: false |
| 344 | + debug: true |
| 345 | + """ |
| 346 | + loaded_yaml = yaml.safe_load(test_yaml) |
| 347 | + run_options = RunOptions.from_dict(yaml.safe_load(test_yaml)) |
| 348 | + dict_export = run_options.as_dict() |
| 349 | + check_dict_is_at_least(loaded_yaml, dict_export) |
| 350 | + |
| 351 | + # Re-import and verify has same elements |
| 352 | + run_options2 = RunOptions.from_dict(dict_export) |
| 353 | + second_export = run_options2.as_dict() |
| 354 | + |
| 355 | + check_dict_is_at_least( |
| 356 | + dict_export, second_export, exceptions=["environment_parameters"] |
| 357 | + ) |
| 358 | + # Should be able to use equality instead of back-and-forth once environment_parameters |
| 359 | + # is working |
| 360 | + check_dict_is_at_least( |
| 361 | + second_export, dict_export, exceptions=["environment_parameters"] |
| 362 | + ) |
0 commit comments