Skip to content

Commit 77c2504

Browse files
author
Chris Elion
authored
allow inference test timeouts (and pass from commandline) (#4932)
2 parents b948f87 + e93c316 commit 77c2504

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public class ModelOverrider : MonoBehaviour
2828
const string k_CommandLineModelOverrideDirectoryFlag = "--mlagents-override-model-directory";
2929
const string k_CommandLineModelOverrideExtensionFlag = "--mlagents-override-model-extension";
3030
const string k_CommandLineQuitAfterEpisodesFlag = "--mlagents-quit-after-episodes";
31+
const string k_CommandLineQuitAfterSeconds = "--mlagents-quit-after-seconds";
3132
const string k_CommandLineQuitOnLoadFailure = "--mlagents-quit-on-load-failure";
3233

3334
// The attached Agent
@@ -45,6 +46,9 @@ public class ModelOverrider : MonoBehaviour
4546
// Will default to 1 if override models are specified, otherwise 0.
4647
int m_MaxEpisodes;
4748

49+
// Deadline - exit if the time exceeds this
50+
DateTime m_Deadline = DateTime.MaxValue;
51+
4852
int m_NumSteps;
4953
int m_PreviousNumSteps;
5054
int m_PreviousAgentCompletedEpisodes;
@@ -89,6 +93,8 @@ public static string GetOverrideBehaviorName(string originalBehaviorName)
8993
void GetAssetPathFromCommandLine()
9094
{
9195
var maxEpisodes = 0;
96+
var timeoutSeconds = 0;
97+
9298
string[] commandLineArgsOverride = null;
9399
if (!string.IsNullOrEmpty(debugCommandLineOverride) && Application.isEditor)
94100
{
@@ -120,6 +126,10 @@ void GetAssetPathFromCommandLine()
120126
{
121127
Int32.TryParse(args[i + 1], out maxEpisodes);
122128
}
129+
else if (args[i] == k_CommandLineQuitAfterSeconds && i < args.Length - 1)
130+
{
131+
Int32.TryParse(args[i + 1], out timeoutSeconds);
132+
}
123133
else if (args[i] == k_CommandLineQuitOnLoadFailure)
124134
{
125135
m_QuitOnLoadFailure = true;
@@ -132,6 +142,13 @@ void GetAssetPathFromCommandLine()
132142
m_MaxEpisodes = maxEpisodes > 0 ? maxEpisodes : 1;
133143
Debug.Log($"setting m_MaxEpisodes to {maxEpisodes}");
134144
}
145+
146+
if (timeoutSeconds > 0)
147+
{
148+
m_Deadline = DateTime.Now + TimeSpan.FromSeconds(timeoutSeconds);
149+
Debug.Log($"setting deadline to {timeoutSeconds} from now.");
150+
151+
}
135152
}
136153

137154
void OnEnable()
@@ -172,9 +189,21 @@ void FixedUpdate()
172189
Application.Quit(0);
173190
#if UNITY_EDITOR
174191
EditorApplication.isPlaying = false;
192+
#endif
193+
}
194+
else if (DateTime.Now >= m_Deadline)
195+
{
196+
Debug.Log(
197+
$"Deadline exceeded. " +
198+
$"{TotalCompletedEpisodes}/{m_MaxEpisodes} episodes and " +
199+
$"{TotalNumSteps}/{m_MaxEpisodes * m_Agent.MaxStep} steps completed. Exiting.");
200+
Application.Quit(0);
201+
#if UNITY_EDITOR
202+
EditorApplication.isPlaying = false;
175203
#endif
176204
}
177205
}
206+
178207
m_NumSteps++;
179208
}
180209

ml-agents/tests/yamato/training_int_tests.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,11 @@ def run_inference(env_path: str, output_path: str, model_extension: str) -> bool
133133

134134
log_output_path = f"{get_base_output_path()}/inference.{model_extension}.txt"
135135

136+
# 10 minutes for inference is more than enough
137+
process_timeout = 10 * 60
138+
# Try to gracefully exit a few seconds before that.
139+
model_override_timeout = process_timeout - 15
140+
136141
exe_path = exes[0]
137142
args = [
138143
exe_path,
@@ -147,10 +152,11 @@ def run_inference(env_path: str, output_path: str, model_extension: str) -> bool
147152
"1",
148153
"--mlagents-override-model-extension",
149154
model_extension,
155+
"--mlagents-quit-after-seconds",
156+
str(model_override_timeout),
150157
]
151158
print(f"Starting inference with args {' '.join(args)}")
152-
timeout = 15 * 60 # 15 minutes for inference is more than enough
153-
res = subprocess.run(args, timeout=timeout)
159+
res = subprocess.run(args, timeout=process_timeout)
154160
end_time = time.time()
155161
if res.returncode != 0:
156162
print("Error running inference!")
@@ -166,7 +172,7 @@ def run_inference(env_path: str, output_path: str, model_extension: str) -> bool
166172
timer_data = json.load(f)
167173

168174
gauges = timer_data.get("gauges", {})
169-
rewards = gauges.get("Override_3DBall.CumulativeReward")
175+
rewards = gauges.get("Override_3DBall.CumulativeReward", {})
170176
max_reward = rewards.get("max")
171177
if max_reward is None:
172178
print(

0 commit comments

Comments
 (0)