Skip to content

Commit cda8843

Browse files
author
Jonathan Harper
committed
Add enum for environment commands
1 parent 020c200 commit cda8843

File tree

2 files changed

+56
-35
lines changed

2 files changed

+56
-35
lines changed

ml-agents/mlagents/trainers/subprocess_env_manager.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set
33
import cloudpickle
4+
import enum
45

56
from mlagents_envs.environment import UnityEnvironment
67
from mlagents_envs.exception import (
@@ -33,13 +34,22 @@
3334
logger = logging.getLogger("mlagents.trainers")
3435

3536

36-
class EnvironmentCommand(NamedTuple):
37-
name: str
37+
class EnvironmentCommand(enum.Enum):
38+
STEP = 1
39+
EXTERNAL_BRAINS = 2
40+
GET_PROPERTIES = 3
41+
RESET = 4
42+
CLOSE = 5
43+
ENV_EXITED = 6
44+
45+
46+
class EnvironmentRequest(NamedTuple):
47+
cmd: EnvironmentCommand
3848
payload: Any = None
3949

4050

4151
class EnvironmentResponse(NamedTuple):
42-
name: str
52+
cmd: EnvironmentCommand
4353
worker_id: int
4454
payload: Any
4555

@@ -58,17 +68,17 @@ def __init__(self, process: Process, worker_id: int, conn: Connection):
5868
self.previous_all_action_info: Dict[str, ActionInfo] = {}
5969
self.waiting = False
6070

61-
def send(self, name: str, payload: Any = None) -> None:
71+
def send(self, cmd: EnvironmentCommand, payload: Any = None) -> None:
6272
try:
63-
cmd = EnvironmentCommand(name, payload)
64-
self.conn.send(cmd)
73+
req = EnvironmentRequest(cmd, payload)
74+
self.conn.send(req)
6575
except (BrokenPipeError, EOFError):
6676
raise UnityCommunicationException("UnityEnvironment worker: send failed.")
6777

6878
def recv(self) -> EnvironmentResponse:
6979
try:
7080
response: EnvironmentResponse = self.conn.recv()
71-
if response.name == "env_close":
81+
if response.cmd == EnvironmentCommand.ENV_EXITED:
7282
env_exception: Exception = response.payload
7383
raise env_exception
7484
return response
@@ -77,7 +87,7 @@ def recv(self) -> EnvironmentResponse:
7787

7888
def close(self):
7989
try:
80-
self.conn.send(EnvironmentCommand("close"))
90+
self.conn.send(EnvironmentRequest(EnvironmentCommand.CLOSE))
8191
except (BrokenPipeError, EOFError):
8292
logger.debug(
8393
f"UnityEnvWorker {self.worker_id} got exception trying to close."
@@ -102,7 +112,7 @@ def worker(
102112
engine_configuration_channel.set_configuration(engine_configuration)
103113
env: BaseEnv = None
104114

105-
def _send_response(cmd_name, payload):
115+
def _send_response(cmd_name: EnvironmentCommand, payload: Any) -> None:
106116
parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload))
107117

108118
def _generate_all_results() -> AllStepResult:
@@ -124,9 +134,9 @@ def external_brains():
124134
worker_id, [shared_float_properties, engine_configuration_channel]
125135
)
126136
while True:
127-
cmd: EnvironmentCommand = parent_conn.recv()
128-
if cmd.name == "step":
129-
all_action_info = cmd.payload
137+
req: EnvironmentRequest = parent_conn.recv()
138+
if req.cmd == EnvironmentCommand.STEP:
139+
all_action_info = req.payload
130140
for brain_name, action_info in all_action_info.items():
131141
if len(action_info.action) != 0:
132142
env.set_actions(brain_name, action_info.action)
@@ -138,20 +148,24 @@ def external_brains():
138148
# the data transferred.
139149
# TODO get gauges from the workers and merge them in the main process too.
140150
step_response = StepResponse(all_step_result, get_timer_root())
141-
step_queue.put(EnvironmentResponse("step", worker_id, step_response))
151+
step_queue.put(
152+
EnvironmentResponse(
153+
EnvironmentCommand.STEP, worker_id, step_response
154+
)
155+
)
142156
reset_timers()
143-
elif cmd.name == "external_brains":
144-
_send_response("external_brains", external_brains())
145-
elif cmd.name == "get_properties":
157+
elif req.cmd == EnvironmentCommand.EXTERNAL_BRAINS:
158+
_send_response(EnvironmentCommand.EXTERNAL_BRAINS, external_brains())
159+
elif req.cmd == EnvironmentCommand.GET_PROPERTIES:
146160
reset_params = shared_float_properties.get_property_dict_copy()
147-
_send_response("get_properties", reset_params)
148-
elif cmd.name == "reset":
149-
for k, v in cmd.payload.items():
161+
_send_response(EnvironmentCommand.GET_PROPERTIES, reset_params)
162+
elif req.cmd == EnvironmentCommand.RESET:
163+
for k, v in req.payload.items():
150164
shared_float_properties.set_property(k, v)
151165
env.reset()
152166
all_step_result = _generate_all_results()
153-
_send_response("reset", all_step_result)
154-
elif cmd.name == "close":
167+
_send_response(EnvironmentCommand.RESET, all_step_result)
168+
elif req.cmd == EnvironmentCommand.CLOSE:
155169
break
156170
except (
157171
KeyboardInterrupt,
@@ -160,8 +174,10 @@ def external_brains():
160174
UnityEnvironmentException,
161175
) as ex:
162176
logger.info(f"UnityEnvironment worker {worker_id}: environment stopping.")
163-
step_queue.put(EnvironmentResponse("env_close", worker_id, ex))
164-
_send_response("env_close", ex)
177+
step_queue.put(
178+
EnvironmentResponse(EnvironmentCommand.ENV_EXITED, worker_id, ex)
179+
)
180+
_send_response(EnvironmentCommand.ENV_EXITED, ex)
165181
finally:
166182
# If this worker has put an item in the step queue that hasn't been processed by the EnvManager, the process
167183
# will hang until the item is processed. We avoid this behavior by using Queue.cancel_join_thread()
@@ -222,7 +238,7 @@ def _queue_steps(self) -> None:
222238
if not env_worker.waiting:
223239
env_action_info = self._take_step(env_worker.previous_step)
224240
env_worker.previous_all_action_info = env_action_info
225-
env_worker.send("step", env_action_info)
241+
env_worker.send(EnvironmentCommand.STEP, env_action_info)
226242
env_worker.waiting = True
227243

228244
def _step(self) -> List[EnvironmentStep]:
@@ -236,8 +252,8 @@ def _step(self) -> List[EnvironmentStep]:
236252
while len(worker_steps) < 1:
237253
try:
238254
while True:
239-
step = self.step_queue.get_nowait()
240-
if step.name == "env_close":
255+
step: EnvironmentResponse = self.step_queue.get_nowait()
256+
if step.cmd == EnvironmentCommand.ENV_EXITED:
241257
env_exception: Exception = step.payload
242258
raise env_exception
243259
self.env_workers[step.worker_id].waiting = False
@@ -257,20 +273,20 @@ def _reset_env(self, config: Optional[Dict] = None) -> List[EnvironmentStep]:
257273
self.env_workers[step.worker_id].waiting = False
258274
# First enqueue reset commands for all workers so that they reset in parallel
259275
for ew in self.env_workers:
260-
ew.send("reset", config)
276+
ew.send(EnvironmentCommand.RESET, config)
261277
# Next (synchronously) collect the reset observations from each worker in sequence
262278
for ew in self.env_workers:
263279
ew.previous_step = EnvironmentStep(ew.recv().payload, ew.worker_id, {})
264280
return list(map(lambda ew: ew.previous_step, self.env_workers))
265281

266282
@property
267283
def external_brains(self) -> Dict[AgentGroup, BrainParameters]:
268-
self.env_workers[0].send("external_brains")
284+
self.env_workers[0].send(EnvironmentCommand.EXTERNAL_BRAINS)
269285
return self.env_workers[0].recv().payload
270286

271287
@property
272288
def get_properties(self) -> Dict[AgentGroup, float]:
273-
self.env_workers[0].send("get_properties")
289+
self.env_workers[0].send(EnvironmentCommand.GET_PROPERTIES)
274290
return self.env_workers[0].recv().payload
275291

276292
def close(self) -> None:

ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
SubprocessEnvManager,
99
EnvironmentResponse,
1010
StepResponse,
11+
EnvironmentCommand,
1112
)
1213
from mlagents.trainers.env_manager import EnvironmentStep
1314
from mlagents_envs.base_env import BaseEnv
@@ -38,7 +39,9 @@ def __init__(self, worker_id, resp=None):
3839

3940

4041
def create_worker_mock(worker_id, step_queue, env_factor, engine_c):
41-
return MockEnvWorker(worker_id, EnvironmentResponse("reset", worker_id, worker_id))
42+
return MockEnvWorker(
43+
worker_id, EnvironmentResponse(EnvironmentCommand.RESET, worker_id, worker_id)
44+
)
4245

4346

4447
class SubprocessEnvManagerTest(unittest.TestCase):
@@ -71,7 +74,9 @@ def test_reset_passes_reset_params(self, mock_create_worker):
7174
)
7275
params = {"test": "params"}
7376
manager._reset_env(params)
74-
manager.env_workers[0].send.assert_called_with("reset", (params))
77+
manager.env_workers[0].send.assert_called_with(
78+
EnvironmentCommand.RESET, (params)
79+
)
7580

7681
@mock.patch(
7782
"mlagents.trainers.subprocess_env_manager.SubprocessEnvManager.create_worker"
@@ -85,7 +90,7 @@ def test_reset_collects_results_from_all_envs(self, mock_create_worker):
8590
params = {"test": "params"}
8691
res = manager._reset_env(params)
8792
for i, env in enumerate(manager.env_workers):
88-
env.send.assert_called_with("reset", (params))
93+
env.send.assert_called_with(EnvironmentCommand.RESET, (params))
8994
env.recv.assert_called()
9095
# Check that the "last steps" are set to the value returned for each step
9196
self.assertEqual(
@@ -103,8 +108,8 @@ def test_step_takes_steps_for_all_non_waiting_envs(self, mock_create_worker):
103108
)
104109
manager.step_queue = Mock()
105110
manager.step_queue.get_nowait.side_effect = [
106-
EnvironmentResponse("step", 0, StepResponse(0, None)),
107-
EnvironmentResponse("step", 1, StepResponse(1, None)),
111+
EnvironmentResponse(EnvironmentCommand.STEP, 0, StepResponse(0, None)),
112+
EnvironmentResponse(EnvironmentCommand.STEP, 1, StepResponse(1, None)),
108113
EmptyQueue(),
109114
]
110115
step_mock = Mock()
@@ -117,7 +122,7 @@ def test_step_takes_steps_for_all_non_waiting_envs(self, mock_create_worker):
117122
res = manager._step()
118123
for i, env in enumerate(manager.env_workers):
119124
if i < 2:
120-
env.send.assert_called_with("step", step_mock)
125+
env.send.assert_called_with(EnvironmentCommand.STEP, step_mock)
121126
manager.step_queue.get_nowait.assert_called()
122127
# Check that the "last steps" are set to the value returned for each step
123128
self.assertEqual(

0 commit comments

Comments
 (0)