Skip to content

Change Agent.Heuristic to take a float[] argument #3765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,10 @@ public override void OnEpisodeBegin()
SetResetParameters();
}

public override float[] Heuristic()
public override void Heuristic(float[] actionsOut)
{
var action = new float[2];

action[0] = -Input.GetAxis("Horizontal");
action[1] = Input.GetAxis("Vertical");
return action;
actionsOut[0] = -Input.GetAxis("Horizontal");
actionsOut[1] = Input.GetAxis("Vertical");
}

public void SetBall()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,11 @@ void FixedUpdate()
}
}

public override float[] Heuristic()
public override void Heuristic(float[] actionsOut)
{
var action = new float[3];

action[0] = Input.GetAxis("Horizontal");
action[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
action[2] = Input.GetAxis("Vertical");
return action;
actionsOut[0] = Input.GetAxis("Horizontal");
actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
actionsOut[2] = Input.GetAxis("Vertical");
}

void Update()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,27 +207,25 @@ public override void OnActionReceived(float[] vectorAction)
MoveAgent(vectorAction);
}

public override float[] Heuristic()
public override void Heuristic(float[] actionsOut)
{
var action = new float[4];
if (Input.GetKey(KeyCode.D))
{
action[2] = 2f;
actionsOut[2] = 2f;
}
if (Input.GetKey(KeyCode.W))
{
action[0] = 1f;
actionsOut[0] = 1f;
}
if (Input.GetKey(KeyCode.A))
{
action[2] = 1f;
actionsOut[2] = 1f;
}
if (Input.GetKey(KeyCode.S))
{
action[0] = 2f;
actionsOut[0] = 2f;
}
action[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
return action;
actionsOut[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
}

public override void OnEpisodeBegin()
Expand Down
12 changes: 6 additions & 6 deletions Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,25 @@ public override void OnActionReceived(float[] vectorAction)
}
}

public override float[] Heuristic()
public override void Heuristic(float[] actionsOut)
{
actionsOut[0] = k_NoAction;
if (Input.GetKey(KeyCode.D))
{
return new float[] { k_Right };
actionsOut[0] = k_Right;
}
if (Input.GetKey(KeyCode.W))
{
return new float[] { k_Up };
actionsOut[0] = k_Up;
}
if (Input.GetKey(KeyCode.A))
{
return new float[] { k_Left };
actionsOut[0] = k_Left;
}
if (Input.GetKey(KeyCode.S))
{
return new float[] { k_Down };
actionsOut[0] = k_Down;
}
return new float[] { k_NoAction };
}

// to be implemented by the developer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,25 +91,25 @@ void OnCollisionEnter(Collision col)
}
}

public override float[] Heuristic()
public override void Heuristic(float[] actionsOut)
{
actionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
return new float[] { 3 };
actionsOut[0] = 3;
}
if (Input.GetKey(KeyCode.W))
else if (Input.GetKey(KeyCode.W))
{
return new float[] { 1 };
actionsOut[0] = 1;
}
if (Input.GetKey(KeyCode.A))
else if (Input.GetKey(KeyCode.A))
{
return new float[] { 4 };
actionsOut[0] = 4;
}
if (Input.GetKey(KeyCode.S))
else if (Input.GetKey(KeyCode.S))
{
return new float[] { 2 };
actionsOut[0] = 2;
}
return new float[] { 0 };
}

public override void OnEpisodeBegin()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,25 +170,25 @@ public override void OnActionReceived(float[] vectorAction)
AddReward(-1f / maxStep);
}

public override float[] Heuristic()
public override void Heuristic(float[] actionsOut)
{
actionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
return new float[] { 3 };
actionsOut[0] = 3;
}
if (Input.GetKey(KeyCode.W))
else if (Input.GetKey(KeyCode.W))
{
return new float[] { 1 };
actionsOut[0] = 1;
}
if (Input.GetKey(KeyCode.A))
else if (Input.GetKey(KeyCode.A))
{
return new float[] { 4 };
actionsOut[0] = 4;
}
if (Input.GetKey(KeyCode.S))
else if (Input.GetKey(KeyCode.S))
{
return new float[] { 2 };
actionsOut[0] = 2;
}
return new float[] { 0 };
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,25 @@ public override void OnActionReceived(float[] vectorAction)
MoveAgent(vectorAction);
}

public override float[] Heuristic()
public override void Heuristic(float[] actionsOut)
{
actionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
return new float[] { 3 };
actionsOut[0] = 3;
}
if (Input.GetKey(KeyCode.W))
else if (Input.GetKey(KeyCode.W))
{
return new float[] { 1 };
actionsOut[0] = 1;
}
if (Input.GetKey(KeyCode.A))
else if (Input.GetKey(KeyCode.A))
{
return new float[] { 4 };
actionsOut[0] = 4;
}
if (Input.GetKey(KeyCode.S))
else if (Input.GetKey(KeyCode.S))
{
return new float[] { 2 };
actionsOut[0] = 2;
}
return new float[] { 0 };
}

public override void OnEpisodeBegin()
Expand Down
16 changes: 7 additions & 9 deletions Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,37 +112,35 @@ public override void OnActionReceived(float[] vectorAction)
MoveAgent(vectorAction);
}

public override float[] Heuristic()
public override void Heuristic(float[] actionsOut)
{
var action = new float[3];
//forward
if (Input.GetKey(KeyCode.W))
{
action[0] = 1f;
actionsOut[0] = 1f;
}
if (Input.GetKey(KeyCode.S))
{
action[0] = 2f;
actionsOut[0] = 2f;
}
//rotate
if (Input.GetKey(KeyCode.A))
{
action[2] = 1f;
actionsOut[2] = 1f;
}
if (Input.GetKey(KeyCode.D))
{
action[2] = 2f;
actionsOut[2] = 2f;
}
//right
if (Input.GetKey(KeyCode.E))
{
action[1] = 1f;
actionsOut[1] = 1f;
}
if (Input.GetKey(KeyCode.Q))
{
action[1] = 2f;
actionsOut[1] = 2f;
}
return action;
}
/// <summary>
/// Used to provide a "kick" to the ball.
Expand Down
11 changes: 4 additions & 7 deletions Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,11 @@ public override void OnActionReceived(float[] vectorAction)
m_TextComponent.text = score.ToString();
}

public override float[] Heuristic()
public override void Heuristic(float[] actionsOut)
{
var action = new float[3];

action[0] = Input.GetAxis("Horizontal"); // Racket Movement
action[1] = Input.GetKey(KeyCode.Space) ? 1f : 0f; // Racket Jumping
action[2] = Input.GetAxis("Vertical"); // Racket Rotation
return action;
actionsOut[0] = Input.GetAxis("Horizontal"); // Racket Movement
actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1f : 0f; // Racket Jumping
actionsOut[2] = Input.GetAxis("Vertical"); // Racket Rotation
}

public override void OnEpisodeBegin()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,27 +241,25 @@ public override void OnActionReceived(float[] vectorAction)
}
}

public override float[] Heuristic()
public override void Heuristic(float[] actionsOut)
{
var action = new float[4];
if (Input.GetKey(KeyCode.D))
{
action[1] = 2f;
actionsOut[1] = 2f;
}
if (Input.GetKey(KeyCode.W))
{
action[0] = 1f;
actionsOut[0] = 1f;
}
if (Input.GetKey(KeyCode.A))
{
action[1] = 1f;
actionsOut[1] = 1f;
}
if (Input.GetKey(KeyCode.S))
{
action[0] = 2f;
actionsOut[0] = 2f;
}
action[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
return action;
actionsOut[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
}

// Detect when the agent hits the goal
Expand Down
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Added ability to start training (initialize model weights) from a previous run ID. (#3710)
- The internal event `Academy.AgentSetStatus` was renamed to `Academy.AgentPreStep` and made public.
- The offset logic was removed from DecisionRequester.
- The signature of `Agent.Heuristic()` was changed to take a `float[]` as a parameter, instead of returning the array. This was done to prevent a common source of error where users would return arrays of the wrong size.
- The communication API version has been bumped up to 1.0.0 and will use [Semantic Versioning](https://semver.org/) to do compatibility checks for communication between Unity and the Python process.

### Minor Changes
Expand Down
6 changes: 2 additions & 4 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -547,12 +547,10 @@ public virtual void Initialize()
/// </summary>
/// <returns> A float array corresponding to the next action of the Agent
/// </returns>
public virtual float[] Heuristic()
public virtual void Heuristic(float[] actionsOut)
{
Debug.LogWarning("Heuristic method called but not implemented. Returning placeholder actions.");
var param = m_PolicyFactory.brainParameters;

return new float[param.numActions];
Array.Clear(actionsOut, 0, actionsOut.Length);
}

/// <summary>
Expand Down
8 changes: 4 additions & 4 deletions com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ public string fullyQualifiedBehaviorName
get { return m_BehaviorName + "?team=" + TeamId; }
}

internal IPolicy GeneratePolicy(Func<float[]> heuristic)
internal IPolicy GeneratePolicy(HeuristicPolicy.ActionGenerator heuristic)
{
switch (m_BehaviorType)
{
case BehaviorType.HeuristicOnly:
return new HeuristicPolicy(heuristic);
return new HeuristicPolicy(heuristic, m_BrainParameters.numActions);
case BehaviorType.InferenceOnly:
{
if (m_Model == null)
Expand All @@ -164,10 +164,10 @@ internal IPolicy GeneratePolicy(Func<float[]> heuristic)
}
else
{
return new HeuristicPolicy(heuristic);
return new HeuristicPolicy(heuristic, m_BrainParameters.numActions);
}
default:
return new HeuristicPolicy(heuristic);
return new HeuristicPolicy(heuristic, m_BrainParameters.numActions);
}
}

Expand Down
11 changes: 8 additions & 3 deletions com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@ namespace MLAgents.Policies
/// </summary>
internal class HeuristicPolicy : IPolicy
{
Func<float[]> m_Heuristic;
public delegate void ActionGenerator(float[] actionsOut);
ActionGenerator m_Heuristic;
float[] m_LastDecision;
int m_numActions;

WriteAdapter m_WriteAdapter = new WriteAdapter();
NullList m_NullList = new NullList();


/// <inheritdoc />
public HeuristicPolicy(Func<float[]> heuristic)
public HeuristicPolicy(ActionGenerator heuristic, int numActions)
{
m_Heuristic = heuristic;
m_numActions = numActions;
}

/// <inheritdoc />
Expand All @@ -31,7 +34,9 @@ public void RequestDecision(AgentInfo info, List<ISensor> sensors)
StepSensors(sensors);
if (!info.done)
{
m_LastDecision = m_Heuristic.Invoke();
// Reset m_LastDecision each time.
m_LastDecision = new float[m_numActions];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't decide whether or not to reuse the array here. I figure this approach is safer, and this isn't performance-critical enough to worry about the allocation/GC

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it depends on the number of agents and the size of the action array. If there is a large action space and a QA person is trying to make a demo file to use in training but they are hitting GC hitches every few seconds that could make it 1. hard to actually make a good demo. 2. frustrate the hell out of that person.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline - we need to do a deeper review of array sharing and/or copying around this (and the other policies). Logged as https://jira.unity3d.com/browse/MLA-892 for followup.

m_Heuristic.Invoke(m_LastDecision);
}
}

Expand Down
Loading