Skip to content

Commit 7f74c08

Browse files
author
Chris Elion
authored
non-IEnumerable interface for action masking (#5060)
1 parent a2d6d79 commit 7f74c08

File tree

12 files changed

+149
-96
lines changed

12 files changed

+149
-96
lines changed

Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,22 @@ public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
4848

4949
if (positionX == 0)
5050
{
51-
actionMask.WriteMask(0, new[] { k_Left });
51+
actionMask.SetActionEnabled(0, k_Left, false);
5252
}
5353

5454
if (positionX == maxPosition)
5555
{
56-
actionMask.WriteMask(0, new[] { k_Right });
56+
actionMask.SetActionEnabled(0, k_Right, false);
5757
}
5858

5959
if (positionZ == 0)
6060
{
61-
actionMask.WriteMask(0, new[] { k_Down });
61+
actionMask.SetActionEnabled(0, k_Down, false);
6262
}
6363

6464
if (positionZ == maxPosition)
6565
{
66-
actionMask.WriteMask(0, new[] { k_Up });
66+
actionMask.SetActionEnabled(0, k_Up, false);
6767
}
6868
}
6969
}

com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ public class Match3Actuator : IActuator, IHeuristicProvider, IBuiltInActuator
3131
/// <param name="agent"></param>
3232
/// <param name="name"></param>
3333
public Match3Actuator(AbstractBoard board,
34-
bool forceHeuristic,
35-
int seed,
36-
Agent agent,
37-
string name)
34+
bool forceHeuristic,
35+
int seed,
36+
Agent agent,
37+
string name)
3838
{
3939
m_Board = board;
4040
m_Rows = board.Rows;
@@ -78,34 +78,27 @@ public void OnActionReceived(ActionBuffers actions)
7878
/// <inheritdoc/>
7979
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
8080
{
81+
const int branch = 0;
82+
bool foundValidMove = false;
8183
using (TimerStack.Instance.Scoped("WriteDiscreteActionMask"))
8284
{
83-
actionMask.WriteMask(0, InvalidMoveIndices());
84-
}
85-
}
86-
87-
/// <inheritdoc/>
88-
public string Name { get; }
85+
var numMoves = m_Board.NumMoves();
8986

90-
/// <inheritdoc/>
91-
public void ResetData()
92-
{
93-
}
94-
95-
/// <inheritdoc/>
96-
public BuiltInActuatorType GetBuiltInActuatorType()
97-
{
98-
return BuiltInActuatorType.Match3Actuator;
99-
}
100-
101-
IEnumerable<int> InvalidMoveIndices()
102-
{
103-
var numValidMoves = m_Board.NumMoves();
87+
var currentMove = Move.FromMoveIndex(0, m_Board.Rows, m_Board.Columns);
88+
for (var i = 0; i < numMoves; i++)
89+
{
90+
if (m_Board.IsMoveValid(currentMove))
91+
{
92+
foundValidMove = true;
93+
}
94+
else
95+
{
96+
actionMask.SetActionEnabled(branch, i, false);
97+
}
98+
currentMove.Next(m_Board.Rows, m_Board.Columns);
99+
}
104100

105-
foreach (var move in m_Board.InvalidMoves())
106-
{
107-
numValidMoves--;
108-
if (numValidMoves == 0)
101+
if (!foundValidMove)
109102
{
110103
// If all the moves are invalid and we mask all the actions out, this will cause an assert
111104
// later on in IDiscreteActionMask. Instead, fire a callback to the user if they provided one,
@@ -122,23 +115,33 @@ IEnumerable<int> InvalidMoveIndices()
122115
"an invalid move will be passed to AbstractBoard.MakeMove()."
123116
);
124117
}
125-
// This means the last move won't be returned as an invalid index.
126-
yield break;
118+
actionMask.SetActionEnabled(branch, numMoves - 1, true);
127119
}
128-
yield return move.MoveIndex;
129120
}
130121
}
131122

123+
/// <inheritdoc/>
124+
public string Name { get; }
125+
126+
/// <inheritdoc/>
127+
public void ResetData()
128+
{
129+
}
130+
131+
/// <inheritdoc/>
132+
public BuiltInActuatorType GetBuiltInActuatorType()
133+
{
134+
return BuiltInActuatorType.Match3Actuator;
135+
}
136+
132137
public void Heuristic(in ActionBuffers actionsOut)
133138
{
134139
var discreteActions = actionsOut.DiscreteActions;
135140
discreteActions[0] = GreedyMove();
136141
}
137142

138-
139143
protected int GreedyMove()
140144
{
141-
142145
var bestMoveIndex = 0;
143146
var bestMovePoints = -1;
144147
var numMovesAtCurrentScore = 0;

com.unity.ml-agents/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ and this project adheres to
1010
### Major Changes
1111
#### com.unity.ml-agents (C#)
1212
- Some methods previously marked as `Obsolete` have been removed. If you were using these methods, you need to replace them with their supported counterpart.
13+
- The interface for disabling discrete actions in `IDiscreteActionMask` has changed.
14+
`WriteMask(int branch, IEnumerable<int> actionIndices)` was replaced with
15+
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. See the
16+
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
17+
details. (#5060)
1318
#### ml-agents / ml-agents-envs / gym-unity (Python)
1419

1520
### Minor Changes

com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,17 @@ internal ActuatorDiscreteActionMask(IList<IActuator> actuators, int sumOfDiscret
3535
}
3636

3737
/// <inheritdoc/>
38-
public void WriteMask(int branch, IEnumerable<int> actionIndices)
38+
public void SetActionEnabled(int branch, int actionIndex, bool isEnabled)
3939
{
4040
LazyInitialize();
41-
42-
// Perform the masking
43-
foreach (var actionIndex in actionIndices)
44-
{
4541
#if DEBUG
46-
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
47-
{
48-
throw new UnityAgentsException(
49-
"Invalid Action Masking: Action Mask is too large for specified branch.");
50-
}
51-
#endif
52-
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = true;
42+
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
43+
{
44+
throw new UnityAgentsException(
45+
"Invalid Action Masking: Action Mask is too large for specified branch.");
5346
}
47+
#endif
48+
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = !isEnabled;
5449
}
5550

5651
void LazyInitialize()
@@ -83,8 +78,12 @@ void LazyInitialize()
8378
}
8479
}
8580

86-
/// <inheritdoc/>
87-
public bool[] GetMask()
81+
/// <summary>
82+
/// Get the current mask for an agent.
83+
/// </summary>
84+
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
85+
/// actions.</returns>
86+
internal bool[] GetMask()
8887
{
8988
#if DEBUG
9089
if (m_CurrentMask != null)
@@ -116,7 +115,7 @@ void AssertMask()
116115
/// <summary>
117116
/// Resets the current mask for an agent.
118117
/// </summary>
119-
public void ResetMask()
118+
internal void ResetMask()
120119
{
121120
if (m_CurrentMask != null)
122121
{

com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ public interface IActionReceiver
173173
/// </param>
174174
/// <remarks>
175175
/// When using Discrete Control, you can prevent the Agent from using a certain
176-
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask"/>.
176+
/// action by masking it with <see cref="IDiscreteActionMask.SetActionEnabled"/>.
177177
///
178178
/// See [Agents - Actions] for more information on masking actions.
179179
///

com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,20 @@ namespace Unity.MLAgents.Actuators
88
public interface IDiscreteActionMask
99
{
1010
/// <summary>
11-
/// Modifies an action mask for discrete control agents.
11+
/// Set whether or not the action index for the given branch is allowed.
1212
/// </summary>
13-
/// <remarks>
14-
/// When used, the agent will not be able to perform the actions passed as argument
15-
/// at the next decision for the specified action branch. The actionIndices correspond
13+
/// By default, all discrete actions are allowed.
14+
/// If isEnabled is false, the agent will not be able to perform the actions passed as argument
15+
/// at the next decision for the specified action branch. The actionIndex correspond
1616
/// to the action options the agent will be unable to perform.
1717
///
1818
/// See [Agents - Actions] for more information on masking actions.
1919
///
2020
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_13_docs/docs/Learning-Environment-Design-Agents.md#actions
2121
/// </remarks>
2222
/// <param name="branch">The branch for which the actions will be masked.</param>
23-
/// <param name="actionIndices">The indices of the masked actions.</param>
24-
void WriteMask(int branch, IEnumerable<int> actionIndices);
25-
26-
/// <summary>
27-
/// Get the current mask for an agent.
28-
/// </summary>
29-
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
30-
/// actions.</returns>
31-
bool[] GetMask();
32-
33-
/// <summary>
34-
/// Resets the current mask for an agent.
35-
/// </summary>
36-
void ResetMask();
23+
/// <param name="actionIndex">Index of the action</param>
24+
/// <param name="isEnabled">Whether the action is allowed or now.</param>
25+
void SetActionEnabled(int branch, int actionIndex, bool isEnabled);
3726
}
3827
}

com.unity.ml-agents/Runtime/Agent.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1178,7 +1178,7 @@ public ReadOnlyCollection<float> GetObservations()
11781178
/// </param>
11791179
/// <remarks>
11801180
/// When using Discrete Control, you can prevent the Agent from using a certain
1181-
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask(int, IEnumerable{int})"/>.
1181+
/// action by masking it with <see cref="IDiscreteActionMask.SetActionEnabled"/>.
11821182
///
11831183
/// See [Agents - Actions] for more information on masking actions.
11841184
///

com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ public void FirstBranchMask()
2929
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
3030
var mask = masker.GetMask();
3131
Assert.IsNull(mask);
32-
masker.WriteMask(0, new[] { 1, 2, 3 });
32+
masker.SetActionEnabled(0, 1, false);
33+
masker.SetActionEnabled(0, 2, false);
34+
masker.SetActionEnabled(0, 3, false);
3335
mask = masker.GetMask();
3436
Assert.IsFalse(mask[0]);
3537
Assert.IsTrue(mask[1]);
@@ -39,12 +41,27 @@ public void FirstBranchMask()
3941
Assert.AreEqual(mask.Length, 15);
4042
}
4143

44+
[Test]
45+
public void CanOverwriteMask()
46+
{
47+
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
48+
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
49+
masker.SetActionEnabled(0, 1, false);
50+
var mask = masker.GetMask();
51+
Assert.IsTrue(mask[1]);
52+
53+
masker.SetActionEnabled(0, 1, true);
54+
Assert.IsFalse(mask[1]);
55+
}
56+
4257
[Test]
4358
public void SecondBranchMask()
4459
{
4560
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
4661
var masker = new ActuatorDiscreteActionMask(new[] { actuator1 }, 15, 3);
47-
masker.WriteMask(1, new[] { 1, 2, 3 });
62+
masker.SetActionEnabled(1, 1, false);
63+
masker.SetActionEnabled(1, 2, false);
64+
masker.SetActionEnabled(1, 3, false);
4865
var mask = masker.GetMask();
4966
Assert.IsFalse(mask[0]);
5067
Assert.IsFalse(mask[4]);
@@ -60,7 +77,9 @@ public void MaskReset()
6077
{
6178
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
6279
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
63-
masker.WriteMask(1, new[] { 1, 2, 3 });
80+
masker.SetActionEnabled(1, 1, false);
81+
masker.SetActionEnabled(1, 2, false);
82+
masker.SetActionEnabled(1, 3, false);
6483
masker.ResetMask();
6584
var mask = masker.GetMask();
6685
for (var i = 0; i < 15; i++)
@@ -75,15 +94,18 @@ public void ThrowsError()
7594
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
7695
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
7796
Assert.Catch<UnityAgentsException>(
78-
() => masker.WriteMask(0, new[] { 5 }));
97+
() => masker.SetActionEnabled(0, 5, false));
7998
Assert.Catch<UnityAgentsException>(
80-
() => masker.WriteMask(1, new[] { 5 }));
81-
masker.WriteMask(2, new[] { 5 });
99+
() => masker.SetActionEnabled(1, 5, false));
100+
masker.SetActionEnabled(2, 5, false);
82101
Assert.Catch<UnityAgentsException>(
83-
() => masker.WriteMask(3, new[] { 1 }));
102+
() => masker.SetActionEnabled(3, 1, false));
84103
masker.GetMask();
85104
masker.ResetMask();
86-
masker.WriteMask(0, new[] { 0, 1, 2, 3 });
105+
masker.SetActionEnabled(0, 0, false);
106+
masker.SetActionEnabled(0, 1, false);
107+
masker.SetActionEnabled(0, 2, false);
108+
masker.SetActionEnabled(0, 3, false);
87109
Assert.Catch<UnityAgentsException>(
88110
() => masker.GetMask());
89111
}
@@ -93,9 +115,10 @@ public void MultipleMaskEdit()
93115
{
94116
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
95117
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
96-
masker.WriteMask(0, new[] { 0, 1 });
97-
masker.WriteMask(0, new[] { 3 });
98-
masker.WriteMask(2, new[] { 1 });
118+
masker.SetActionEnabled(0, 0, false);
119+
masker.SetActionEnabled(0, 1, false);
120+
masker.SetActionEnabled(0, 3, false);
121+
masker.SetActionEnabled(2, 1, false);
99122
var mask = masker.GetMask();
100123
for (var i = 0; i < 15; i++)
101124
{

com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@ public void OnActionReceived(ActionBuffers actionBuffers)
2222

2323
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
2424
{
25+
2526
for (var i = 0; i < Masks.Length; i++)
2627
{
27-
actionMask.WriteMask(i, Masks[i]);
28+
foreach (var actionIndex in Masks[i])
29+
{
30+
actionMask.SetActionEnabled(i, actionIndex, false);
31+
}
2832
}
2933
}
3034

com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ public void OnActionReceived(ActionBuffers actionBuffers)
2525

2626
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
2727
{
28-
actionMask.WriteMask(Branch, Mask);
28+
foreach (var actionIndex in Mask)
29+
{
30+
actionMask.SetActionEnabled(Branch, actionIndex, false);
31+
}
2932
}
3033

3134
public void Heuristic(in ActionBuffers actionBuffersOut)

0 commit comments

Comments
 (0)