Skip to content

Commit 6af4d21

Browse files
cmardmaryamhonari
authored andcommitted
Return deterministic actions for training (#5615)
* Added more stable test. * Fix the tests. * Fix pre-commit * Fix help line to pass precommit.
1 parent d23a719 commit 6af4d21

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

ml-agents/mlagents/trainers/cli_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def _create_parser() -> argparse.ArgumentParser:
9696
default=False,
9797
dest="deterministic",
9898
action=DetectDefaultStoreTrue,
99-
help="Whether to select actions deterministically in policy. `dist.mean` for continuous action space, and `dist.argmax` for deterministic action space ",
99+
help="Whether to select actions deterministically in policy. `dist.mean` for continuous action "
100+
"space, and `dist.argmax` for deterministic action space ",
100101
)
101102
argparser.add_argument(
102103
"--force",

ml-agents/mlagents/trainers/tests/torch/test_action_model.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,22 @@ def test_deterministic_sample_action():
6565
agent_action2 = action_model._sample_action(dists)
6666
agent_action3 = action_model._sample_action(dists)
6767

68-
chance_counter = 0
69-
70-
if not torch.equal(
68+
assert not torch.equal(
7169
agent_action1.continuous_tensor, agent_action2.continuous_tensor
72-
):
73-
chance_counter += 1
70+
)
7471

75-
if not torch.equal(
72+
assert not torch.equal(
7673
agent_action1.continuous_tensor, agent_action3.continuous_tensor
77-
):
78-
chance_counter += 1
74+
)
7975

80-
assert chance_counter > 1
8176
chance_counter = 0
8277
if not torch.equal(agent_action1.discrete_tensor, agent_action2.discrete_tensor):
8378
chance_counter += 1
8479
if not torch.equal(agent_action1.discrete_tensor, agent_action3.discrete_tensor):
8580
chance_counter += 1
81+
if not torch.equal(agent_action2.discrete_tensor, agent_action3.discrete_tensor):
82+
chance_counter += 1
83+
8684
assert chance_counter > 1
8785

8886

0 commit comments

Comments
 (0)