|
1 | 1 | import attr
|
2 | 2 | import pytest
|
| 3 | +import yaml |
3 | 4 |
|
4 |
| -from typing import Dict |
| 5 | +from typing import Dict, List, Optional |
5 | 6 |
|
6 | 7 | from mlagents.trainers.settings import (
|
7 | 8 | RunOptions,
|
@@ -32,6 +33,32 @@ def check_if_different(testobj1: object, testobj2: object) -> None:
|
32 | 33 | check_if_different(val, attr.asdict(testobj2, recurse=False)[key])
|
33 | 34 |
|
34 | 35 |
|
| 36 | +def check_dict_is_at_least( |
| 37 | + testdict1: Dict, testdict2: Dict, exceptions: Optional[List[str]] = None |
| 38 | +) -> None: |
| 39 | + """ |
| 40 | + Check if everything present in the 1st dict is the same in the second dict. |
| 41 | + Excludes things that the second dict has but is not present in the heirarchy of the |
| 42 | + 1st dict. Used to compare an underspecified config dict structure (e.g. as |
| 43 | + would be provided by a user) with a complete one (e.g. as exported by RunOptions). |
| 44 | + """ |
| 45 | + for key, val in testdict1.items(): |
| 46 | + if exceptions is not None and key in exceptions: |
| 47 | + continue |
| 48 | + assert key in testdict2 |
| 49 | + if isinstance(val, dict): |
| 50 | + check_dict_is_at_least(val, testdict2[key]) |
| 51 | + elif isinstance(val, list): |
| 52 | + assert isinstance(testdict2[key], list) |
| 53 | + for _el0, _el1 in zip(val, testdict2[key]): |
| 54 | + if isinstance(_el0, dict): |
| 55 | + check_dict_is_at_least(_el0, _el1) |
| 56 | + else: |
| 57 | + assert val == testdict2[key] |
| 58 | + else: # If not a dict, don't recurse into it |
| 59 | + assert val == testdict2[key] |
| 60 | + |
| 61 | + |
35 | 62 | def test_is_new_instance():
|
36 | 63 | """
|
37 | 64 | Verify that every instance of RunOptions() and its subclasses
|
@@ -289,3 +316,92 @@ def test_env_parameter_structure():
|
289 | 316 | EnvironmentParameterSettings.structure(
|
290 | 317 | invalid_curriculum_dict, Dict[str, EnvironmentParameterSettings]
|
291 | 318 | )
|
| 319 | + |
| 320 | + |
| 321 | +def test_exportable_settings(): |
| 322 | + """ |
| 323 | + Test that structuring and unstructuring a RunOptions object results in the same |
| 324 | + configuration representation. |
| 325 | + """ |
| 326 | + # Try to enable as many features as possible in this test YAML to hit all the |
| 327 | + # edge cases. Set as much as possible as non-default values to ensure no flukes. |
| 328 | + # TODO: Add back in environment_parameters |
| 329 | + test_yaml = """ |
| 330 | + behaviors: |
| 331 | + 3DBall: |
| 332 | + trainer_type: sac |
| 333 | + hyperparameters: |
| 334 | + learning_rate: 0.0004 |
| 335 | + learning_rate_schedule: constant |
| 336 | + batch_size: 64 |
| 337 | + buffer_size: 200000 |
| 338 | + buffer_init_steps: 100 |
| 339 | + tau: 0.006 |
| 340 | + steps_per_update: 10.0 |
| 341 | + save_replay_buffer: true |
| 342 | + init_entcoef: 0.5 |
| 343 | + reward_signal_steps_per_update: 10.0 |
| 344 | + network_settings: |
| 345 | + normalize: false |
| 346 | + hidden_units: 256 |
| 347 | + num_layers: 3 |
| 348 | + vis_encode_type: nature_cnn |
| 349 | + memory: |
| 350 | + memory_size: 1288 |
| 351 | + sequence_length: 12 |
| 352 | + reward_signals: |
| 353 | + extrinsic: |
| 354 | + gamma: 0.999 |
| 355 | + strength: 1.0 |
| 356 | + curiosity: |
| 357 | + gamma: 0.999 |
| 358 | + strength: 1.0 |
| 359 | + keep_checkpoints: 5 |
| 360 | + max_steps: 500000 |
| 361 | + time_horizon: 1000 |
| 362 | + summary_freq: 12000 |
| 363 | + checkpoint_interval: 1 |
| 364 | + threaded: true |
| 365 | + env_settings: |
| 366 | + env_path: test_env_path |
| 367 | + env_args: |
| 368 | + - test_env_args1 |
| 369 | + - test_env_args2 |
| 370 | + base_port: 12345 |
| 371 | + num_envs: 8 |
| 372 | + seed: 12345 |
| 373 | + engine_settings: |
| 374 | + width: 12345 |
| 375 | + height: 12345 |
| 376 | + quality_level: 12345 |
| 377 | + time_scale: 12345 |
| 378 | + target_frame_rate: 12345 |
| 379 | + capture_frame_rate: 12345 |
| 380 | + no_graphics: true |
| 381 | + checkpoint_settings: |
| 382 | + run_id: test_run_id |
| 383 | + initialize_from: test_directory |
| 384 | + load_model: false |
| 385 | + resume: true |
| 386 | + force: true |
| 387 | + train_model: false |
| 388 | + inference: false |
| 389 | + debug: true |
| 390 | + """ |
| 391 | + loaded_yaml = yaml.safe_load(test_yaml) |
| 392 | + run_options = RunOptions.from_dict(yaml.safe_load(test_yaml)) |
| 393 | + dict_export = run_options.as_dict() |
| 394 | + check_dict_is_at_least(loaded_yaml, dict_export) |
| 395 | + |
| 396 | + # Re-import and verify has same elements |
| 397 | + run_options2 = RunOptions.from_dict(dict_export) |
| 398 | + second_export = run_options2.as_dict() |
| 399 | + |
| 400 | + check_dict_is_at_least( |
| 401 | + dict_export, second_export, exceptions=["environment_parameters"] |
| 402 | + ) |
| 403 | + # Should be able to use equality instead of back-and-forth once environment_parameters |
| 404 | + # is working |
| 405 | + check_dict_is_at_least( |
| 406 | + second_export, dict_export, exceptions=["environment_parameters"] |
| 407 | + ) |
0 commit comments