Skip to content

Commit

Permalink
[AWS SageMaker] Unit tests for Training component (#3722)
Browse files Browse the repository at this point in the history
* Added additional training unit tests

* Add main training function tests

* Add full training test coverage

* Fix import sys

* Fix poorly named test
  • Loading branch information
RedbackThomson committed May 13, 2020
1 parent d418f57 commit ddd1969
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 26 deletions.
36 changes: 23 additions & 13 deletions components/aws/sagemaker/common/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def get_component_version():
"""Get component version from the first line of License file"""
component_version = 'NULL'

with open('THIRD-PARTY-LICENSES.txt', 'r') as license_file:
# Get license file using known common directory
license_file_path = os.path.abspath(os.path.join(__cwd__, '../THIRD-PARTY-LICENSES.txt'))
with open(license_file_path, 'r') as license_file:
version_match = re.search('Amazon SageMaker Components for Kubeflow Pipelines; version (([0-9]+[.])+[0-9]+)',
license_file.readline())
if version_match is not None:
Expand All @@ -99,7 +101,7 @@ def create_training_job_request(args):

request['TrainingJobName'] = job_name
request['RoleArn'] = args['role']
request['HyperParameters'] = args['hyperparameters']
request['HyperParameters'] = create_hyperparameters(args['hyperparameters'])
request['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode']

### Update training image (for BYOC and built-in algorithms) or algorithm resource name
Expand Down Expand Up @@ -137,8 +139,8 @@ def create_training_job_request(args):

### Update or pop VPC configs
if args['vpc_security_group_ids'] and args['vpc_subnets']:
request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
request['VpcConfig']['Subnets'] = [args['vpc_subnets']]
request['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',')
request['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',')
else:
request.pop('VpcConfig')

Expand Down Expand Up @@ -191,7 +193,7 @@ def create_training_job(client, args):
raise Exception(e.response['Error']['Message'])


def wait_for_training_job(client, training_job_name):
def wait_for_training_job(client, training_job_name, poll_interval=30):
while(True):
response = client.describe_training_job(TrainingJobName=training_job_name)
status = response['TrainingJobStatus']
Expand All @@ -203,7 +205,7 @@ def wait_for_training_job(client, training_job_name):
logging.info('Training failed with the following error: {}'.format(message))
raise Exception('Training job failed')
logging.info("Training job is still in status: " + status)
time.sleep(30)
time.sleep(poll_interval)


def get_model_artifacts_from_job(client, job_name):
Expand All @@ -214,9 +216,9 @@ def get_model_artifacts_from_job(client, job_name):

def get_image_from_job(client, job_name):
info = client.describe_training_job(TrainingJobName=job_name)
try:
if 'TrainingImage' in info['AlgorithmSpecification']:
image = info['AlgorithmSpecification']['TrainingImage']
except:
else:
algorithm_name = info['AlgorithmSpecification']['AlgorithmName']
image = client.describe_algorithm(AlgorithmName=algorithm_name)['TrainingSpecification']['TrainingImage']

Expand Down Expand Up @@ -273,8 +275,8 @@ def create_model_request(args):

### Update or pop VPC configs
if args['vpc_security_group_ids'] and args['vpc_subnets']:
request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
request['VpcConfig']['Subnets'] = [args['vpc_subnets']]
request['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',')
request['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',')
else:
request.pop('VpcConfig')

Expand Down Expand Up @@ -494,7 +496,7 @@ def create_hyperparameter_tuning_job_request(args):
request['HyperParameterTuningJobConfig']['ParameterRanges']['CategoricalParameterRanges'] = args['categorical_parameters']
request['HyperParameterTuningJobConfig']['TrainingJobEarlyStoppingType'] = args['early_stopping_type']

request['TrainingJobDefinition']['StaticHyperParameters'] = args['static_parameters']
request['TrainingJobDefinition']['StaticHyperParameters'] = create_hyperparameters(args['static_parameters'])
request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode']

### Update training image (for BYOC) or algorithm resource name
Expand Down Expand Up @@ -532,8 +534,8 @@ def create_hyperparameter_tuning_job_request(args):

### Update or pop VPC configs
if args['vpc_security_group_ids'] and args['vpc_subnets']:
request['TrainingJobDefinition']['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
request['TrainingJobDefinition']['VpcConfig']['Subnets'] = [args['vpc_subnets']]
request['TrainingJobDefinition']['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',')
request['TrainingJobDefinition']['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',')
else:
request['TrainingJobDefinition'].pop('VpcConfig')

Expand Down Expand Up @@ -836,6 +838,14 @@ def get_labeling_job_outputs(client, labeling_job_name, auto_labeling):
active_learning_model_arn = ' '
return output_manifest, active_learning_model_arn

def create_hyperparameters(hyperparam_args):
# Validate all values are strings
for key, value in hyperparam_args.items():
if not isinstance(value, str):
raise Exception(f"Could not parse hyperparameters. Value for {key} was not a string.")

return hyperparam_args

def enable_spot_instance_support(training_job_config, args):
if args['spot_instance']:
training_job_config['EnableManagedSpotTraining'] = args['spot_instance']
Expand Down
4 changes: 2 additions & 2 deletions components/aws/sagemaker/hyperparameter_tuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ max_num_jobs | The maximum number of training jobs that a hyperparameter tuning
max_parallel_jobs | The maximum number of concurrent training jobs that a hyperparameter tuning job can launch | No | No | Int | [1, 10] | |
max_run_time | The maximum run time in seconds per training job | Yes | Yes | Int | ≤ 432000 (5 days) | 86400 (1 day) |
resource_encryption_key | The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) | Yes | Yes | String | | |
vpc_security_group_ids | The VPC security group IDs, in the form sg-xxxxxxxx | Yes | Yes | String | | |
vpc_subnets | The ID of the subnets in the VPC to which you want to connect your hpo job | Yes | Yes | String | | |
vpc_security_group_ids | A comma-delimited list of security group IDs, in the form sg-xxxxxxxx | Yes | Yes | String | | |
vpc_subnets | A comma-delimited list of subnet IDs in the VPC to which you want to connect your hpo job | Yes | Yes | String | | |
network_isolation | Isolates the training container if true | Yes | No | Boolean | False, True | True |
traffic_encryption | Encrypts all communications between ML compute instances in distributed training if true | Yes | No | Boolean | False, True | False |
spot_instance | Use managed spot training if true | Yes | No | Boolean | False, True | False |
Expand Down
2 changes: 2 additions & 0 deletions components/aws/sagemaker/tests/unit_tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
addopts = -rA
2 changes: 1 addition & 1 deletion components/aws/sagemaker/tests/unit_tests/run_all_tests.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export PYTHONPATH=../../

coverage run -m pytest --ignore=tests/test_utils.py --junitxml ./unit_tests.log
coverage report --omit "*/usr/*,tests/*,*__init__*,*/Python/*"
coverage report -m --omit "*/usr/*,tests/*,*__init__*,*/Python/*"
Loading

0 comments on commit ddd1969

Please sign in to comment.