Skip to content

Remove {text,custom} {action,observations} #2839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ public void TestStoreInitalize()
id = 5,
maxStepReached = true,
floatObservations = new List<float>() { 1f, 1f, 1f },
storedTextActions = "TestAction",
storedVectorActions = new[] { 0f, 1f },
textObservation = "TestAction",
};

demoStore.Record(agentInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public override void CollectObservations()
AddVectorObs(0f);
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
agentActionCalls += 1;
AddReward(0.1f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public override void CollectObservations()
AddVectorObs(m_BallRb.velocity);
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public override void CollectObservations()
AddVectorObs((ball.transform.position - gameObject.transform.position));
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public override void CollectObservations()
AddVectorObs(m_Position, 20);
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
var movement = (int)vectorAction[0];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public override void CollectObservations()
AddVectorObs(target.transform.localPosition);
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
for (var i = 0; i < vectorAction.Length; i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ public void GetRandomTargetPos()
target.position = newTargetPos + ground.position;
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
if (detectTargets)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ void Unsatiate()
gameObject.GetComponentInChildren<Renderer>().material = normalMaterial;
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
MoveAgent(vectorAction);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void SetMask()
}

// to be implemented by the developer
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
AddReward(-0.01f);
var action = Mathf.FloorToInt(vectorAction[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public void MoveAgent(float[] act)
m_AgentRb.AddForce(dirToGo * m_Academy.agentRunSpeed, ForceMode.VelocityChange);
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
AddReward(-1f / agentParameters.maxStep);
MoveAgent(vectorAction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public void MoveAgent(float[] act)
/// <summary>
/// Called every step of the engine. Here the agent takes an action.
/// </summary>
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
// Move the agent using the action.
MoveAgent(vectorAction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void MoveAgent(float[] act)
m_AgentRb.AddForce(dirToGo * 2f, ForceMode.VelocityChange);
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
AddReward(-1f / agentParameters.maxStep);
MoveAgent(vectorAction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public override void CollectObservations()
/// <summary>
/// The agent's four actions correspond to torques on each of the two joints.
/// </summary>
public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
m_GoalDegree += m_GoalSpeed;
UpdateGoalPosition();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ public void MoveAgent(float[] act)
ForceMode.VelocityChange);
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
// Existential penalty for strikers.
if (agentRole == AgentRole.Striker)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public override void CollectObservations()
{
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public override void CollectObservations()
AddVectorObs(m_BallRb.velocity.y);
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
var moveX = Mathf.Clamp(vectorAction[0], -1f, 1f) * m_InvertMult;
var moveY = Mathf.Clamp(vectorAction[1], -1f, 1f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public override void CollectObservations()
}
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
m_DirToTarget = target.position - m_JdController.bodyPartsDict[hips].rb.position;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ public void MoveAgent(float[] act)
jumpingTime -= Time.fixedDeltaTime;
}

public override void AgentAction(float[] vectorAction, string textAction)
public override void AgentAction(float[] vectorAction)
{
MoveAgent(vectorAction);
if ((!Physics.Raycast(m_AgentRb.position, Vector3.down, 20))
Expand Down
75 changes: 4 additions & 71 deletions UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,11 @@ public struct AgentInfo
// TODO struct?
public List<float> floatObservations;

/// <summary>
/// Most recent text observation.
/// </summary>
public string textObservation;

/// <summary>
/// Keeps track of the last vector action taken by the Brain.
/// </summary>
public float[] storedVectorActions;

/// <summary>
/// Keeps track of the last text action taken by the Brain.
/// </summary>
public string storedTextActions;

/// <summary>
/// For discrete control, specifies the actions that the agent cannot take. Is true if
/// the action is masked.
Expand All @@ -61,13 +51,6 @@ public struct AgentInfo
/// to separate between different agents in the environment.
/// </summary>
public int id;

/// <summary>
/// User-customizable object for sending structured output from Unity to Python in response
/// to an action in addition to a scalar reward.
/// TODO(cgoy): All references to protobuf objects should be removed.
/// </summary>
public CommunicatorObjects.CustomObservationProto customObservation;
}

/// <summary>
Expand All @@ -77,10 +60,7 @@ public struct AgentInfo
public struct AgentAction
{
public float[] vectorActions;
public string textActions;
public float value;
/// TODO(cgoy): All references to protobuf objects should be removed.
public CommunicatorObjects.CustomActionProto customAction;
}

/// <summary>
Expand Down Expand Up @@ -467,16 +447,11 @@ void ResetData()
}
}

if (m_Info.textObservation == null)
m_Info.textObservation = "";
m_Action.textActions = "";

m_Info.compressedObservations = new List<CompressedObservation>();
m_Info.floatObservations = new List<float>();
m_Info.floatObservations.AddRange(
new float[param.vectorObservationSize
* param.numStackedVectorObservations]);
m_Info.customObservation = null;
}

/// <summary>
Expand Down Expand Up @@ -561,7 +536,6 @@ void SendInfoToBrain()
}

m_Info.storedVectorActions = m_Action.vectorActions;
m_Info.storedTextActions = m_Action.textActions;
m_Info.compressedObservations.Clear();
m_ActionMasker.ResetMask();
using (TimerStack.Instance.Scoped("CollectObservations"))
Expand Down Expand Up @@ -591,7 +565,6 @@ void SendInfoToBrain()
m_Recorder.WriteExperience(m_Info);
}

m_Info.textObservation = "";
}

/// <summary>
Expand Down Expand Up @@ -627,7 +600,7 @@ public void GenerateSensorData()
}

/// <summary>
/// Collects the (vector, visual, text) observations of the agent.
/// Collects the (vector, visual) observations of the agent.
/// The agent observation describes the current environment from the
/// perspective of the agent.
/// </summary>
Expand All @@ -636,7 +609,7 @@ public void GenerateSensorData()
/// the Agent acheive its goal. For example, for a fighting Agent, its
/// observation could include distances to friends or enemies, or the
/// current level of ammunition at its disposal.
/// Recall that an Agent may attach vector, visual or textual observations.
/// Recall that an Agent may attach vector or visual observations.
/// Vector observations are added by calling the provided helper methods:
/// - <see cref="AddVectorObs(int)"/>
/// - <see cref="AddVectorObs(float)"/>
Expand All @@ -657,8 +630,6 @@ public void GenerateSensorData()
/// needs to match the vectorObservationSize attribute of the linked Brain.
/// Visual observations are implicitly added from the cameras attached to
/// the Agent.
/// Lastly, textual observations are added using
/// <see cref="SetTextObs(string)"/>.
/// </remarks>
public virtual void CollectObservations()
{
Expand Down Expand Up @@ -789,28 +760,6 @@ protected void AddVectorObs(int observation, int range)
collectObservationsSensor.AddOneHotObservation(observation, range);
}

/// <summary>
/// Sets the text observation.
/// </summary>
/// <param name="textObservation">The text observation.</param>
public void SetTextObs(string textObservation)
{
m_Info.textObservation = textObservation;
}

/// <summary>
/// Specifies the agent behavior at every step based on the provided
/// action.
/// </summary>
/// <param name="vectorAction">
/// Vector action. Note that for discrete actions, the provided array
/// will be of length 1.
/// </param>
/// <param name="textAction">Text action.</param>
public virtual void AgentAction(float[] vectorAction, string textAction)
{
}

/// <summary>
/// Specifies the agent behavior at every step based on the provided
/// action.
Expand All @@ -819,15 +768,8 @@ public virtual void AgentAction(float[] vectorAction, string textAction)
/// Vector action. Note that for discrete actions, the provided array
/// will be of length 1.
/// </param>
/// <param name="textAction">Text action.</param>
/// <param name="customAction">
/// A custom action, defined by the user as custom protobuf message. Useful if the action is hard to encode
/// as either a flat vector or a single string.
/// </param>
public virtual void AgentAction(float[] vectorAction, string textAction, CommunicatorObjects.CustomActionProto customAction)
public virtual void AgentAction(float[] vectorAction)
{
// We fall back to not using the custom action if the subclassed Agent doesn't override this method.
AgentAction(vectorAction, textAction);
}

/// <summary>
Expand Down Expand Up @@ -992,7 +934,7 @@ void AgentStep()
if ((m_RequestAction) && (m_Brain != null))
{
m_RequestAction = false;
AgentAction(m_Action.vectorActions, m_Action.textActions, m_Action.customAction);
AgentAction(m_Action.vectorActions);
}

if ((m_StepCount >= agentParameters.maxStep)
Expand Down Expand Up @@ -1028,14 +970,5 @@ void DecideAction()
{
m_Brain?.DecideAction();
}

/// <summary>
/// Sets the custom observation for the agent for this episode.
/// </summary>
/// <param name="customObservation">New value of the agent's custom observation.</param>
public void SetCustomObservation(CommunicatorObjects.CustomObservationProto customObservation)
{
m_Info.customObservation = customObservation;
}
}
}
Loading