Skip to content

Renaming Agent's methods #3557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7bd6e4e
[skip ci] Renamed methods in the Agent class
vincentpierre Mar 4, 2020
24dac1b
[skip ci] Updated the example environment
vincentpierre Mar 4, 2020
d17789d
[skip ci] Updated migrating and changelog
vincentpierre Mar 4, 2020
dbebc42
[skip ci] Editing the docs
vincentpierre Mar 4, 2020
29409b3
[skip ci] Missing docs
vincentpierre Mar 4, 2020
f0d9a00
:+1
vincentpierre Mar 4, 2020
675c38f
Update docs/Getting-Started-with-Balance-Ball.md
vincentpierre Mar 4, 2020
0da6957
Update docs/Learning-Environment-Create-New.md
vincentpierre Mar 4, 2020
c29ac08
Update docs/Learning-Environment-Create-New.md
vincentpierre Mar 4, 2020
59f8e19
[skip ci] documentation changes
vincentpierre Mar 4, 2020
07464ca
[skip ci] Update docs/Getting-Started-with-Balance-Ball.md
vincentpierre Mar 4, 2020
dcf8b11
[skip ci] Update docs/Getting-Started-with-Balance-Ball.md
vincentpierre Mar 4, 2020
c66535c
[skip ci] Update docs/Getting-Started-with-Balance-Ball.md
vincentpierre Mar 4, 2020
4b9f428
[skip ci] Update docs/Getting-Started-with-Balance-Ball.md
vincentpierre Mar 4, 2020
de5c8c0
Addressing comments
vincentpierre Mar 4, 2020
a45a4ee
Merge branch 'develop-renaming-001' of https://github.com/Unity-Techn…
vincentpierre Mar 4, 2020
d085ad6
Merge branch 'master' into develop-renaming-001
vincentpierre Mar 5, 2020
b7bcb86
[skip ci] Update com.unity.ml-agents/CHANGELOG.md
vincentpierre Mar 6, 2020
7eaefce
[skip ci]Update docs/Getting-Started-with-Balance-Ball.md
vincentpierre Mar 6, 2020
e6bf6dd
[skip ci]Update docs/Learning-Environment-Create-New.md
vincentpierre Mar 6, 2020
f60ea8a
[skip ci]Addressing the comments
vincentpierre Mar 6, 2020
0cf8620
[skip ci] Update docs/Migrating.md
vincentpierre Mar 9, 2020
ea0f0e1
[skip ci] Update docs/Migrating.md
vincentpierre Mar 9, 2020
92eb723
resolving conflicts
vincentpierre Mar 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public class Ball3DAgent : Agent
Rigidbody m_BallRb;
FloatPropertiesChannel m_ResetParams;

public override void InitializeAgent()
public override void Initialize()
{
m_BallRb = ball.GetComponent<Rigidbody>();
m_ResetParams = Academy.Instance.FloatProperties;
Expand All @@ -25,7 +25,7 @@ public override void CollectObservations(VectorSensor sensor)
sensor.AddObservation(m_BallRb.velocity);
}

public override void AgentAction(float[] vectorAction)
public override void OnActionReceived(float[] vectorAction)
{
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
Expand All @@ -46,15 +46,15 @@ public override void AgentAction(float[] vectorAction)
Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
{
SetReward(-1f);
Done();
EndEpisode();
}
else
{
SetReward(0.1f);
}
}

public override void AgentReset()
public override void OnEpisodeBegin()
{
gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public class Ball3DHardAgent : Agent
Rigidbody m_BallRb;
FloatPropertiesChannel m_ResetParams;

public override void InitializeAgent()
public override void Initialize()
{
m_BallRb = ball.GetComponent<Rigidbody>();
m_ResetParams = Academy.Instance.FloatProperties;
Expand All @@ -24,7 +24,7 @@ public override void CollectObservations(VectorSensor sensor)
sensor.AddObservation((ball.transform.position - gameObject.transform.position));
}

public override void AgentAction(float[] vectorAction)
public override void OnActionReceived(float[] vectorAction)
{
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
Expand All @@ -45,15 +45,15 @@ public override void AgentAction(float[] vectorAction)
Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
{
SetReward(-1f);
Done();
EndEpisode();
}
else
{
SetReward(0.1f);
}
}

public override void AgentReset()
public override void OnEpisodeBegin()
{
gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ public void ApplyAction(float[] vectorAction)
if (m_Position == k_SmallGoalPosition)
{
m_Agent.AddReward(0.1f);
m_Agent.Done();
m_Agent.EndEpisode();
ResetAgent();
}

if (m_Position == k_LargeGoalPosition)
{
m_Agent.AddReward(1f);
m_Agent.Done();
m_Agent.EndEpisode();
ResetAgent();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class BouncerAgent : Agent

FloatPropertiesChannel m_ResetParams;

public override void InitializeAgent()
public override void Initialize()
{
m_Rb = gameObject.GetComponent<Rigidbody>();
m_LookDir = Vector3.zero;
Expand All @@ -33,7 +33,7 @@ public override void CollectObservations(VectorSensor sensor)
sensor.AddObservation(target.transform.localPosition);
}

public override void AgentAction(float[] vectorAction)
public override void OnActionReceived(float[] vectorAction)
{
for (var i = 0; i < vectorAction.Length; i++)
{
Expand All @@ -52,7 +52,7 @@ public override void AgentAction(float[] vectorAction)
m_LookDir = new Vector3(x, y, z);
}

public override void AgentReset()
public override void OnEpisodeBegin()
{
gameObject.transform.localPosition = new Vector3(
(1 - 2 * Random.value) * 5, 2, (1 - 2 * Random.value) * 5);
Expand Down Expand Up @@ -85,20 +85,20 @@ void FixedUpdate()
if (gameObject.transform.position.y < -1)
{
AddReward(-1);
Done();
EndEpisode();
return;
}

if (gameObject.transform.localPosition.x < -19 || gameObject.transform.localPosition.x > 19
|| gameObject.transform.localPosition.z < -19 || gameObject.transform.localPosition.z > 19)
{
AddReward(-1);
Done();
EndEpisode();
return;
}
if (m_JumpLeft == 0)
{
Done();
EndEpisode();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class CrawlerAgent : Agent
Quaternion m_LookRotation;
Matrix4x4 m_TargetDirMatrix;

public override void InitializeAgent()
public override void Initialize()
{
m_JdController = GetComponent<JointDriveController>();
m_DirToTarget = target.position - body.position;
Expand Down Expand Up @@ -147,7 +147,7 @@ public void GetRandomTargetPos()
target.position = newTargetPos + ground.position;
}

public override void AgentAction(float[] vectorAction)
public override void OnActionReceived(float[] vectorAction)
{
// The dictionary with all the body parts in it are in the jdController
var bpDict = m_JdController.bodyPartsDict;
Expand Down Expand Up @@ -251,7 +251,7 @@ void RewardFunctionTimePenalty()
/// <summary>
/// Loop over body parts and reset them to initial conditions.
/// </summary>
public override void AgentReset()
public override void OnEpisodeBegin()
{
if (m_DirToTarget != Vector3.zero)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ public class FoodCollectorAgent : Agent
public bool useVectorObs;


public override void InitializeAgent()
public override void Initialize()
{
base.InitializeAgent();
m_AgentRb = GetComponent<Rigidbody>();
m_MyArea = area.GetComponent<FoodCollectorArea>();
m_FoodCollecterSettings = FindObjectOfType<FoodCollectorSettings>();
Expand Down Expand Up @@ -202,7 +201,7 @@ void Unsatiate()
gameObject.GetComponentInChildren<Renderer>().material = normalMaterial;
}

public override void AgentAction(float[] vectorAction)
public override void OnActionReceived(float[] vectorAction)
{
MoveAgent(vectorAction);
}
Expand Down Expand Up @@ -230,7 +229,7 @@ public override float[] Heuristic()
return action;
}

public override void AgentReset()
public override void OnEpisodeBegin()
{
Unfreeze();
Unpoison();
Expand Down
12 changes: 4 additions & 8 deletions Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ public class GridAgent : Agent
const int k_Left = 3;
const int k_Right = 4;

public override void InitializeAgent()
{
}

public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
{
// Mask the necessary actions if selected by the user.
Expand Down Expand Up @@ -65,7 +61,7 @@ public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMaske
}

// to be implemented by the developer
public override void AgentAction(float[] vectorAction)
public override void OnActionReceived(float[] vectorAction)
{
AddReward(-0.01f);
var action = Mathf.FloorToInt(vectorAction[0]);
Expand Down Expand Up @@ -101,12 +97,12 @@ public override void AgentAction(float[] vectorAction)
if (hit.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1)
{
SetReward(1f);
Done();
EndEpisode();
}
else if (hit.Where(col => col.gameObject.CompareTag("pit")).ToArray().Length == 1)
{
SetReward(-1f);
Done();
EndEpisode();
}
}
}
Expand All @@ -133,7 +129,7 @@ public override float[] Heuristic()
}

// to be implemented by the developer
public override void AgentReset()
public override void OnEpisodeBegin()
{
area.AreaReset();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ public class HallwayAgent : Agent
HallwaySettings m_HallwaySettings;
int m_Selection;

public override void InitializeAgent()
public override void Initialize()
{
base.InitializeAgent();
m_HallwaySettings = FindObjectOfType<HallwaySettings>();
m_AgentRb = GetComponent<Rigidbody>();
m_GroundRenderer = ground.GetComponent<Renderer>();
Expand Down Expand Up @@ -67,7 +66,7 @@ public void MoveAgent(float[] act)
m_AgentRb.AddForce(dirToGo * m_HallwaySettings.agentRunSpeed, ForceMode.VelocityChange);
}

public override void AgentAction(float[] vectorAction)
public override void OnActionReceived(float[] vectorAction)
{
AddReward(-1f / maxStep);
MoveAgent(vectorAction);
Expand All @@ -88,7 +87,7 @@ void OnCollisionEnter(Collision col)
SetReward(-0.1f);
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.failMaterial, 0.5f));
}
Done();
EndEpisode();
}
}

Expand All @@ -113,7 +112,7 @@ public override float[] Heuristic()
return new float[] { 0 };
}

public override void AgentReset()
public override void OnEpisodeBegin()
{
var agentOffset = -15f;
var blockOffset = 0f;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ void Awake()
m_PushBlockSettings = FindObjectOfType<PushBlockSettings>();
}

public override void InitializeAgent()
public override void Initialize()
{
base.InitializeAgent();
goalDetect = block.GetComponent<GoalDetect>();
goalDetect.agent = this;

Expand Down Expand Up @@ -105,7 +104,7 @@ public void ScoredAGoal()
AddReward(5f);

// By marking an agent as done AgentReset() will be called automatically.
Done();
EndEpisode();

// Swap ground material for a bit to indicate we scored.
StartCoroutine(GoalScoredSwapGroundMaterial(m_PushBlockSettings.goalScoredMaterial, 0.5f));
Expand Down Expand Up @@ -161,7 +160,7 @@ public void MoveAgent(float[] act)
/// <summary>
/// Called every step of the engine. Here the agent takes an action.
/// </summary>
public override void AgentAction(float[] vectorAction)
public override void OnActionReceived(float[] vectorAction)
{
// Move the agent using the action.
MoveAgent(vectorAction);
Expand Down Expand Up @@ -210,7 +209,7 @@ void ResetBlock()
/// In the editor, if "Reset On Done" is checked then AgentReset() will be
/// called automatically anytime we mark done = true in an agent script.
/// </summary>
public override void AgentReset()
public override void OnEpisodeBegin()
{
var rotation = Random.Range(0, 4);
var rotationAngle = rotation * 90f;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ public class PyramidAgent : Agent
public GameObject areaSwitch;
public bool useVectorObs;

public override void InitializeAgent()
public override void Initialize()
{
base.InitializeAgent();
m_AgentRb = GetComponent<Rigidbody>();
m_MyArea = area.GetComponent<PyramidArea>();
m_SwitchLogic = areaSwitch.GetComponent<PyramidSwitch>();
Expand Down Expand Up @@ -56,7 +55,7 @@ public void MoveAgent(float[] act)
m_AgentRb.AddForce(dirToGo * 2f, ForceMode.VelocityChange);
}

public override void AgentAction(float[] vectorAction)
public override void OnActionReceived(float[] vectorAction)
{
AddReward(-1f / maxStep);
MoveAgent(vectorAction);
Expand All @@ -83,7 +82,7 @@ public override float[] Heuristic()
return new float[] { 0 };
}

public override void AgentReset()
public override void OnEpisodeBegin()
{
var enumerable = Enumerable.Range(0, 9).OrderBy(x => Guid.NewGuid()).Take(9);
var items = enumerable.ToArray();
Expand All @@ -108,7 +107,7 @@ void OnCollisionEnter(Collision collision)
if (collision.gameObject.CompareTag("goal"))
{
SetReward(2f);
Done();
EndEpisode();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class ReacherAgent : Agent
/// Collect the rigidbodies of the reacher in order to resue them for
/// observations and actions.
/// </summary>
public override void InitializeAgent()
public override void Initialize()
{
m_RbA = pendulumA.GetComponent<Rigidbody>();
m_RbB = pendulumB.GetComponent<Rigidbody>();
Expand Down Expand Up @@ -57,7 +57,7 @@ public override void CollectObservations(VectorSensor sensor)
/// <summary>
/// The agent's four actions correspond to torques on each of the two joints.
/// </summary>
public override void AgentAction(float[] vectorAction)
public override void OnActionReceived(float[] vectorAction)
{
m_GoalDegree += m_GoalSpeed;
UpdateGoalPosition();
Expand Down Expand Up @@ -86,7 +86,7 @@ void UpdateGoalPosition()
/// <summary>
/// Resets the position and velocity of the agent and the goal.
/// </summary>
public override void AgentReset()
public override void OnEpisodeBegin()
{
pendulumA.transform.position = new Vector3(0f, -4f, 0f) + transform.position;
pendulumA.transform.rotation = Quaternion.Euler(180f, 0f, 0f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void OnCollisionEnter(Collision col)

if (agentDoneOnGroundContact)
{
agent.Done();
agent.EndEpisode();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void OverrideModel()
var nnModel = GetModelForBehaviorName(name);
Debug.Log($"Overriding behavior {name} for agent with model {nnModel?.name}");
// This might give a null model; that's better because we'll fall back to the Heuristic
m_Agent.GiveModel($"Override_{name}", nnModel);
m_Agent.SetModel($"Override_{name}", nnModel);

}
}
Expand Down
Loading