Skip to content

Commit b6d46b6

Browse files
committed
Extend TrainingAnalytics side channel to expose configuration details
Update python files to adhere to linting rules Fix typing check
1 parent 0fd6b96 commit b6d46b6

File tree

8 files changed

+208
-28
lines changed

8 files changed

+208
-28
lines changed

com.unity.ml-agents/Runtime/Analytics/Events.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ internal struct TrainingEnvironmentInitializedEvent
156156
public string TorchDeviceType;
157157
public int NumEnvironments;
158158
public int NumEnvironmentParameters;
159+
public string RunOptions;
159160
}
160161

161162
[Flags]
@@ -188,5 +189,6 @@ internal struct TrainingBehaviorInitializedEvent
188189
public string VisualEncoder;
189190
public int NumNetworkLayers;
190191
public int NumNetworkHiddenUnits;
192+
public string Config;
191193
}
192194
}

com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,14 @@ public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent
211211
return;
212212
}
213213

214-
// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
215214
tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString();
216-
tbiEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, tbiEvent.BehaviorName);
217215

216+
if(tbiEvent.Config.Length == 0 || tbiEvent.BehaviorName.Length != 64) {
217+
// Hash the behavior name if the message version is from an older version of ml-agents that doesn't do trainer-side hashing.
218+
// We'll also, for extra safety, verify that the BehaviorName is the size of the expected SHA256 hash.
219+
// Context: The config field was added at the same time as trainer side hashing, so messages including it should already be hashed.
220+
tbiEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, tbiEvent.BehaviorName);
221+
}
218222
// Note - to debug, use JsonUtility.ToJson on the event.
219223
// Debug.Log(
220224
// $"Would send event {k_TrainingBehaviorInitializedEventName} with body {JsonUtility.ToJson(tbiEvent, true)}"
@@ -236,7 +240,7 @@ IList<IActuator> actuators
236240
var remotePolicyEvent = new RemotePolicyInitializedEvent();
237241

238242
// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
239-
remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName);
243+
remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, behaviorName);
240244

241245
remotePolicyEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString();
242246
remotePolicyEvent.ActionSpec = EventActionSpec.FromActionSpec(actionSpec);

com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitial
501501
TorchDeviceType = inputProto.TorchDeviceType,
502502
NumEnvironments = inputProto.NumEnvs,
503503
NumEnvironmentParameters = inputProto.NumEnvironmentParameters,
504+
RunOptions = inputProto.RunOptions,
504505
};
505506
}
506507

@@ -530,6 +531,7 @@ internal static TrainingBehaviorInitializedEvent ToTrainingBehaviorInitializedEv
530531
VisualEncoder = inputProto.VisualEncoder,
531532
NumNetworkLayers = inputProto.NumNetworkLayers,
532533
NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits,
534+
Config = inputProto.Config,
533535
};
534536
}
535537

com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/TrainingAnalytics.cs

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,29 @@ static TrainingAnalyticsReflection() {
2525
byte[] descriptorData = global::System.Convert.FromBase64String(
2626
string.Concat(
2727
"CjttbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3RyYWluaW5n",
28-
"X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi2QEKHlRy",
28+
"X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi7gEKHlRy",
2929
"YWluaW5nRW52aXJvbm1lbnRJbml0aWFsaXplZBIYChBtbGFnZW50c192ZXJz",
3030
"aW9uGAEgASgJEh0KFW1sYWdlbnRzX2VudnNfdmVyc2lvbhgCIAEoCRIWCg5w",
3131
"eXRob25fdmVyc2lvbhgDIAEoCRIVCg10b3JjaF92ZXJzaW9uGAQgASgJEhkK",
3232
"EXRvcmNoX2RldmljZV90eXBlGAUgASgJEhAKCG51bV9lbnZzGAYgASgFEiIK",
33-
"Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFIq0DChtUcmFpbmlu",
34-
"Z0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoNYmVoYXZpb3JfbmFtZRgBIAEoCRIU",
35-
"Cgx0cmFpbmVyX3R5cGUYAiABKAkSIAoYZXh0cmluc2ljX3Jld2FyZF9lbmFi",
36-
"bGVkGAMgASgIEhsKE2dhaWxfcmV3YXJkX2VuYWJsZWQYBCABKAgSIAoYY3Vy",
37-
"aW9zaXR5X3Jld2FyZF9lbmFibGVkGAUgASgIEhoKEnJuZF9yZXdhcmRfZW5h",
38-
"YmxlZBgGIAEoCBIiChpiZWhhdmlvcmFsX2Nsb25pbmdfZW5hYmxlZBgHIAEo",
39-
"CBIZChFyZWN1cnJlbnRfZW5hYmxlZBgIIAEoCBIWCg52aXN1YWxfZW5jb2Rl",
40-
"chgJIAEoCRIaChJudW1fbmV0d29ya19sYXllcnMYCiABKAUSIAoYbnVtX25l",
41-
"dHdvcmtfaGlkZGVuX3VuaXRzGAsgASgFEhgKEHRyYWluZXJfdGhyZWFkZWQY",
42-
"DCABKAgSGQoRc2VsZl9wbGF5X2VuYWJsZWQYDSABKAgSGgoSY3VycmljdWx1",
43-
"bV9lbmFibGVkGA4gASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0",
44-
"b3JPYmplY3RzYgZwcm90bzM="));
33+
"Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFEhMKC3J1bl9vcHRp",
34+
"b25zGAggASgJIr0DChtUcmFpbmluZ0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoN",
35+
"YmVoYXZpb3JfbmFtZRgBIAEoCRIUCgx0cmFpbmVyX3R5cGUYAiABKAkSIAoY",
36+
"ZXh0cmluc2ljX3Jld2FyZF9lbmFibGVkGAMgASgIEhsKE2dhaWxfcmV3YXJk",
37+
"X2VuYWJsZWQYBCABKAgSIAoYY3VyaW9zaXR5X3Jld2FyZF9lbmFibGVkGAUg",
38+
"ASgIEhoKEnJuZF9yZXdhcmRfZW5hYmxlZBgGIAEoCBIiChpiZWhhdmlvcmFs",
39+
"X2Nsb25pbmdfZW5hYmxlZBgHIAEoCBIZChFyZWN1cnJlbnRfZW5hYmxlZBgI",
40+
"IAEoCBIWCg52aXN1YWxfZW5jb2RlchgJIAEoCRIaChJudW1fbmV0d29ya19s",
41+
"YXllcnMYCiABKAUSIAoYbnVtX25ldHdvcmtfaGlkZGVuX3VuaXRzGAsgASgF",
42+
"EhgKEHRyYWluZXJfdGhyZWFkZWQYDCABKAgSGQoRc2VsZl9wbGF5X2VuYWJs",
43+
"ZWQYDSABKAgSGgoSY3VycmljdWx1bV9lbmFibGVkGA4gASgIEg4KBmNvbmZp",
44+
"ZxgPIAEoCUIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0",
45+
"c2IGcHJvdG8z"));
4546
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
4647
new pbr::FileDescriptor[] { },
4748
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
48-
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters" }, null, null, null),
49-
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled" }, null, null, null)
49+
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters", "RunOptions" }, null, null, null),
50+
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled", "Config" }, null, null, null)
5051
}));
5152
}
5253
#endregion
@@ -85,6 +86,7 @@ public TrainingEnvironmentInitialized(TrainingEnvironmentInitialized other) : th
8586
torchDeviceType_ = other.torchDeviceType_;
8687
numEnvs_ = other.numEnvs_;
8788
numEnvironmentParameters_ = other.numEnvironmentParameters_;
89+
runOptions_ = other.runOptions_;
8890
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
8991
}
9092

@@ -170,6 +172,17 @@ public int NumEnvironmentParameters {
170172
}
171173
}
172174

175+
/// <summary>Field number for the "run_options" field.</summary>
176+
public const int RunOptionsFieldNumber = 8;
177+
private string runOptions_ = "";
178+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
179+
public string RunOptions {
180+
get { return runOptions_; }
181+
set {
182+
runOptions_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
183+
}
184+
}
185+
173186
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
174187
public override bool Equals(object other) {
175188
return Equals(other as TrainingEnvironmentInitialized);
@@ -190,6 +203,7 @@ public bool Equals(TrainingEnvironmentInitialized other) {
190203
if (TorchDeviceType != other.TorchDeviceType) return false;
191204
if (NumEnvs != other.NumEnvs) return false;
192205
if (NumEnvironmentParameters != other.NumEnvironmentParameters) return false;
206+
if (RunOptions != other.RunOptions) return false;
193207
return Equals(_unknownFields, other._unknownFields);
194208
}
195209

@@ -203,6 +217,7 @@ public override int GetHashCode() {
203217
if (TorchDeviceType.Length != 0) hash ^= TorchDeviceType.GetHashCode();
204218
if (NumEnvs != 0) hash ^= NumEnvs.GetHashCode();
205219
if (NumEnvironmentParameters != 0) hash ^= NumEnvironmentParameters.GetHashCode();
220+
if (RunOptions.Length != 0) hash ^= RunOptions.GetHashCode();
206221
if (_unknownFields != null) {
207222
hash ^= _unknownFields.GetHashCode();
208223
}
@@ -244,6 +259,10 @@ public void WriteTo(pb::CodedOutputStream output) {
244259
output.WriteRawTag(56);
245260
output.WriteInt32(NumEnvironmentParameters);
246261
}
262+
if (RunOptions.Length != 0) {
263+
output.WriteRawTag(66);
264+
output.WriteString(RunOptions);
265+
}
247266
if (_unknownFields != null) {
248267
_unknownFields.WriteTo(output);
249268
}
@@ -273,6 +292,9 @@ public int CalculateSize() {
273292
if (NumEnvironmentParameters != 0) {
274293
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvironmentParameters);
275294
}
295+
if (RunOptions.Length != 0) {
296+
size += 1 + pb::CodedOutputStream.ComputeStringSize(RunOptions);
297+
}
276298
if (_unknownFields != null) {
277299
size += _unknownFields.CalculateSize();
278300
}
@@ -305,6 +327,9 @@ public void MergeFrom(TrainingEnvironmentInitialized other) {
305327
if (other.NumEnvironmentParameters != 0) {
306328
NumEnvironmentParameters = other.NumEnvironmentParameters;
307329
}
330+
if (other.RunOptions.Length != 0) {
331+
RunOptions = other.RunOptions;
332+
}
308333
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
309334
}
310335

@@ -344,6 +369,10 @@ public void MergeFrom(pb::CodedInputStream input) {
344369
NumEnvironmentParameters = input.ReadInt32();
345370
break;
346371
}
372+
case 66: {
373+
RunOptions = input.ReadString();
374+
break;
375+
}
347376
}
348377
}
349378
}
@@ -389,6 +418,7 @@ public TrainingBehaviorInitialized(TrainingBehaviorInitialized other) : this() {
389418
trainerThreaded_ = other.trainerThreaded_;
390419
selfPlayEnabled_ = other.selfPlayEnabled_;
391420
curriculumEnabled_ = other.curriculumEnabled_;
421+
config_ = other.config_;
392422
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
393423
}
394424

@@ -551,6 +581,17 @@ public bool CurriculumEnabled {
551581
}
552582
}
553583

584+
/// <summary>Field number for the "config" field.</summary>
585+
public const int ConfigFieldNumber = 15;
586+
private string config_ = "";
587+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
588+
public string Config {
589+
get { return config_; }
590+
set {
591+
config_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
592+
}
593+
}
594+
554595
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
555596
public override bool Equals(object other) {
556597
return Equals(other as TrainingBehaviorInitialized);
@@ -578,6 +619,7 @@ public bool Equals(TrainingBehaviorInitialized other) {
578619
if (TrainerThreaded != other.TrainerThreaded) return false;
579620
if (SelfPlayEnabled != other.SelfPlayEnabled) return false;
580621
if (CurriculumEnabled != other.CurriculumEnabled) return false;
622+
if (Config != other.Config) return false;
581623
return Equals(_unknownFields, other._unknownFields);
582624
}
583625

@@ -598,6 +640,7 @@ public override int GetHashCode() {
598640
if (TrainerThreaded != false) hash ^= TrainerThreaded.GetHashCode();
599641
if (SelfPlayEnabled != false) hash ^= SelfPlayEnabled.GetHashCode();
600642
if (CurriculumEnabled != false) hash ^= CurriculumEnabled.GetHashCode();
643+
if (Config.Length != 0) hash ^= Config.GetHashCode();
601644
if (_unknownFields != null) {
602645
hash ^= _unknownFields.GetHashCode();
603646
}
@@ -667,6 +710,10 @@ public void WriteTo(pb::CodedOutputStream output) {
667710
output.WriteRawTag(112);
668711
output.WriteBool(CurriculumEnabled);
669712
}
713+
if (Config.Length != 0) {
714+
output.WriteRawTag(122);
715+
output.WriteString(Config);
716+
}
670717
if (_unknownFields != null) {
671718
_unknownFields.WriteTo(output);
672719
}
@@ -717,6 +764,9 @@ public int CalculateSize() {
717764
if (CurriculumEnabled != false) {
718765
size += 1 + 1;
719766
}
767+
if (Config.Length != 0) {
768+
size += 1 + pb::CodedOutputStream.ComputeStringSize(Config);
769+
}
720770
if (_unknownFields != null) {
721771
size += _unknownFields.CalculateSize();
722772
}
@@ -770,6 +820,9 @@ public void MergeFrom(TrainingBehaviorInitialized other) {
770820
if (other.CurriculumEnabled != false) {
771821
CurriculumEnabled = other.CurriculumEnabled;
772822
}
823+
if (other.Config.Length != 0) {
824+
Config = other.Config;
825+
}
773826
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
774827
}
775828

@@ -837,6 +890,10 @@ public void MergeFrom(pb::CodedInputStream input) {
837890
CurriculumEnabled = input.ReadBool();
838891
break;
839892
}
893+
case 122: {
894+
Config = input.ReadString();
895+
break;
896+
}
840897
}
841898
}
842899
}

ml-agents-envs/mlagents_envs/communicator_objects/training_analytics_pb2.py

Lines changed: 18 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)