Skip to content

Commit ad27b4d

Browse files
[Renaming] SetActionMask -> SetDiscreteActionMask + added the virtual method CollectDiscreteActionMasks (#3525)
* Code edits * Modified the markdowns * Update com.unity.ml-agents/CHANGELOG.md Co-Authored-By: Chris Elion <chris.elion@unity3d.com> * Update docs/Learning-Environment-Design-Agents.md Co-Authored-By: Chris Elion <chris.elion@unity3d.com> * Update docs/Learning-Environment-Design-Agents.md Co-Authored-By: Chris Elion <chris.elion@unity3d.com> * Renaming files and methods * Addressing comments * Update docs/Learning-Environment-Design-Agents.md Co-Authored-By: Chris Elion <chris.elion@unity3d.com> Co-authored-by: Chris Elion <celion@gmail.com>
1 parent c0a2b29 commit ad27b4d

File tree

8 files changed

+73
-145
lines changed

8 files changed

+73
-145
lines changed

Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,46 +32,35 @@ public override void InitializeAgent()
3232
{
3333
}
3434

35-
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
35+
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
3636
{
37-
// There are no numeric observations to collect as this environment uses visual
38-
// observations.
39-
4037
// Mask the necessary actions if selected by the user.
4138
if (maskActions)
4239
{
43-
SetMask(actionMasker);
44-
}
45-
}
46-
47-
/// <summary>
48-
/// Applies the mask for the agents action to disallow unnecessary actions.
49-
/// </summary>
50-
void SetMask(ActionMasker actionMasker)
51-
{
52-
// Prevents the agent from picking an action that would make it collide with a wall
53-
var positionX = (int)transform.position.x;
54-
var positionZ = (int)transform.position.z;
55-
var maxPosition = (int)Academy.Instance.FloatProperties.GetPropertyWithDefault("gridSize", 5f) - 1;
40+
// Prevents the agent from picking an action that would make it collide with a wall
41+
var positionX = (int)transform.position.x;
42+
var positionZ = (int)transform.position.z;
43+
var maxPosition = (int)Academy.Instance.FloatProperties.GetPropertyWithDefault("gridSize", 5f) - 1;
5644

57-
if (positionX == 0)
58-
{
59-
actionMasker.SetActionMask(k_Left);
60-
}
45+
if (positionX == 0)
46+
{
47+
actionMasker.SetMask(0, new int[]{ k_Left});
48+
}
6149

62-
if (positionX == maxPosition)
63-
{
64-
actionMasker.SetActionMask(k_Right);
65-
}
50+
if (positionX == maxPosition)
51+
{
52+
actionMasker.SetMask(0, new int[]{k_Right});
53+
}
6654

67-
if (positionZ == 0)
68-
{
69-
actionMasker.SetActionMask(k_Down);
70-
}
55+
if (positionZ == 0)
56+
{
57+
actionMasker.SetMask(0, new int[]{k_Down});
58+
}
7159

72-
if (positionZ == maxPosition)
73-
{
74-
actionMasker.SetActionMask(k_Up);
60+
if (positionZ == maxPosition)
61+
{
62+
actionMasker.SetMask(0, new int[]{k_Up});
63+
}
7564
}
7665
}
7766

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
77

88
## [Unreleased]
99
### Major Changes
10-
- Agent.CollectObservations now takes a VectorSensor argument. It was also overloaded to optionally take an ActionMasker argument. (#3352, #3389)
10+
- `Agent.CollectObservations` now takes a VectorSensor argument. (#3352, #3389)
11+
- Added `Agent.CollectDiscreteActionMasks` virtual method with a `DiscreteActionMasker` argument to specify which discrete actions are unavailable to the Agent. (#3525)
1112
- Beta support for ONNX export was added. If the `tf2onnx` python package is installed, models will be saved to `.onnx` as well as `.nn` format.
1213
Note that Barracuda 0.6.0 or later is required to import the `.onnx` files properly
1314
- Multi-GPU training and the `--multi-gpu` option has been removed temporarily. (#3345)

com.unity.ml-agents/Runtime/Agent.cs

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ internal struct AgentAction
5959
/// an Agent. An agent produces observations and takes actions in the
6060
/// environment. Observations are determined by the cameras attached
6161
/// to the agent in addition to the vector observations implemented by the
62-
/// user in <see cref="Agent.CollectObservations(VectorSensor)"/> or
63-
/// <see cref="Agent.CollectObservations(VectorSensor, ActionMasker)"/>.
62+
/// user in <see cref="Agent.CollectObservations(VectorSensor)"/>.
6463
/// On the other hand, actions are determined by decisions produced by a Policy.
6564
/// Currently, this class is expected to be extended to implement the desired agent behavior.
6665
/// </summary>
@@ -173,7 +172,7 @@ internal struct AgentParameters
173172
bool m_Initialized;
174173

175174
/// Keeps track of the actions that are masked at each step.
176-
ActionMasker m_ActionMasker;
175+
DiscreteActionMasker m_ActionMasker;
177176

178177
/// <summary>
179178
/// Set of DemonstrationWriters that the Agent will write its step information to.
@@ -434,7 +433,7 @@ public void RequestAction()
434433
void ResetData()
435434
{
436435
var param = m_PolicyFactory.brainParameters;
437-
m_ActionMasker = new ActionMasker(param);
436+
m_ActionMasker = new DiscreteActionMasker(param);
438437
// If we haven't initialized vectorActions, initialize to 0. This should only
439438
// happen during the creation of the Agent. In subsequent episodes, vectorAction
440439
// should stay the previous action before the Done(), so that it is properly recorded.
@@ -549,7 +548,14 @@ void SendInfoToBrain()
549548
UpdateSensors();
550549
using (TimerStack.Instance.Scoped("CollectObservations"))
551550
{
552-
CollectObservations(collectObservationsSensor, m_ActionMasker);
551+
CollectObservations(collectObservationsSensor);
552+
}
553+
using (TimerStack.Instance.Scoped("CollectDiscreteActionMasks"))
554+
{
555+
if (m_PolicyFactory.brainParameters.vectorActionSpaceType == SpaceType.Discrete)
556+
{
557+
CollectDiscreteActionMasks(m_ActionMasker);
558+
}
553559
}
554560
m_Info.actionMasks = m_ActionMasker.GetMask();
555561

@@ -612,51 +618,18 @@ public virtual void CollectObservations(VectorSensor sensor)
612618
}
613619

614620
/// <summary>
615-
/// Collects the vector observations of the agent alongside the masked actions.
616-
/// The agent observation describes the current environment from the
617-
/// perspective of the agent.
621+
/// Collects the masks for discrete actions.
622+
/// When using discrete actions, the agent will not perform the masked action.
618623
/// </summary>
619-
/// <param name="sensor">
620-
/// The vector observations for the agent.
621-
/// </param>
622624
/// <param name="actionMasker">
623-
/// The masked actions for the agent.
625+
/// The action masker for the agent.
624626
/// </param>
625627
/// <remarks>
626-
/// An agents observation is any environment information that helps
627-
/// the Agent achieve its goal. For example, for a fighting Agent, its
628-
/// observation could include distances to friends or enemies, or the
629-
/// current level of ammunition at its disposal.
630-
/// Recall that an Agent may attach vector or visual observations.
631-
/// Vector observations are added by calling the provided helper methods
632-
/// on the VectorSensor input:
633-
/// - <see cref="VectorSensor.AddObservation(int)"/>
634-
/// - <see cref="VectorSensor.AddObservation(float)"/>
635-
/// - <see cref="VectorSensor.AddObservation(Vector3)"/>
636-
/// - <see cref="VectorSensor.AddObservation(Vector2)"/>
637-
/// - <see cref="VectorSensor.AddObservation(Quaternion)"/>
638-
/// - <see cref="VectorSensor.AddObservation(bool)"/>
639-
/// - <see cref="VectorSensor.AddObservation(IEnumerable{float})"/>
640-
/// - <see cref="VectorSensor.AddOneHotObservation(int, int)"/>
641-
/// Depending on your environment, any combination of these helpers can
642-
/// be used. They just need to be used in the exact same order each time
643-
/// this method is called and the resulting size of the vector observation
644-
/// needs to match the vectorObservationSize attribute of the linked Brain.
645-
/// Visual observations are implicitly added from the cameras attached to
646-
/// the Agent.
647628
/// When using Discrete Control, you can prevent the Agent from using a certain
648-
/// action by masking it. You can call the following method on the ActionMasker
649-
/// input :
650-
/// - <see cref="ActionMasker.SetActionMask(int)"/>
651-
/// - <see cref="ActionMasker.SetActionMask(int, int)"/>
652-
/// - <see cref="ActionMasker.SetActionMask(int, IEnumerable{int})"/>
653-
/// - <see cref="ActionMasker.SetActionMask(IEnumerable{int})"/>
654-
/// The branch input is the index of the action, actionIndices are the indices of the
655-
/// invalid options for that action.
629+
/// action by masking it with <see cref="DiscreteActionMasker.SetMask(int, IEnumerable{int})"/>
656630
/// </remarks>
657-
public virtual void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
631+
public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
658632
{
659-
CollectObservations(sensor);
660633
}
661634

662635
/// <summary>

com.unity.ml-agents/Runtime/ActionMasker.cs renamed to com.unity.ml-agents/Runtime/DiscreteActionMasker.cs

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace MLAgents
1111
/// left side of the board). This class represents the set of masked actions and provides
1212
/// the utilities for setting and retrieving them.
1313
/// </summary>
14-
public class ActionMasker
14+
public class DiscreteActionMasker
1515
{
1616
/// When using discrete control, is the starting indices of the actions
1717
/// when all the branches are concatenated with each other.
@@ -21,47 +21,11 @@ public class ActionMasker
2121

2222
readonly BrainParameters m_BrainParameters;
2323

24-
internal ActionMasker(BrainParameters brainParameters)
24+
internal DiscreteActionMasker(BrainParameters brainParameters)
2525
{
2626
m_BrainParameters = brainParameters;
2727
}
2828

29-
/// <summary>
30-
/// Sets an action mask for discrete control agents. When used, the agent will not be
31-
/// able to perform the actions passed as argument at the next decision.
32-
/// The actionIndices correspond to the actions the agent will be unable to perform
33-
/// on the branch 0.
34-
/// </summary>
35-
/// <param name="actionIndices">The indices of the masked actions on branch 0.</param>
36-
public void SetActionMask(IEnumerable<int> actionIndices)
37-
{
38-
SetActionMask(0, actionIndices);
39-
}
40-
41-
/// <summary>
42-
/// Sets an action mask for discrete control agents. When used, the agent will not be
43-
/// able to perform the action passed as argument at the next decision for the specified
44-
/// action branch. The actionIndex correspond to the action the agent will be unable
45-
/// to perform.
46-
/// </summary>
47-
/// <param name="branch">The branch for which the actions will be masked.</param>
48-
/// <param name="actionIndex">The index of the masked action.</param>
49-
public void SetActionMask(int branch, int actionIndex)
50-
{
51-
SetActionMask(branch, new[] { actionIndex });
52-
}
53-
54-
/// <summary>
55-
/// Sets an action mask for discrete control agents. When used, the agent will not be
56-
/// able to perform the action passed as argument at the next decision. The actionIndex
57-
/// correspond to the action the agent will be unable to perform on the branch 0.
58-
/// </summary>
59-
/// <param name="actionIndex">The index of the masked action on branch 0</param>
60-
public void SetActionMask(int actionIndex)
61-
{
62-
SetActionMask(0, new[] { actionIndex });
63-
}
64-
6529
/// <summary>
6630
/// Modifies an action mask for discrete control agents. When used, the agent will not be
6731
/// able to perform the actions passed as argument at the next decision for the specified
@@ -70,7 +34,7 @@ public void SetActionMask(int actionIndex)
7034
/// </summary>
7135
/// <param name="branch">The branch for which the actions will be masked</param>
7236
/// <param name="actionIndices">The indices of the masked actions</param>
73-
public void SetActionMask(int branch, IEnumerable<int> actionIndices)
37+
public void SetMask(int branch, IEnumerable<int> actionIndices)
7438
{
7539
// If the branch does not exist, raise an error
7640
if (branch >= m_BrainParameters.vectorActionSize.Length)

com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ public class EditModeTestActionMasker
88
public void Contruction()
99
{
1010
var bp = new BrainParameters();
11-
var masker = new ActionMasker(bp);
11+
var masker = new DiscreteActionMasker(bp);
1212
Assert.IsNotNull(masker);
1313
}
1414

@@ -18,8 +18,8 @@ public void FailsWithContinuous()
1818
var bp = new BrainParameters();
1919
bp.vectorActionSpaceType = SpaceType.Continuous;
2020
bp.vectorActionSize = new[] {4};
21-
var masker = new ActionMasker(bp);
22-
masker.SetActionMask(0, new[] {0});
21+
var masker = new DiscreteActionMasker(bp);
22+
masker.SetMask(0, new[] {0});
2323
Assert.Catch<UnityAgentsException>(() => masker.GetMask());
2424
}
2525

@@ -28,7 +28,7 @@ public void NullMask()
2828
{
2929
var bp = new BrainParameters();
3030
bp.vectorActionSpaceType = SpaceType.Discrete;
31-
var masker = new ActionMasker(bp);
31+
var masker = new DiscreteActionMasker(bp);
3232
var mask = masker.GetMask();
3333
Assert.IsNull(mask);
3434
}
@@ -39,10 +39,10 @@ public void FirstBranchMask()
3939
var bp = new BrainParameters();
4040
bp.vectorActionSpaceType = SpaceType.Discrete;
4141
bp.vectorActionSize = new[] {4, 5, 6};
42-
var masker = new ActionMasker(bp);
42+
var masker = new DiscreteActionMasker(bp);
4343
var mask = masker.GetMask();
4444
Assert.IsNull(mask);
45-
masker.SetActionMask(0, new[] {1, 2, 3});
45+
masker.SetMask(0, new[] {1, 2, 3});
4646
mask = masker.GetMask();
4747
Assert.IsFalse(mask[0]);
4848
Assert.IsTrue(mask[1]);
@@ -60,8 +60,8 @@ public void SecondBranchMask()
6060
vectorActionSpaceType = SpaceType.Discrete,
6161
vectorActionSize = new[] { 4, 5, 6 }
6262
};
63-
var masker = new ActionMasker(bp);
64-
masker.SetActionMask(1, new[] {1, 2, 3});
63+
var masker = new DiscreteActionMasker(bp);
64+
masker.SetMask(1, new[] {1, 2, 3});
6565
var mask = masker.GetMask();
6666
Assert.IsFalse(mask[0]);
6767
Assert.IsFalse(mask[4]);
@@ -80,8 +80,8 @@ public void MaskReset()
8080
vectorActionSpaceType = SpaceType.Discrete,
8181
vectorActionSize = new[] { 4, 5, 6 }
8282
};
83-
var masker = new ActionMasker(bp);
84-
masker.SetActionMask(1, new[] {1, 2, 3});
83+
var masker = new DiscreteActionMasker(bp);
84+
masker.SetMask(1, new[] {1, 2, 3});
8585
masker.ResetMask();
8686
var mask = masker.GetMask();
8787
for (var i = 0; i < 15; i++)
@@ -98,18 +98,18 @@ public void ThrowsError()
9898
vectorActionSpaceType = SpaceType.Discrete,
9999
vectorActionSize = new[] { 4, 5, 6 }
100100
};
101-
var masker = new ActionMasker(bp);
101+
var masker = new DiscreteActionMasker(bp);
102102

103103
Assert.Catch<UnityAgentsException>(
104-
() => masker.SetActionMask(0, new[] {5}));
104+
() => masker.SetMask(0, new[] {5}));
105105
Assert.Catch<UnityAgentsException>(
106-
() => masker.SetActionMask(1, new[] {5}));
107-
masker.SetActionMask(2, new[] {5});
106+
() => masker.SetMask(1, new[] {5}));
107+
masker.SetMask(2, new[] {5});
108108
Assert.Catch<UnityAgentsException>(
109-
() => masker.SetActionMask(3, new[] {1}));
109+
() => masker.SetMask(3, new[] {1}));
110110
masker.GetMask();
111111
masker.ResetMask();
112-
masker.SetActionMask(0, new[] {0, 1, 2, 3});
112+
masker.SetMask(0, new[] {0, 1, 2, 3});
113113
Assert.Catch<UnityAgentsException>(
114114
() => masker.GetMask());
115115
}
@@ -120,10 +120,10 @@ public void MultipleMaskEdit()
120120
var bp = new BrainParameters();
121121
bp.vectorActionSpaceType = SpaceType.Discrete;
122122
bp.vectorActionSize = new[] {4, 5, 6};
123-
var masker = new ActionMasker(bp);
124-
masker.SetActionMask(0, new[] {0, 1});
125-
masker.SetActionMask(0, new[] {3});
126-
masker.SetActionMask(2, new[] {1});
123+
var masker = new DiscreteActionMasker(bp);
124+
masker.SetMask(0, new[] {0, 1});
125+
masker.SetMask(0, new[] {3});
126+
masker.SetMask(2, new[] {1});
127127
var mask = masker.GetMask();
128128
for (var i = 0; i < 15; i++)
129129
{

docs/Learning-Environment-Design-Agents.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -390,20 +390,20 @@ impossible for the next decision. When the Agent is controlled by a
390390
neural network, the Agent will be unable to perform the specified action. Note
391391
that when the Agent is controlled by its Heuristic, the Agent will
392392
still be able to decide to perform the masked action. In order to mask an
393-
action, call the method `SetActionMask` on the optional `ActionMasker` argument of the `CollectObservation` method :
393+
action, override the `Agent.CollectDiscreteActionMasks()` virtual method, and call `DiscreteActionMasker.SetMask()` in it:
394394

395395
```csharp
396-
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker){
397-
actionMasker.SetActionMask(branch, actionIndices)
396+
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker){
397+
actionMasker.SetMask(branch, actionIndices)
398398
}
399399
```
400400

401401
Where:
402402

403403
* `branch` is the index (starting at 0) of the branch on which you want to mask
404404
the action
405-
* `actionIndices` is a list of `int` or a single `int` corresponding to the
406-
index of the action that the Agent cannot perform.
405+
* `actionIndices` is a list of `int` corresponding to the
406+
indices of the actions that the Agent cannot perform.
407407

408408
For example, if you have an Agent with 2 branches and on the first branch
409409
(branch 0) there are 4 possible actions : _"do nothing"_, _"jump"_, _"shoot"_
@@ -412,12 +412,12 @@ nothing"_ or _"change weapon"_ for his next decision (since action index 1 and 2
412412
are masked)
413413

414414
```csharp
415-
SetActionMask(0, new int[2]{1,2})
415+
SetMask(0, new int[2]{1,2})
416416
```
417417

418418
Notes:
419419

420-
* You can call `SetActionMask` multiple times if you want to put masks on
420+
* You can call `SetMask` multiple times if you want to put masks on
421421
multiple branches.
422422
* You cannot mask all the actions of a branch.
423423
* You cannot mask actions in continuous control.

0 commit comments

Comments
 (0)