Skip to content

Commit

Permalink
Adding HPO unit test (#3791)
Browse files Browse the repository at this point in the history
* Adding HPO unit test

* Adding best training job

* Addressing comment
  • Loading branch information
goswamig authored May 23, 2020
1 parent 43ce639 commit bbe598d
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 10 deletions.
8 changes: 4 additions & 4 deletions components/aws/sagemaker/common/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def create_hyperparameter_tuning_job(client, args):
"""Create a Sagemaker HPO job"""
request = create_hyperparameter_tuning_job_request(args)
try:
job_arn = client.create_hyper_parameter_tuning_job(**request)
client.create_hyper_parameter_tuning_job(**request)
hpo_job_name = request['HyperParameterTuningJobName']
logging.info("Created Hyperparameter Training Job with name: " + hpo_job_name)
logging.info("HPO job in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/hyper-tuning-jobs/{}"
Expand All @@ -604,7 +604,7 @@ def create_hyperparameter_tuning_job(client, args):
raise Exception(e.response['Error']['Message'])


def wait_for_hyperparameter_training_job(client, hpo_job_name):
def wait_for_hyperparameter_training_job(client, hpo_job_name, poll_interval=30):
### Wait until the job finishes
while(True):
response = client.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=hpo_job_name)
Expand All @@ -617,7 +617,7 @@ def wait_for_hyperparameter_training_job(client, hpo_job_name):
logging.error('Hyperparameter tuning failed with the following error: {}'.format(message))
raise Exception('Hyperparameter tuning job failed')
logging.info("Hyperparameter tuning job is still in status: " + status)
time.sleep(30)
time.sleep(poll_interval)


def get_best_training_job_and_hyperparameters(client, hpo_job_name):
Expand Down Expand Up @@ -880,4 +880,4 @@ def yaml_or_json_str(str):
def str_to_bool(str):
# This distutils function returns an integer representation of the boolean
# rather than a True/False value. This simply hard casts it.
return bool(strtobool(str))
return bool(strtobool(str))
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def create_parser():

parser.add_argument('--job_name', type=str, required=False, help='The name of the tuning job. Must be unique within the same AWS account and AWS region.')
parser.add_argument('--role', type=str, required=True, help='The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.')
parser.add_argument('--image', type=str, required=True, help='The registry path of the Docker image that contains the training algorithm.', default='')
parser.add_argument('--image', type=str, required=False, help='The registry path of the Docker image that contains the training algorithm.', default='')
parser.add_argument('--algorithm_name', type=str, required=False, help='The name of the resource algorithm to use for the hyperparameter tuning job.', default='')
parser.add_argument('--training_input_mode', choices=['File', 'Pipe'], type=str, required=False, help='The input mode that the algorithm supports. File or Pipe.', default='File')
parser.add_argument('--metric_definitions', type=_utils.yaml_or_json_str, required=False, help='The dictionary of name-regex pairs specify the metrics that the algorithm emits.', default={})
Expand Down Expand Up @@ -63,7 +63,7 @@ def create_parser():

def main(argv=None):
parser = create_parser()
args = parser.parse_args()
args = parser.parse_args(argv)

logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region)
Expand All @@ -90,4 +90,4 @@ def main(argv=None):


if __name__== "__main__":
main()
main(sys.argv[1:])
Loading

0 comments on commit bbe598d

Please sign in to comment.