Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AWS SageMaker] Unit tests for Training component #3722

Merged
merged 6 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions components/aws/sagemaker/common/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def get_component_version():
"""Get component version from the first line of License file"""
component_version = 'NULL'

with open('/THIRD-PARTY-LICENSES.txt', 'r') as license_file:
# Get license file using known common directory
license_file_path = os.path.abspath(os.path.join(__cwd__, '../THIRD-PARTY-LICENSES.txt'))
with open(license_file_path, 'r') as license_file:
version_match = re.search('Amazon SageMaker Components for Kubeflow Pipelines; version (([0-9]+[.])+[0-9]+)',
license_file.readline())
if version_match is not None:
Expand All @@ -98,7 +100,7 @@ def create_training_job_request(args):

request['TrainingJobName'] = job_name
request['RoleArn'] = args['role']
request['HyperParameters'] = args['hyperparameters']
request['HyperParameters'] = create_hyperparameters(args['hyperparameters'])
request['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode']

### Update training image (for BYOC and built-in algorithms) or algorithm resource name
Expand Down Expand Up @@ -136,8 +138,8 @@ def create_training_job_request(args):

### Update or pop VPC configs
if args['vpc_security_group_ids'] and args['vpc_subnets']:
request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
request['VpcConfig']['Subnets'] = [args['vpc_subnets']]
request['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',')
request['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',')
else:
request.pop('VpcConfig')

Expand Down Expand Up @@ -190,7 +192,7 @@ def create_training_job(client, args):
raise Exception(e.response['Error']['Message'])


def wait_for_training_job(client, training_job_name):
def wait_for_training_job(client, training_job_name, poll_interval=30):
while(True):
response = client.describe_training_job(TrainingJobName=training_job_name)
status = response['TrainingJobStatus']
Expand All @@ -202,7 +204,7 @@ def wait_for_training_job(client, training_job_name):
logging.info('Training failed with the following error: {}'.format(message))
raise Exception('Training job failed')
logging.info("Training job is still in status: " + status)
time.sleep(30)
time.sleep(poll_interval)


def get_model_artifacts_from_job(client, job_name):
Expand All @@ -213,9 +215,9 @@ def get_model_artifacts_from_job(client, job_name):

def get_image_from_job(client, job_name):
info = client.describe_training_job(TrainingJobName=job_name)
try:
if 'TrainingImage' in info['AlgorithmSpecification']:
RedbackThomson marked this conversation as resolved.
Show resolved Hide resolved
image = info['AlgorithmSpecification']['TrainingImage']
except:
else:
algorithm_name = info['AlgorithmSpecification']['AlgorithmName']
image = client.describe_algorithm(AlgorithmName=algorithm_name)['TrainingSpecification']['TrainingImage']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, but do check if the README or docs need any updates. We might have mentioned this as an unhandled case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The README actually claims that the image name is optional. So this was an outstanding bug.


Expand Down Expand Up @@ -272,8 +274,8 @@ def create_model_request(args):

### Update or pop VPC configs
if args['vpc_security_group_ids'] and args['vpc_subnets']:
request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
request['VpcConfig']['Subnets'] = [args['vpc_subnets']]
request['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',')
request['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',')
else:
request.pop('VpcConfig')

Expand Down Expand Up @@ -493,7 +495,7 @@ def create_hyperparameter_tuning_job_request(args):
request['HyperParameterTuningJobConfig']['ParameterRanges']['CategoricalParameterRanges'] = args['categorical_parameters']
request['HyperParameterTuningJobConfig']['TrainingJobEarlyStoppingType'] = args['early_stopping_type']

request['TrainingJobDefinition']['StaticHyperParameters'] = args['static_parameters']
request['TrainingJobDefinition']['StaticHyperParameters'] = create_hyperparameters(args['static_parameters'])
request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode']

### Update training image (for BYOC) or algorithm resource name
Expand Down Expand Up @@ -531,8 +533,8 @@ def create_hyperparameter_tuning_job_request(args):

### Update or pop VPC configs
if args['vpc_security_group_ids'] and args['vpc_subnets']:
request['TrainingJobDefinition']['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
request['TrainingJobDefinition']['VpcConfig']['Subnets'] = [args['vpc_subnets']]
request['TrainingJobDefinition']['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',')
request['TrainingJobDefinition']['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',')
else:
request['TrainingJobDefinition'].pop('VpcConfig')

Expand Down Expand Up @@ -835,6 +837,14 @@ def get_labeling_job_outputs(client, labeling_job_name, auto_labeling):
active_learning_model_arn = ' '
return output_manifest, active_learning_model_arn

def create_hyperparameters(hyperparam_args):
# Validate all values are strings
for key, value in hyperparam_args.items():
if not isinstance(value, str):
raise Exception(f"Could not parse hyperparameters. Value for {key} was not a string.")

return hyperparam_args

def enable_spot_instance_support(training_job_config, args):
if args['spot_instance']:
training_job_config['EnableManagedSpotTraining'] = args['spot_instance']
Expand Down
4 changes: 2 additions & 2 deletions components/aws/sagemaker/hyperparameter_tuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ max_num_jobs | The maximum number of training jobs that a hyperparameter tuning
max_parallel_jobs | The maximum number of concurrent training jobs that a hyperparameter tuning job can launch | No | No | Int | [1, 10] | |
max_run_time | The maximum run time in seconds per training job | Yes | Yes | Int | ≤ 432000 (5 days) | 86400 (1 day) |
resource_encryption_key | The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) | Yes | Yes | String | | |
vpc_security_group_ids | The VPC security group IDs, in the form sg-xxxxxxxx | Yes | Yes | String | | |
vpc_subnets | The ID of the subnets in the VPC to which you want to connect your hpo job | Yes | Yes | String | | |
vpc_security_group_ids | A comma-delimited list of security group IDs, in the form sg-xxxxxxxx | Yes | Yes | String | | |
vpc_subnets | A comma-delimited list of subnet IDs in the VPC to which you want to connect your hpo job | Yes | Yes | String | | |
network_isolation | Isolates the training container if true | Yes | No | Boolean | False, True | True |
traffic_encryption | Encrypts all communications between ML compute instances in distributed training if true | Yes | No | Boolean | False, True | False |
spot_instance | Use managed spot training if true | Yes | No | Boolean | False, True | False |
Expand Down
2 changes: 2 additions & 0 deletions components/aws/sagemaker/tests/unit_tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
addopts = -rA
2 changes: 1 addition & 1 deletion components/aws/sagemaker/tests/unit_tests/run_all_tests.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export PYTHONPATH=../../

coverage run -m pytest --ignore=tests/test_utils.py --junitxml ./unit_tests.log
coverage report --omit "*/usr/*,tests/*,*__init__*,*/Python/*"
coverage report -m --omit "*/usr/*,tests/*,*__init__*,*/Python/*"
Loading