Skip to content

Commit 4c2d2bb

Browse files
vincentpierreeshvk
authored andcommitted
Fixed the Training, ZeroK now trains properly again
Modified the API, not satisfied with the offset and size arguments
1 parent c149f6e commit 4c2d2bb

File tree

10 files changed

+52
-48
lines changed

10 files changed

+52
-48
lines changed

Assets/ECS_MLAgents_v0/Core/AgentSystem.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ protected override JobHandle OnUpdate(JobHandle inputDeps)
167167

168168
handle.Complete();
169169

170-
Decision.BatchProcess(ref _sensorTensor, ref _actuatorTensor);
170+
Decision.BatchProcess(ref _sensorTensor, ref _actuatorTensor, 0, nAgents);
171171

172172
/*
173173
* Copy the data from the actuator NativeArray<float> to the actuators of each entity.
@@ -201,7 +201,7 @@ public void Execute(int i)
201201
/*
202202
* This IJobParallelFor copies the Actuator data to the appropriate IComponentData
203203
*/
204-
// [BurstCompile]
204+
[BurstCompile]
205205
private struct CopyActuatorsJob : IJobParallelFor
206206
{
207207
public ComponentDataArray<TA> Actuators;

Assets/ECS_MLAgents_v0/Core/ExternalDecision.cs

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public class ExternalDecision<TS, TA> : IAgentDecision<TS, TA>
2929
private const int ACTUATOR_DATA_POSITION = 100001;
3030

3131

32-
private float[] actuatorData = new float[0];
32+
private TA[] actuatorData = new TA[0];
3333

3434

3535

@@ -58,15 +58,21 @@ public ExternalDecision()
5858
Debug.Log("Is Ready to Communicate");
5959
}
6060

61-
public void BatchProcess(ref NativeArray<TS> sensors, ref NativeArray<TA> actuators )
61+
public void BatchProcess(ref NativeArray<TS> sensors, ref NativeArray<TA> actuators, int offset = 0, int size = -1)
6262
{
63-
Profiler.BeginSample("Communicating");
63+
Profiler.BeginSample("__Communicating");
6464

65+
Profiler.BeginSample("__TypeCheck");
6566
VerifySensor(typeof(TS));
6667
VerifyActuator(typeof(TA));
68+
if (size ==-1){
69+
size = sensors.Length - offset;
70+
}
71+
Profiler.EndSample();
6772

68-
int batch = sensors.Length;
69-
if (batch != actuators.Length)
73+
Profiler.BeginSample("__VerifyLength");
74+
int batch = size;
75+
if (sensors.Length != actuators.Length)
7076
{
7177
throw new Exception("Error in the length of the sensors and actuators");
7278
}
@@ -78,21 +84,23 @@ public void BatchProcess(ref NativeArray<TS> sensors, ref NativeArray<TA> actuat
7884

7985
if (actuatorData.Length < _actuatorSize* batch)
8086
{
81-
actuatorData = new float[_actuatorSize * batch];
87+
actuatorData = new TA[batch];
8288
}
89+
Profiler.EndSample();
8390

84-
91+
Profiler.BeginSample("__Write");
8592
accessor.Write(NUMBER_AGENTS_POSITION, batch);
8693
accessor.Write(SENSOR_SIZE_POSITION, _sensorSize);
8794
accessor.Write(ACTUATOR_SIZE_POSITION, _actuatorSize);
8895

89-
accessor.WriteArray(SENSOR_DATA_POSITION, sensors.ToArray(), 0, batch);
96+
accessor.WriteArray(SENSOR_DATA_POSITION, sensors.Slice(offset, batch).ToArray(), 0, batch);
9097

9198
accessor.Write(PYTHON_READY_POSITION, false);
9299

93100
accessor.Write(UNITY_READY_POSITION, true);
101+
Profiler.EndSample();
94102

95-
103+
Profiler.BeginSample("__Wait");
96104
var readyToContinue = false;
97105
int loopIter = 0;
98106
while (!readyToContinue)
@@ -105,22 +113,20 @@ public void BatchProcess(ref NativeArray<TS> sensors, ref NativeArray<TA> actuat
105113
Debug.Log("Missed Communication");
106114
}
107115
}
116+
Profiler.EndSample();
108117

109-
accessor.ReadArray(ACTUATOR_DATA_POSITION, actuatorData, 0, batch * _actuatorSize);
118+
Profiler.BeginSample("__Read");
119+
accessor.ReadArray(ACTUATOR_DATA_POSITION, actuatorData, 0, batch);
110120

111-
// actuator.CopyFrom(actuatorData);
121+
actuators.Slice(offset, batch).CopyFrom(actuatorData);
112122

113-
var tmpA = new NativeArray<float>(batch * _actuatorSize, Allocator.Persistent);
114-
tmpA.CopyFrom(actuatorData);
115-
for(var i = 0; i< batch; i++){
116-
var act = new TA();
117-
TensorUtility.CopyFromNativeArray(tmpA, out act, i * _sensorSize * 4);
118-
actuators[i] = act;
119-
}
120-
tmpA.Dispose();
123+
// for(var i = 0; i< batch; i++){
124+
// actuators[i] = actuatorData[i];
125+
// }
121126

127+
Profiler.EndSample();
128+
Profiler.EndSample();
122129

123-
Profiler.BeginSample("Communicating");
124130
}
125131

126132
private void VerifySensor(System.Type t){

Assets/ECS_MLAgents_v0/Core/IAgentDecision.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ namespace ECS_MLAgents_v0.Core
88
* based on the information present in the sensor.
99
*/
1010
public interface IAgentDecision<TS, TA>
11-
where TS : struct, IComponentData
12-
where TA : struct, IComponentData
11+
where TS : struct
12+
where TA : struct
1313
{
1414
/// <summary>
1515
/// DecideBatch updates the actuators of the agents present in the batch from
@@ -19,8 +19,9 @@ public interface IAgentDecision<TS, TA>
1919
/// batch. T.</param>
2020
/// <param name="actuators">The aggregated data for the actuator information present in the
2121
/// batch. </param>
22-
void BatchProcess(ref NativeArray<TS> sensors, ref NativeArray<TA> actuators);
22+
void BatchProcess(ref NativeArray<TS> sensors, ref NativeArray<TA> actuators, int offset = 0, int size = -1);
2323

24+
// TODO : It is debatable wether or not we want to enforce the type here
2425
}
2526

2627
}

Assets/ECS_MLAgents_v0/Core/NNDecision.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public NNDecision(NNModel model){
4848

4949
}
5050

51-
public void BatchProcess(ref NativeArray<TS> sensors, ref NativeArray<TA> actuators )
51+
public void BatchProcess(ref NativeArray<TS> sensors, ref NativeArray<TA> actuators , int offset = 0, int size = -1)
5252
{
5353
VerifySensor();
5454
VerifyActuator();

Assets/ECS_MLAgents_v0/Example/SpaceMagic/Scripts/HeuristicSpace.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@ public HeuristicSpace(float3 center, float strength)
2828
_strength = strength;
2929
}
3030

31-
public void BatchProcess(ref NativeArray<Position> sensors, ref NativeArray<Acceleration> actuators )
31+
public void BatchProcess(ref NativeArray<Position> sensors, ref NativeArray<Acceleration> actuators, int offset = 0, int size = -1)
3232
{
33+
3334
var nAgents = sensors.Length;
35+
if (size ==-1){
36+
size = sensors.Length - offset;
37+
}
3438
float3 pos = new float3();
35-
for (int n = 0; n < nAgents; n++)
39+
for (int n = offset; n < size + offset; n++)
3640
{
3741
pos = sensors[n].Value;
3842

Assets/ECS_MLAgents_v0/Example/SpaceWars/Scripts/HumanDecision.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public class HumanDecision<TS> : IAgentDecision<TS, Steering>
1111
where TS : struct, IComponentData
1212
{
1313

14-
public void BatchProcess(ref NativeArray<TS> sensors, ref NativeArray<Steering> actuators )
14+
public void BatchProcess(ref NativeArray<TS> sensors, ref NativeArray<Steering> actuators, int offset = 0, int size = -1)
1515
{
1616
var input = new float3();
1717
if (Input.GetKey(KeyCode.LeftArrow))
@@ -35,12 +35,14 @@ public void BatchProcess(ref NativeArray<TS> sensors, ref NativeArray<Steering>
3535
{
3636
input.z = 1;
3737
}
38-
38+
if (size ==-1){
39+
size = sensors.Length - offset;
40+
}
3941
for (int n = 0; n < actuators.Length; n++)
4042
{
4143
actuators[n] = new Steering{
42-
XAxis = input.x,
43-
YAxis = input.y,
44+
YAxis = input.x,
45+
XAxis = input.y,
4446
Shoot = input.z
4547
};
4648
}

Assets/ECS_MLAgents_v0/Example/SpaceWars/Scripts/Manager.cs

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ void Start()
7070
_shipSystemA.Decision =new NNDecision<ShipSensor, Steering>(model);
7171
// _shipSystemA.Decision = new ExternalDecision();
7272
_playerSystem = World.Active.GetExistingManager<PlayerShipSystem>();
73-
_playerSystem.Decision = new NNDecision<ShipSensor, Steering>(model);
73+
// _playerSystem.Decision = new NNDecision<ShipSensor, Steering>(model);
74+
_playerSystem.Decision = new HumanDecision<ShipSensor>();
7475
_playerSystem.SetNewComponentGroup(typeof(PlayerFlag));
7576
_shipSystemA.DecisionRequester = new FixedTimeRequester(0.1f);
7677

@@ -99,19 +100,7 @@ void FixedUpdate(){
99100
void Update()
100101
{
101102

102-
var decision = new NNDecision<ShipSensor, Steering>(model);
103-
104-
var ss = new NativeArray<ShipSensor>(2, Allocator.Temp);
105-
ss[0] = new ShipSensor();
106-
ss[1] = new ShipSensor();
107-
108-
var aa = new NativeArray<Steering>(2, Allocator.Temp);
109-
aa[0] = new Steering();
110-
aa[1] = new Steering();
111-
112-
decision.BatchProcess(ref ss, ref aa);
113-
114-
Debug.Log(aa[0].XAxis+" "+ aa[1].XAxis);
103+
115104

116105
// for (var i = 0; i < 10; i++){
117106
// foreach(var behavior in World.Active.BehaviourManagers)

Assets/ECS_MLAgents_v0/Example/SpaceWars/Scripts/WarsHeuristic.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace ECS_MLAgents_v0.Example.SpaceWars.Scripts
99
{
1010
public class WarsHeuristic : IAgentDecision<ShipSensor, Steering>
1111
{
12-
public void BatchProcess(ref NativeArray<ShipSensor> sensors, ref NativeArray<Steering> actuators )
12+
public void BatchProcess(ref NativeArray<ShipSensor> sensors, ref NativeArray<Steering> actuators, int offset = 0, int size = -1)
1313
{
1414
for (int i = 0; i < sensors.Length; i++)
1515
{

Assets/ECS_MLAgents_v0/Example/ZeroK/Manager.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void Start()
8484
// sC.Decision = new NNDecision(modelC);
8585
sA.DecisionRequester = new FixedCountRequester(1);
8686

87-
Time.captureFramerate = 60;
87+
// Time.captureFramerate = 60;
8888

8989
Spawn(100);
9090
}

Assets/ECS_MLAgents_v0/Example/ZeroK/System.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ public void Execute(
7272
{
7373
sensor.Position = pos.Value / 50;
7474
sensor.Timer += 0.001f;
75+
// sensor.Timer += 0.01f;
7576
if (sensor.Timer > 1f)
7677
{
7778
sensor.Done = 1f;
@@ -89,6 +90,7 @@ private struct ResetPositionsJob : IJobProcessComponentData<Position, Sensor>
8990
public void Execute(ref Position position, ref Sensor sensor)
9091
{
9192
if (sensor.Timer > 1.0015f)
93+
// if (sensor.Timer > 1.02f)
9294
{
9395
sensor.Done = 0;
9496
sensor.Timer = 0;

0 commit comments

Comments
 (0)