Skip to content

[Renaming] SetActionMask -> SetDiscreteActionMask + added the virtual method CollectDiscreteActionMasks #3525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 21 additions & 32 deletions Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,46 +32,35 @@ public override void InitializeAgent()
{
}

public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
{
// There are no numeric observations to collect as this environment uses visual
// observations.

// Mask the necessary actions if selected by the user.
if (maskActions)
{
SetMask(actionMasker);
}
}

/// <summary>
/// Applies the mask for the agents action to disallow unnecessary actions.
/// </summary>
void SetMask(ActionMasker actionMasker)
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var maxPosition = (int)Academy.Instance.FloatProperties.GetPropertyWithDefault("gridSize", 5f) - 1;
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var maxPosition = (int)Academy.Instance.FloatProperties.GetPropertyWithDefault("gridSize", 5f) - 1;

if (positionX == 0)
{
actionMasker.SetActionMask(k_Left);
}
if (positionX == 0)
{
actionMasker.SetMask(0, new int[]{ k_Left});
}

if (positionX == maxPosition)
{
actionMasker.SetActionMask(k_Right);
}
if (positionX == maxPosition)
{
actionMasker.SetMask(0, new int[]{k_Right});
}

if (positionZ == 0)
{
actionMasker.SetActionMask(k_Down);
}
if (positionZ == 0)
{
actionMasker.SetMask(0, new int[]{k_Down});
}

if (positionZ == maxPosition)
{
actionMasker.SetActionMask(k_Up);
if (positionZ == maxPosition)
{
actionMasker.SetMask(0, new int[]{k_Up});
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]
### Major Changes
- Agent.CollectObservations now takes a VectorSensor argument. It was also overloaded to optionally take an ActionMasker argument. (#3352, #3389)
- `Agent.CollectObservations` now takes a VectorSensor argument. (#3352, #3389)
- Added `Agent.CollectDiscreteActionMasks` virtual method with a `DiscreteActionMasker` argument to specify which discrete actions are unavailable to the Agent. (#3525)
- Beta support for ONNX export was added. If the `tf2onnx` python package is installed, models will be saved to `.onnx` as well as `.nn` format.
Note that Barracuda 0.6.0 or later is required to import the `.onnx` files properly
- Multi-GPU training and the `--multi-gpu` option has been removed temporarily. (#3345)
Expand Down
59 changes: 16 additions & 43 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ internal struct AgentAction
/// an Agent. An agent produces observations and takes actions in the
/// environment. Observations are determined by the cameras attached
/// to the agent in addition to the vector observations implemented by the
/// user in <see cref="Agent.CollectObservations(VectorSensor)"/> or
/// <see cref="Agent.CollectObservations(VectorSensor, ActionMasker)"/>.
/// user in <see cref="Agent.CollectObservations(VectorSensor)"/>.
/// On the other hand, actions are determined by decisions produced by a Policy.
/// Currently, this class is expected to be extended to implement the desired agent behavior.
/// </summary>
Expand Down Expand Up @@ -173,7 +172,7 @@ internal struct AgentParameters
bool m_Initialized;

/// Keeps track of the actions that are masked at each step.
ActionMasker m_ActionMasker;
DiscreteActionMasker m_ActionMasker;

/// <summary>
/// Set of DemonstrationWriters that the Agent will write its step information to.
Expand Down Expand Up @@ -434,7 +433,7 @@ public void RequestAction()
void ResetData()
{
var param = m_PolicyFactory.brainParameters;
m_ActionMasker = new ActionMasker(param);
m_ActionMasker = new DiscreteActionMasker(param);
// If we haven't initialized vectorActions, initialize to 0. This should only
// happen during the creation of the Agent. In subsequent episodes, vectorAction
// should stay the previous action before the Done(), so that it is properly recorded.
Expand Down Expand Up @@ -549,7 +548,14 @@ void SendInfoToBrain()
UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{
CollectObservations(collectObservationsSensor, m_ActionMasker);
CollectObservations(collectObservationsSensor);
}
using (TimerStack.Instance.Scoped("CollectDiscreteActionMasks"))
{
if (m_PolicyFactory.brainParameters.vectorActionSpaceType == SpaceType.Discrete)
{
CollectDiscreteActionMasks(m_ActionMasker);
}
}
m_Info.actionMasks = m_ActionMasker.GetMask();

Expand Down Expand Up @@ -612,51 +618,18 @@ public virtual void CollectObservations(VectorSensor sensor)
}

/// <summary>
/// Collects the vector observations of the agent alongside the masked actions.
/// The agent observation describes the current environment from the
/// perspective of the agent.
/// Collects the masks for discrete actions.
/// When using discrete actions, the agent will not perform the masked action.
/// </summary>
/// <param name="sensor">
/// The vector observations for the agent.
/// </param>
/// <param name="actionMasker">
/// The masked actions for the agent.
/// The action masker for the agent.
/// </param>
/// <remarks>
/// An agents observation is any environment information that helps
/// the Agent achieve its goal. For example, for a fighting Agent, its
/// observation could include distances to friends or enemies, or the
/// current level of ammunition at its disposal.
/// Recall that an Agent may attach vector or visual observations.
/// Vector observations are added by calling the provided helper methods
/// on the VectorSensor input:
/// - <see cref="VectorSensor.AddObservation(int)"/>
/// - <see cref="VectorSensor.AddObservation(float)"/>
/// - <see cref="VectorSensor.AddObservation(Vector3)"/>
/// - <see cref="VectorSensor.AddObservation(Vector2)"/>
/// - <see cref="VectorSensor.AddObservation(Quaternion)"/>
/// - <see cref="VectorSensor.AddObservation(bool)"/>
/// - <see cref="VectorSensor.AddObservation(IEnumerable{float})"/>
/// - <see cref="VectorSensor.AddOneHotObservation(int, int)"/>
/// Depending on your environment, any combination of these helpers can
/// be used. They just need to be used in the exact same order each time
/// this method is called and the resulting size of the vector observation
/// needs to match the vectorObservationSize attribute of the linked Brain.
/// Visual observations are implicitly added from the cameras attached to
/// the Agent.
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it. You can call the following method on the ActionMasker
/// input :
/// - <see cref="ActionMasker.SetActionMask(int)"/>
/// - <see cref="ActionMasker.SetActionMask(int, int)"/>
/// - <see cref="ActionMasker.SetActionMask(int, IEnumerable{int})"/>
/// - <see cref="ActionMasker.SetActionMask(IEnumerable{int})"/>
/// The branch input is the index of the action, actionIndices are the indices of the
/// invalid options for that action.
/// action by masking it with <see cref="DiscreteActionMasker.SetMask(int, IEnumerable{int})"/>
/// </remarks>
public virtual void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
{
CollectObservations(sensor);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace MLAgents
/// left side of the board). This class represents the set of masked actions and provides
/// the utilities for setting and retrieving them.
/// </summary>
public class ActionMasker
public class DiscreteActionMasker
{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.
Expand All @@ -21,47 +21,11 @@ public class ActionMasker

readonly BrainParameters m_BrainParameters;

internal ActionMasker(BrainParameters brainParameters)
internal DiscreteActionMasker(BrainParameters brainParameters)
{
m_BrainParameters = brainParameters;
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the actions passed as argument at the next decision.
/// The actionIndices correspond to the actions the agent will be unable to perform
/// on the branch 0.
/// </summary>
/// <param name="actionIndices">The indices of the masked actions on branch 0.</param>
public void SetActionMask(IEnumerable<int> actionIndices)
{
SetActionMask(0, actionIndices);
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision for the specified
/// action branch. The actionIndex correspond to the action the agent will be unable
/// to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked.</param>
/// <param name="actionIndex">The index of the masked action.</param>
public void SetActionMask(int branch, int actionIndex)
{
SetActionMask(branch, new[] { actionIndex });
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. The actionIndex
/// correspond to the action the agent will be unable to perform on the branch 0.
/// </summary>
/// <param name="actionIndex">The index of the masked action on branch 0</param>
public void SetActionMask(int actionIndex)
{
SetActionMask(0, new[] { actionIndex });
}

/// <summary>
/// Modifies an action mask for discrete control agents. When used, the agent will not be
/// able to perform the actions passed as argument at the next decision for the specified
Expand All @@ -70,7 +34,7 @@ public void SetActionMask(int actionIndex)
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndices">The indices of the masked actions</param>
public void SetActionMask(int branch, IEnumerable<int> actionIndices)
public void SetMask(int branch, IEnumerable<int> actionIndices)
{
// If the branch does not exist, raise an error
if (branch >= m_BrainParameters.vectorActionSize.Length)
Expand Down
40 changes: 20 additions & 20 deletions com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public class EditModeTestActionMasker
public void Contruction()
{
var bp = new BrainParameters();
var masker = new ActionMasker(bp);
var masker = new DiscreteActionMasker(bp);
Assert.IsNotNull(masker);
}

Expand All @@ -18,8 +18,8 @@ public void FailsWithContinuous()
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Continuous;
bp.vectorActionSize = new[] {4};
var masker = new ActionMasker(bp);
masker.SetActionMask(0, new[] {0});
var masker = new DiscreteActionMasker(bp);
masker.SetMask(0, new[] {0});
Assert.Catch<UnityAgentsException>(() => masker.GetMask());
}

Expand All @@ -28,7 +28,7 @@ public void NullMask()
{
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Discrete;
var masker = new ActionMasker(bp);
var masker = new DiscreteActionMasker(bp);
var mask = masker.GetMask();
Assert.IsNull(mask);
}
Expand All @@ -39,10 +39,10 @@ public void FirstBranchMask()
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Discrete;
bp.vectorActionSize = new[] {4, 5, 6};
var masker = new ActionMasker(bp);
var masker = new DiscreteActionMasker(bp);
var mask = masker.GetMask();
Assert.IsNull(mask);
masker.SetActionMask(0, new[] {1, 2, 3});
masker.SetMask(0, new[] {1, 2, 3});
mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsTrue(mask[1]);
Expand All @@ -60,8 +60,8 @@ public void SecondBranchMask()
vectorActionSpaceType = SpaceType.Discrete,
vectorActionSize = new[] { 4, 5, 6 }
};
var masker = new ActionMasker(bp);
masker.SetActionMask(1, new[] {1, 2, 3});
var masker = new DiscreteActionMasker(bp);
masker.SetMask(1, new[] {1, 2, 3});
var mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsFalse(mask[4]);
Expand All @@ -80,8 +80,8 @@ public void MaskReset()
vectorActionSpaceType = SpaceType.Discrete,
vectorActionSize = new[] { 4, 5, 6 }
};
var masker = new ActionMasker(bp);
masker.SetActionMask(1, new[] {1, 2, 3});
var masker = new DiscreteActionMasker(bp);
masker.SetMask(1, new[] {1, 2, 3});
masker.ResetMask();
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
Expand All @@ -98,18 +98,18 @@ public void ThrowsError()
vectorActionSpaceType = SpaceType.Discrete,
vectorActionSize = new[] { 4, 5, 6 }
};
var masker = new ActionMasker(bp);
var masker = new DiscreteActionMasker(bp);

Assert.Catch<UnityAgentsException>(
() => masker.SetActionMask(0, new[] {5}));
() => masker.SetMask(0, new[] {5}));
Assert.Catch<UnityAgentsException>(
() => masker.SetActionMask(1, new[] {5}));
masker.SetActionMask(2, new[] {5});
() => masker.SetMask(1, new[] {5}));
masker.SetMask(2, new[] {5});
Assert.Catch<UnityAgentsException>(
() => masker.SetActionMask(3, new[] {1}));
() => masker.SetMask(3, new[] {1}));
masker.GetMask();
masker.ResetMask();
masker.SetActionMask(0, new[] {0, 1, 2, 3});
masker.SetMask(0, new[] {0, 1, 2, 3});
Assert.Catch<UnityAgentsException>(
() => masker.GetMask());
}
Expand All @@ -120,10 +120,10 @@ public void MultipleMaskEdit()
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Discrete;
bp.vectorActionSize = new[] {4, 5, 6};
var masker = new ActionMasker(bp);
masker.SetActionMask(0, new[] {0, 1});
masker.SetActionMask(0, new[] {3});
masker.SetActionMask(2, new[] {1});
var masker = new DiscreteActionMasker(bp);
masker.SetMask(0, new[] {0, 1});
masker.SetMask(0, new[] {3});
masker.SetMask(2, new[] {1});
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{
Expand Down
14 changes: 7 additions & 7 deletions docs/Learning-Environment-Design-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,20 +390,20 @@ impossible for the next decision. When the Agent is controlled by a
neural network, the Agent will be unable to perform the specified action. Note
that when the Agent is controlled by its Heuristic, the Agent will
still be able to decide to perform the masked action. In order to mask an
action, call the method `SetActionMask` on the optional `ActionMasker` argument of the `CollectObservation` method :
action, override the `Agent.CollectDiscreteActionMasks()` virtual method, and call `DiscreteActionMasker.SetMask()` in it:

```csharp
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker){
actionMasker.SetActionMask(branch, actionIndices)
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker){
actionMasker.SetMask(branch, actionIndices)
}
```

Where:

* `branch` is the index (starting at 0) of the branch on which you want to mask
the action
* `actionIndices` is a list of `int` or a single `int` corresponding to the
index of the action that the Agent cannot perform.
* `actionIndices` is a list of `int` corresponding to the
indices of the actions that the Agent cannot perform.

For example, if you have an Agent with 2 branches and on the first branch
(branch 0) there are 4 possible actions : _"do nothing"_, _"jump"_, _"shoot"_
Expand All @@ -412,12 +412,12 @@ nothing"_ or _"change weapon"_ for his next decision (since action index 1 and 2
are masked)

```csharp
SetActionMask(0, new int[2]{1,2})
SetMask(0, new int[2]{1,2})
```

Notes:

* You can call `SetActionMask` multiple times if you want to put masks on
* You can call `SetMask` multiple times if you want to put masks on
multiple branches.
* You cannot mask all the actions of a branch.
* You cannot mask actions in continuous control.
Expand Down
Loading