1
1
import logging
2
2
from typing import Dict , NamedTuple , List , Any , Optional , Callable , Set
3
3
import cloudpickle
4
+ import enum
4
5
5
6
from mlagents_envs .environment import UnityEnvironment
6
7
from mlagents_envs .exception import (
33
34
logger = logging .getLogger ("mlagents.trainers" )
34
35
35
36
36
- class EnvironmentCommand (NamedTuple ):
37
- name : str
37
+ class EnvironmentCommand (enum .Enum ):
38
+ STEP = 1
39
+ EXTERNAL_BRAINS = 2
40
+ GET_PROPERTIES = 3
41
+ RESET = 4
42
+ CLOSE = 5
43
+ ENV_EXITED = 6
44
+
45
+
46
+ class EnvironmentRequest (NamedTuple ):
47
+ cmd : EnvironmentCommand
38
48
payload : Any = None
39
49
40
50
41
51
class EnvironmentResponse (NamedTuple ):
42
- name : str
52
+ cmd : EnvironmentCommand
43
53
worker_id : int
44
54
payload : Any
45
55
@@ -58,17 +68,17 @@ def __init__(self, process: Process, worker_id: int, conn: Connection):
58
68
self .previous_all_action_info : Dict [str , ActionInfo ] = {}
59
69
self .waiting = False
60
70
61
- def send (self , name : str , payload : Any = None ) -> None :
71
+ def send (self , cmd : EnvironmentCommand , payload : Any = None ) -> None :
62
72
try :
63
- cmd = EnvironmentCommand ( name , payload )
64
- self .conn .send (cmd )
73
+ req = EnvironmentRequest ( cmd , payload )
74
+ self .conn .send (req )
65
75
except (BrokenPipeError , EOFError ):
66
76
raise UnityCommunicationException ("UnityEnvironment worker: send failed." )
67
77
68
78
def recv (self ) -> EnvironmentResponse :
69
79
try :
70
80
response : EnvironmentResponse = self .conn .recv ()
71
- if response .name == "env_close" :
81
+ if response .cmd == EnvironmentCommand . ENV_EXITED :
72
82
env_exception : Exception = response .payload
73
83
raise env_exception
74
84
return response
@@ -77,7 +87,7 @@ def recv(self) -> EnvironmentResponse:
77
87
78
88
def close (self ):
79
89
try :
80
- self .conn .send (EnvironmentCommand ( "close" ))
90
+ self .conn .send (EnvironmentRequest ( EnvironmentCommand . CLOSE ))
81
91
except (BrokenPipeError , EOFError ):
82
92
logger .debug (
83
93
f"UnityEnvWorker { self .worker_id } got exception trying to close."
@@ -102,7 +112,7 @@ def worker(
102
112
engine_configuration_channel .set_configuration (engine_configuration )
103
113
env : BaseEnv = None
104
114
105
- def _send_response (cmd_name , payload ) :
115
+ def _send_response (cmd_name : EnvironmentCommand , payload : Any ) -> None :
106
116
parent_conn .send (EnvironmentResponse (cmd_name , worker_id , payload ))
107
117
108
118
def _generate_all_results () -> AllStepResult :
@@ -124,9 +134,9 @@ def external_brains():
124
134
worker_id , [shared_float_properties , engine_configuration_channel ]
125
135
)
126
136
while True :
127
- cmd : EnvironmentCommand = parent_conn .recv ()
128
- if cmd . name == "step" :
129
- all_action_info = cmd .payload
137
+ req : EnvironmentRequest = parent_conn .recv ()
138
+ if req . cmd == EnvironmentCommand . STEP :
139
+ all_action_info = req .payload
130
140
for brain_name , action_info in all_action_info .items ():
131
141
if len (action_info .action ) != 0 :
132
142
env .set_actions (brain_name , action_info .action )
@@ -138,20 +148,24 @@ def external_brains():
138
148
# the data transferred.
139
149
# TODO get gauges from the workers and merge them in the main process too.
140
150
step_response = StepResponse (all_step_result , get_timer_root ())
141
- step_queue .put (EnvironmentResponse ("step" , worker_id , step_response ))
151
+ step_queue .put (
152
+ EnvironmentResponse (
153
+ EnvironmentCommand .STEP , worker_id , step_response
154
+ )
155
+ )
142
156
reset_timers ()
143
- elif cmd . name == "external_brains" :
144
- _send_response ("external_brains" , external_brains ())
145
- elif cmd . name == "get_properties" :
157
+ elif req . cmd == EnvironmentCommand . EXTERNAL_BRAINS :
158
+ _send_response (EnvironmentCommand . EXTERNAL_BRAINS , external_brains ())
159
+ elif req . cmd == EnvironmentCommand . GET_PROPERTIES :
146
160
reset_params = shared_float_properties .get_property_dict_copy ()
147
- _send_response ("get_properties" , reset_params )
148
- elif cmd . name == "reset" :
149
- for k , v in cmd .payload .items ():
161
+ _send_response (EnvironmentCommand . GET_PROPERTIES , reset_params )
162
+ elif req . cmd == EnvironmentCommand . RESET :
163
+ for k , v in req .payload .items ():
150
164
shared_float_properties .set_property (k , v )
151
165
env .reset ()
152
166
all_step_result = _generate_all_results ()
153
- _send_response ("reset" , all_step_result )
154
- elif cmd . name == "close" :
167
+ _send_response (EnvironmentCommand . RESET , all_step_result )
168
+ elif req . cmd == EnvironmentCommand . CLOSE :
155
169
break
156
170
except (
157
171
KeyboardInterrupt ,
@@ -160,8 +174,10 @@ def external_brains():
160
174
UnityEnvironmentException ,
161
175
) as ex :
162
176
logger .info (f"UnityEnvironment worker { worker_id } : environment stopping." )
163
- step_queue .put (EnvironmentResponse ("env_close" , worker_id , ex ))
164
- _send_response ("env_close" , ex )
177
+ step_queue .put (
178
+ EnvironmentResponse (EnvironmentCommand .ENV_EXITED , worker_id , ex )
179
+ )
180
+ _send_response (EnvironmentCommand .ENV_EXITED , ex )
165
181
finally :
166
182
# If this worker has put an item in the step queue that hasn't been processed by the EnvManager, the process
167
183
# will hang until the item is processed. We avoid this behavior by using Queue.cancel_join_thread()
@@ -222,7 +238,7 @@ def _queue_steps(self) -> None:
222
238
if not env_worker .waiting :
223
239
env_action_info = self ._take_step (env_worker .previous_step )
224
240
env_worker .previous_all_action_info = env_action_info
225
- env_worker .send ("step" , env_action_info )
241
+ env_worker .send (EnvironmentCommand . STEP , env_action_info )
226
242
env_worker .waiting = True
227
243
228
244
def _step (self ) -> List [EnvironmentStep ]:
@@ -236,8 +252,8 @@ def _step(self) -> List[EnvironmentStep]:
236
252
while len (worker_steps ) < 1 :
237
253
try :
238
254
while True :
239
- step = self .step_queue .get_nowait ()
240
- if step .name == "env_close" :
255
+ step : EnvironmentResponse = self .step_queue .get_nowait ()
256
+ if step .cmd == EnvironmentCommand . ENV_EXITED :
241
257
env_exception : Exception = step .payload
242
258
raise env_exception
243
259
self .env_workers [step .worker_id ].waiting = False
@@ -257,20 +273,20 @@ def _reset_env(self, config: Optional[Dict] = None) -> List[EnvironmentStep]:
257
273
self .env_workers [step .worker_id ].waiting = False
258
274
# First enqueue reset commands for all workers so that they reset in parallel
259
275
for ew in self .env_workers :
260
- ew .send ("reset" , config )
276
+ ew .send (EnvironmentCommand . RESET , config )
261
277
# Next (synchronously) collect the reset observations from each worker in sequence
262
278
for ew in self .env_workers :
263
279
ew .previous_step = EnvironmentStep (ew .recv ().payload , ew .worker_id , {})
264
280
return list (map (lambda ew : ew .previous_step , self .env_workers ))
265
281
266
282
@property
267
283
def external_brains (self ) -> Dict [AgentGroup , BrainParameters ]:
268
- self .env_workers [0 ].send ("external_brains" )
284
+ self .env_workers [0 ].send (EnvironmentCommand . EXTERNAL_BRAINS )
269
285
return self .env_workers [0 ].recv ().payload
270
286
271
287
@property
272
288
def get_properties (self ) -> Dict [AgentGroup , float ]:
273
- self .env_workers [0 ].send ("get_properties" )
289
+ self .env_workers [0 ].send (EnvironmentCommand . GET_PROPERTIES )
274
290
return self .env_workers [0 ].recv ().payload
275
291
276
292
def close (self ) -> None :
0 commit comments