Skip to content

Commit e2885b0

Browse files
Update sagemaker debugger TF actions notebook to use built-in actions (aws#1964)
* Update TF actions notebook to use built-in actions
1 parent 0e57a28 commit e2885b0

File tree

1 file changed

+61
-20
lines changed

1 file changed

+61
-20
lines changed

sagemaker-debugger/tensorflow_action_on_rule/detect_stalled_training_job_and_stop.ipynb renamed to sagemaker-debugger/tensorflow_action_on_rule/detect_stalled_training_job_and_actions.ipynb

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7-
"# Detect Stalled Training and Stop Training Job Using SageMaker Debugger Rule\n",
7+
"# Detect Stalled Training and Invoke Actions Using SageMaker Debugger Rule\n",
88
" \n",
9-
"This notebook shows you how to use the `StalledTrainingRule` built-in rule. This rule can take an action to stop your training job, when the rule detects an inactivity in your training job for a certain time period. This functionality helps you monitor the training job status and reduces redundant resource usage.\n",
9+
"This notebook shows you how to use the `StalledTrainingRule` built-in rule. This rule can take an action to stop your training job or send you an email/SMS, when the rule detects an inactivity in your training job for a certain time period. This functionality helps you monitor the training job status and reduces redundant resource usage.\n",
1010
"\n",
1111
"## How the StalledTrainingRule Built-in Rule Works\n",
1212
"\n",
@@ -17,6 +17,23 @@
1717
"The Debugger `StalledTrainingRule` watches tensor updates from your training job. If the rule doesn't find new tensors updated to the default S3 URI for a threshold period of time, it takes an action to trigger the `StopTrainingJob` API operation. The following code cells set up a SageMaker TensorFlow estimator with the Debugger `StalledTrainingRule` to watch the `losses` pre-built tensor collection."
1818
]
1919
},
20+
{
21+
"cell_type": "markdown",
22+
"metadata": {},
23+
"source": [
24+
"### Install custom packages\n",
25+
"These packages were built manually with the changes needed to run rules with actions, since the changes have not been released yet. Remember to refresh the kernel after installing these packages"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"! pip install -q -U sagemaker"
35+
]
36+
},
2037
{
2138
"cell_type": "markdown",
2239
"metadata": {},
@@ -55,15 +72,15 @@
5572
"cell_type": "markdown",
5673
"metadata": {},
5774
"source": [
58-
"### Create a unique training job prefix\n",
59-
"A unique prefix must be specified for `StalledTrainingRule` to identify the exact training job name that you want to monitor and stop when the rule triggers the stalled training job issue.\n",
60-
"If there are multiple training jobs sharing the same prefix, this rule may react to other training jobs. If the rule cannot find the exact training job name with a provided prefix, it falls back to safe mode and does not stop the training job. The rule evaluation process goes on in parallel while the training jobs are running. If you want to access the rule job logs, you will later find how to get the information at [Get a direct Amazon CloudWatch URL to find the current rule processing job log](#cw-url).\n",
75+
"### Create the actions to be used in the rules\n",
6176
"\n",
62-
"The following code cell includes:\n",
63-
"* a code line to create a unique `base_job_name_prefix`\n",
64-
"* a stalled training job rule configuration object\n",
77+
"The following code cells include:\n",
78+
"* a code line to create the action objects\n",
79+
"* a stalled training job rule configuration object that uses these actions\n",
6580
"* a SageMaker TensorFlow estimator configuration with the Debugger `rules` parameter to run the built-in rule\n",
6681
"\n",
82+
"Valid action objects are individual actions (`StopTraining`, `Email`, `SMS`) or an `ActionList` with a combination of these.\n",
83+
"\n",
6784
"**Note**: Debugger collects `loss` tensors by default every 500 steps."
6885
]
6986
},
@@ -73,32 +90,46 @@
7390
"metadata": {},
7491
"outputs": [],
7592
"source": [
76-
"# Append current time to your training job name to generate a unique base_job_name_prefix\n",
77-
"import time\n",
78-
"base_job_name_prefix= 'smdebug-stalled-demo-' + str(int(time.time()))\n",
79-
"\n",
93+
"training_job_prefix = None # Feel free to customize this if desired."
94+
]
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": null,
99+
"metadata": {},
100+
"outputs": [],
101+
"source": [
102+
"stop_training_action = rule_configs.StopTraining() # or specify a training job prefix with StopTraining(\"prefix\")\n",
103+
"actions = stop_training_action"
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": null,
109+
"metadata": {},
110+
"outputs": [],
111+
"source": [
80112
"# Configure a StalledTrainingRule rule parameter object\n",
81113
"stalled_training_job_rule = [\n",
82114
" Rule.sagemaker(\n",
83115
" base_config=rule_configs.stalled_training_rule(),\n",
84116
" rule_parameters={\n",
85-
" \"threshold\": \"120\", \n",
86-
" \"stop_training_on_fire\": \"True\",\n",
87-
" \"training_job_name_prefix\": base_job_name_prefix\n",
88-
" }\n",
117+
" \"threshold\": \"60\", \n",
118+
" },\n",
119+
" actions=actions\n",
89120
" )\n",
90121
"]\n",
91122
"\n",
92123
"# Configure a SageMaker TensorFlow estimator\n",
93124
"estimator = TensorFlow(\n",
94125
" role=sagemaker.get_execution_role(),\n",
95-
" base_job_name=base_job_name_prefix,\n",
96-
" train_instance_count=1,\n",
97-
" train_instance_type='ml.m5.4xlarge',\n",
126+
" base_job_name=\"stalled-training-test\",\n",
127+
" instance_count=1,\n",
128+
" instance_type='ml.m5.4xlarge',\n",
98129
" entry_point='src/simple_stalled_training.py', # This sample script forces the training job to sleep for 10 minutes\n",
99130
" framework_version='1.15.0',\n",
100131
" py_version='py3',\n",
101-
" train_max_run=3600,\n",
132+
" max_run=3600,\n",
102133
" ## Debugger-specific parameter\n",
103134
" rules = stalled_training_job_rule\n",
104135
")"
@@ -177,6 +208,16 @@
177208
" time.sleep(15)"
178209
]
179210
},
211+
{
212+
"cell_type": "code",
213+
"execution_count": null,
214+
"metadata": {},
215+
"outputs": [],
216+
"source": [
217+
"description = client.describe_training_job(TrainingJobName=job_name)\n",
218+
"print(description)"
219+
]
220+
},
180221
{
181222
"cell_type": "markdown",
182223
"metadata": {},

0 commit comments

Comments
 (0)