Skip to content

Commit bdb69f1

Browse files
committed
Minor changes
1 parent 9aca6cc commit bdb69f1

File tree

12 files changed

+79
-39
lines changed

12 files changed

+79
-39
lines changed

Editor/MLAgentsWorldSpecsDrawer.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace Unity.AI.MLAgents.Editor
1212
internal static class SpecsPropertyNames
1313
{
1414
public const string k_Name = "Name";
15+
public const string k_WorldProcessorType = "WorldProcessorType";
1516
public const string k_NumberAgents = "NumberAgents";
1617
public const string k_ActionType = "ActionType";
1718
public const string k_ObservationShapes = "ObservationShapes";
@@ -47,7 +48,7 @@ public override float GetPropertyHeight(SerializedProperty property, GUIContent
4748
var nbLines = 0;
4849
nbLines += GetHeightObservationShape(property);
4950
nbLines += GetHeightDiscreteAction(property);
50-
nbLines += 7; // TODO : COMPUTE
51+
nbLines += 8; // TODO : COMPUTE
5152
m_TotalHeight = k_LineHeight * nbLines + k_WarningLineHeight * GetHeightWarnings(property);
5253

5354
return m_TotalHeight + 6f;
@@ -76,6 +77,12 @@ public override void OnGUI(Rect position, SerializedProperty property, GUIConten
7677
new GUIContent("World Name", "The name of the World"));
7778
position.y += k_LineHeight;
7879

80+
// WorldProcessorType
81+
EditorGUI.PropertyField(position,
82+
property.FindPropertyRelative(SpecsPropertyNames.k_WorldProcessorType),
83+
new GUIContent("Processor Type", "The Policy for the World"));
84+
position.y += k_LineHeight;
85+
7986
// Number of Agents
8087
EditorGUI.PropertyField(position,
8188
property.FindPropertyRelative(SpecsPropertyNames.k_NumberAgents),

Runtime/Academy.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ internal void UpdateWorld(MLAgentsWorld world)
171171
throw new MLAgentsException("TODO : Null processor");
172172
}
173173

174-
var command = RemoteCommand.DEFAULT;
174+
var command = WorldCommand.DEFAULT;
175175

176176
if (com != null && processor.IsConnected)
177177
{
@@ -185,7 +185,7 @@ internal void UpdateWorld(MLAgentsWorld world)
185185
SideChannelUtils.ProcessSideChannelData(m_SideChannels, com.ReadAndClearSideChannelData());
186186
FirstMessageReceived = true;
187187
}
188-
if (command == RemoteCommand.DEFAULT)
188+
if (command == WorldCommand.DEFAULT)
189189
{
190190
com.WriteSideChannelData(SideChannelUtils.GetSideChannelMessage(m_SideChannels));
191191
command = processor.ProcessWorld();
@@ -203,15 +203,15 @@ internal void UpdateWorld(MLAgentsWorld world)
203203
}
204204
switch (command)
205205
{
206-
case RemoteCommand.RESET:
207-
Debug.Log("RESET");
206+
case WorldCommand.RESET:
208207
World.DefaultGameObjectInjectionWorld.EntityManager.CompleteAllJobs(); // This is problematic because it completes only for the active world
209208
ResetAllWorlds();
210209
OnEnvironmentReset?.Invoke();
211210
// TODO : RESET logic
212211
break;
213212

214-
case RemoteCommand.CLOSE:
213+
case WorldCommand.CLOSE:
214+
Debug.LogError("Communication was closed.");
215215
#if UNITY_EDITOR
216216
EditorApplication.isPlaying = false;
217217
#else
@@ -220,7 +220,7 @@ internal void UpdateWorld(MLAgentsWorld world)
220220
com = null;
221221
break;
222222

223-
case RemoteCommand.DEFAULT:
223+
case WorldCommand.DEFAULT:
224224
break;
225225

226226
default:

Runtime/Remote/RemoteCommand.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
namespace Unity.AI.MLAgents
22
{
3-
public enum RemoteCommand : sbyte
3+
internal enum RemoteCommand : sbyte
44
{
55
DEFAULT = 0,
66
RESET = 1,
77
CHANGE_FILE = 2,
88
CLOSE = 3
99
}
10+
11+
public enum WorldCommand
12+
{
13+
DEFAULT,
14+
RESET,
15+
CLOSE
16+
}
1017
}

Runtime/Remote/SharedMemoryCom.cs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -468,34 +468,33 @@ public void SetUnityReady()
468468
accessor.Write(k_CommandOffset, (sbyte)RemoteCommand.DEFAULT);
469469
accessor.Write(k_MutexOffset, false);
470470
}
471-
else
472-
{
473-
throw new MLAgentsException("Communication has stopped.");
474-
}
475471
}
476472

477-
public RemoteCommand Advance()
473+
public WorldCommand Advance()
478474
{
475+
if (!accessor.CanWrite)
476+
{
477+
return WorldCommand.CLOSE;
478+
}
479479
var pythonAlive = WaitOnPython();
480480
RemoteCommand commandReceived = (RemoteCommand)accessor.ReadSByte(k_CommandOffset);
481481
if (!pythonAlive)
482482
{
483-
// commandReceived = RemoteCommand.CLOSE;
484-
return RemoteCommand.CLOSE;
483+
return WorldCommand.CLOSE;
485484
}
486485

487486
switch (commandReceived)
488487
{
489488
case RemoteCommand.RESET:
490-
return commandReceived;
489+
return WorldCommand.RESET;
491490
case RemoteCommand.CLOSE:
492491
OnCloseCommand();
493-
return commandReceived;
492+
return WorldCommand.CLOSE;
494493
case RemoteCommand.CHANGE_FILE:
495494
OnChangeFileCommand();
496495
return Advance();
497496
default:
498-
return commandReceived;
497+
return WorldCommand.DEFAULT;
499498
}
500499
}
501500

Runtime/SideChannels/EngineConfigurationChannel.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ public override void OnMessageReceived(byte[] data)
3636

3737
var simGroup = World.DefaultGameObjectInjectionWorld.GetOrCreateSystem<SimulationSystemGroup>();
3838

39-
// #if UNITY_EDITOR
40-
// FixedRateUtils.EnableFixedRateSimple(simGroup, 1 / 60f);
41-
// #else
39+
#if UNITY_EDITOR
40+
TimeUtils.EnableFixedRateWithCatchUp(simGroup, 1 / 60f, 1f);
41+
#else
4242
TimeUtils.EnableFixedRateWithCatchUp(simGroup, 1 / 60f, timeScale);
43-
// #endif
43+
#endif
4444
}
4545
}
4646
}

Runtime/UI/MLAgentsWorldSpecs.cs

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,28 @@
66

77
namespace Unity.AI.MLAgents
88
{
9+
internal enum WorldProcessorType
10+
{
11+
Default,
12+
InferenceOnly,
13+
None
14+
}
15+
916
[Serializable]
1017
public struct MLAgentsWorldSpecs
1118
{
1219
[SerializeField] internal string Name;
1320

14-
[SerializeField] public int NumberAgents;
15-
[SerializeField] public ActionType ActionType;
16-
[SerializeField] public int3[] ObservationShapes;
17-
[SerializeField] public int ActionSize;
18-
[SerializeField] public int[] DiscreteActionBranches;
21+
[SerializeField] internal WorldProcessorType WorldProcessorType;
22+
23+
[SerializeField] internal int NumberAgents;
24+
[SerializeField] internal ActionType ActionType;
25+
[SerializeField] internal int3[] ObservationShapes;
26+
[SerializeField] internal int ActionSize;
27+
[SerializeField] internal int[] DiscreteActionBranches;
1928

20-
[SerializeField] public NNModel Model;
21-
[SerializeField] public InferenceDevice InferenceDevice;
29+
[SerializeField] internal NNModel Model;
30+
[SerializeField] internal InferenceDevice InferenceDevice;
2231

2332
private MLAgentsWorld m_World;
2433

@@ -35,7 +44,25 @@ public MLAgentsWorld GetWorld()
3544
ActionSize,
3645
DiscreteActionBranches
3746
);
38-
m_World.RegisterWorldWithBarracudaModel(Name, Model, InferenceDevice);
47+
switch (WorldProcessorType)
48+
{
49+
case WorldProcessorType.Default:
50+
m_World.RegisterWorldWithBarracudaModel(Name, Model, InferenceDevice);
51+
break;
52+
case WorldProcessorType.InferenceOnly:
53+
if (Model == null)
54+
{
55+
throw new MLAgentsException($"No model specified for {Name}");
56+
}
57+
m_World.RegisterWorldWithBarracudaModelForceNoCommunication(Name, Model, InferenceDevice);
58+
break;
59+
case WorldProcessorType.None:
60+
Academy.Instance.RegisterWorld(Name, m_World, new NullWorldProcessor(m_World), false);
61+
break;
62+
default:
63+
throw new MLAgentsException($"Unknown WorldProcessor Type");
64+
}
65+
3966
return m_World;
4067
}
4168
}

Runtime/WorldProcessor/BarracudaWorldProcessor.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public static void RegisterWorldWithBarracudaModel(
3131
}
3232
}
3333

34-
public static void RegisterWorldWithBarracudaModelForceNoCommunication<TH>(
34+
public static void RegisterWorldWithBarracudaModelForceNoCommunication(
3535
this MLAgentsWorld world,
3636
string policyId,
3737
NNModel model,
@@ -76,7 +76,7 @@ internal BarracudaWorldProcessor(MLAgentsWorld world, NNModel model, InferenceDe
7676
executionDevice, _barracudaModel, _verbose);
7777
}
7878

79-
public RemoteCommand ProcessWorld()
79+
public WorldCommand ProcessWorld()
8080
{
8181
// FOR VECTOR OBS ONLY
8282
// For Continuous control only
@@ -136,7 +136,7 @@ public RemoteCommand ProcessWorld()
136136
}
137137
actuatorT.Dispose();
138138

139-
return RemoteCommand.DEFAULT;
139+
return WorldCommand.DEFAULT;
140140
}
141141

142142
public void Dispose()

Runtime/WorldProcessor/HeuristicWorldProcessor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ internal HeuristicWorldProcessor(MLAgentsWorld world, Func<T> heuristic)
4747
}
4848
}
4949

50-
public RemoteCommand ProcessWorld()
50+
public WorldCommand ProcessWorld()
5151
{
5252
T action = heuristic.Invoke();
5353
var totalCount = world.AgentCounter.Count;
@@ -69,7 +69,7 @@ public RemoteCommand ProcessWorld()
6969
s[i] = action;
7070
}
7171
}
72-
return RemoteCommand.DEFAULT;
72+
return WorldCommand.DEFAULT;
7373
}
7474

7575
public void Dispose()

Runtime/WorldProcessor/IWorldProcessor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ namespace Unity.AI.MLAgents
88
public interface IWorldProcessor : IDisposable
99
{
1010
bool IsConnected {get;}
11-
RemoteCommand ProcessWorld();
11+
WorldCommand ProcessWorld();
1212
}
1313
}

Runtime/WorldProcessor/NullWorldProcessor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ internal NullWorldProcessor(MLAgentsWorld world)
1616
this.world = world;
1717
}
1818

19-
public RemoteCommand ProcessWorld()
19+
public WorldCommand ProcessWorld()
2020
{
21-
return RemoteCommand.DEFAULT;
21+
return WorldCommand.DEFAULT;
2222
}
2323

2424
public void Dispose()

0 commit comments

Comments
 (0)