Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AWS SageMaker] Unit tests for Training component #3722

Merged
merged 6 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added additional training unit tests
  • Loading branch information
Nicholas Thomson committed May 7, 2020
commit 27fcce524d2b15bffb2bcb74a1dedc60c970bfb6
24 changes: 16 additions & 8 deletions components/aws/sagemaker/common/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def create_training_job_request(args):

request['TrainingJobName'] = job_name
request['RoleArn'] = args['role']
request['HyperParameters'] = args['hyperparameters']
request['HyperParameters'] = create_hyperparameters(args['hyperparameters'])
request['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode']

### Update training image (for BYOC and built-in algorithms) or algorithm resource name
Expand Down Expand Up @@ -136,8 +136,8 @@ def create_training_job_request(args):

### Update or pop VPC configs
if args['vpc_security_group_ids'] and args['vpc_subnets']:
request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
request['VpcConfig']['Subnets'] = [args['vpc_subnets']]
request['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',')
request['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',')
else:
request.pop('VpcConfig')

Expand Down Expand Up @@ -272,8 +272,8 @@ def create_model_request(args):

### Update or pop VPC configs
if args['vpc_security_group_ids'] and args['vpc_subnets']:
request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
request['VpcConfig']['Subnets'] = [args['vpc_subnets']]
request['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',')
request['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',')
else:
request.pop('VpcConfig')

Expand Down Expand Up @@ -493,7 +493,7 @@ def create_hyperparameter_tuning_job_request(args):
request['HyperParameterTuningJobConfig']['ParameterRanges']['CategoricalParameterRanges'] = args['categorical_parameters']
request['HyperParameterTuningJobConfig']['TrainingJobEarlyStoppingType'] = args['early_stopping_type']

request['TrainingJobDefinition']['StaticHyperParameters'] = args['static_parameters']
request['TrainingJobDefinition']['StaticHyperParameters'] = create_hyperparameters(args['static_parameters'])
request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode']

### Update training image (for BYOC) or algorithm resource name
Expand Down Expand Up @@ -531,8 +531,8 @@ def create_hyperparameter_tuning_job_request(args):

### Update or pop VPC configs
if args['vpc_security_group_ids'] and args['vpc_subnets']:
request['TrainingJobDefinition']['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
request['TrainingJobDefinition']['VpcConfig']['Subnets'] = [args['vpc_subnets']]
request['TrainingJobDefinition']['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',')
request['TrainingJobDefinition']['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',')
else:
request['TrainingJobDefinition'].pop('VpcConfig')

Expand Down Expand Up @@ -835,6 +835,14 @@ def get_labeling_job_outputs(client, labeling_job_name, auto_labeling):
active_learning_model_arn = ' '
return output_manifest, active_learning_model_arn

def create_hyperparameters(hyperparam_args):
# Validate all values are strings
for key, value in hyperparam_args.items():
if not isinstance(value, str):
raise Exception(f"Could not parse hyperparameters. Value for {key} was not a string.")

return hyperparam_args

def enable_spot_instance_support(training_job_config, args):
if args['spot_instance']:
training_job_config['EnableManagedSpotTraining'] = args['spot_instance']
Expand Down
4 changes: 2 additions & 2 deletions components/aws/sagemaker/hyperparameter_tuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ max_num_jobs | The maximum number of training jobs that a hyperparameter tuning
max_parallel_jobs | The maximum number of concurrent training jobs that a hyperparameter tuning job can launch | No | No | Int | [1, 10] | |
max_run_time | The maximum run time in seconds per training job | Yes | Yes | Int | ≤ 432000 (5 days) | 86400 (1 day) |
resource_encryption_key | The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) | Yes | Yes | String | | |
vpc_security_group_ids | The VPC security group IDs, in the form sg-xxxxxxxx | Yes | Yes | String | | |
vpc_subnets | The ID of the subnets in the VPC to which you want to connect your hpo job | Yes | Yes | String | | |
vpc_security_group_ids | A comma-delimited list of security group IDs, in the form sg-xxxxxxxx | Yes | Yes | String | | |
vpc_subnets | A comma-delimited list of subnet IDs in the VPC to which you want to connect your hpo job | Yes | Yes | String | | |
network_isolation | Isolates the training container if true | Yes | No | Boolean | False, True | True |
traffic_encryption | Encrypts all communications between ML compute instances in distributed training if true | Yes | No | Boolean | False, True | False |
spot_instance | Use managed spot training if true | Yes | No | Boolean | False, True | False |
Expand Down
80 changes: 80 additions & 0 deletions components/aws/sagemaker/tests/unit_tests/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,86 @@ def setUpClass(cls):
parser = train.create_parser()
cls.parser = parser

def test_create_parser(self):
self.assertIsNotNone(self.parser)

def test_reasonable_required_args(self):
response = _utils.create_training_job_request(vars(self.parser.parse_args(required_args)))

# Ensure all of the optional arguments have reasonable default values
self.assertFalse(response['EnableManagedSpotTraining'])
self.assertDictEqual(response['HyperParameters'], {})
self.assertNotIn('VpcConfig', response)
self.assertNotIn('MetricDefinitions', response)
self.assertEqual(response['Tags'], [])
self.assertEqual(response['AlgorithmSpecification']['TrainingInputMode'], 'File')
self.assertEqual(response['OutputDataConfig']['S3OutputPath'], 'test-path')

def test_metric_definitions(self):
metric_definition_args = self.parser.parse_args(required_args + ['--metric_definitions', '{"metric1": "regexval1", "metric2": "regexval2"}'])
response = _utils.create_training_job_request(vars(metric_definition_args))

self.assertIn('MetricDefinitions', response['AlgorithmSpecification'])
response_metric_definitions = response['AlgorithmSpecification']['MetricDefinitions']

self.assertEqual(response_metric_definitions, [{
'Name': "metric1",
'Regex': "regexval1"
}, {
'Name': "metric2",
'Regex': "regexval2"
}])

def test_invalid_instance_type(self):
invalid_instance_args = required_args + ['--instance_type', 'invalid-instance']

with self.assertRaises(SystemExit):
self.parser.parse_args(invalid_instance_args)

def test_valid_hyperparameters(self):
hyperparameters_str = '{"hp1": "val1", "hp2": "val2", "hp3": "val3"}'

good_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str])
response = _utils.create_training_job_request(vars(good_args))

self.assertIn('hp1', response['HyperParameters'])
self.assertIn('hp2', response['HyperParameters'])
self.assertIn('hp3', response['HyperParameters'])
self.assertEqual(response['HyperParameters']['hp1'], "val1")
self.assertEqual(response['HyperParameters']['hp2'], "val2")
self.assertEqual(response['HyperParameters']['hp3'], "val3")

def test_empty_hyperparameters(self):
hyperparameters_str = '{}'

good_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str])
response = _utils.create_training_job_request(vars(good_args))

self.assertEqual(response['HyperParameters'], {})

def test_object_hyperparameters(self):
hyperparameters_str = '{"hp1": {"innerkey": "innerval"}}'

invalid_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str])
with self.assertRaises(Exception):
_utils.create_training_job_request(vars(invalid_args))

def test_vpc_configuration(self):
required_vpc_args = self.parser.parse_args(required_args + ['--vpc_security_group_ids', 'sg1,sg2', '--vpc_subnets', 'subnet1,subnet2'])
response = _utils.create_training_job_request(vars(required_vpc_args))

self.assertIn('VpcConfig', response)
self.assertIn('sg1', response['VpcConfig']['SecurityGroupIds'])
self.assertIn('sg2', response['VpcConfig']['SecurityGroupIds'])
self.assertIn('subnet1', response['VpcConfig']['Subnets'])
self.assertIn('subnet2', response['VpcConfig']['Subnets'])

def test_training_mode(self):
required_vpc_args = self.parser.parse_args(required_args + ['--training_input_mode', 'Pipe'])
response = _utils.create_training_job_request(vars(required_vpc_args))

self.assertEqual(response['AlgorithmSpecification']['TrainingInputMode'], 'Pipe')

def test_spot_bad_args(self):
no_max_wait_args = self.parser.parse_args(required_args + ['--spot_instance', 'True'])
no_checkpoint_args = self.parser.parse_args(required_args + ['--spot_instance', 'True', '--max_wait_time', '3600'])
Expand Down
4 changes: 2 additions & 2 deletions components/aws/sagemaker/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ resource_encryption_key | The AWS KMS key that Amazon SageMaker uses to encrypt
max_run_time | The maximum run time in seconds per training job | Yes | Int | ≤ 432000 (5 days) | 86400 (1 day) |
model_artifact_path | | No | String | | |
output_encryption_key | The AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts | Yes | String | | |
vpc_security_group_ids | The VPC security group IDs, in the form sg-xxxxxxxx | Yes | String | | |
vpc_subnets | The ID of the subnets in the VPC to which you want to connect your hpo job | Yes | String | | |
vpc_security_group_ids | A comma-delimited list of security group IDs, in the form sg-xxxxxxxx | Yes | String | | |
vpc_subnets | A comma-delimited list of subnet IDs in the VPC to which you want to connect your hpo job | Yes | String | | |
network_isolation | Isolates the training container if true | No | Boolean | False, True | True |
traffic_encryption | Encrypts all communications between ML compute instances in distributed training if true | No | Boolean | False, True | False |
spot_instance | Use managed spot training if true | No | Boolean | False, True | False |
Expand Down