Skip to content

Commit 14fad19

Browse files
author
Chris Elion
authored
[MLA-1138] joint observations (#4224)
1 parent 02ad5c0 commit 14fad19

15 files changed

+424
-31
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#if UNITY_2020_1_OR_NEWER
2+
3+
using System.Collections.Generic;
4+
using UnityEngine;
5+
using Unity.MLAgents.Sensors;
6+
7+
namespace Unity.MLAgents.Extensions.Sensors
8+
{
9+
public class ArticulationBodyJointExtractor : IJointExtractor
10+
{
11+
ArticulationBody m_Body;
12+
13+
public ArticulationBodyJointExtractor(ArticulationBody body)
14+
{
15+
m_Body = body;
16+
}
17+
18+
public int NumObservations(PhysicsSensorSettings settings)
19+
{
20+
return NumObservations(m_Body, settings);
21+
}
22+
23+
public static int NumObservations(ArticulationBody body, PhysicsSensorSettings settings)
24+
{
25+
if (body == null || body.isRoot)
26+
{
27+
return 0;
28+
}
29+
30+
var totalCount = 0;
31+
if (settings.UseJointPositionsAndAngles)
32+
{
33+
switch (body.jointType)
34+
{
35+
case ArticulationJointType.RevoluteJoint:
36+
case ArticulationJointType.SphericalJoint:
37+
// Both RevoluteJoint and SphericalJoint have all angular components.
38+
// We use sine and cosine of the angles for the observations.
39+
totalCount += 2 * body.dofCount;
40+
break;
41+
case ArticulationJointType.FixedJoint:
42+
// Since FixedJoint can't moved, there aren't any interesting observations for it.
43+
break;
44+
case ArticulationJointType.PrismaticJoint:
45+
// One linear component
46+
totalCount += body.dofCount;
47+
break;
48+
}
49+
}
50+
51+
if (settings.UseJointForces)
52+
{
53+
totalCount += body.dofCount;
54+
}
55+
56+
return totalCount;
57+
}
58+
59+
public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset)
60+
{
61+
if (m_Body == null || m_Body.isRoot)
62+
{
63+
return 0;
64+
}
65+
66+
var currentOffset = offset;
67+
68+
// Write joint positions
69+
if (settings.UseJointPositionsAndAngles)
70+
{
71+
switch (m_Body.jointType)
72+
{
73+
case ArticulationJointType.RevoluteJoint:
74+
case ArticulationJointType.SphericalJoint:
75+
// All joint positions are angular
76+
for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
77+
{
78+
var jointRotationRads = m_Body.jointPosition[dofIndex];
79+
writer[currentOffset++] = Mathf.Sin(jointRotationRads);
80+
writer[currentOffset++] = Mathf.Cos(jointRotationRads);
81+
}
82+
break;
83+
case ArticulationJointType.FixedJoint:
84+
// No observations
85+
break;
86+
case ArticulationJointType.PrismaticJoint:
87+
writer[currentOffset++] = GetPrismaticValue();
88+
break;
89+
}
90+
}
91+
92+
if (settings.UseJointForces)
93+
{
94+
for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
95+
{
96+
// take tanh to keep in [-1, 1]
97+
writer[currentOffset++] = (float) System.Math.Tanh(m_Body.jointForce[dofIndex]);
98+
}
99+
}
100+
101+
return currentOffset - offset;
102+
}
103+
104+
float GetPrismaticValue()
105+
{
106+
// Prismatic joints should have at most one free axis.
107+
bool limited = false;
108+
var drive = m_Body.xDrive;
109+
if (m_Body.linearLockX == ArticulationDofLock.LimitedMotion)
110+
{
111+
drive = m_Body.xDrive;
112+
limited = true;
113+
}
114+
else if (m_Body.linearLockY == ArticulationDofLock.LimitedMotion)
115+
{
116+
drive = m_Body.yDrive;
117+
limited = true;
118+
}
119+
else if (m_Body.linearLockZ == ArticulationDofLock.LimitedMotion)
120+
{
121+
drive = m_Body.zDrive;
122+
limited = true;
123+
}
124+
125+
var jointPos = m_Body.jointPosition[0];
126+
if (limited)
127+
{
128+
// If locked, interpolate between the limits.
129+
var upperLimit = drive.upperLimit;
130+
var lowerLimit = drive.lowerLimit;
131+
if (upperLimit <= lowerLimit)
132+
{
133+
// Invalid limits (probably equal), so don't try to lerp
134+
return 0;
135+
}
136+
var invLerped = Mathf.InverseLerp(lowerLimit, upperLimit, jointPos);
137+
138+
// Convert [0, 1] -> [-1, 1]
139+
var normalized = 2.0f * invLerped - 1.0f;
140+
return normalized;
141+
}
142+
// take tanh() to keep in [-1, 1]
143+
return (float) System.Math.Tanh(jointPos);
144+
}
145+
}
146+
}
147+
#endif

com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ protected override Pose GetPoseAt(int index)
7171
var t = go.transform;
7272
return new Pose { rotation = t.rotation, position = t.position };
7373
}
74+
75+
internal ArticulationBody[] Bodies => m_Bodies;
7476
}
7577
}
7678
#endif // UNITY_2020_1_OR_NEWER

com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,16 @@ public override int[] GetObservationShape()
3232
// TODO static method in PhysicsBodySensor?
3333
// TODO only update PoseExtractor when body changes?
3434
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
35-
var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
36-
return new[] { numTransformObservations };
35+
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
36+
var numJointObservations = 0;
37+
// Start from i=1 to ignore the root
38+
for (var i = 1; i < poseExtractor.Bodies.Length; i++)
39+
{
40+
numJointObservations += ArticulationBodyJointExtractor.NumObservations(
41+
poseExtractor.Bodies[i], Settings
42+
);
43+
}
44+
return new[] { numPoseObservations + numJointObservations };
3745
}
3846
}
3947

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using Unity.MLAgents.Sensors;
2+
3+
namespace Unity.MLAgents.Extensions.Sensors
4+
{
5+
/// <summary>
6+
/// Interface for generating observations from a physical joint or constraint.
7+
/// </summary>
8+
public interface IJointExtractor
9+
{
10+
/// <summary>
11+
/// Determine the number of observations that would be generated for the particular joint
12+
/// using the provided PhysicsSensorSettings.
13+
/// </summary>
14+
/// <param name="settings"></param>
15+
/// <returns>Number of floats that will be written.</returns>
16+
int NumObservations(PhysicsSensorSettings settings);
17+
18+
/// <summary>
19+
/// Write the observations to the ObservationWriter, starting at the specified offset.
20+
/// </summary>
21+
/// <param name="settings"></param>
22+
/// <param name="writer"></param>
23+
/// <param name="offset"></param>
24+
/// <returns>Number of floats that were written.</returns>
25+
int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset);
26+
}
27+
}

com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ public class PhysicsBodySensor : ISensor
1212
string m_SensorName;
1313

1414
PoseExtractor m_PoseExtractor;
15+
IJointExtractor[] m_JointExtractors;
1516
PhysicsSensorSettings m_Settings;
1617

1718
/// <summary>
@@ -22,23 +23,59 @@ public class PhysicsBodySensor : ISensor
2223
/// <param name="sensorName"></param>
2324
public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
2425
{
25-
m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
26+
var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
27+
m_PoseExtractor = poseExtractor;
2628
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
2729
m_Settings = settings;
2830

29-
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
30-
m_Shape = new[] { numTransformObservations };
31+
var numJointExtractorObservations = 0;
32+
var rigidBodies = poseExtractor.Bodies;
33+
if (rigidBodies != null)
34+
{
35+
m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root
36+
for (var i = 1; i < rigidBodies.Length; i++)
37+
{
38+
var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]);
39+
numJointExtractorObservations += jointExtractor.NumObservations(settings);
40+
m_JointExtractors[i - 1] = jointExtractor;
41+
}
42+
}
43+
else
44+
{
45+
m_JointExtractors = new IJointExtractor[0];
46+
}
47+
48+
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
49+
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
3150
}
3251

3352
#if UNITY_2020_1_OR_NEWER
3453
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
3554
{
36-
m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody);
55+
var poseExtractor = new ArticulationBodyPoseExtractor(rootBody);
56+
m_PoseExtractor = poseExtractor;
3757
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName;
3858
m_Settings = settings;
3959

40-
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
41-
m_Shape = new[] { numTransformObservations };
60+
var numJointExtractorObservations = 0;
61+
var articBodies = poseExtractor.Bodies;
62+
if (articBodies != null)
63+
{
64+
m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root
65+
for (var i = 1; i < articBodies.Length; i++)
66+
{
67+
var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]);
68+
numJointExtractorObservations += jointExtractor.NumObservations(settings);
69+
m_JointExtractors[i - 1] = jointExtractor;
70+
}
71+
}
72+
else
73+
{
74+
m_JointExtractors = new IJointExtractor[0];
75+
}
76+
77+
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
78+
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
4279
}
4380
#endif
4481

@@ -52,6 +89,10 @@ public int[] GetObservationShape()
5289
public int Write(ObservationWriter writer)
5390
{
5491
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
92+
foreach (var jointExtractor in m_JointExtractors)
93+
{
94+
numWritten += jointExtractor.Write(m_Settings, writer, numWritten);
95+
}
5596
return numWritten;
5697
}
5798

com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ public struct PhysicsSensorSettings
4040
/// </summary>
4141
public bool UseLocalSpaceLinearVelocity;
4242

43+
/// <summary>
44+
/// Whether to use joint-specific positions and angles as observations.
45+
/// </summary>
46+
public bool UseJointPositionsAndAngles;
47+
48+
/// <summary>
49+
/// Whether to use the joint forces and torques that are applied by the solver as observations.
50+
/// </summary>
51+
public bool UseJointForces;
52+
4353
/// <summary>
4454
/// Creates a PhysicsSensorSettings with reasonable default values.
4555
/// </summary>
@@ -68,26 +78,6 @@ public bool UseLocalSpace
6878
{
6979
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; }
7080
}
71-
72-
73-
/// <summary>
74-
/// The number of floats needed to represent a given number of transforms.
75-
/// </summary>
76-
/// <param name="numTransforms"></param>
77-
/// <returns></returns>
78-
public int TransformSize(int numTransforms)
79-
{
80-
int obsPerTransform = 0;
81-
obsPerTransform += UseModelSpaceTranslations ? 3 : 0;
82-
obsPerTransform += UseModelSpaceRotations ? 4 : 0;
83-
obsPerTransform += UseLocalSpaceTranslations ? 3 : 0;
84-
obsPerTransform += UseLocalSpaceRotations ? 4 : 0;
85-
86-
obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0;
87-
obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0;
88-
89-
return numTransforms * obsPerTransform;
90-
}
9181
}
9282

9383
internal static class ObservationWriterPhysicsExtensions

com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,24 @@ public void UpdateLocalSpacePoses()
167167
}
168168
}
169169

170+
/// <summary>
171+
/// Compute the number of floats needed to represent the poses for the given PhysicsSensorSettings.
172+
/// </summary>
173+
/// <param name="settings"></param>
174+
/// <returns></returns>
175+
public int GetNumPoseObservations(PhysicsSensorSettings settings)
176+
{
177+
int obsPerPose = 0;
178+
obsPerPose += settings.UseModelSpaceTranslations ? 3 : 0;
179+
obsPerPose += settings.UseModelSpaceRotations ? 4 : 0;
180+
obsPerPose += settings.UseLocalSpaceTranslations ? 3 : 0;
181+
obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0;
182+
183+
obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0;
184+
obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0;
185+
186+
return NumPoses * obsPerPose;
187+
}
170188

171189
internal void DrawModelSpace(Vector3 offset)
172190
{

0 commit comments

Comments
 (0)