Skip to content

Commit

Permalink
[AWS SageMaker] Add more unit tests (#3783)
Browse files Browse the repository at this point in the history
* add more tests for deploy and ground_truth components

* add more tests for workteam component

* add unit tests for model component

* add more unit tests for batchTransform component

* add more tests

* add 'request' function tests

* add more unit tests for ground truth
  • Loading branch information
akartsky committed May 27, 2020
1 parent 695573d commit b503050
Show file tree
Hide file tree
Showing 13 changed files with 648 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import argparse
import logging
from pathlib2 import Path
Expand Down Expand Up @@ -54,7 +55,7 @@ def create_parser():

def main(argv=None):
parser = create_parser()
args = parser.parse_args()
args = parser.parse_args(argv)

logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
Expand All @@ -64,10 +65,11 @@ def main(argv=None):
_utils.wait_for_transform_job(client, batch_job_name)

Path(args.output_location_file).parent.mkdir(parents=True, exist_ok=True)
Path(args.output_location_file).write_text(unicode(args.output_location))
with open(args.output_location_file, 'w') as f:
f.write(unicode(args.output_location))

logging.info('Batch Transformation creation completed.')


if __name__== "__main__":
main()
main(sys.argv[1:])
40 changes: 18 additions & 22 deletions components/aws/sagemaker/common/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,21 +254,17 @@ def create_model_request(args):
else:
request['PrimaryContainer'].pop('ContainerHostname')

if (args['image'] or args['model_artifact_url']) and args['model_package']:
logging.error("Please specify an image AND model artifact url, OR a model package name.")
raise Exception("Could not make create model request.")
elif args['model_package']:
if args['model_package']:
request['PrimaryContainer']['ModelPackageName'] = args['model_package']
request['PrimaryContainer'].pop('Image')
request['PrimaryContainer'].pop('ModelDataUrl')
elif args['image'] and args['model_artifact_url']:
request['PrimaryContainer']['Image'] = args['image']
request['PrimaryContainer']['ModelDataUrl'] = args['model_artifact_url']
request['PrimaryContainer'].pop('ModelPackageName')
else:
if args['image'] and args['model_artifact_url']:
request['PrimaryContainer']['Image'] = args['image']
request['PrimaryContainer']['ModelDataUrl'] = args['model_artifact_url']
request['PrimaryContainer'].pop('ModelPackageName')
else:
logging.error("Please specify an image AND model artifact url.")
raise Exception("Could not make create model request.")
logging.error("Please specify an image AND model artifact url, OR a model package name.")
raise Exception("Could not make create model request.")

request['ExecutionRoleArn'] = args['role']
request['EnableNetworkIsolation'] = args['network_isolation']
Expand Down Expand Up @@ -296,6 +292,10 @@ def create_endpoint_config_request(args):
with open(os.path.join(__cwd__, 'endpoint_config.template.yaml'), 'r') as f:
request = yaml.safe_load(f)

if not args['model_name_1']:
logging.error("Must specify at least one model (model name) to host.")
raise Exception("Could not create endpoint config.")

endpoint_config_name = args['endpoint_config_name'] if args['endpoint_config_name'] else 'EndpointConfig' + args['model_name_1'][args['model_name_1'].index('-'):]
request['EndpointConfigName'] = endpoint_config_name

Expand All @@ -304,10 +304,6 @@ def create_endpoint_config_request(args):
else:
request.pop('KmsKeyId')

if not args['model_name_1']:
logging.error("Must specify at least one model (model name) to host.")
raise Exception("Could not create endpoint config.")

for i in range(len(request['ProductionVariants']), 0, -1):
if args['model_name_' + str(i)]:
request['ProductionVariants'][i-1]['ModelName'] = args['model_name_' + str(i)]
Expand Down Expand Up @@ -377,14 +373,14 @@ def wait_for_endpoint_creation(client, endpoint_name):
finally:
resp = client.describe_endpoint(EndpointName=endpoint_name)
status = resp['EndpointStatus']
logging.info("Endpoint Arn: " + resp['EndpointArn'])
logging.info("Create endpoint ended with status: " + status)

if status != 'InService':
message = client.describe_endpoint(EndpointName=endpoint_name)['FailureReason']
message = resp['FailureReason']
logging.info('Create endpoint failed with the following error: {}'.format(message))
raise Exception('Endpoint creation did not succeed')

logging.info("Endpoint Arn: " + resp['EndpointArn'])
logging.info("Create endpoint ended with status: " + status)

def create_transform_job_request(args):
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_transform_job
Expand Down Expand Up @@ -462,7 +458,7 @@ def create_transform_job(client, args):
raise Exception(e.response['Error']['Message'])


def wait_for_transform_job(client, batch_job_name):
def wait_for_transform_job(client, batch_job_name, poll_interval=30):
### Wait until the job finishes
while(True):
response = client.describe_transform_job(TransformJobName=batch_job_name)
Expand All @@ -475,7 +471,7 @@ def wait_for_transform_job(client, batch_job_name):
logging.info('Transform failed with the following error: {}'.format(message))
raise Exception('Transform job failed')
logging.info("Transform job is still in status: " + status)
time.sleep(30)
time.sleep(poll_interval)


def create_hyperparameter_tuning_job_request(args):
Expand Down Expand Up @@ -809,7 +805,7 @@ def create_labeling_job(client, args):
raise Exception(e.response['Error']['Message'])


def wait_for_labeling_job(client, labeling_job_name):
def wait_for_labeling_job(client, labeling_job_name, poll_interval=30):
### Wait until the job finishes
status = 'InProgress'
while(status == 'InProgress'):
Expand All @@ -820,7 +816,7 @@ def wait_for_labeling_job(client, labeling_job_name):
logging.info('Labeling failed with the following error: {}'.format(message))
raise Exception('Labeling job failed')
logging.info("Labeling job is still in status: " + status)
time.sleep(30)
time.sleep(poll_interval)

if status == 'Completed':
logging.info("Labeling job ended with status: " + status)
Expand Down
5 changes: 3 additions & 2 deletions components/aws/sagemaker/deploy/src/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import argparse
import logging

Expand Down Expand Up @@ -47,7 +48,7 @@ def create_parser():

def main(argv=None):
parser = create_parser()
args = parser.parse_args()
args = parser.parse_args(argv)

logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
Expand All @@ -63,4 +64,4 @@ def main(argv=None):


if __name__== "__main__":
main()
main(sys.argv[1:])
5 changes: 3 additions & 2 deletions components/aws/sagemaker/ground_truth/src/ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import argparse
import logging

Expand Down Expand Up @@ -53,7 +54,7 @@ def create_parser():

def main(argv=None):
parser = create_parser()
args = parser.parse_args()
args = parser.parse_args(argv)

logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
Expand All @@ -72,4 +73,4 @@ def main(argv=None):


if __name__== "__main__":
main()
main(sys.argv[1:])
5 changes: 3 additions & 2 deletions components/aws/sagemaker/model/src/create_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import argparse
import logging

Expand All @@ -36,7 +37,7 @@ def create_parser():

def main(argv=None):
parser = create_parser()
args = parser.parse_args()
args = parser.parse_args(argv)

logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
Expand All @@ -50,4 +51,4 @@ def main(argv=None):


if __name__== "__main__":
main()
main(sys.argv[1:])
143 changes: 133 additions & 10 deletions components/aws/sagemaker/tests/unit_tests/tests/test_batch_transform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import unittest

from unittest.mock import patch, Mock, MagicMock
from unittest.mock import patch, call, Mock, MagicMock, mock_open
from botocore.exceptions import ClientError
from datetime import datetime

Expand All @@ -19,7 +19,7 @@
'--input_location', 's3://fake-bucket/data',
'--output_location', 's3://fake-bucket/output',
'--instance_type', 'ml.c5.18xlarge',
'--output_location_file', 'tmp/'
'--output_location_file', 'tmp/output.txt'
]

class BatchTransformTestCase(unittest.TestCase):
Expand All @@ -28,12 +28,135 @@ def setUpClass(cls):
parser = batch_transform.create_parser()
cls.parser = parser

def test_sample(self):
args = self.parser.parse_args(required_args)
response = _utils.create_transform_job_request(vars(args))
self.assertEqual(response['TransformOutput']['S3OutputPath'], 's3://fake-bucket/output')

def test_empty_string(self):
args = self.parser.parse_args(required_args)
response = _utils.create_transform_job_request(vars(args))
test_utils.check_empty_string_values(response)
def test_create_parser(self):
self.assertIsNotNone(self.parser)


def test_main(self):
# Mock out all of utils except parser
batch_transform._utils = MagicMock()
batch_transform._utils.add_default_client_arguments = _utils.add_default_client_arguments

# Set some static returns
batch_transform._utils.create_transform_job.return_value = 'test-batch-job'

with patch('builtins.open', mock_open()) as file_open:
batch_transform.main(required_args)

# Check if correct requests were created and triggered
batch_transform._utils.create_transform_job.assert_called()
batch_transform._utils.wait_for_transform_job.assert_called()

# Check the file outputs
file_open.assert_has_calls([
call('tmp/output.txt', 'w')
])

file_open().write.assert_has_calls([
call('s3://fake-bucket/output')
])


def test_batch_transform(self):
mock_client = MagicMock()
mock_args = self.parser.parse_args(required_args + ['--job_name', 'test-batch-job'])
response = _utils.create_transform_job(mock_client, vars(mock_args))

mock_client.create_transform_job.assert_called_once_with(
DataProcessing={'InputFilter': '', 'OutputFilter': '', 'JoinSource': 'None'},
Environment={},
MaxConcurrentTransforms=0,
MaxPayloadInMB=6,
ModelName='model-test',
Tags=[],
TransformInput={'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://fake-bucket/data'}},
'ContentType': '', 'CompressionType': 'None', 'SplitType': 'None'},
TransformJobName='test-batch-job',
TransformOutput={'S3OutputPath': 's3://fake-bucket/output', 'Accept': None, 'KmsKeyId': ''},
TransformResources={'InstanceType': 'ml.c5.18xlarge', 'InstanceCount': None, 'VolumeKmsKeyId': ''}
)

self.assertEqual(response, 'test-batch-job')


def test_pass_all_arguments(self):
mock_client = MagicMock()
mock_args = self.parser.parse_args(required_args + [
'--job_name', 'test-batch-job',
'--max_concurrent', '5',
'--max_payload', '100',
'--batch_strategy', 'MultiRecord',
'--data_type', 'S3Prefix',
'--compression_type', 'Gzip',
'--split_type', 'RecordIO',
'--assemble_with', 'Line',
'--join_source', 'Input',
'--tags', '{"fake_key": "fake_value"}'
])
response = _utils.create_transform_job(mock_client, vars(mock_args))

mock_client.create_transform_job.assert_called_once_with(
BatchStrategy='MultiRecord',
DataProcessing={'InputFilter': '', 'OutputFilter': '',
'JoinSource': 'Input'},
Environment={},
MaxConcurrentTransforms=5,
MaxPayloadInMB=100,
ModelName='model-test',
Tags=[{'Key': 'fake_key', 'Value': 'fake_value'}],
TransformInput={
'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix',
'S3Uri': 's3://fake-bucket/data'}},
'ContentType': '',
'CompressionType': 'Gzip',
'SplitType': 'RecordIO',
},
TransformJobName='test-batch-job',
TransformOutput={
'S3OutputPath': 's3://fake-bucket/output',
'Accept': None,
'AssembleWith': 'Line',
'KmsKeyId': '',
},
TransformResources={'InstanceType': 'ml.c5.18xlarge',
'InstanceCount': None, 'VolumeKmsKeyId': ''}
)


def test_sagemaker_exception_in_batch_transform(self):
mock_client = MagicMock()
mock_exception = ClientError({"Error": {"Message": "SageMaker broke"}}, "batch_transform")
mock_client.create_transform_job.side_effect = mock_exception
mock_args = self.parser.parse_args(required_args)

with self.assertRaises(Exception):
_utils.create_transform_job(mock_client, vars(mock_args))


def test_wait_for_transform_job_creation(self):
mock_client = MagicMock()
mock_client.describe_transform_job.side_effect = [
{"TransformJobStatus": "InProgress"},
{"TransformJobStatus": "Completed"},
{"TransformJobStatus": "Should not be called"}
]

_utils.wait_for_transform_job(mock_client, 'test-batch', 0)
self.assertEqual(mock_client.describe_transform_job.call_count, 2)


def test_wait_for_failed_job(self):
mock_client = MagicMock()
mock_client.describe_transform_job.side_effect = [
{"TransformJobStatus": "InProgress"},
{"TransformJobStatus": "Failed", "FailureReason": "SYSTEM FAILURE"},
{"TransformJobStatus": "Should not be called"}
]

with self.assertRaises(Exception):
_utils.wait_for_transform_job(mock_client, 'test-batch', 0)

self.assertEqual(mock_client.describe_transform_job.call_count, 2)


Loading

0 comments on commit b503050

Please sign in to comment.