Skip to content

Pass action masker as input to CollectObservations #3389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@ public override void InitializeAgent()
{
}

public override void CollectObservations(VectorSensor sensor)
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
{
// There are no numeric observations to collect as this environment uses visual
// observations.

// Mask the necessary actions if selected by the user.
if (maskActions)
{
SetMask();
SetMask(actionMasker);
}
}

/// <summary>
/// Applies the mask for the agents action to disallow unnecessary actions.
/// </summary>
void SetMask()
void SetMask(ActionMasker actionMasker)
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
Expand All @@ -55,22 +55,22 @@ void SetMask()

if (positionX == 0)
{
SetActionMask(k_Left);
actionMasker.SetActionMask(k_Left);
}

if (positionX == maxPosition)
{
SetActionMask(k_Right);
actionMasker.SetActionMask(k_Right);
}

if (positionZ == 0)
{
SetActionMask(k_Down);
actionMasker.SetActionMask(k_Down);
}

if (positionZ == maxPosition)
{
SetActionMask(k_Up);
actionMasker.SetActionMask(k_Up);
}
}

Expand Down
48 changes: 42 additions & 6 deletions com.unity.ml-agents/Runtime/ActionMasker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace MLAgents
{
internal class ActionMasker
public class ActionMasker
{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.
Expand All @@ -19,11 +19,47 @@ internal ActionMasker(BrainParameters brainParameters)
m_BrainParameters = brainParameters;
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the actions passed as argument at the next decision.
/// The actionIndices correspond to the actions the agent will be unable to perform
/// on the branch 0.
/// </summary>
/// <param name="actionIndices">The indices of the masked actions on branch 0</param>
public void SetActionMask(IEnumerable<int> actionIndices)
{
SetActionMask(0, actionIndices);
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision for the specified
/// action branch. The actionIndex correspond to the action the agent will be unable
/// to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndex">The index of the masked action</param>
public void SetActionMask(int branch, int actionIndex)
{
SetActionMask(branch, new[] { actionIndex });
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. The actionIndex
/// correspond to the action the agent will be unable to perform on the branch 0.
/// </summary>
/// <param name="actionIndex">The index of the masked action on branch 0</param>
public void SetActionMask(int actionIndex)
{
SetActionMask(0, new[] { actionIndex });
}

/// <summary>
/// Modifies an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// able to perform the actions passed as argument at the next decision for the specified
/// action branch. The actionIndices correspond to the action options the agent will
/// be unable to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndices">The indices of the masked actions</param>
Expand Down Expand Up @@ -67,7 +103,7 @@ public void SetActionMask(int branch, IEnumerable<int> actionIndices)
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
public bool[] GetMask()
internal bool[] GetMask()
{
if (m_CurrentMask != null)
{
Expand Down Expand Up @@ -103,7 +139,7 @@ void AssertMask()
/// <summary>
/// Resets the current mask for an agent
/// </summary>
public void ResetMask()
internal void ResetMask()
{
if (m_CurrentMask != null)
{
Expand Down
89 changes: 37 additions & 52 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ void SendInfoToBrain()
UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{
CollectObservations(collectObservationsSensor);
CollectObservations(collectObservationsSensor, m_ActionMasker);
}
m_Info.actionMasks = m_ActionMasker.GetMask();

Expand Down Expand Up @@ -523,12 +523,6 @@ void UpdateSensors()
/// - <see cref="AddObservation(float)"/>
/// - <see cref="AddObservation(Vector3)"/>
/// - <see cref="AddObservation(Vector2)"/>
/// - <see>
/// <cref>AddVectorObs(float[])</cref>
/// </see>
/// - <see>
/// <cref>AddVectorObs(List{float})</cref>
/// </see>
/// - <see cref="AddObservation(Quaternion)"/>
/// - <see cref="AddObservation(bool)"/>
/// - <see cref="AddOneHotObservation(int, int)"/>
Expand All @@ -544,53 +538,44 @@ public virtual void CollectObservations(VectorSensor sensor)
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// </summary>
/// <param name="actionIndices">The indices of the masked actions on branch 0</param>
protected void SetActionMask(IEnumerable<int> actionIndices)
{
m_ActionMasker.SetActionMask(0, actionIndices);
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// </summary>
/// <param name="actionIndex">The index of the masked action on branch 0</param>
protected void SetActionMask(int actionIndex)
{
m_ActionMasker.SetActionMask(0, new[] { actionIndex });
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndex">The index of the masked action</param>
protected void SetActionMask(int branch, int actionIndex)
{
m_ActionMasker.SetActionMask(branch, new[] { actionIndex });
}

/// <summary>
/// Modifies an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// Collects the vector observations of the agent.
/// The agent observation describes the current environment from the
/// perspective of the agent.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndices">The indices of the masked actions</param>
protected void SetActionMask(int branch, IEnumerable<int> actionIndices)
/// <remarks>
/// An agents observation is any environment information that helps
/// the Agent achieve its goal. For example, for a fighting Agent, its
/// observation could include distances to friends or enemies, or the
/// current level of ammunition at its disposal.
/// Recall that an Agent may attach vector or visual observations.
/// Vector observations are added by calling the provided helper methods
/// on the VectorSensor input:
/// - <see cref="AddObservation(int)"/>
/// - <see cref="AddObservation(float)"/>
/// - <see cref="AddObservation(Vector3)"/>
/// - <see cref="AddObservation(Vector2)"/>
/// - <see cref="AddObservation(Quaternion)"/>
/// - <see cref="AddObservation(bool)"/>
/// - <see cref="AddOneHotObservation(int, int)"/>
/// Depending on your environment, any combination of these helpers can
/// be used. They just need to be used in the exact same order each time
/// this method is called and the resulting size of the vector observation
/// needs to match the vectorObservationSize attribute of the linked Brain.
/// Visual observations are implicitly added from the cameras attached to
/// the Agent.
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it. You can call the following method on the ActionMasker
/// input :
/// - <see cref="SetActionMask(int branch, IEnumerable<int> actionIndices)"/>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Links to other overloads?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

/// - <see cref="SetActionMask(int branch, int actionIndex)"/>
/// - <see cref="SetActionMask(IEnumerable<int> actionIndices)"/>
/// - <see cref="SetActionMask(int branch, int actionIndex)"/>
/// The branch input is the index of the action, actionIndices are the indices of the
/// invalid options for that action.
/// </remarks>
public virtual void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
{
m_ActionMasker.SetActionMask(branch, actionIndices);
CollectObservations(sensor);
}

/// <summary>
Expand Down
6 changes: 4 additions & 2 deletions docs/Learning-Environment-Design-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,12 @@ impossible for the next decision. When the Agent is controlled by a
neural network, the Agent will be unable to perform the specified action. Note
that when the Agent is controlled by its Heuristic, the Agent will
still be able to decide to perform the masked action. In order to mask an
action, call the method `SetActionMask` within the `CollectObservation` method :
action, call the method `SetActionMask` on the optional `ActionMasker` argument of the `CollectObservation` method :

```csharp
SetActionMask(branch, actionIndices)
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker){
actionMasker.SetActionMask(branch, actionIndices)
}
```

Where:
Expand Down
3 changes: 2 additions & 1 deletion docs/Migrating.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ The versions can be found in
* The `Agent.CollectObservations()` virtual method now takes as input a `VectorSensor` sensor as argument. The `Agent.AddVectorObs()` methods were removed.
* The `Monitor` class has been moved to the Examples Project. (It was prone to errors during testing)
* The `MLAgents.Sensor` namespace has been removed. All sensors now belong to the `MLAgents` namespace.
* The `SetActionMask` method must now be called on the optional `ActionMasker` argument of the `CollectObservations` method. (We now consider an action mask as a type of observation)


### Steps to Migrate
* Replace your Agent's implementation of `CollectObservations()` with `CollectObservations(VectorSensor sensor)`. In addition, replace all calls to `AddVectorObs()` with `sensor.AddObservation()` or `sensor.AddOneHotObservation()` on the `VectorSensor` passed as argument.

* Replace your calls to `SetActionMask` on your Agent to `ActionMasker.SetActionMask` in `CollectObservations`


## Migrating from 0.13 to 0.14
Expand Down