Skip to content

Commit 399ad3c

Browse files
author
Chris Elion
authored
Stats SideChannel (for custom TensorBoard metrics) (#3660)
1 parent 312a439 commit 399ad3c

File tree

16 files changed

+316
-15
lines changed

16 files changed

+316
-15
lines changed

Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorSettings.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using UnityEngine;
33
using UnityEngine.UI;
44
using MLAgents;
5+
using MLAgents.SideChannels;
56

67
public class FoodCollectorSettings : MonoBehaviour
78
{
@@ -13,9 +14,12 @@ public class FoodCollectorSettings : MonoBehaviour
1314
public int totalScore;
1415
public Text scoreText;
1516

17+
StatsSideChannel m_statsSideChannel;
18+
1619
public void Awake()
1720
{
1821
Academy.Instance.OnEnvironmentReset += EnvironmentReset;
22+
m_statsSideChannel = Academy.Instance.GetSideChannel<StatsSideChannel>();
1923
}
2024

2125
public void EnvironmentReset()
@@ -44,5 +48,13 @@ void ClearObjects(GameObject[] objects)
4448
public void Update()
4549
{
4650
scoreText.text = $"Score: {totalScore}";
51+
52+
// Send stats via SideChannel so that they'll appear in TensorBoard.
53+
// These values get averaged every summary_frequency steps, so we don't
54+
// need to send every Update() call.
55+
if ((Time.frameCount % 100)== 0)
56+
{
57+
m_statsSideChannel?.AddStat("TotalScore", totalScore);
58+
}
4759
}
4860
}

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1111
### Minor Changes
1212
- Format of console output has changed slightly and now matches the name of the model/summary directory. (#3630, #3616)
1313
- Raise the wall in CrawlerStatic scene to prevent Agent from falling off. (#3650)
14+
- Added a feature to allow sending stats from C# environments to TensorBoard (and other python StatsWriters). To do this from your code, use `Academy.Instance.GetSideChannel<StatsSideChannel>().AddStat(key, value)` (#3660)
1415
- Renamed 'Generalization' feature to 'Environment Parameter Randomization'.
1516
- Fixed an issue where specifying `vis_encode_type` was required only for SAC. (#3677)
1617
- The way that UnityEnvironment decides the port was changed. If no port is specified, the behavior will depend on the `file_name` parameter. If it is `None`, 5004 (the editor port) will be used; otherwise 5005 (the base environment port) will be used.

com.unity.ml-agents/Runtime/Academy.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,33 @@ public void UnregisterSideChannel(SideChannel channel)
235235
Communicator?.UnregisterSideChannel(channel);
236236
}
237237

238+
/// <summary>
239+
/// Returns the SideChannel of Type T if there is one registered, or null if it doesn't.
240+
/// If there are multiple SideChannels of the same type registered, the returned instance is arbitrary.
241+
/// </summary>
242+
/// <typeparam name="T"></typeparam>
243+
/// <returns></returns>
244+
public T GetSideChannel<T>() where T: SideChannel
245+
{
246+
return Communicator?.GetSideChannel<T>();
247+
}
248+
249+
/// <summary>
250+
/// Returns all SideChannels of Type T that are registered. Use <see cref="GetSideChannel{T}()"/> if possible,
251+
/// as that does not make any memory allocations.
252+
/// </summary>
253+
/// <typeparam name="T"></typeparam>
254+
/// <returns></returns>
255+
public List<T> GetSideChannels<T>() where T: SideChannel
256+
{
257+
if (Communicator == null)
258+
{
259+
// Make sure we return a non-null List.
260+
return new List<T>();
261+
}
262+
return Communicator.GetSideChannels<T>();
263+
}
264+
238265
/// <summary>
239266
/// Disable stepping of the Academy during the FixedUpdate phase. If this is called, the Academy must be
240267
/// stepped manually by the user by calling Academy.EnvironmentStep().
@@ -334,6 +361,7 @@ void InitializeEnvironment()
334361
{
335362
Communicator.RegisterSideChannel(new EngineConfigurationChannel());
336363
Communicator.RegisterSideChannel(floatProperties);
364+
Communicator.RegisterSideChannel(new StatsSideChannel());
337365
// We try to exchange the first message with Python. If this fails, it means
338366
// no Python Process is ready to train the environment. In this case, the
339367
//environment must use Inference.

com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,5 +167,21 @@ internal interface ICommunicator : IDisposable
167167
/// </summary>
168168
/// <param name="sideChannel"> The side channel to be unregistered.</param>
169169
void UnregisterSideChannel(SideChannel sideChannel);
170+
171+
/// <summary>
172+
/// Returns the SideChannel of Type T if there is one registered, or null if it doesn't.
173+
/// If there are multiple SideChannels of the same type registered, the returned instance is arbitrary.
174+
/// </summary>
175+
/// <typeparam name="T"></typeparam>
176+
/// <returns></returns>
177+
T GetSideChannel<T>() where T : SideChannel;
178+
179+
/// <summary>
180+
/// Returns all SideChannels of Type T that are registered. Use <see cref="GetSideChannel{T}()"/> if possible,
181+
/// as that does not make any memory allocations.
182+
/// </summary>
183+
/// <typeparam name="T"></typeparam>
184+
/// <returns></returns>
185+
List<T> GetSideChannels<T>() where T : SideChannel;
170186
}
171187
}

com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,34 @@ public void UnregisterSideChannel(SideChannel sideChannel)
544544
}
545545
}
546546

547+
/// <inheritdoc/>
548+
public T GetSideChannel<T>() where T: SideChannel
549+
{
550+
foreach (var sc in m_SideChannels.Values)
551+
{
552+
if (sc.GetType() == typeof(T))
553+
{
554+
return (T) sc;
555+
}
556+
}
557+
return null;
558+
}
559+
560+
/// <inheritdoc/>
561+
public List<T> GetSideChannels<T>() where T: SideChannel
562+
{
563+
var output = new List<T>();
564+
565+
foreach (var sc in m_SideChannels.Values)
566+
{
567+
if (sc.GetType() == typeof(T))
568+
{
569+
output.Add((T) sc);
570+
}
571+
}
572+
return output;
573+
}
574+
547575
/// <summary>
548576
/// Grabs the messages that the registered side channels will send to Python at the current step
549577
/// into a singe byte array.

com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ namespace MLAgents.SideChannels
88
/// </summary>
99
public class EngineConfigurationChannel : SideChannel
1010
{
11-
private const string k_EngineConfigId = "e951342c-4f7e-11ea-b238-784f4387d1f7";
11+
const string k_EngineConfigId = "e951342c-4f7e-11ea-b238-784f4387d1f7";
1212

1313
/// <summary>
14-
/// Initializes the side channel.
14+
/// Initializes the side channel. The constructor is internal because only one instance is
15+
/// supported at a time, and is created by the Academy.
1516
/// </summary>
16-
public EngineConfigurationChannel()
17+
internal EngineConfigurationChannel()
1718
{
1819
ChannelId = new Guid(k_EngineConfigId);
1920
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
using System;
2+
namespace MLAgents.SideChannels
3+
{
4+
/// <summary>
5+
/// Determines the behavior of how multiple stats within the same summary period are combined.
6+
/// </summary>
7+
public enum StatAggregationMethod
8+
{
9+
/// <summary>
10+
/// Values within the summary period are averaged before reporting.
11+
/// Note that values from the same C# environment in the same step may replace each other.
12+
/// </summary>
13+
Average = 0,
14+
15+
/// <summary>
16+
/// Only the most recent value is reported.
17+
/// To avoid conflicts between multiple environments, the ML Agents environment will only
18+
/// keep stats from worker index 0.
19+
/// </summary>
20+
MostRecent = 1
21+
}
22+
23+
/// <summary>
24+
/// Add stats (key-value pairs) for reporting. The ML Agents environment will send these to a StatsReporter
25+
/// instance, which means the values will appear in the Tensorboard summary, as well as trainer gauges.
26+
/// Note that stats are only written every summary_frequency steps; See <see cref="StatAggregationMethod"/>
27+
/// for options on how multiple values are handled.
28+
/// </summary>
29+
public class StatsSideChannel : SideChannel
30+
{
31+
const string k_StatsSideChannelDefaultId = "a1d8f7b7-cec8-50f9-b78b-d3e165a78520";
32+
33+
/// <summary>
34+
/// Initializes the side channel with the provided channel ID.
35+
/// The constructor is internal because only one instance is
36+
/// supported at a time, and is created by the Academy.
37+
/// </summary>
38+
internal StatsSideChannel()
39+
{
40+
ChannelId = new Guid(k_StatsSideChannelDefaultId);
41+
}
42+
43+
/// <summary>
44+
/// Add a stat value for reporting. This will appear in the Tensorboard summary and trainer gauges.
45+
/// You can nest stats in Tensorboard with "/".
46+
/// Note that stats are only written to Tensorboard each summary_frequency steps; if a stat is
47+
/// received multiple times, only the most recent version is used.
48+
/// To avoid conflicts between multiple environments, only stats from worker index 0 are used.
49+
/// </summary>
50+
/// <param name="key">The stat name.</param>
51+
/// <param name="value">The stat value. You can nest stats in Tensorboard by using "/". </param>
52+
/// <param name="aggregationMethod">How multiple values should be treated.</param>
53+
public void AddStat(
54+
string key, float value, StatAggregationMethod aggregationMethod = StatAggregationMethod.Average
55+
)
56+
{
57+
using (var msg = new OutgoingMessage())
58+
{
59+
msg.WriteString(key);
60+
msg.WriteFloat32(value);
61+
msg.WriteInt32((int)aggregationMethod);
62+
QueueMessageToSend(msg);
63+
}
64+
}
65+
66+
/// <inheritdoc/>
67+
public override void OnMessageReceived(IncomingMessage msg)
68+
{
69+
throw new UnityAgentsException("StatsSideChannel should never receive messages.");
70+
}
71+
}
72+
}

com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs.meta

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/Using-Tensorboard.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,10 @@ The ML-Agents training program saves the following statistics:
8787
taken between two observations.
8888
8989
* `Losses/Cloning Loss` (BC) - The mean magnitude of the behavioral cloning loss. Corresponds to how well the model imitates the demonstration data.
90+
91+
## Custom Metrics from C#
92+
To get custom metrics from a C# environment into Tensorboard, you can use the StatsSideChannel:
93+
```csharp
94+
var statsSideChannel = Academy.Instance.GetSideChannel<StatsSideChannel>();
95+
statsSideChannel.AddStat("MyMetric", 1.0);
96+
```
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from mlagents_envs.side_channel import SideChannel, IncomingMessage
2+
import uuid
3+
from typing import Dict, Tuple
4+
from enum import Enum
5+
6+
7+
# Determines the behavior of how multiple stats within the same summary period are combined.
8+
class StatsAggregationMethod(Enum):
9+
# Values within the summary period are averaged before reporting.
10+
AVERAGE = 0
11+
12+
# Only the most recent value is reported.
13+
MOST_RECENT = 1
14+
15+
16+
class StatsSideChannel(SideChannel):
17+
"""
18+
Side channel that receives (string, float) pairs from the environment, so that they can eventually
19+
be passed to a StatsReporter.
20+
"""
21+
22+
def __init__(self) -> None:
23+
# >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/StatsSideChannel")
24+
# UUID('a1d8f7b7-cec8-50f9-b78b-d3e165a78520')
25+
super().__init__(uuid.UUID("a1d8f7b7-cec8-50f9-b78b-d3e165a78520"))
26+
27+
self.stats: Dict[str, Tuple[float, StatsAggregationMethod]] = {}
28+
29+
def on_message_received(self, msg: IncomingMessage) -> None:
30+
"""
31+
Receive the message from the environment, and save it for later retrieval.
32+
:param msg:
33+
:return:
34+
"""
35+
key = msg.read_string()
36+
val = msg.read_float32()
37+
agg_type = StatsAggregationMethod(msg.read_int32())
38+
39+
self.stats[key] = (val, agg_type)
40+
41+
def get_and_reset_stats(self) -> Dict[str, Tuple[float, StatsAggregationMethod]]:
42+
"""
43+
Returns the current stats, and resets the internal storage of the stats.
44+
:return:
45+
"""
46+
s = self.stats
47+
self.stats = {}
48+
return s

ml-agents/mlagents/trainers/agent_processor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import defaultdict, Counter, deque
44

55
from mlagents_envs.base_env import BatchedStepResult, StepResult
6+
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
67
from mlagents.trainers.trajectory import Trajectory, AgentExperience
78
from mlagents.trainers.policy.tf_policy import TFPolicy
89
from mlagents.trainers.policy import Policy
@@ -267,3 +268,23 @@ def __init__(
267268
self.behavior_id
268269
)
269270
self.publish_trajectory_queue(self.trajectory_queue)
271+
272+
def record_environment_stats(
273+
self, env_stats: Dict[str, Tuple[float, StatsAggregationMethod]], worker_id: int
274+
) -> None:
275+
"""
276+
Pass stats from the environment to the StatsReporter.
277+
Depending on the StatsAggregationMethod, either StatsReporter.add_stat or StatsReporter.set_stat is used.
278+
The worker_id is used to determin whether StatsReporter.set_stat should be used.
279+
:param env_stats:
280+
:param worker_id:
281+
:return:
282+
"""
283+
for stat_name, (val, agg_type) in env_stats.items():
284+
if agg_type == StatsAggregationMethod.AVERAGE:
285+
self.stats_reporter.add_stat(stat_name, val)
286+
elif agg_type == StatsAggregationMethod.MOST_RECENT:
287+
# In order to prevent conflicts between multiple environments,
288+
# only stats from the first environment are recorded.
289+
if worker_id == 0:
290+
self.stats_reporter.set_stat(stat_name, val)

ml-agents/mlagents/trainers/env_manager.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from abc import ABC, abstractmethod
22
import logging
3-
from typing import List, Dict, NamedTuple, Iterable
3+
from typing import List, Dict, NamedTuple, Iterable, Tuple
44
from mlagents_envs.base_env import BatchedStepResult, AgentGroupSpec, AgentGroup
5+
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
56
from mlagents.trainers.brain import BrainParameters
67
from mlagents.trainers.policy.tf_policy import TFPolicy
78
from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue
@@ -17,14 +18,15 @@ class EnvironmentStep(NamedTuple):
1718
current_all_step_result: AllStepResult
1819
worker_id: int
1920
brain_name_to_action_info: Dict[AgentGroup, ActionInfo]
21+
environment_stats: Dict[str, Tuple[float, StatsAggregationMethod]]
2022

2123
@property
2224
def name_behavior_ids(self) -> Iterable[AgentGroup]:
2325
return self.current_all_step_result.keys()
2426

2527
@staticmethod
2628
def empty(worker_id: int) -> "EnvironmentStep":
27-
return EnvironmentStep({}, worker_id, {})
29+
return EnvironmentStep({}, worker_id, {}, {})
2830

2931

3032
class EnvManager(ABC):
@@ -108,4 +110,8 @@ def _process_step_infos(self, step_infos: List[EnvironmentStep]) -> int:
108110
name_behavior_id, ActionInfo.empty()
109111
),
110112
)
113+
114+
self.agent_managers[name_behavior_id].record_environment_stats(
115+
step_info.environment_stats, step_info.worker_id
116+
)
111117
return len(step_infos)

ml-agents/mlagents/trainers/simple_env_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def _step(self) -> List[EnvironmentStep]:
3131
self.env.step()
3232
all_step_result = self._generate_all_results()
3333

34-
step_info = EnvironmentStep(all_step_result, 0, self.previous_all_action_info)
34+
step_info = EnvironmentStep(
35+
all_step_result, 0, self.previous_all_action_info, {}
36+
)
3537
self.previous_step = step_info
3638
return [step_info]
3739

@@ -43,7 +45,7 @@ def _reset_env(
4345
self.shared_float_properties.set_property(k, v)
4446
self.env.reset()
4547
all_step_result = self._generate_all_results()
46-
self.previous_step = EnvironmentStep(all_step_result, 0, {})
48+
self.previous_step = EnvironmentStep(all_step_result, 0, {}, {})
4749
return [self.previous_step]
4850

4951
@property

0 commit comments

Comments
 (0)