Skip to content

Commit 8f2f73f

Browse files
vincentpierreChris Elion
and
Chris Elion
authored
V2 staging new model version (#5080)
* Make modelCheck have flavors of error messages * ONNX exporter v3 * Using a better CheckType and a switch statement * Removing unused message * More tests * Use an enum for valid versions and use GetVersion on model directly * Maybe the model export version a static constant in Python * Use static constructor for FailedCheck * Use static constructor for FailedCheck * Modifying the docstrings * renaming LegacyDiscreteActionOutputApplier * removing testing code * better warning message * Nest the CheckTypeEnum into the FailedCheck class * Update com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs * Adding a line explaining that legacy tensor checks are for versions 1.X only * Modifying the changelog * Exporting all the branches size instead of omly the sum (#5092) * addressing comments * Update com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs Co-authored-by: Chris Elion <chris.elion@unity3d.com> * readding tests * Adding a comment around the new DiscreteOutputSize method * Clearer warning : Model contains unexpected input > Model requires unknown input * Fixing a bug in the case where the discrete action tensor does not exist Co-authored-by: Chris Elion <chris.elion@unity3d.com>
1 parent c8d1b5f commit 8f2f73f

12 files changed

+697
-215
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ details. (#5060)
1919

2020
### Minor Changes
2121
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
22+
- The `.onnx` models input names have changed. All input placeholders will now use the prefix `obs_` removing the distinction between visual and vector observations. Models created with this version will not be usable with previous versions of the package (#5080)
23+
- The `.onnx` models discrete action output now contains the discrete actions values and not the logits. Models created with this version will not be usable with previous versions of the package (#5080)
2224
#### ml-agents / ml-agents-envs / gym-unity (Python)
2325

2426
### Bug Fixes

com.unity.ml-agents/Editor/BehaviorParametersEditor.cs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using Unity.MLAgents.Policies;
66
using Unity.MLAgents.Sensors;
77
using Unity.MLAgents.Sensors.Reflection;
8+
using CheckTypeEnum = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck.CheckTypeEnum;
89

910
namespace Unity.MLAgents.Editor
1011
{
@@ -147,7 +148,20 @@ void DisplayFailedModelChecks()
147148
{
148149
if (check != null)
149150
{
150-
EditorGUILayout.HelpBox(check, MessageType.Warning);
151+
switch (check.CheckType)
152+
{
153+
case CheckTypeEnum.Info:
154+
EditorGUILayout.HelpBox(check.Message, MessageType.Info);
155+
break;
156+
case CheckTypeEnum.Warning:
157+
EditorGUILayout.HelpBox(check.Message, MessageType.Warning);
158+
break;
159+
case CheckTypeEnum.Error:
160+
EditorGUILayout.HelpBox(check.Message, MessageType.Error);
161+
break;
162+
default:
163+
break;
164+
}
151165
}
152166
}
153167
}

com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,51 @@ public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int,
4646
}
4747
}
4848

49+
/// <summary>
50+
/// The Applier for the Discrete Action output tensor.
51+
/// </summary>
52+
internal class DiscreteActionOutputApplier : TensorApplier.IApplier
53+
{
54+
readonly ActionSpec m_ActionSpec;
55+
56+
57+
public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
58+
{
59+
m_ActionSpec = actionSpec;
60+
}
61+
62+
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
63+
{
64+
var agentIndex = 0;
65+
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
66+
for (var i = 0; i < actionIds.Count; i++)
67+
{
68+
var agentId = actionIds[i];
69+
if (lastActions.ContainsKey(agentId))
70+
{
71+
var actionBuffer = lastActions[agentId];
72+
if (actionBuffer.IsEmpty())
73+
{
74+
actionBuffer = new ActionBuffers(m_ActionSpec);
75+
lastActions[agentId] = actionBuffer;
76+
}
77+
var discreteBuffer = actionBuffer.DiscreteActions;
78+
for (var j = 0; j < actionSize; j++)
79+
{
80+
discreteBuffer[j] = (int)tensorProxy.data[agentIndex, j];
81+
}
82+
}
83+
agentIndex++;
84+
}
85+
}
86+
}
87+
88+
4989
/// <summary>
5090
/// The Applier for the Discrete Action output tensor. Uses multinomial to sample discrete
5191
/// actions from the logits contained in the tensor.
5292
/// </summary>
53-
internal class DiscreteActionOutputApplier : TensorApplier.IApplier
93+
internal class LegacyDiscreteActionOutputApplier : TensorApplier.IApplier
5494
{
5595
readonly int[] m_ActionSize;
5696
readonly Multinomial m_Multinomial;
@@ -59,7 +99,7 @@ internal class DiscreteActionOutputApplier : TensorApplier.IApplier
5999
readonly float[] m_CdfBuffer;
60100

61101

62-
public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
102+
public LegacyDiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
63103
{
64104
m_ActionSize = actionSpec.BranchSizes;
65105
m_Multinomial = new Multinomial(seed);

com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Collections.Generic;
22
using System.Linq;
33
using Unity.Barracuda;
4+
using FailedCheck = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck;
45

56
namespace Unity.MLAgents.Inference
67
{
@@ -38,6 +39,18 @@ public static string[] GetInputNames(this Model model)
3839
return names.ToArray();
3940
}
4041

42+
/// <summary>
43+
/// Get the version of the model.
44+
/// </summary>
45+
/// <param name="model">
46+
/// The Barracuda engine model for loading static parameters.
47+
/// </param>
48+
/// <returns>The api version of the model</returns>
49+
public static int GetVersion(this Model model)
50+
{
51+
return (int)model.GetTensorByName(TensorNames.VersionNumber)[0];
52+
}
53+
4154
/// <summary>
4255
/// Generates the Tensor inputs that are expected to be present in the Model.
4356
/// </summary>
@@ -226,12 +239,20 @@ public static bool HasDiscreteOutputs(this Model model)
226239
else
227240
{
228241
return model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
229-
(int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0] > 0;
242+
(int)model.DiscreteOutputSize() > 0;
230243
}
231244
}
232245

233246
/// <summary>
234247
/// Discrete action output size of the model. This is equal to the sum of the branch sizes.
248+
/// This method gets the tensor representing the list of branch size and returns the
249+
/// sum of all the elements in the Tensor.
250+
/// - In version 1.X this tensor contains a single number, the sum of all branch
251+
/// size values.
252+
/// - In version 2.X this tensor contains a 1D Tensor with each element corresponding
253+
/// to a branch size.
254+
/// Since this method does the sum of all elements in the tensor, the output
255+
/// will be the same on both 1.X and 2.X.
235256
/// </summary>
236257
/// <param name="model">
237258
/// The Barracuda engine model for loading static parameters.
@@ -249,7 +270,19 @@ public static int DiscreteOutputSize(this Model model)
249270
else
250271
{
251272
var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape);
252-
return discreteOutputShape == null ? 0 : (int)discreteOutputShape[0];
273+
if (discreteOutputShape == null)
274+
{
275+
return 0;
276+
}
277+
else
278+
{
279+
int result = 0;
280+
for (int i = 0; i < discreteOutputShape.length; i++)
281+
{
282+
result += (int)discreteOutputShape[i];
283+
}
284+
return result;
285+
}
253286
}
254287
}
255288

@@ -298,21 +331,25 @@ public static bool SupportsContinuousAndDiscrete(this Model model)
298331
/// <param name="failedModelChecks">Output list of failure messages</param>
299332
///
300333
/// <returns>True if the model contains all the expected tensors.</returns>
301-
public static bool CheckExpectedTensors(this Model model, List<string> failedModelChecks)
334+
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks)
302335
{
303336
// Check the presence of model version
304337
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
305338
if (modelApiVersionTensor == null)
306339
{
307-
failedModelChecks.Add($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.");
340+
failedModelChecks.Add(
341+
FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.")
342+
);
308343
return false;
309344
}
310345

311346
// Check the presence of memory size
312347
var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize);
313348
if (memorySizeTensor == null)
314349
{
315-
failedModelChecks.Add($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.");
350+
failedModelChecks.Add(
351+
FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.")
352+
);
316353
return false;
317354
}
318355

@@ -321,7 +358,9 @@ public static bool CheckExpectedTensors(this Model model, List<string> failedMod
321358
!model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
322359
!model.outputs.Contains(TensorNames.DiscreteActionOutput))
323360
{
324-
failedModelChecks.Add("The model does not contain any Action Output Node.");
361+
failedModelChecks.Add(
362+
FailedCheck.Warning("The model does not contain any Action Output Node.")
363+
);
325364
return false;
326365
}
327366

@@ -330,13 +369,18 @@ public static bool CheckExpectedTensors(this Model model, List<string> failedMod
330369
{
331370
if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null)
332371
{
333-
failedModelChecks.Add("The model does not contain any Action Output Shape Node.");
372+
failedModelChecks.Add(
373+
FailedCheck.Warning("The model does not contain any Action Output Shape Node.")
374+
);
334375
return false;
335376
}
336377
if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null)
337378
{
338-
failedModelChecks.Add($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was not found in the model file. " +
339-
"This is only required for model that uses a deprecated model format.");
379+
failedModelChecks.Add(
380+
FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " +
381+
"not found in the model file. " +
382+
"This is only required for model that uses a deprecated model format.")
383+
);
340384
return false;
341385
}
342386
}
@@ -345,13 +389,17 @@ public static bool CheckExpectedTensors(this Model model, List<string> failedMod
345389
if (model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
346390
model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
347391
{
348-
failedModelChecks.Add("The model uses continuous action but does not contain Continuous Action Output Shape Node.");
392+
failedModelChecks.Add(
393+
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
394+
);
349395
return false;
350396
}
351397
if (model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
352398
model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
353399
{
354-
failedModelChecks.Add("The model uses discrete action but does not contain Discrete Action Output Shape Node.");
400+
failedModelChecks.Add(
401+
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
402+
);
355403
return false;
356404
}
357405
}

0 commit comments

Comments
 (0)