Skip to content

Custom raycast sensor #5411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ public struct RayPerceptionInput
public int LayerMask;

/// <summary>
/// Returns the expected number of floats in the output.
/// (Deprecated) Returns the expected number of floats in the output.
/// </summary>
/// <returns></returns>
[Obsolete("RayPerceptionInput.OutputSize() has been deprecated, please use RayPerceptionSensor.OutputSize instead.")]
public int OutputSize()
{
return ((DetectableTags?.Count ?? 0) + 2) * (Angles?.Count ?? 0);
Expand Down Expand Up @@ -201,7 +202,7 @@ public float ScaledRayLength
public float ScaledCastRadius;

/// <summary>
/// Writes the ray output information to a subset of the float array. Each element in the rayAngles array
/// (Deprecated) Writes the ray output information to a subset of the float array. Each element in the rayAngles array
/// determines a sublist of data to the observation. The sublist contains the observation data for a single cast.
/// The list is composed of the following:
/// 1. A one-hot encoding for detectable tags. For example, if DetectableTags.Length = n, the
Expand All @@ -215,6 +216,7 @@ public float ScaledRayLength
/// <param name="numDetectableTags"></param>
/// <param name="rayIndex"></param>
/// <param name="buffer">Output buffer. The size must be equal to (numDetectableTags+2) * RayOutputs.Length</param>
[Obsolete("RayPerceptionOutput.ToFloatArray() has been deprecated, please use RayPerceptionSensor.RayOutputToArray() instead.")]
public void ToFloatArray(int numDetectableTags, int rayIndex, float[] buffer)
{
var bufferOffset = (numDetectableTags + 2) * rayIndex;
Expand Down Expand Up @@ -265,12 +267,20 @@ public RayPerceptionSensor(string name, RayPerceptionInput rayInput)
m_Name = name;
m_RayPerceptionInput = rayInput;

SetNumObservations(rayInput.OutputSize());
SetNumObservations(GetObservationSizePerRay(), GetNumberOfRays());

m_DebugLastFrameCount = Time.frameCount;
m_RayPerceptionOutput = new RayPerceptionOutput();
}

/// <summary>
/// The ray perception input configurations.
/// </summary>
public RayPerceptionInput RayPerceptionInput
{
get { return m_RayPerceptionInput; }
}

/// <summary>
/// The most recent raycast results.
/// </summary>
Expand All @@ -279,27 +289,56 @@ public RayPerceptionOutput RayPerceptionOutput
get { return m_RayPerceptionOutput; }
}

void SetNumObservations(int numObservations)
/// <summary>
/// The observation size per ray.
/// Override this method for custom observations.
/// </summary>
public virtual int GetObservationSizePerRay()
{
return (RayPerceptionInput.DetectableTags?.Count ?? 0) + 2;
}

/// <summary>
/// The number of rays in the sensor.
/// </summary>
public int GetNumberOfRays()
{
return RayPerceptionInput.Angles?.Count ?? 0;
}

void SetNumObservations(int observationsSizePerRay, int numRays)
{
m_ObservationSpec = ObservationSpec.Vector(numObservations);
m_Observations = new float[numObservations];
m_ObservationSpec = ObservationSpec.Vector(observationsSizePerRay * numRays);
m_Observations = new float[observationsSizePerRay];
}

internal void SetRayPerceptionInput(RayPerceptionInput rayInput)
{
// Note that change the number of rays or tags doesn't directly call this,
// but changing them and then changing another field will.
if (m_RayPerceptionInput.OutputSize() != rayInput.OutputSize())
var oldObservationSize = GetObservationSizePerRay() * GetNumberOfRays();
m_RayPerceptionInput = rayInput;
if (GetObservationSizePerRay() * GetNumberOfRays() != oldObservationSize)
{
Debug.Log(
"Changing the number of tags or rays at runtime is not " +
"supported and may cause errors in training or inference."
);
// Changing the shape will probably break things downstream, but we can at least
// keep this consistent.
SetNumObservations(rayInput.OutputSize());
SetNumObservations(GetObservationSizePerRay(), GetNumberOfRays());
}
m_RayPerceptionInput = rayInput;
}

public virtual void RayOutputToArray(RayPerceptionOutput.RayOutput rayOutput, int rayIndex, float[] buffer)
{
if (rayOutput.HitTaggedObject)
{
buffer[rayOutput.HitTagIndex] = 1f;
}
var numDetectableTags = RayPerceptionInput.DetectableTags.Count;
buffer[numDetectableTags] = rayOutput.HasHit ? 0f : 1f;
buffer[numDetectableTags + 1] = rayOutput.HitFraction;
}

/// <summary>
Expand All @@ -310,22 +349,22 @@ internal void SetRayPerceptionInput(RayPerceptionInput rayInput)
/// <returns></returns>
public int Write(ObservationWriter writer)
{
var numWritten = 0;
using (TimerStack.Instance.Scoped("RayPerceptionSensor.Perceive"))
{
Array.Clear(m_Observations, 0, m_Observations.Length);
var numRays = m_RayPerceptionInput.Angles.Count;
var numDetectableTags = m_RayPerceptionInput.DetectableTags.Count;
var rayObservartionSize = GetObservationSizePerRay();

// For each ray, write the information to the observation buffer
for (var rayIndex = 0; rayIndex < numRays; rayIndex++)
for (var rayIndex = 0; rayIndex < GetNumberOfRays(); rayIndex++)
{
m_RayPerceptionOutput.RayOutputs?[rayIndex].ToFloatArray(numDetectableTags, rayIndex, m_Observations);
}
Array.Clear(m_Observations, 0, rayObservartionSize);
RayOutputToArray(m_RayPerceptionOutput.RayOutputs[rayIndex], rayIndex, m_Observations);

// Finally, add the observations to the ObservationWriter
writer.AddList(m_Observations);
writer.AddList(m_Observations, numWritten);
numWritten += rayObservartionSize;
}
}
return m_Observations.Length;
return numWritten;
}

/// <inheritdoc/>
Expand Down