Skip to content

Exporting all the branches size instead of omly the sum #5092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public static bool HasDiscreteOutputs(this Model model)
else
{
return model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
(int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0] > 0;
(int)model.DiscreteOutputSize() > 0;
}
}

Expand All @@ -262,7 +262,19 @@ public static int DiscreteOutputSize(this Model model)
else
{
var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape);
return discreteOutputShape == null ? 0 : (int)discreteOutputShape[0];
if (discreteOutputShape == null)
{
return 0;
}
else
{
int result = 0;
for (int i = 0; i < discreteOutputShape.length; i++)
{
result += (int)discreteOutputShape[i];
}
return result;
}
}
}

Expand Down
61 changes: 58 additions & 3 deletions com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,19 @@ static IEnumerable<FailedCheck> CheckOutputTensorShape(
{
failedModelChecks.Add(continuousError);
}
var modelSumDiscreteBranchSizes = model.DiscreteOutputSize();
var discreteError = CheckDiscreteActionOutputShape(brainParameters, actuatorComponents, modelSumDiscreteBranchSizes);
FailedCheck discreteError = null;
var modelApiVersion = model.GetVersion();
if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0)
{
var modelSumDiscreteBranchSizes = model.DiscreteOutputSize();
discreteError = CheckDiscreteActionOutputShapeLegacy(brainParameters, actuatorComponents, modelSumDiscreteBranchSizes);
}
if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0)
{
var modeDiscreteBranches = model.GetTensorByName(TensorNames.DiscreteActionOutputShape);
discreteError = CheckDiscreteActionOutputShape(brainParameters, actuatorComponents, modeDiscreteBranches);
}

if (discreteError != null)
{
failedModelChecks.Add(discreteError);
Expand All @@ -733,14 +744,58 @@ static IEnumerable<FailedCheck> CheckOutputTensorShape(
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="actuatorComponents">Array of attached actuator components.</param>
/// <param name="modelDiscreteBranches"> The Tensor of branch sizes.
/// </param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static FailedCheck CheckDiscreteActionOutputShape(
BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, Tensor modelDiscreteBranches)
{

var discreteActionBranches = brainParameters.ActionSpec.BranchSizes.ToList();
foreach (var actuatorComponent in actuatorComponents)
{
var actionSpec = actuatorComponent.ActionSpec;
discreteActionBranches.AddRange(actionSpec.BranchSizes);
}

if (modelDiscreteBranches.length != discreteActionBranches.Count)
{
return FailedCheck.Warning("Discrete Action Size of the model does not match. The BrainParameters expect " +
$"{discreteActionBranches.Count} branches but the model contains {modelDiscreteBranches.length}."
);
}

for (int i = 0; i < modelDiscreteBranches.length; i++)
{
if (modelDiscreteBranches[i] != discreteActionBranches[i])
{
return FailedCheck.Warning($"The number of Discrete Actions of branch {i} does not match. " +
$"Was expecting {discreteActionBranches[i]} but the model contains {modelDiscreteBranches[i]} "
);
}
}
return null;
}

/// <summary>
/// Checks that the shape of the discrete action output is the same in the
/// model and in the Brain Parameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="actuatorComponents">Array of attached actuator components.</param>
/// <param name="modelSumDiscreteBranchSizes">
/// The size of the discrete action output that is expected by the model.
/// </param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static FailedCheck CheckDiscreteActionOutputShape(
static FailedCheck CheckDiscreteActionOutputShapeLegacy(
BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelSumDiscreteBranchSizes)
{
// TODO: check each branch size instead of sum of branch sizes
Expand Down
3 changes: 1 addition & 2 deletions ml-agents/mlagents/trainers/torch/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,8 @@ def __init__(
self.continuous_act_size_vector = torch.nn.Parameter(
torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False
)
# TODO: export list of branch sizes instead of sum
self.discrete_act_size_vector = torch.nn.Parameter(
torch.Tensor([sum(self.action_spec.discrete_branches)]), requires_grad=False
torch.Tensor([self.action_spec.discrete_branches]), requires_grad=False
)
self.act_size_vector_deprecated = torch.nn.Parameter(
torch.Tensor(
Expand Down