Skip to content

Commit e4826b6

Browse files
Marwan MattarvincentpierreChris Elion
authored
[WIP] Side Channel Design Changes (#3807)
* Make EnvironmentParameters a first-class citizen in the API Missing: Python conterparts and testing. * Minor comment fix to Engine Parameters * A second minor fix. * Make EngineConfigChannel Internal and add a singleton/sealed accessor * Make StatsSideChannel Internal and add a singleton/sealed accessor * Changes to SideChannelUtils - Disallow two sidechannels of the same type to be added - Remove GetSideChannels that return a list as that is now unnecessary - Make most methods except (register/unregister) internal to limit users impacting the “system-level” side channels - Add an improved comment to SideChannel.cs * Added Dispose methods to system-level sidechannel wrappers - Specifically to StatsRecorder, EnvironmentParameters and EngineParameters. - Updated Academy.Dispose to take advantage of these. - Updated Editor tests to cover all three “system-level” side channels. Kudos to Unit Tests (TestAcademy / TestAcademyDispose) for catching these. * Removed debub log. * Back-up commit. * Revert "Back-up commit." This reverts commit f81e835. * key changes to wrapper classes made the wrapper classes non-singleton (but internal constructors) made EngineParameters internal * Re-enabled the option to add multiple side channels of the same type * Fixed example env * Add an enum flag to the EnvParamsChannel * Adding .cs.meta files * Update engine config side channel Removed unnecessary accessors Made capture frame rate a parameter * Rename SideChannelUtils —> SideChannelsManager * PR feedback * Minor PR feedback. * Python side changes to the SideChannel redesign (#3826) * Modified the EngineConfig to send one message per field * Created the Python Environment Parameters Channel and hooked it in * Make OnMessageReceived protected * addressing comments * [Side Channels] Edited the documenation and renamed a few things (#3833) * Edited the documetation and renamed a few things * addressing comments * Update docs/Python-API.md Co-Authored-By: Chris Elion <chris.elion@unity3d.com> * Update com.unity.ml-agents/CHANGELOG.md Co-Authored-By: Chris Elion <chris.elion@unity3d.com> * Removing unecessary migrating line Co-authored-by: Chris Elion <chris.elion@unity3d.com> * Addressing renaming comments * Removing the EngineParameters class Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com> Co-authored-by: Chris Elion <chris.elion@unity3d.com>
1 parent cc74f81 commit e4826b6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+724
-305
lines changed

Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ public class Ball3DAgent : Agent
88
[Header("Specific to Ball3D")]
99
public GameObject ball;
1010
Rigidbody m_BallRb;
11-
FloatPropertiesChannel m_ResetParams;
11+
EnvironmentParameters m_ResetParams;
1212

1313
public override void Initialize()
1414
{
1515
m_BallRb = ball.GetComponent<Rigidbody>();
16-
m_ResetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
16+
m_ResetParams = Academy.Instance.EnvironmentParameters;
1717
SetResetParameters();
1818
}
1919

@@ -75,8 +75,8 @@ public override void Heuristic(float[] actionsOut)
7575
public void SetBall()
7676
{
7777
//Set the attributes of the ball by fetching the information from the academy
78-
m_BallRb.mass = m_ResetParams.GetPropertyWithDefault("mass", 1.0f);
79-
var scale = m_ResetParams.GetPropertyWithDefault("scale", 1.0f);
78+
m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
79+
var scale = m_ResetParams.GetWithDefault("scale", 1.0f);
8080
ball.transform.localScale = new Vector3(scale, scale, scale);
8181
}
8282

Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ public class Ball3DHardAgent : Agent
88
[Header("Specific to Ball3DHard")]
99
public GameObject ball;
1010
Rigidbody m_BallRb;
11-
FloatPropertiesChannel m_ResetParams;
11+
EnvironmentParameters m_ResetParams;
1212

1313
public override void Initialize()
1414
{
1515
m_BallRb = ball.GetComponent<Rigidbody>();
16-
m_ResetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
16+
m_ResetParams = Academy.Instance.EnvironmentParameters;
1717
SetResetParameters();
1818
}
1919

@@ -66,8 +66,8 @@ public override void OnEpisodeBegin()
6666
public void SetBall()
6767
{
6868
//Set the attributes of the ball by fetching the information from the academy
69-
m_BallRb.mass = m_ResetParams.GetPropertyWithDefault("mass", 1.0f);
70-
var scale = m_ResetParams.GetPropertyWithDefault("scale", 1.0f);
69+
m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
70+
var scale = m_ResetParams.GetWithDefault("scale", 1.0f);
7171
ball.transform.localScale = new Vector3(scale, scale, scale);
7272
}
7373

Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ public class BouncerAgent : Agent
1515
int m_NumberJumps = 20;
1616
int m_JumpLeft = 20;
1717

18-
FloatPropertiesChannel m_ResetParams;
18+
EnvironmentParameters m_ResetParams;
1919

2020
public override void Initialize()
2121
{
2222
m_Rb = gameObject.GetComponent<Rigidbody>();
2323
m_LookDir = Vector3.zero;
2424

25-
m_ResetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
25+
m_ResetParams = Academy.Instance.EnvironmentParameters;
2626

2727
SetResetParameters();
2828
}
@@ -121,7 +121,7 @@ void Update()
121121

122122
public void SetTargetScale()
123123
{
124-
var targetScale = m_ResetParams.GetPropertyWithDefault("target_scale", 1.0f);
124+
var targetScale = m_ResetParams.GetWithDefault("target_scale", 1.0f);
125125
target.transform.localScale = new Vector3(targetScale, targetScale, targetScale);
126126
}
127127

Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,14 @@ public class FoodCollectorAgent : Agent
2929
public bool contribute;
3030
public bool useVectorObs;
3131

32+
EnvironmentParameters m_ResetParams;
3233

3334
public override void Initialize()
3435
{
3536
m_AgentRb = GetComponent<Rigidbody>();
3637
m_MyArea = area.GetComponent<FoodCollectorArea>();
3738
m_FoodCollecterSettings = FindObjectOfType<FoodCollectorSettings>();
38-
39+
m_ResetParams = Academy.Instance.EnvironmentParameters;
3940
SetResetParameters();
4041
}
4142

@@ -271,12 +272,12 @@ void OnCollisionEnter(Collision collision)
271272

272273
public void SetLaserLengths()
273274
{
274-
m_LaserLength = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("laser_length", 1.0f);
275+
m_LaserLength = m_ResetParams.GetWithDefault("laser_length", 1.0f);
275276
}
276277

277278
public void SetAgentScale()
278279
{
279-
float agentScale = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("agent_scale", 1.0f);
280+
float agentScale = m_ResetParams.GetWithDefault("agent_scale", 1.0f);
280281
gameObject.transform.localScale = new Vector3(agentScale, agentScale, agentScale);
281282
}
282283

Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorSettings.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
using System;
21
using UnityEngine;
32
using UnityEngine.UI;
43
using MLAgents;
5-
using MLAgents.SideChannels;
64

75
public class FoodCollectorSettings : MonoBehaviour
86
{
@@ -14,15 +12,15 @@ public class FoodCollectorSettings : MonoBehaviour
1412
public int totalScore;
1513
public Text scoreText;
1614

17-
StatsSideChannel m_statsSideChannel;
15+
StatsRecorder m_Recorder;
1816

1917
public void Awake()
2018
{
2119
Academy.Instance.OnEnvironmentReset += EnvironmentReset;
22-
m_statsSideChannel = SideChannelUtils.GetSideChannel<StatsSideChannel>();
20+
m_Recorder = Academy.Instance.StatsRecorder;
2321
}
2422

25-
public void EnvironmentReset()
23+
private void EnvironmentReset()
2624
{
2725
ClearObjects(GameObject.FindGameObjectsWithTag("food"));
2826
ClearObjects(GameObject.FindGameObjectsWithTag("badFood"));
@@ -54,7 +52,7 @@ public void Update()
5452
// need to send every Update() call.
5553
if ((Time.frameCount % 100)== 0)
5654
{
57-
m_statsSideChannel?.AddStat("TotalScore", totalScore);
55+
m_Recorder.Add("TotalScore", totalScore);
5856
}
5957
}
6058
}

Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ public class GridAgent : Agent
2929
const int k_Left = 3;
3030
const int k_Right = 4;
3131

32+
EnvironmentParameters m_ResetParams;
33+
34+
public override void Initialize()
35+
{
36+
m_ResetParams = Academy.Instance.EnvironmentParameters;
37+
}
38+
3239
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
3340
{
3441
// Mask the necessary actions if selected by the user.
@@ -37,7 +44,7 @@ public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMaske
3744
// Prevents the agent from picking an action that would make it collide with a wall
3845
var positionX = (int)transform.position.x;
3946
var positionZ = (int)transform.position.z;
40-
var maxPosition = (int)SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("gridSize", 5f) - 1;
47+
var maxPosition = (int)m_ResetParams.GetWithDefault("gridSize", 5f) - 1;
4148

4249
if (positionX == 0)
4350
{

Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ public class GridArea : MonoBehaviour
1414

1515
public GameObject trueAgent;
1616

17-
FloatPropertiesChannel m_ResetParameters;
18-
1917
Camera m_AgentCam;
2018

2119
public GameObject goalPref;
@@ -30,9 +28,11 @@ public class GridArea : MonoBehaviour
3028

3129
Vector3 m_InitialPosition;
3230

31+
EnvironmentParameters m_ResetParams;
32+
3333
public void Start()
3434
{
35-
m_ResetParameters = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
35+
m_ResetParams = Academy.Instance.EnvironmentParameters;
3636

3737
m_Objects = new[] { goalPref, pitPref };
3838

@@ -50,23 +50,23 @@ public void Start()
5050
m_InitialPosition = transform.position;
5151
}
5252

53-
public void SetEnvironment()
53+
private void SetEnvironment()
5454
{
55-
transform.position = m_InitialPosition * (m_ResetParameters.GetPropertyWithDefault("gridSize", 5f) + 1);
55+
transform.position = m_InitialPosition * (m_ResetParams.GetWithDefault("gridSize", 5f) + 1);
5656
var playersList = new List<int>();
5757

58-
for (var i = 0; i < (int)m_ResetParameters.GetPropertyWithDefault("numObstacles", 1); i++)
58+
for (var i = 0; i < (int)m_ResetParams.GetWithDefault("numObstacles", 1); i++)
5959
{
6060
playersList.Add(1);
6161
}
6262

63-
for (var i = 0; i < (int)m_ResetParameters.GetPropertyWithDefault("numGoals", 1f); i++)
63+
for (var i = 0; i < (int)m_ResetParams.GetWithDefault("numGoals", 1f); i++)
6464
{
6565
playersList.Add(0);
6666
}
6767
players = playersList.ToArray();
6868

69-
var gridSize = (int)m_ResetParameters.GetPropertyWithDefault("gridSize", 5f);
69+
var gridSize = (int)m_ResetParams.GetWithDefault("gridSize", 5f);
7070
m_Plane.transform.localScale = new Vector3(gridSize / 10.0f, 1f, gridSize / 10.0f);
7171
m_Plane.transform.localPosition = new Vector3((gridSize - 1) / 2f, -0.5f, (gridSize - 1) / 2f);
7272
m_Sn.transform.localScale = new Vector3(1, 1, gridSize + 2);
@@ -84,7 +84,7 @@ public void SetEnvironment()
8484

8585
public void AreaReset()
8686
{
87-
var gridSize = (int)m_ResetParameters.GetPropertyWithDefault("gridSize", 5f);
87+
var gridSize = (int)m_ResetParams.GetWithDefault("gridSize", 5f);
8888
foreach (var actor in actorObjs)
8989
{
9090
DestroyImmediate(actor);
@@ -98,7 +98,7 @@ public void AreaReset()
9898
{
9999
numbers.Add(Random.Range(0, gridSize * gridSize));
100100
}
101-
var numbersA = Enumerable.ToArray(numbers);
101+
var numbersA = numbers.ToArray();
102102

103103
for (var i = 0; i < players.Length; i++)
104104
{

Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridSettings.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ public class GridSettings : MonoBehaviour
88

99
public void Awake()
1010
{
11-
SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().RegisterCallback("gridSize", f =>
11+
Academy.Instance.EnvironmentParameters.RegisterCallback("gridSize", f =>
1212
{
1313
MainCamera.transform.position = new Vector3(-(f - 1) / 2f, f * 1.25f, -(f - 1) / 2f);
1414
MainCamera.orthographicSize = (f + 5f) / 2f;

Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ public class PushAgentBasic : Agent
4949
/// </summary>
5050
Renderer m_GroundRenderer;
5151

52+
private EnvironmentParameters m_ResetParams;
53+
5254
void Awake()
5355
{
5456
m_PushBlockSettings = FindObjectOfType<PushBlockSettings>();
@@ -70,6 +72,8 @@ public override void Initialize()
7072
// Starting material
7173
m_GroundMaterial = m_GroundRenderer.material;
7274

75+
m_ResetParams = Academy.Instance.EnvironmentParameters;
76+
7377
SetResetParameters();
7478
}
7579

@@ -226,27 +230,23 @@ public override void OnEpisodeBegin()
226230

227231
public void SetGroundMaterialFriction()
228232
{
229-
var resetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
230-
231233
var groundCollider = ground.GetComponent<Collider>();
232234

233-
groundCollider.material.dynamicFriction = resetParams.GetPropertyWithDefault("dynamic_friction", 0);
234-
groundCollider.material.staticFriction = resetParams.GetPropertyWithDefault("static_friction", 0);
235+
groundCollider.material.dynamicFriction = m_ResetParams.GetWithDefault("dynamic_friction", 0);
236+
groundCollider.material.staticFriction = m_ResetParams.GetWithDefault("static_friction", 0);
235237
}
236238

237239
public void SetBlockProperties()
238240
{
239-
var resetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
240-
241-
var scale = resetParams.GetPropertyWithDefault("block_scale", 2);
241+
var scale = m_ResetParams.GetWithDefault("block_scale", 2);
242242
//Set the scale of the block
243243
m_BlockRb.transform.localScale = new Vector3(scale, 0.75f, scale);
244244

245245
// Set the drag of the block
246-
m_BlockRb.drag = resetParams.GetPropertyWithDefault("block_drag", 0.5f);
246+
m_BlockRb.drag = m_ResetParams.GetWithDefault("block_drag", 0.5f);
247247
}
248248

249-
public void SetResetParameters()
249+
private void SetResetParameters()
250250
{
251251
SetGroundMaterialFriction();
252252
SetBlockProperties();

Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ public class ReacherAgent : Agent
2121
// Frequency of the cosine deviation of the goal along the vertical dimension
2222
float m_DeviationFreq;
2323

24+
private EnvironmentParameters m_ResetParams;
25+
2426
/// <summary>
2527
/// Collect the rigidbodies of the reacher in order to resue them for
2628
/// observations and actions.
@@ -30,6 +32,8 @@ public override void Initialize()
3032
m_RbA = pendulumA.GetComponent<Rigidbody>();
3133
m_RbB = pendulumB.GetComponent<Rigidbody>();
3234

35+
m_ResetParams = Academy.Instance.EnvironmentParameters;
36+
3337
SetResetParameters();
3438
}
3539

@@ -110,10 +114,9 @@ public override void OnEpisodeBegin()
110114

111115
public void SetResetParameters()
112116
{
113-
var fp = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
114-
m_GoalSize = fp.GetPropertyWithDefault("goal_size", 5);
115-
m_GoalSpeed = Random.Range(-1f, 1f) * fp.GetPropertyWithDefault("goal_speed", 1);
116-
m_Deviation = fp.GetPropertyWithDefault("deviation", 0);
117-
m_DeviationFreq = fp.GetPropertyWithDefault("deviation_freq", 0);
117+
m_GoalSize = m_ResetParams.GetWithDefault("goal_size", 5);
118+
m_GoalSpeed = Random.Range(-1f, 1f) * m_ResetParams.GetWithDefault("goal_speed", 1);
119+
m_Deviation = m_ResetParams.GetWithDefault("deviation", 0);
120+
m_DeviationFreq = m_ResetParams.GetWithDefault("deviation_freq", 0);
118121
}
119122
}

Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ProjectSettingsOverrides.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ public void Awake()
4444
Physics.defaultSolverVelocityIterations = solverVelocityIterations;
4545

4646
// Make sure the Academy singleton is initialized first, since it will create the SideChannels.
47-
var academy = Academy.Instance;
48-
SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); });
47+
Academy.Instance.EnvironmentParameters.RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); });
4948
}
5049

5150
public void OnDestroy()

Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ public enum Position
4949
BehaviorParameters m_BehaviorParameters;
5050
Vector3 m_Transform;
5151

52+
private EnvironmentParameters m_ResetParams;
53+
5254
public override void Initialize()
5355
{
5456
m_Existential = 1f / MaxStep;
@@ -73,7 +75,7 @@ public override void Initialize()
7375
m_LateralSpeed = 0.3f;
7476
m_ForwardSpeed = 1.3f;
7577
}
76-
else
78+
else
7779
{
7880
m_LateralSpeed = 0.3f;
7981
m_ForwardSpeed = 1.0f;
@@ -91,6 +93,8 @@ public override void Initialize()
9193
area.playerStates.Add(playerState);
9294
m_PlayerIndex = area.playerStates.IndexOf(playerState);
9395
playerState.playerIndex = m_PlayerIndex;
96+
97+
m_ResetParams = Academy.Instance.EnvironmentParameters;
9498
}
9599

96100
public void MoveAgent(float[] act)
@@ -214,7 +218,7 @@ public override void OnEpisodeBegin()
214218
{
215219

216220
timePenalty = 0;
217-
m_BallTouch = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("ball_touch", 0);
221+
m_BallTouch = m_ResetParams.GetWithDefault("ball_touch", 0);
218222
if (team == Team.Purple)
219223
{
220224
transform.rotation = Quaternion.Euler(0f, -90f, 0f);

0 commit comments

Comments
 (0)