Skip to content

Commit f81e835

Browse files
author
Marwan Mattar
committed
Back-up commit.
1 parent 194cea4 commit f81e835

File tree

4 files changed

+92
-13
lines changed

4 files changed

+92
-13
lines changed

com.unity.ml-agents/Runtime/SideChannels/EnvironmentParametersChannel.cs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33

44
namespace MLAgents.SideChannels
55
{
6+
/// <summary>
7+
/// Lists the different data types supported.
8+
/// </summary>
9+
internal enum EnvironmentDataTypes
10+
{
11+
Float = 0
12+
}
13+
614
/// <summary>
715
/// A side channel that manages the environment parameter values from Python. Currently
816
/// limited to parameters of type float.
@@ -28,13 +36,21 @@ internal EnvironmentParametersChannel()
2836
public override void OnMessageReceived(IncomingMessage msg)
2937
{
3038
var key = msg.ReadString();
31-
var value = msg.ReadFloat32();
39+
var type = msg.ReadInt32();
40+
if ((int)EnvironmentDataTypes.Float == type)
41+
{
42+
var value = msg.ReadFloat32();
3243

33-
m_Parameters[key] = value;
44+
m_Parameters[key] = value;
3445

35-
Action<float> action;
36-
m_RegisteredActions.TryGetValue(key, out action);
37-
action?.Invoke(value);
46+
Action<float> action;
47+
m_RegisteredActions.TryGetValue(key, out action);
48+
action?.Invoke(value);
49+
}
50+
else
51+
{
52+
throw new UnityAgentsException("EnvironmentParametersChannel only supports floats.");
53+
}
3854
}
3955

4056
/// <summary>
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
2+
import uuid
3+
from typing import Dict, Optional, List
4+
5+
6+
class EnvironmentParametersChannel(SideChannel):
7+
"""
8+
This is the SideChannel for float properties shared with Unity.
9+
You can modify the float properties of an environment with the commands
10+
set_property, get_property and list_properties.
11+
"""
12+
13+
def __init__(self, channel_id: uuid.UUID = None) -> None:
14+
self._float_properties: Dict[str, float] = {}
15+
if channel_id is None:
16+
channel_id = uuid.UUID(("534c891e-810f-11ea-a9d0-822485860400"))
17+
super().__init__(channel_id)
18+
19+
def on_message_received(self, msg: IncomingMessage) -> None:
20+
"""
21+
Is called by the environment to the side channel. Can be called
22+
multiple times per step if multiple messages are meant for that
23+
SideChannel.
24+
Note that Python should never receive an environment parameters from
25+
Unity
26+
"""
27+
raise UnityCommunicationException(
28+
"The EnvironmentParametersChannel received a message from Unity, "
29+
+ "this should not have happend."
30+
)
31+
32+
def set_float_property(self, key: str, value: float) -> None:
33+
"""
34+
Sets a property in the Unity Environment.
35+
:param key: The string identifier of the property.
36+
:param value: The float value of the property.
37+
"""
38+
self._float_properties[key] = value
39+
msg = OutgoingMessage()
40+
msg.write_string(key)
41+
msg.write_float32(value)
42+
super().queue_message_to_send(msg)
43+
44+
def get_float_property(self, key: str) -> Optional[float]:
45+
"""
46+
Gets a property in the Unity Environment. If the property was not
47+
found, will return None.
48+
:param key: The string identifier of the property.
49+
:return: The float value of the property or None.
50+
"""
51+
return self._float_properties.get(key)
52+
53+
def list_properties(self) -> List[str]:
54+
"""
55+
Returns a list of all the string identifiers of the properties
56+
currently present in the Unity Environment.
57+
"""
58+
return list(self._float_properties.keys())
59+
60+
def get_property_dict_copy(self) -> Dict[str, float]:
61+
"""
62+
Returns a copy of the float properties.
63+
:return:
64+
"""
65+
return dict(self._float_properties)

ml-agents/mlagents/trainers/simple_env_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mlagents_envs.timers import timed
66
from mlagents.trainers.action_info import ActionInfo
77
from mlagents.trainers.brain import BrainParameters
8-
from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel
8+
from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParameters
99
from mlagents.trainers.brain_conversion_utils import behavior_spec_to_brain_parameters
1010

1111

@@ -15,9 +15,9 @@ class SimpleEnvManager(EnvManager):
1515
This is generally only useful for testing; see SubprocessEnvManager for a production-quality implementation.
1616
"""
1717

18-
def __init__(self, env: BaseEnv, float_prop_channel: FloatPropertiesChannel):
18+
def __init__(self, env: BaseEnv, environment_parameters_channel: EnvironmentParameters):
1919
super().__init__()
20-
self.shared_float_properties = float_prop_channel
20+
self.environment_parameters_channel = environment_parameters_channel
2121
self.env = env
2222
self.previous_step: EnvironmentStep = EnvironmentStep.empty(0)
2323
self.previous_all_action_info: Dict[str, ActionInfo] = {}

ml-agents/mlagents/trainers/subprocess_env_manager.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def worker(
113113
env_factory: Callable[
114114
[int, List[SideChannel]], UnityEnvironment
115115
] = cloudpickle.loads(pickled_env_factory)
116-
shared_float_properties = FloatPropertiesChannel()
116+
environment_parameters_channel = EnvironmentParametersChannel()
117117
engine_configuration_channel = EngineConfigurationChannel()
118118
engine_configuration_channel.set_configuration(engine_configuration)
119119
stats_channel = StatsSideChannel()
@@ -139,7 +139,7 @@ def external_brains():
139139
try:
140140
env = env_factory(
141141
worker_id,
142-
[shared_float_properties, engine_configuration_channel, stats_channel],
142+
[environment_parameters_channel, engine_configuration_channel, stats_channel],
143143
)
144144
while True:
145145
req: EnvironmentRequest = parent_conn.recv()
@@ -168,11 +168,9 @@ def external_brains():
168168
elif req.cmd == EnvironmentCommand.EXTERNAL_BRAINS:
169169
_send_response(EnvironmentCommand.EXTERNAL_BRAINS, external_brains())
170170
elif req.cmd == EnvironmentCommand.GET_PROPERTIES:
171-
reset_params = shared_float_properties.get_property_dict_copy()
171+
reset_params = environment_parameters_channel.get_property_dict_copy()
172172
_send_response(EnvironmentCommand.GET_PROPERTIES, reset_params)
173173
elif req.cmd == EnvironmentCommand.RESET:
174-
for k, v in req.payload.items():
175-
shared_float_properties.set_property(k, v)
176174
env.reset()
177175
all_step_result = _generate_all_results()
178176
_send_response(EnvironmentCommand.RESET, all_step_result)

0 commit comments

Comments
 (0)