Skip to content

Commit 192b5e6

Browse files
author
Ervin T
authored
Add multiAgentGroup capabilities flag (#5096)
* Add multiAgentGroup capabilities flag * Add proto * Fix compiler error * Add warning for multiagent group * Add comment * Fix spelling mistake
1 parent f67ed30 commit 192b5e6

File tree

8 files changed

+82
-11
lines changed

8 files changed

+82
-11
lines changed

com.unity.ml-agents/Runtime/Academy.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ public class Academy : IDisposable
9797
/// </item>
9898
/// <item>
9999
/// <term>1.5.0</term>
100-
/// <description>Support variable length observation training.</description>
100+
/// <description>Support variable length observation training and multi-agent groups.</description>
101101
/// </item>
102102
/// </list>
103103
/// </remarks>

com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ namespace Unity.MLAgents
2020
internal static class GrpcExtensions
2121
{
2222
#region AgentInfo
23+
/// <summary>
24+
/// Static flag to make sure that we only fire the warning once.
25+
/// </summary>
26+
private static bool s_HaveWarnedTrainerCapabilitiesAgentGroup = false;
27+
2328
/// <summary>
2429
/// Converts a AgentInfo to a protobuf generated AgentInfoActionPairProto
2530
/// </summary>
@@ -55,6 +60,22 @@ public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai)
5560
/// <returns>The protobuf version of the AgentInfo.</returns>
5661
public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
5762
{
63+
if(ai.groupId > 0)
64+
{
65+
var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.MultiAgentGroups;
66+
if (!trainerCanHandle)
67+
{
68+
if (!s_HaveWarnedTrainerCapabilitiesAgentGroup)
69+
{
70+
Debug.LogWarning(
71+
$"Attached trainer doesn't support Multi Agent Groups; group rewards will be ignored." +
72+
"Please find the versions that work best together from our release page: " +
73+
"https://github.com/Unity-Technologies/ml-agents/releases"
74+
);
75+
s_HaveWarnedTrainerCapabilitiesAgentGroup = true;
76+
}
77+
}
78+
}
5879
var agentInfoProto = new AgentInfoProto
5980
{
6081
Reward = ai.reward,
@@ -457,6 +478,7 @@ public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto
457478
HybridActions = proto.HybridActions,
458479
TrainingAnalytics = proto.TrainingAnalytics,
459480
VariableLengthObservation = proto.VariableLengthObservation,
481+
MultiAgentGroups = proto.MultiAgentGroups,
460482
};
461483
}
462484

@@ -470,6 +492,7 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps)
470492
HybridActions = rlCaps.HybridActions,
471493
TrainingAnalytics = rlCaps.TrainingAnalytics,
472494
VariableLengthObservation = rlCaps.VariableLengthObservation,
495+
MultiAgentGroups = rlCaps.MultiAgentGroups,
473496
};
474497
}
475498

com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ internal class UnityRLCapabilities
1010
public bool HybridActions;
1111
public bool TrainingAnalytics;
1212
public bool VariableLengthObservation;
13+
public bool MultiAgentGroups;
1314

1415
/// <summary>
1516
/// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This
@@ -21,14 +22,16 @@ public UnityRLCapabilities(
2122
bool compressedChannelMapping = true,
2223
bool hybridActions = true,
2324
bool trainingAnalytics = true,
24-
bool variableLengthObservation = true)
25+
bool variableLengthObservation = true,
26+
bool multiAgentGroups = true)
2527
{
2628
BaseRLCapabilities = baseRlCapabilities;
2729
ConcatenatedPngObservations = concatenatedPngObservations;
2830
CompressedChannelMapping = compressedChannelMapping;
2931
HybridActions = hybridActions;
3032
TrainingAnalytics = trainingAnalytics;
3133
VariableLengthObservation = variableLengthObservation;
34+
MultiAgentGroups = multiAgentGroups;
3235
}
3336

3437
/// <summary>

com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,18 @@ static CapabilitiesReflection() {
2525
byte[] descriptorData = global::System.Convert.FromBase64String(
2626
string.Concat(
2727
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
28-
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi0gEKGFVuaXR5UkxD",
28+
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi7AEKGFVuaXR5UkxD",
2929
"YXBhYmlsaXRpZXNQcm90bxIaChJiYXNlUkxDYXBhYmlsaXRpZXMYASABKAgS",
3030
"IwobY29uY2F0ZW5hdGVkUG5nT2JzZXJ2YXRpb25zGAIgASgIEiAKGGNvbXBy",
3131
"ZXNzZWRDaGFubmVsTWFwcGluZxgDIAEoCBIVCg1oeWJyaWRBY3Rpb25zGAQg",
3232
"ASgIEhkKEXRyYWluaW5nQW5hbHl0aWNzGAUgASgIEiEKGXZhcmlhYmxlTGVu",
33-
"Z3RoT2JzZXJ2YXRpb24YBiABKAhCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11",
34-
"bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
33+
"Z3RoT2JzZXJ2YXRpb24YBiABKAgSGAoQbXVsdGlBZ2VudEdyb3VwcxgHIAEo",
34+
"CEIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJv",
35+
"dG8z"));
3536
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
3637
new pbr::FileDescriptor[] { },
3738
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
38-
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics", "VariableLengthObservation" }, null, null, null)
39+
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics", "VariableLengthObservation", "MultiAgentGroups" }, null, null, null)
3940
}));
4041
}
4142
#endregion
@@ -78,6 +79,7 @@ public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() {
7879
hybridActions_ = other.hybridActions_;
7980
trainingAnalytics_ = other.trainingAnalytics_;
8081
variableLengthObservation_ = other.variableLengthObservation_;
82+
multiAgentGroups_ = other.multiAgentGroups_;
8183
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
8284
}
8385

@@ -170,6 +172,20 @@ public bool VariableLengthObservation {
170172
}
171173
}
172174

175+
/// <summary>Field number for the "multiAgentGroups" field.</summary>
176+
public const int MultiAgentGroupsFieldNumber = 7;
177+
private bool multiAgentGroups_;
178+
/// <summary>
179+
/// Support for multi agent groups and group rewards
180+
/// </summary>
181+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
182+
public bool MultiAgentGroups {
183+
get { return multiAgentGroups_; }
184+
set {
185+
multiAgentGroups_ = value;
186+
}
187+
}
188+
173189
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
174190
public override bool Equals(object other) {
175191
return Equals(other as UnityRLCapabilitiesProto);
@@ -189,6 +205,7 @@ public bool Equals(UnityRLCapabilitiesProto other) {
189205
if (HybridActions != other.HybridActions) return false;
190206
if (TrainingAnalytics != other.TrainingAnalytics) return false;
191207
if (VariableLengthObservation != other.VariableLengthObservation) return false;
208+
if (MultiAgentGroups != other.MultiAgentGroups) return false;
192209
return Equals(_unknownFields, other._unknownFields);
193210
}
194211

@@ -201,6 +218,7 @@ public override int GetHashCode() {
201218
if (HybridActions != false) hash ^= HybridActions.GetHashCode();
202219
if (TrainingAnalytics != false) hash ^= TrainingAnalytics.GetHashCode();
203220
if (VariableLengthObservation != false) hash ^= VariableLengthObservation.GetHashCode();
221+
if (MultiAgentGroups != false) hash ^= MultiAgentGroups.GetHashCode();
204222
if (_unknownFields != null) {
205223
hash ^= _unknownFields.GetHashCode();
206224
}
@@ -238,6 +256,10 @@ public void WriteTo(pb::CodedOutputStream output) {
238256
output.WriteRawTag(48);
239257
output.WriteBool(VariableLengthObservation);
240258
}
259+
if (MultiAgentGroups != false) {
260+
output.WriteRawTag(56);
261+
output.WriteBool(MultiAgentGroups);
262+
}
241263
if (_unknownFields != null) {
242264
_unknownFields.WriteTo(output);
243265
}
@@ -264,6 +286,9 @@ public int CalculateSize() {
264286
if (VariableLengthObservation != false) {
265287
size += 1 + 1;
266288
}
289+
if (MultiAgentGroups != false) {
290+
size += 1 + 1;
291+
}
267292
if (_unknownFields != null) {
268293
size += _unknownFields.CalculateSize();
269294
}
@@ -293,6 +318,9 @@ public void MergeFrom(UnityRLCapabilitiesProto other) {
293318
if (other.VariableLengthObservation != false) {
294319
VariableLengthObservation = other.VariableLengthObservation;
295320
}
321+
if (other.MultiAgentGroups != false) {
322+
MultiAgentGroups = other.MultiAgentGroups;
323+
}
296324
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
297325
}
298326

@@ -328,6 +356,10 @@ public void MergeFrom(pb::CodedInputStream input) {
328356
VariableLengthObservation = input.ReadBool();
329357
break;
330358
}
359+
case 56: {
360+
MultiAgentGroups = input.ReadBool();
361+
break;
362+
}
331363
}
332364
}
333365
}

ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message):
3131
hybridActions = ... # type: builtin___bool
3232
trainingAnalytics = ... # type: builtin___bool
3333
variableLengthObservation = ... # type: builtin___bool
34+
multiAgentGroups = ... # type: builtin___bool
3435

3536
def __init__(self,
3637
*,
@@ -40,12 +41,13 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message):
4041
hybridActions : typing___Optional[builtin___bool] = None,
4142
trainingAnalytics : typing___Optional[builtin___bool] = None,
4243
variableLengthObservation : typing___Optional[builtin___bool] = None,
44+
multiAgentGroups : typing___Optional[builtin___bool] = None,
4345
) -> None: ...
4446
@classmethod
4547
def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ...
4648
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
4749
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
4850
if sys.version_info >= (3,):
49-
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"trainingAnalytics",u"variableLengthObservation"]) -> None: ...
51+
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"multiAgentGroups",u"trainingAnalytics",u"variableLengthObservation"]) -> None: ...
5052
else:
51-
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"trainingAnalytics",b"trainingAnalytics",u"variableLengthObservation",b"variableLengthObservation"]) -> None: ...
53+
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"multiAgentGroups",b"multiAgentGroups",u"trainingAnalytics",b"trainingAnalytics",u"variableLengthObservation",b"variableLengthObservation"]) -> None: ...

ml-agents-envs/mlagents_envs/environment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class UnityEnvironment(BaseEnv):
6363
# * 1.2.0 - support compression mapping for stacked compressed observations.
6464
# * 1.3.0 - support action spaces with both continuous and discrete actions.
6565
# * 1.4.0 - support training analytics sent from python trainer to the editor.
66-
# * 1.5.0 - support variable length observation training.
66+
# * 1.5.0 - support variable length observation training and multi-agent groups.
6767
API_VERSION = "1.5.0"
6868

6969
# Default port that the editor listens on. If an environment executable
@@ -124,6 +124,7 @@ def _get_capabilities_proto() -> UnityRLCapabilitiesProto:
124124
capabilities.hybridActions = True
125125
capabilities.trainingAnalytics = True
126126
capabilities.variableLengthObservation = True
127+
capabilities.multiAgentGroups = True
127128
return capabilities
128129

129130
@staticmethod

protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,7 @@ message UnityRLCapabilitiesProto {
2525

2626
// Support for variable length observations of rank 2
2727
bool variableLengthObservation = 6;
28+
29+
// Support for multi agent groups and group rewards
30+
bool multiAgentGroups = 7;
2831
}

0 commit comments

Comments
 (0)