diff --git a/components/aws/sagemaker/common/_utils.py b/components/aws/sagemaker/common/_utils.py index 69ae36748ae..71d9ffe47ab 100644 --- a/components/aws/sagemaker/common/_utils.py +++ b/components/aws/sagemaker/common/_utils.py @@ -72,7 +72,9 @@ def get_component_version(): """Get component version from the first line of License file""" component_version = 'NULL' - with open('THIRD-PARTY-LICENSES.txt', 'r') as license_file: + # Get license file using known common directory + license_file_path = os.path.abspath(os.path.join(__cwd__, '../THIRD-PARTY-LICENSES.txt')) + with open(license_file_path, 'r') as license_file: version_match = re.search('Amazon SageMaker Components for Kubeflow Pipelines; version (([0-9]+[.])+[0-9]+)', license_file.readline()) if version_match is not None: @@ -99,7 +101,7 @@ def create_training_job_request(args): request['TrainingJobName'] = job_name request['RoleArn'] = args['role'] - request['HyperParameters'] = args['hyperparameters'] + request['HyperParameters'] = create_hyperparameters(args['hyperparameters']) request['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode'] ### Update training image (for BYOC and built-in algorithms) or algorithm resource name @@ -137,8 +139,8 @@ def create_training_job_request(args): ### Update or pop VPC configs if args['vpc_security_group_ids'] and args['vpc_subnets']: - request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']] - request['VpcConfig']['Subnets'] = [args['vpc_subnets']] + request['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',') + request['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',') else: request.pop('VpcConfig') @@ -191,7 +193,7 @@ def create_training_job(client, args): raise Exception(e.response['Error']['Message']) -def wait_for_training_job(client, training_job_name): +def wait_for_training_job(client, training_job_name, poll_interval=30): while(True): response = client.describe_training_job(TrainingJobName=training_job_name) status = response['TrainingJobStatus'] @@ -203,7 +205,7 @@ def wait_for_training_job(client, training_job_name): logging.info('Training failed with the following error: {}'.format(message)) raise Exception('Training job failed') logging.info("Training job is still in status: " + status) - time.sleep(30) + time.sleep(poll_interval) def get_model_artifacts_from_job(client, job_name): @@ -214,9 +216,9 @@ def get_model_artifacts_from_job(client, job_name): def get_image_from_job(client, job_name): info = client.describe_training_job(TrainingJobName=job_name) - try: + if 'TrainingImage' in info['AlgorithmSpecification']: image = info['AlgorithmSpecification']['TrainingImage'] - except: + else: algorithm_name = info['AlgorithmSpecification']['AlgorithmName'] image = client.describe_algorithm(AlgorithmName=algorithm_name)['TrainingSpecification']['TrainingImage'] @@ -273,8 +275,8 @@ def create_model_request(args): ### Update or pop VPC configs if args['vpc_security_group_ids'] and args['vpc_subnets']: - request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']] - request['VpcConfig']['Subnets'] = [args['vpc_subnets']] + request['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',') + request['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',') else: request.pop('VpcConfig') @@ -494,7 +496,7 @@ def create_hyperparameter_tuning_job_request(args): request['HyperParameterTuningJobConfig']['ParameterRanges']['CategoricalParameterRanges'] = args['categorical_parameters'] request['HyperParameterTuningJobConfig']['TrainingJobEarlyStoppingType'] = args['early_stopping_type'] - request['TrainingJobDefinition']['StaticHyperParameters'] = args['static_parameters'] + request['TrainingJobDefinition']['StaticHyperParameters'] = create_hyperparameters(args['static_parameters']) request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode'] ### Update training image (for BYOC) or algorithm resource name @@ -532,8 +534,8 @@ def create_hyperparameter_tuning_job_request(args): ### Update or pop VPC configs if args['vpc_security_group_ids'] and args['vpc_subnets']: - request['TrainingJobDefinition']['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']] - request['TrainingJobDefinition']['VpcConfig']['Subnets'] = [args['vpc_subnets']] + request['TrainingJobDefinition']['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',') + request['TrainingJobDefinition']['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',') else: request['TrainingJobDefinition'].pop('VpcConfig') @@ -836,6 +838,14 @@ def get_labeling_job_outputs(client, labeling_job_name, auto_labeling): active_learning_model_arn = ' ' return output_manifest, active_learning_model_arn +def create_hyperparameters(hyperparam_args): + # Validate all values are strings + for key, value in hyperparam_args.items(): + if not isinstance(value, str): + raise Exception(f"Could not parse hyperparameters. Value for {key} was not a string.") + + return hyperparam_args + def enable_spot_instance_support(training_job_config, args): if args['spot_instance']: training_job_config['EnableManagedSpotTraining'] = args['spot_instance'] diff --git a/components/aws/sagemaker/hyperparameter_tuning/README.md b/components/aws/sagemaker/hyperparameter_tuning/README.md index 9a6a0090063..8f719b10a89 100644 --- a/components/aws/sagemaker/hyperparameter_tuning/README.md +++ b/components/aws/sagemaker/hyperparameter_tuning/README.md @@ -35,8 +35,8 @@ max_num_jobs | The maximum number of training jobs that a hyperparameter tuning max_parallel_jobs | The maximum number of concurrent training jobs that a hyperparameter tuning job can launch | No | No | Int | [1, 10] | | max_run_time | The maximum run time in seconds per training job | Yes | Yes | Int | ≤ 432000 (5 days) | 86400 (1 day) | resource_encryption_key | The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) | Yes | Yes | String | | | -vpc_security_group_ids | The VPC security group IDs, in the form sg-xxxxxxxx | Yes | Yes | String | | | -vpc_subnets | The ID of the subnets in the VPC to which you want to connect your hpo job | Yes | Yes | String | | | +vpc_security_group_ids | A comma-delimited list of security group IDs, in the form sg-xxxxxxxx | Yes | Yes | String | | | +vpc_subnets | A comma-delimited list of subnet IDs in the VPC to which you want to connect your hpo job | Yes | Yes | String | | | network_isolation | Isolates the training container if true | Yes | No | Boolean | False, True | True | traffic_encryption | Encrypts all communications between ML compute instances in distributed training if true | Yes | No | Boolean | False, True | False | spot_instance | Use managed spot training if true | Yes | No | Boolean | False, True | False | diff --git a/components/aws/sagemaker/tests/unit_tests/pytest.ini b/components/aws/sagemaker/tests/unit_tests/pytest.ini new file mode 100644 index 00000000000..77749dcfd2b --- /dev/null +++ b/components/aws/sagemaker/tests/unit_tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = -rA \ No newline at end of file diff --git a/components/aws/sagemaker/tests/unit_tests/run_all_tests.sh b/components/aws/sagemaker/tests/unit_tests/run_all_tests.sh index 661d27da7c6..756716e0475 100755 --- a/components/aws/sagemaker/tests/unit_tests/run_all_tests.sh +++ b/components/aws/sagemaker/tests/unit_tests/run_all_tests.sh @@ -1,4 +1,4 @@ export PYTHONPATH=../../ coverage run -m pytest --ignore=tests/test_utils.py --junitxml ./unit_tests.log -coverage report --omit "*/usr/*,tests/*,*__init__*,*/Python/*" \ No newline at end of file +coverage report -m --omit "*/usr/*,tests/*,*__init__*,*/Python/*" \ No newline at end of file diff --git a/components/aws/sagemaker/tests/unit_tests/tests/test_train.py b/components/aws/sagemaker/tests/unit_tests/tests/test_train.py index 46389fca1f8..1993c72d248 100644 --- a/components/aws/sagemaker/tests/unit_tests/tests/test_train.py +++ b/components/aws/sagemaker/tests/unit_tests/tests/test_train.py @@ -1,7 +1,7 @@ import json import unittest -from unittest.mock import patch, Mock, MagicMock +from unittest.mock import patch, call, Mock, MagicMock, mock_open from botocore.exceptions import ClientError from datetime import datetime @@ -27,6 +27,272 @@ def setUpClass(cls): parser = train.create_parser() cls.parser = parser + def test_create_parser(self): + self.assertIsNotNone(self.parser) + + def test_main(self): + # Mock out all of utils except parser + train._utils = MagicMock() + train._utils.add_default_client_arguments = _utils.add_default_client_arguments + + # Set some static returns + train._utils.create_training_job.return_value = 'job-name' + train._utils.get_image_from_job.return_value = 'training-image' + train._utils.get_model_artifacts_from_job.return_value = 'model-artifacts' + + with patch('builtins.open', mock_open()) as file_open: + train.main(required_args) + + # Check if correct requests were created and triggered + train._utils.create_training_job.assert_called() + train._utils.wait_for_training_job.assert_called() + + # Check the file outputs + file_open.assert_has_calls([ + call('/tmp/model_artifact_url.txt', 'w'), + call('/tmp/job_name.txt', 'w'), + call('/tmp/training_image.txt', 'w') + ], any_order=True) + + file_open().write.assert_has_calls([ + call('model-artifacts'), + call('job-name'), + call('training-image'), + ], any_order=False) # Must be in the same order as called + + def test_create_training_job(self): + mock_client = MagicMock() + mock_args = self.parser.parse_args(required_args + ['--job_name', 'test-job']) + response = _utils.create_training_job(mock_client, vars(mock_args)) + + mock_client.create_training_job.assert_called_once_with(AlgorithmSpecification={'TrainingImage': 'test-image', + 'TrainingInputMode': 'File'}, EnableInterContainerTrafficEncryption=False, EnableManagedSpotTraining=False, + EnableNetworkIsolation=True, HyperParameters={}, InputDataConfig=[{'ChannelName': 'train', 'DataSource': + {'S3DataSource': {'S3Uri': 's3://fake-bucket/data', 'S3DataType': 'S3Prefix', 'S3DataDistributionType': + 'FullyReplicated'}}, 'ContentType': '', 'CompressionType': 'None', 'RecordWrapperType': 'None', 'InputMode': + 'File'}], OutputDataConfig={'KmsKeyId': '', 'S3OutputPath': 'test-path'}, ResourceConfig={'InstanceType': + 'ml.m4.xlarge', 'InstanceCount': 1, 'VolumeSizeInGB': 50, 'VolumeKmsKeyId': ''}, + RoleArn='arn:aws:iam::123456789012:user/Development/product_1234/*', StoppingCondition={'MaxRuntimeInSeconds': + 3600}, Tags=[], TrainingJobName='test-job') + self.assertEqual(response, 'test-job') + + def test_sagemaker_exception_in_create_training_job(self): + mock_client = MagicMock() + mock_exception = ClientError({"Error": {"Message": "SageMaker broke"}}, "create_training_job") + mock_client.create_training_job.side_effect = mock_exception + mock_args = self.parser.parse_args(required_args) + + with self.assertRaises(Exception): + response = _utils.create_training_job(mock_client, vars(mock_args)) + + def test_wait_for_training_job(self): + mock_client = MagicMock() + mock_client.describe_training_job.side_effect = [ + {"TrainingJobStatus": "Starting"}, + {"TrainingJobStatus": "InProgress"}, + {"TrainingJobStatus": "Downloading"}, + {"TrainingJobStatus": "Completed"}, + {"TrainingJobStatus": "Should not be called"} + ] + + _utils.wait_for_training_job(mock_client, 'training-job', 0) + self.assertEqual(mock_client.describe_training_job.call_count, 4) + + def test_wait_for_failed_job(self): + mock_client = MagicMock() + mock_client.describe_training_job.side_effect = [ + {"TrainingJobStatus": "Starting"}, + {"TrainingJobStatus": "InProgress"}, + {"TrainingJobStatus": "Downloading"}, + {"TrainingJobStatus": "Failed", "FailureReason": "Something broke lol"}, + {"TrainingJobStatus": "Should not be called"} + ] + + with self.assertRaises(Exception): + _utils.wait_for_training_job(mock_client, 'training-job', 0) + + self.assertEqual(mock_client.describe_training_job.call_count, 4) + + def test_get_model_artifacts_from_job(self): + mock_client = MagicMock() + mock_client.describe_training_job.return_value = {"ModelArtifacts": {"S3ModelArtifacts": "s3://path/"}} + + self.assertEqual(_utils.get_model_artifacts_from_job(mock_client, 'training-job'), 's3://path/') + + def test_get_image_from_defined_job(self): + mock_client = MagicMock() + mock_client.describe_training_job.return_value = {"AlgorithmSpecification": {"TrainingImage": "training-image-url"}} + + self.assertEqual(_utils.get_image_from_job(mock_client, 'training-job'), "training-image-url") + + def test_get_image_from_algorithm_job(self): + mock_client = MagicMock() + mock_client.describe_training_job.return_value = {"AlgorithmSpecification": {"AlgorithmName": "my-algorithm"}} + mock_client.describe_algorithm.return_value = {"TrainingSpecification": {"TrainingImage": "training-image-url"}} + + self.assertEqual(_utils.get_image_from_job(mock_client, 'training-job'), "training-image-url") + + def test_reasonable_required_args(self): + response = _utils.create_training_job_request(vars(self.parser.parse_args(required_args))) + + # Ensure all of the optional arguments have reasonable default values + self.assertFalse(response['EnableManagedSpotTraining']) + self.assertDictEqual(response['HyperParameters'], {}) + self.assertNotIn('VpcConfig', response) + self.assertNotIn('MetricDefinitions', response) + self.assertEqual(response['Tags'], []) + self.assertEqual(response['AlgorithmSpecification']['TrainingInputMode'], 'File') + self.assertEqual(response['OutputDataConfig']['S3OutputPath'], 'test-path') + + def test_metric_definitions(self): + metric_definition_args = self.parser.parse_args(required_args + ['--metric_definitions', '{"metric1": "regexval1", "metric2": "regexval2"}']) + response = _utils.create_training_job_request(vars(metric_definition_args)) + + self.assertIn('MetricDefinitions', response['AlgorithmSpecification']) + response_metric_definitions = response['AlgorithmSpecification']['MetricDefinitions'] + + self.assertEqual(response_metric_definitions, [{ + 'Name': "metric1", + 'Regex': "regexval1" + }, { + 'Name': "metric2", + 'Regex': "regexval2" + }]) + + def test_no_defined_image(self): + # Pass the image to pass the parser + no_image_args = required_args.copy() + image_index = no_image_args.index('--image') + # Cut out --image and it's associated value + no_image_args = no_image_args[:image_index] + no_image_args[image_index+2:] + + parsed_args = self.parser.parse_args(no_image_args) + + with self.assertRaises(Exception): + _utils.create_training_job_request(vars(parsed_args)) + + def test_first_party_algorithm(self): + algorithm_name_args = self.parser.parse_args(required_args + ['--algorithm_name', 'first-algorithm']) + + # Should not throw an exception + response = _utils.create_training_job_request(vars(algorithm_name_args)) + self.assertIn('TrainingImage', response['AlgorithmSpecification']) + self.assertNotIn('AlgorithmName', response['AlgorithmSpecification']) + + def test_known_algorithm_key(self): + # This passes an algorithm that is a known NAME of an algorithm + known_algorithm_args = required_args + ['--algorithm_name', 'seq2seq modeling'] + image_index = required_args.index('--image') + # Cut out --image and it's associated value + known_algorithm_args = known_algorithm_args[:image_index] + known_algorithm_args[image_index+2:] + + parsed_args = self.parser.parse_args(known_algorithm_args) + + # Patch get_image_uri + _utils.get_image_uri = MagicMock() + _utils.get_image_uri.return_value = "seq2seq-url" + + response = _utils.create_training_job_request(vars(parsed_args)) + + _utils.get_image_uri.assert_called_with('us-west-2', 'seq2seq') + self.assertEqual(response['AlgorithmSpecification']['TrainingImage'], "seq2seq-url") + + def test_known_algorithm_value(self): + # This passes an algorithm that is a known SageMaker algorithm name + known_algorithm_args = required_args + ['--algorithm_name', 'seq2seq'] + image_index = required_args.index('--image') + # Cut out --image and it's associated value + known_algorithm_args = known_algorithm_args[:image_index] + known_algorithm_args[image_index+2:] + + parsed_args = self.parser.parse_args(known_algorithm_args) + + # Patch get_image_uri + _utils.get_image_uri = MagicMock() + _utils.get_image_uri.return_value = "seq2seq-url" + + response = _utils.create_training_job_request(vars(parsed_args)) + + _utils.get_image_uri.assert_called_with('us-west-2', 'seq2seq') + self.assertEqual(response['AlgorithmSpecification']['TrainingImage'], "seq2seq-url") + + def test_unknown_algorithm(self): + known_algorithm_args = required_args + ['--algorithm_name', 'unknown algorithm'] + image_index = required_args.index('--image') + # Cut out --image and it's associated value + known_algorithm_args = known_algorithm_args[:image_index] + known_algorithm_args[image_index+2:] + + parsed_args = self.parser.parse_args(known_algorithm_args) + + # Patch get_image_uri + _utils.get_image_uri = MagicMock() + _utils.get_image_uri.return_value = "unknown-url" + + response = _utils.create_training_job_request(vars(parsed_args)) + + # Should just place the algorithm name in regardless + _utils.get_image_uri.assert_not_called() + self.assertEqual(response['AlgorithmSpecification']['AlgorithmName'], "unknown algorithm") + + def test_no_channels(self): + no_channels_args = required_args.copy() + channels_index = required_args.index('--channels') + # Replace the value after the flag with an empty list + no_channels_args[channels_index + 1] = '[]' + parsed_args = self.parser.parse_args(no_channels_args) + + with self.assertRaises(Exception): + _utils.create_training_job_request(vars(parsed_args)) + + def test_invalid_instance_type(self): + invalid_instance_args = required_args + ['--instance_type', 'invalid-instance'] + + with self.assertRaises(SystemExit): + self.parser.parse_args(invalid_instance_args) + + def test_valid_hyperparameters(self): + hyperparameters_str = '{"hp1": "val1", "hp2": "val2", "hp3": "val3"}' + + good_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str]) + response = _utils.create_training_job_request(vars(good_args)) + + self.assertIn('hp1', response['HyperParameters']) + self.assertIn('hp2', response['HyperParameters']) + self.assertIn('hp3', response['HyperParameters']) + self.assertEqual(response['HyperParameters']['hp1'], "val1") + self.assertEqual(response['HyperParameters']['hp2'], "val2") + self.assertEqual(response['HyperParameters']['hp3'], "val3") + + def test_empty_hyperparameters(self): + hyperparameters_str = '{}' + + good_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str]) + response = _utils.create_training_job_request(vars(good_args)) + + self.assertEqual(response['HyperParameters'], {}) + + def test_object_hyperparameters(self): + hyperparameters_str = '{"hp1": {"innerkey": "innerval"}}' + + invalid_args = self.parser.parse_args(required_args + ['--hyperparameters', hyperparameters_str]) + with self.assertRaises(Exception): + _utils.create_training_job_request(vars(invalid_args)) + + def test_vpc_configuration(self): + required_vpc_args = self.parser.parse_args(required_args + ['--vpc_security_group_ids', 'sg1,sg2', '--vpc_subnets', 'subnet1,subnet2']) + response = _utils.create_training_job_request(vars(required_vpc_args)) + + self.assertIn('VpcConfig', response) + self.assertIn('sg1', response['VpcConfig']['SecurityGroupIds']) + self.assertIn('sg2', response['VpcConfig']['SecurityGroupIds']) + self.assertIn('subnet1', response['VpcConfig']['Subnets']) + self.assertIn('subnet2', response['VpcConfig']['Subnets']) + + def test_training_mode(self): + required_vpc_args = self.parser.parse_args(required_args + ['--training_input_mode', 'Pipe']) + response = _utils.create_training_job_request(vars(required_vpc_args)) + + self.assertEqual(response['AlgorithmSpecification']['TrainingInputMode'], 'Pipe') + def test_spot_bad_args(self): no_max_wait_args = self.parser.parse_args(required_args + ['--spot_instance', 'True']) no_checkpoint_args = self.parser.parse_args(required_args + ['--spot_instance', 'True', '--max_wait_time', '3600']) @@ -54,7 +320,8 @@ def test_spot_local_path(self): self.assertEqual(response['CheckpointConfig']['S3Uri'], 's3://fake-uri/') self.assertEqual(response['CheckpointConfig']['LocalPath'], 'local-path') - def test_empty_string(self): - good_args = self.parser.parse_args(required_args + ['--spot_instance', 'True', '--max_wait_time', '3600', '--checkpoint_config', '{"S3Uri": "s3://fake-uri/"}']) - response = _utils.create_training_job_request(vars(good_args)) - test_utils.check_empty_string_values(response) \ No newline at end of file + def test_tags(self): + args = self.parser.parse_args(required_args + ['--tags', '{"key1": "val1", "key2": "val2"}']) + response = _utils.create_training_job_request(vars(args)) + self.assertIn({'Key': 'key1', 'Value': 'val1'}, response['Tags']) + self.assertIn({'Key': 'key2', 'Value': 'val2'}, response['Tags']) \ No newline at end of file diff --git a/components/aws/sagemaker/train/README.md b/components/aws/sagemaker/train/README.md index 5b9e68eeac0..a8455658655 100644 --- a/components/aws/sagemaker/train/README.md +++ b/components/aws/sagemaker/train/README.md @@ -28,8 +28,8 @@ resource_encryption_key | The AWS KMS key that Amazon SageMaker uses to encrypt max_run_time | The maximum run time in seconds per training job | Yes | Int | ≤ 432000 (5 days) | 86400 (1 day) | model_artifact_path | | No | String | | | output_encryption_key | The AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts | Yes | String | | | -vpc_security_group_ids | The VPC security group IDs, in the form sg-xxxxxxxx | Yes | String | | | -vpc_subnets | The ID of the subnets in the VPC to which you want to connect your hpo job | Yes | String | | | +vpc_security_group_ids | A comma-delimited list of security group IDs, in the form sg-xxxxxxxx | Yes | String | | | +vpc_subnets | A comma-delimited list of subnet IDs in the VPC to which you want to connect your hpo job | Yes | String | | | network_isolation | Isolates the training container if true | No | Boolean | False, True | True | traffic_encryption | Encrypts all communications between ML compute instances in distributed training if true | No | Boolean | False, True | False | spot_instance | Use managed spot training if true | No | Boolean | False, True | False | diff --git a/components/aws/sagemaker/train/src/train.py b/components/aws/sagemaker/train/src/train.py index a6ae3bdbf11..2e01aaffb78 100644 --- a/components/aws/sagemaker/train/src/train.py +++ b/components/aws/sagemaker/train/src/train.py @@ -10,6 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import argparse import logging @@ -21,7 +22,7 @@ def create_parser(): parser.add_argument('--job_name', type=str, required=False, help='The name of the training job.', default='') parser.add_argument('--role', type=str, required=True, help='The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.') - parser.add_argument('--image', type=str, required=True, help='The registry path of the Docker image that contains the training algorithm.', default='') + parser.add_argument('--image', type=str, required=False, help='The registry path of the Docker image that contains the training algorithm.', default='') parser.add_argument('--algorithm_name', type=str, required=False, help='The name of the resource algorithm to use for the training job.', default='') parser.add_argument('--metric_definitions', type=_utils.yaml_or_json_str, required=False, help='The dictionary of name-regex pairs specify the metrics that the algorithm emits.', default={}) parser.add_argument('--training_input_mode', choices=['File', 'Pipe'], type=str, help='The input mode that the algorithm supports. File or Pipe.', default='File') @@ -53,7 +54,7 @@ def create_parser(): def main(argv=None): parser = create_parser() - args = parser.parse_args() + args = parser.parse_args(argv) logging.getLogger().setLevel(logging.INFO) client = _utils.get_sagemaker_client(args.region, args.endpoint_url) @@ -78,4 +79,4 @@ def main(argv=None): if __name__== "__main__": - main() + main(sys.argv[1:])