Skip to content

Commit 30bcb01

Browse files
author
Chris Elion
authored
[MLA-1712] Make UnityEnvironment fail fast if the env crashes (#4880)
1 parent 2f71f72 commit 30bcb01

File tree

11 files changed

+168
-42
lines changed

11 files changed

+168
-42
lines changed

Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Linq;
44
using Unity.MLAgents;
55
using Unity.MLAgents.Actuators;
6+
using UnityEngine.Rendering;
67
using UnityEngine.Serialization;
78

89
public class GridAgent : Agent
@@ -150,7 +151,7 @@ public void FixedUpdate()
150151

151152
void WaitTimeInference()
152153
{
153-
if (renderCamera != null)
154+
if (renderCamera != null && SystemInfo.graphicsDeviceType != GraphicsDeviceType.Null)
154155
{
155156
renderCamera.Render();
156157
}

Project/ProjectSettings/UnityConnectSettings.asset

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
UnityConnectSettings:
55
m_ObjectHideFlags: 0
66
serializedVersion: 1
7-
m_Enabled: 1
7+
m_Enabled: 0
88
m_TestMode: 0
99
m_EventOldUrl: https://api.uca.cloud.unity3d.com/v1/events
1010
m_EventUrl: https://cdp.cloud.unity3d.com/v1/events

com.unity.ml-agents/CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,22 @@ removed when training with a player. The Editor still requires it to be clamped
2424
Updated the Basic example and the Match3 Example to use Actuators.
2525
Changed the namespace and file names of classes in com.unity.ml-agents.extensions. (#4849)
2626

27-
2827
#### ml-agents / ml-agents-envs / gym-unity (Python)
2928

3029
### Bug Fixes
3130
#### com.unity.ml-agents (C#)
3231
- Fix a compile warning about using an obsolete enum in `GrpcExtensions.cs`. (#4812)
32+
- CameraSensor now logs an error if the GraphicsDevice is null. (#4880)
3333
#### ml-agents / ml-agents-envs / gym-unity (Python)
3434
- Fixed a bug that would cause an exception when `RunOptions` was deserialized via `pickle`. (#4842)
3535
- Fixed a bug that can cause a crash if a behavior can appear during training in multi-environment training. (#4872)
3636
- Fixed the computation of entropy for continuous actions. (#4869)
37+
- Fixed a bug that would cause `UnityEnvironment` to wait the full timeout
38+
period and report a misleading error message if the executable crashed
39+
without closing the connection. It now periodically checks the process status
40+
while waiting for a connection, and raises a better error message if it crashes. (#4880)
41+
- Passing a `-logfile` option in the `--env-args` option to `mlagents-learn` is
42+
no longer overwritten. (#4880)
3743

3844

3945
## [1.7.2-preview] - 2020-12-22

com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using UnityEngine;
2+
using UnityEngine.Rendering;
23

34
namespace Unity.MLAgents.Sensors
45
{
@@ -128,6 +129,11 @@ public SensorCompressionType GetCompressionType()
128129
/// <returns name="texture2D">Texture2D to render to.</returns>
129130
public static Texture2D ObservationToTexture(Camera obsCamera, int width, int height)
130131
{
132+
if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null)
133+
{
134+
Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render.");
135+
}
136+
131137
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
132138
var oldRec = obsCamera.rect;
133139
obsCamera.rect = new Rect(0f, 0f, 1f, 1f);

ml-agents-envs/mlagents_envs/communicator.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
from typing import Optional
1+
from typing import Callable, Optional
22
from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto
33
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
44

55

6+
# Function to call while waiting for a connection timeout.
7+
# This should raise an exception if it needs to break from waiting for the timeout.
8+
PollCallback = Callable[[], None]
9+
10+
611
class Communicator:
712
def __init__(self, worker_id=0, base_port=5005):
813
"""
@@ -12,17 +17,23 @@ def __init__(self, worker_id=0, base_port=5005):
1217
:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
1318
"""
1419

15-
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
20+
def initialize(
21+
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
22+
) -> UnityOutputProto:
1623
"""
1724
Used to exchange initialization parameters between Python and the Environment
1825
:param inputs: The initialization input that will be sent to the environment.
26+
:param poll_callback: Optional callback to be used while polling the connection.
1927
:return: UnityOutput: The initialization output sent by Unity
2028
"""
2129

22-
def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]:
30+
def exchange(
31+
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
32+
) -> Optional[UnityOutputProto]:
2333
"""
2434
Used to send an input and receive an output from the Environment
2535
:param inputs: The UnityInput that needs to be sent the Environment
36+
:param poll_callback: Optional callback to be used while polling the connection.
2637
:return: The UnityOutputs generated by the Environment
2738
"""
2839

ml-agents-envs/mlagents_envs/env_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from mlagents_envs.exception import UnityEnvironmentException
88

99

10+
logger = get_logger(__name__)
11+
12+
1013
def get_platform():
1114
"""
1215
returns the platform of the operating system : linux, darwin or win32
@@ -27,7 +30,7 @@ def validate_environment_path(env_path: str) -> Optional[str]:
2730
.replace(".x86", "")
2831
)
2932
true_filename = os.path.basename(os.path.normpath(env_path))
30-
get_logger(__name__).debug(f"The true file name is {true_filename}")
33+
logger.debug(f"The true file name is {true_filename}")
3134

3235
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
3336
return None
@@ -99,7 +102,8 @@ def launch_executable(file_name: str, args: List[str]) -> subprocess.Popen:
99102
f"Couldn't launch the {file_name} environment. Provided filename does not match any environments."
100103
)
101104
else:
102-
get_logger(__name__).debug(f"This is the launch string {launch_string}")
105+
logger.debug(f"The launch string is {launch_string}")
106+
logger.debug(f"Running with args {args}")
103107
# Launch Unity environment
104108
subprocess_args = [launch_string] + args
105109
try:

ml-agents-envs/mlagents_envs/environment.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def __init__(
177177
# If true, this means the environment was successfully loaded
178178
self._loaded = False
179179
# The process that is started. If None, no process was started
180-
self._proc1 = None
180+
self._process: Optional[subprocess.Popen] = None
181181
self._timeout_wait: int = timeout_wait
182182
self._communicator = self._get_communicator(worker_id, base_port, timeout_wait)
183183
self._worker_id = worker_id
@@ -194,7 +194,7 @@ def __init__(
194194
)
195195
if file_name is not None:
196196
try:
197-
self._proc1 = env_utils.launch_executable(
197+
self._process = env_utils.launch_executable(
198198
file_name, self._executable_args()
199199
)
200200
except UnityEnvironmentException:
@@ -249,7 +249,11 @@ def _executable_args(self) -> List[str]:
249249
if self._no_graphics:
250250
args += ["-nographics", "-batchmode"]
251251
args += [UnityEnvironment._PORT_COMMAND_LINE_ARG, str(self._port)]
252-
if self._log_folder:
252+
253+
# If the logfile arg isn't already set in the env args,
254+
# try to set it to an output directory
255+
logfile_set = "-logfile" in (arg.lower() for arg in self._additional_args)
256+
if self._log_folder and not logfile_set:
253257
log_file_path = os.path.join(
254258
self._log_folder, f"Player-{self._worker_id}.log"
255259
)
@@ -289,7 +293,9 @@ def _update_state(self, output: UnityRLOutputProto) -> None:
289293

290294
def reset(self) -> None:
291295
if self._loaded:
292-
outputs = self._communicator.exchange(self._generate_reset_input())
296+
outputs = self._communicator.exchange(
297+
self._generate_reset_input(), self._poll_process
298+
)
293299
if outputs is None:
294300
raise UnityCommunicatorStoppedException("Communicator has exited.")
295301
self._update_behavior_specs(outputs)
@@ -317,7 +323,7 @@ def step(self) -> None:
317323
].action_spec.empty_action(n_agents)
318324
step_input = self._generate_step_input(self._env_actions)
319325
with hierarchical_timer("communicator.exchange"):
320-
outputs = self._communicator.exchange(step_input)
326+
outputs = self._communicator.exchange(step_input, self._poll_process)
321327
if outputs is None:
322328
raise UnityCommunicatorStoppedException("Communicator has exited.")
323329
self._update_behavior_specs(outputs)
@@ -377,6 +383,18 @@ def get_steps(
377383
self._assert_behavior_exists(behavior_name)
378384
return self._env_state[behavior_name]
379385

386+
def _poll_process(self) -> None:
387+
"""
388+
Check the status of the subprocess. If it has exited, raise a UnityEnvironmentException
389+
:return: None
390+
"""
391+
if not self._process:
392+
return
393+
poll_res = self._process.poll()
394+
if poll_res is not None:
395+
exc_msg = self._returncode_to_env_message(self._process.returncode)
396+
raise UnityEnvironmentException(exc_msg)
397+
380398
def close(self):
381399
"""
382400
Sends a shutdown signal to the unity environment, and closes the socket connection.
@@ -397,19 +415,16 @@ def _close(self, timeout: Optional[int] = None) -> None:
397415
timeout = self._timeout_wait
398416
self._loaded = False
399417
self._communicator.close()
400-
if self._proc1 is not None:
418+
if self._process is not None:
401419
# Wait a bit for the process to shutdown, but kill it if it takes too long
402420
try:
403-
self._proc1.wait(timeout=timeout)
404-
signal_name = self._returncode_to_signal_name(self._proc1.returncode)
405-
signal_name = f" ({signal_name})" if signal_name else ""
406-
return_info = f"Environment shut down with return code {self._proc1.returncode}{signal_name}."
407-
logger.info(return_info)
421+
self._process.wait(timeout=timeout)
422+
logger.info(self._returncode_to_env_message(self._process.returncode))
408423
except subprocess.TimeoutExpired:
409424
logger.info("Environment timed out shutting down. Killing...")
410-
self._proc1.kill()
425+
self._process.kill()
411426
# Set to None so we don't try to close multiple times.
412-
self._proc1 = None
427+
self._process = None
413428

414429
@timed
415430
def _generate_step_input(
@@ -452,7 +467,7 @@ def _send_academy_parameters(
452467
) -> UnityOutputProto:
453468
inputs = UnityInputProto()
454469
inputs.rl_initialization_input.CopyFrom(init_parameters)
455-
return self._communicator.initialize(inputs)
470+
return self._communicator.initialize(inputs, self._poll_process)
456471

457472
@staticmethod
458473
def _wrap_unity_input(rl_input: UnityRLInputProto) -> UnityInputProto:
@@ -473,3 +488,9 @@ def _returncode_to_signal_name(returncode: int) -> Optional[str]:
473488
except Exception:
474489
# Should generally be a ValueError, but catch everything just in case.
475490
return None
491+
492+
@staticmethod
493+
def _returncode_to_env_message(returncode: int) -> str:
494+
signal_name = UnityEnvironment._returncode_to_signal_name(returncode)
495+
signal_name = f" ({signal_name})" if signal_name else ""
496+
return f"Environment shut down with return code {returncode}{signal_name}."

ml-agents-envs/mlagents_envs/mock_communicator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from .communicator import Communicator
1+
from typing import Optional
2+
3+
from .communicator import Communicator, PollCallback
24
from .environment import UnityEnvironment
35
from mlagents_envs.communicator_objects.unity_rl_output_pb2 import UnityRLOutputProto
46
from mlagents_envs.communicator_objects.brain_parameters_pb2 import (
@@ -39,7 +41,9 @@ def __init__(
3941
self.brain_name = brain_name
4042
self.vec_obs_size = vec_obs_size
4143

42-
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
44+
def initialize(
45+
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
46+
) -> UnityOutputProto:
4347
if self.is_discrete:
4448
action_spec = ActionSpecProto(
4549
num_discrete_actions=2, discrete_branch_sizes=[3, 2]
@@ -94,7 +98,9 @@ def _get_agent_infos(self):
9498
)
9599
return dict_agent_info
96100

97-
def exchange(self, inputs: UnityInputProto) -> UnityOutputProto:
101+
def exchange(
102+
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
103+
) -> UnityOutputProto:
98104
result = UnityRLOutputProto(agentInfos=self._get_agent_infos())
99105
return UnityOutputProto(rl_output=result)
100106

ml-agents-envs/mlagents_envs/rpc_communicator.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import grpc
22
from typing import Optional
33

4+
from multiprocessing import Pipe
45
from sys import platform
56
import socket
6-
from multiprocessing import Pipe
7+
import time
78
from concurrent.futures import ThreadPoolExecutor
89

9-
from .communicator import Communicator
10+
from .communicator import Communicator, PollCallback
1011
from mlagents_envs.communicator_objects.unity_to_external_pb2_grpc import (
1112
UnityToExternalProtoServicer,
1213
add_UnityToExternalProtoServicer_to_server,
@@ -86,22 +87,38 @@ def check_port(self, port):
8687
finally:
8788
s.close()
8889

89-
def poll_for_timeout(self):
90+
def poll_for_timeout(self, poll_callback: Optional[PollCallback] = None) -> None:
9091
"""
9192
Polls the GRPC parent connection for data, to be used before calling recv. This prevents
9293
us from hanging indefinitely in the case where the environment process has died or was not
9394
launched.
94-
"""
95-
if not self.unity_to_external.parent_conn.poll(self.timeout_wait):
96-
raise UnityTimeOutException(
97-
"The Unity environment took too long to respond. Make sure that :\n"
98-
"\t The environment does not need user interaction to launch\n"
99-
'\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n'
100-
"\t The environment and the Python interface have compatible versions."
101-
)
10295
103-
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
104-
self.poll_for_timeout()
96+
Additionally, a callback can be passed to periodically check the state of the environment.
97+
This is used to detect the case when the environment dies without cleaning up the connection,
98+
so that we can stop sooner and raise a more appropriate error.
99+
"""
100+
deadline = time.monotonic() + self.timeout_wait
101+
callback_timeout_wait = self.timeout_wait // 10
102+
while time.monotonic() < deadline:
103+
if self.unity_to_external.parent_conn.poll(callback_timeout_wait):
104+
# Got an acknowledgment from the connection
105+
return
106+
if poll_callback:
107+
# Fire the callback - if it detects something wrong, it should raise an exception.
108+
poll_callback()
109+
110+
# Got this far without reading any data from the connection, so it must be dead.
111+
raise UnityTimeOutException(
112+
"The Unity environment took too long to respond. Make sure that :\n"
113+
"\t The environment does not need user interaction to launch\n"
114+
'\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n'
115+
"\t The environment and the Python interface have compatible versions."
116+
)
117+
118+
def initialize(
119+
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
120+
) -> UnityOutputProto:
121+
self.poll_for_timeout(poll_callback)
105122
aca_param = self.unity_to_external.parent_conn.recv().unity_output
106123
message = UnityMessageProto()
107124
message.header.status = 200
@@ -110,12 +127,14 @@ def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
110127
self.unity_to_external.parent_conn.recv()
111128
return aca_param
112129

113-
def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]:
130+
def exchange(
131+
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
132+
) -> Optional[UnityOutputProto]:
114133
message = UnityMessageProto()
115134
message.header.status = 200
116135
message.unity_input.CopyFrom(inputs)
117136
self.unity_to_external.parent_conn.send(message)
118-
self.poll_for_timeout()
137+
self.poll_for_timeout(poll_callback)
119138
output = self.unity_to_external.parent_conn.recv()
120139
if output.header.status != 200:
121140
return None

0 commit comments

Comments
 (0)