Skip to content

Add capabtilities checks to c# and python. #3831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ and this project adheres to
- `num_updates` and `train_interval` for SAC were replaced with `steps_per_update`. (#3690)
- `WriteAdapter` was renamed to `ObservationWriter`. If you have a custom `ISensor` implementation,
you will need to change the signature of its `Write()` method. (#3834)
- `UnityRLCapabilities` was added to help inform users when RL features are mismatched between C# and Python packages. (#3831)

### Bug Fixes

Expand Down
11 changes: 10 additions & 1 deletion com.unity.ml-agents/Runtime/Academy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void FixedUpdate()
/// Access the Academy singleton through the <see cref="Instance"/>
/// property. The Academy instance is initialized the first time it is accessed (which will
/// typically be by the first <see cref="Agent"/> initialized in a scene).
///
///
/// At initialization, the Academy attempts to connect to the Python training process through
/// the external communicator. If successful, the training process can train <see cref="Agent"/>
/// instances. When you set an agent's <see cref="BehaviorParameters.behaviorType"/> setting
Expand Down Expand Up @@ -141,6 +141,12 @@ public int InferenceSeed
set { m_InferenceSeed = value; }
}

/// <summary>
/// Returns the RLCapabilities of the python client that the unity process is connected to.
/// </summary>
internal UnityRLCapabilities TrainerCapabilities { get; set; }


// The Academy uses a series of events to communicate with agents
// to facilitate synchronization. More specifically, it ensures
// that all the agents perform their steps in a consistent order (i.e. no
Expand Down Expand Up @@ -350,10 +356,13 @@ void InitializeEnvironment()
unityCommunicationVersion = k_ApiVersion,
unityPackageVersion = k_PackageVersion,
name = "AcademySingleton",
CSharpCapabilities = new UnityRLCapabilities()
});
UnityEngine.Random.InitState(unityRlInitParameters.seed);
// We might have inference-only Agents, so set the seed for them too.
m_InferenceSeed = unityRlInitParameters.seed;
TrainerCapabilities = unityRlInitParameters.TrainerCapabilities;
TrainerCapabilities.WarnOnPythonMissingBaseRLCapabilities();
}
catch
{
Expand Down
17 changes: 17 additions & 0 deletions com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ public static UnityRLInitParameters ToUnityRLInitParameters(this UnityRLInitiali
seed = inputProto.Seed,
pythonLibraryVersion = inputProto.PackageVersion,
pythonCommunicationVersion = inputProto.CommunicationVersion,
TrainerCapabilities = inputProto.Capabilities.ToRLCapabilities()
};
}

Expand Down Expand Up @@ -280,5 +281,21 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
return observationProto;
}
#endregion

public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto proto)
{
return new UnityRLCapabilities
{
m_BaseRLCapabilities = proto.BaseRLCapabilities
};
}

public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps)
{
return new UnityRLCapabilitiesProto
{
BaseRLCapabilities = rlCaps.m_BaseRLCapabilities
};
}
}
}
10 changes: 10 additions & 0 deletions com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ internal struct CommunicatorInitParameters
/// The version of the communication API.
/// </summary>
public string unityCommunicationVersion;

/// <summary>
/// The RL capabilities of the C# codebase.
/// </summary>
public UnityRLCapabilities CSharpCapabilities;
}
internal struct UnityRLInitParameters
{
Expand All @@ -44,6 +49,11 @@ internal struct UnityRLInitParameters
/// The version of the communication API that python is using.
/// </summary>
public string pythonCommunicationVersion;

/// <summary>
/// The RL capabilities of the Trainer codebase.
/// </summary>
public UnityRLCapabilities TrainerCapabilities;
}
internal struct UnityRLInputParameters
{
Expand Down
3 changes: 2 additions & 1 deletion com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ public UnityRLInitParameters Initialize(CommunicatorInitParameters initParameter
{
Name = initParameters.name,
PackageVersion = initParameters.unityPackageVersion,
CommunicationVersion = initParameters.unityCommunicationVersion
CommunicationVersion = initParameters.unityCommunicationVersion,
Capabilities = initParameters.CSharpCapabilities.ToProto()
};

UnityInputProto input;
Expand Down
36 changes: 36 additions & 0 deletions com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using UnityEngine;

namespace MLAgents
{
internal class UnityRLCapabilities
{
internal bool m_BaseRLCapabilities;

/// <summary>
/// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This
/// struct will be used to inform users if and when they are using C# / Trainer features that are mismatched.
/// </summary>
public UnityRLCapabilities(bool baseRlCapabilities=true)
{
m_BaseRLCapabilities = baseRlCapabilities;
}

/// <summary>
/// Will print a warning to the console if Python does not support base capabilities and will
/// return <value>true</value> if the warning was printed.
/// </summary>
/// <returns></returns>
public bool WarnOnPythonMissingBaseRLCapabilities()
{
if (m_BaseRLCapabilities)
{
return false;
}
Debug.LogWarning("Unity has connected to a Training process that does not support" +
"Base Reinforcement Learning Capabilities. Please make sure you have the" +
" latest training codebase installed for this version of the ML-Agents package.");
return true;
}

}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

182 changes: 182 additions & 0 deletions com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: mlagents_envs/communicator_objects/capabilities.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#region Designer generated code

using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace MLAgents.CommunicatorObjects {

/// <summary>Holder for reflection information generated from mlagents_envs/communicator_objects/capabilities.proto</summary>
internal static partial class CapabilitiesReflection {

#region Descriptor
/// <summary>File descriptor for mlagents_envs/communicator_objects/capabilities.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;

static CapabilitiesReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiNgoYVW5pdHlSTENh",
"cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCEIf",
"qgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities" }, null, null, null)
}));
}
#endregion

}
#region Messages
/// <summary>
///
/// A Capabilities message that will communicate both C# and Python
/// what features are available to both.
/// </summary>
internal sealed partial class UnityRLCapabilitiesProto : pb::IMessage<UnityRLCapabilitiesProto> {
private static readonly pb::MessageParser<UnityRLCapabilitiesProto> _parser = new pb::MessageParser<UnityRLCapabilitiesProto>(() => new UnityRLCapabilitiesProto());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<UnityRLCapabilitiesProto> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::MLAgents.CommunicatorObjects.CapabilitiesReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public UnityRLCapabilitiesProto() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() {
baseRLCapabilities_ = other.baseRLCapabilities_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public UnityRLCapabilitiesProto Clone() {
return new UnityRLCapabilitiesProto(this);
}

/// <summary>Field number for the "baseRLCapabilities" field.</summary>
public const int BaseRLCapabilitiesFieldNumber = 1;
private bool baseRLCapabilities_;
/// <summary>
/// These are the 1.0 capabilities.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool BaseRLCapabilities {
get { return baseRLCapabilities_; }
set {
baseRLCapabilities_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(UnityRLCapabilitiesProto other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (BaseRLCapabilities != other.BaseRLCapabilities) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (BaseRLCapabilities != false) {
output.WriteRawTag(8);
output.WriteBool(BaseRLCapabilities);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (BaseRLCapabilities != false) {
size += 1 + 1;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(UnityRLCapabilitiesProto other) {
if (other == null) {
return;
}
if (other.BaseRLCapabilities != false) {
BaseRLCapabilities = other.BaseRLCapabilities;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 8: {
BaseRLCapabilities = input.ReadBool();
break;
}
}
}
}

}

#endregion

}

#endregion Designer generated code

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading