Skip to content

Commit 06a6f4f

Browse files
author
Chris Elion
authored
[MLA-1145] don't allow --num-envs >1 with no --env (#4209)
* don't allow --num-envs >1 with no --env (#4203)
1 parent 1ca557e commit 06a6f4f

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ empty string). (#4155)
4545
- Fixed issue with FoodCollector, Soccer, and WallJump when playing with keyboard. (#4147, #4174)
4646
- Fixed a crash in StatsReporter when using threaded trainers with very frequent summary writes
4747
(#4201)
48+
- `mlagents-learn` will now raise an error immediately if `--num-envs` is greater than 1 without setting the `--env`
49+
argument. (#4203)
4850

4951
## [1.1.0-preview] - 2020-06-10
5052
### Major Changes

ml-agents/mlagents/trainers/settings.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,9 +600,14 @@ class EnvironmentSettings:
600600
env_path: Optional[str] = parser.get_default("env_path")
601601
env_args: Optional[List[str]] = parser.get_default("env_args")
602602
base_port: int = parser.get_default("base_port")
603-
num_envs: int = parser.get_default("num_envs")
603+
num_envs: int = attr.ib(default=parser.get_default("num_envs"))
604604
seed: int = parser.get_default("seed")
605605

606+
@num_envs.validator
607+
def validate_num_envs(self, attribute, value):
608+
if value > 1 and self.env_path is None:
609+
raise ValueError("num_envs must be 1 if env_path is not set.")
610+
606611

607612
@attr.s(auto_attribs=True)
608613
class EngineSettings:

ml-agents/mlagents/trainers/tests/test_settings.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
RewardSignalType,
1414
RewardSignalSettings,
1515
CuriositySettings,
16+
EnvironmentSettings,
1617
EnvironmentParameterSettings,
1718
ConstantSettings,
1819
UniformSettings,
@@ -452,3 +453,18 @@ def test_exportable_settings(use_defaults):
452453
check_dict_is_at_least(second_export, dict_export)
453454
# Check that the two exports are the same
454455
assert dict_export == second_export
456+
457+
458+
def test_environment_settings():
459+
# default args
460+
EnvironmentSettings()
461+
462+
# 1 env is OK if no env_path
463+
EnvironmentSettings(num_envs=1)
464+
465+
# multiple envs is OK if env_path is set
466+
EnvironmentSettings(num_envs=42, env_path="/foo/bar.exe")
467+
468+
# Multiple environments with no env_path is an error
469+
with pytest.raises(ValueError):
470+
EnvironmentSettings(num_envs=2)

0 commit comments

Comments
 (0)