Skip to content

Add multiAgentGroup capabilities flag #5096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion com.unity.ml-agents/Runtime/Academy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public class Academy : IDisposable
/// </item>
/// <item>
/// <term>1.5.0</term>
/// <description>Support variable length observation training.</description>
/// <description>Support variable length observation training and multi-agent groups.</description>
/// </item>
/// </list>
/// </remarks>
Expand Down
23 changes: 23 additions & 0 deletions com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ namespace Unity.MLAgents
internal static class GrpcExtensions
{
#region AgentInfo
/// <summary>
/// Static flag to make sure that we only fire the warning once.
/// </summary>
private static bool s_HaveWarnedTrainerCapabilitiesAgentGroup = false;

/// <summary>
/// Converts a AgentInfo to a protobuf generated AgentInfoActionPairProto
/// </summary>
Expand Down Expand Up @@ -55,6 +60,22 @@ public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai)
/// <returns>The protobuf version of the AgentInfo.</returns>
public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
{
if(ai.groupId > 0)
{
var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.MultiAgentGroups;
if (!trainerCanHandle)
{
if (!s_HaveWarnedTrainerCapabilitiesAgentGroup)
{
Debug.LogWarning(
$"Attached trainer doesn't support Multi Agent Groups; group rewards will be ignored." +
"Please find the versions that work best together from our release page: " +
"https://github.com/Unity-Technologies/ml-agents/releases"
);
s_HaveWarnedTrainerCapabilitiesAgentGroup = true;
}
}
}
var agentInfoProto = new AgentInfoProto
{
Reward = ai.reward,
Expand Down Expand Up @@ -457,6 +478,7 @@ public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto
HybridActions = proto.HybridActions,
TrainingAnalytics = proto.TrainingAnalytics,
VariableLengthObservation = proto.VariableLengthObservation,
MultiAgentGroups = proto.MultiAgentGroups,
};
}

Expand All @@ -470,6 +492,7 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps)
HybridActions = rlCaps.HybridActions,
TrainingAnalytics = rlCaps.TrainingAnalytics,
VariableLengthObservation = rlCaps.VariableLengthObservation,
MultiAgentGroups = rlCaps.MultiAgentGroups,
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ internal class UnityRLCapabilities
public bool HybridActions;
public bool TrainingAnalytics;
public bool VariableLengthObservation;
public bool MultiAgentGroups;

/// <summary>
/// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This
Expand All @@ -21,14 +22,16 @@ public UnityRLCapabilities(
bool compressedChannelMapping = true,
bool hybridActions = true,
bool trainingAnalytics = true,
bool variableLengthObservation = true)
bool variableLengthObservation = true,
bool multiAgentGroups = true)
{
BaseRLCapabilities = baseRlCapabilities;
ConcatenatedPngObservations = concatenatedPngObservations;
CompressedChannelMapping = compressedChannelMapping;
HybridActions = hybridActions;
TrainingAnalytics = trainingAnalytics;
VariableLengthObservation = variableLengthObservation;
MultiAgentGroups = multiAgentGroups;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@ static CapabilitiesReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi0gEKGFVuaXR5UkxD",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi7AEKGFVuaXR5UkxD",
"YXBhYmlsaXRpZXNQcm90bxIaChJiYXNlUkxDYXBhYmlsaXRpZXMYASABKAgS",
"IwobY29uY2F0ZW5hdGVkUG5nT2JzZXJ2YXRpb25zGAIgASgIEiAKGGNvbXBy",
"ZXNzZWRDaGFubmVsTWFwcGluZxgDIAEoCBIVCg1oeWJyaWRBY3Rpb25zGAQg",
"ASgIEhkKEXRyYWluaW5nQW5hbHl0aWNzGAUgASgIEiEKGXZhcmlhYmxlTGVu",
"Z3RoT2JzZXJ2YXRpb24YBiABKAhCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11",
"bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
"Z3RoT2JzZXJ2YXRpb24YBiABKAgSGAoQbXVsdGlBZ2VudEdyb3VwcxgHIAEo",
"CEIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJv",
"dG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics", "VariableLengthObservation" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics", "VariableLengthObservation", "MultiAgentGroups" }, null, null, null)
}));
}
#endregion
Expand Down Expand Up @@ -78,6 +79,7 @@ public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() {
hybridActions_ = other.hybridActions_;
trainingAnalytics_ = other.trainingAnalytics_;
variableLengthObservation_ = other.variableLengthObservation_;
multiAgentGroups_ = other.multiAgentGroups_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand Down Expand Up @@ -170,6 +172,20 @@ public bool VariableLengthObservation {
}
}

/// <summary>Field number for the "multiAgentGroups" field.</summary>
public const int MultiAgentGroupsFieldNumber = 7;
private bool multiAgentGroups_;
/// <summary>
/// Support for multi agent groups and group rewards
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool MultiAgentGroups {
get { return multiAgentGroups_; }
set {
multiAgentGroups_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);
Expand All @@ -189,6 +205,7 @@ public bool Equals(UnityRLCapabilitiesProto other) {
if (HybridActions != other.HybridActions) return false;
if (TrainingAnalytics != other.TrainingAnalytics) return false;
if (VariableLengthObservation != other.VariableLengthObservation) return false;
if (MultiAgentGroups != other.MultiAgentGroups) return false;
return Equals(_unknownFields, other._unknownFields);
}

Expand All @@ -201,6 +218,7 @@ public override int GetHashCode() {
if (HybridActions != false) hash ^= HybridActions.GetHashCode();
if (TrainingAnalytics != false) hash ^= TrainingAnalytics.GetHashCode();
if (VariableLengthObservation != false) hash ^= VariableLengthObservation.GetHashCode();
if (MultiAgentGroups != false) hash ^= MultiAgentGroups.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand Down Expand Up @@ -238,6 +256,10 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(48);
output.WriteBool(VariableLengthObservation);
}
if (MultiAgentGroups != false) {
output.WriteRawTag(56);
output.WriteBool(MultiAgentGroups);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand All @@ -264,6 +286,9 @@ public int CalculateSize() {
if (VariableLengthObservation != false) {
size += 1 + 1;
}
if (MultiAgentGroups != false) {
size += 1 + 1;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand Down Expand Up @@ -293,6 +318,9 @@ public void MergeFrom(UnityRLCapabilitiesProto other) {
if (other.VariableLengthObservation != false) {
VariableLengthObservation = other.VariableLengthObservation;
}
if (other.MultiAgentGroups != false) {
MultiAgentGroups = other.MultiAgentGroups;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand Down Expand Up @@ -328,6 +356,10 @@ public void MergeFrom(pb::CodedInputStream input) {
VariableLengthObservation = input.ReadBool();
break;
}
case 56: {
MultiAgentGroups = input.ReadBool();
break;
}
}
}
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message):
hybridActions = ... # type: builtin___bool
trainingAnalytics = ... # type: builtin___bool
variableLengthObservation = ... # type: builtin___bool
multiAgentGroups = ... # type: builtin___bool

def __init__(self,
*,
Expand All @@ -40,12 +41,13 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message):
hybridActions : typing___Optional[builtin___bool] = None,
trainingAnalytics : typing___Optional[builtin___bool] = None,
variableLengthObservation : typing___Optional[builtin___bool] = None,
multiAgentGroups : typing___Optional[builtin___bool] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"trainingAnalytics",u"variableLengthObservation"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"multiAgentGroups",u"trainingAnalytics",u"variableLengthObservation"]) -> None: ...
else:
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"trainingAnalytics",b"trainingAnalytics",u"variableLengthObservation",b"variableLengthObservation"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"multiAgentGroups",b"multiAgentGroups",u"trainingAnalytics",b"trainingAnalytics",u"variableLengthObservation",b"variableLengthObservation"]) -> None: ...
3 changes: 2 additions & 1 deletion ml-agents-envs/mlagents_envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class UnityEnvironment(BaseEnv):
# * 1.2.0 - support compression mapping for stacked compressed observations.
# * 1.3.0 - support action spaces with both continuous and discrete actions.
# * 1.4.0 - support training analytics sent from python trainer to the editor.
# * 1.5.0 - support variable length observation training.
# * 1.5.0 - support variable length observation training and multi-agent groups.
API_VERSION = "1.5.0"

# Default port that the editor listens on. If an environment executable
Expand Down Expand Up @@ -124,6 +124,7 @@ def _get_capabilities_proto() -> UnityRLCapabilitiesProto:
capabilities.hybridActions = True
capabilities.trainingAnalytics = True
capabilities.variableLengthObservation = True
capabilities.multiAgentGroups = True
return capabilities

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ message UnityRLCapabilitiesProto {

// Support for variable length observations of rank 2
bool variableLengthObservation = 6;

// Support for multi agent groups and group rewards
bool multiAgentGroups = 7;
}