Skip to content

Commit

Permalink
Adds job_id parameter. (#3850)
Browse files Browse the repository at this point in the history
* Adds job_id parameter to ml_engine train component, which takes precedence over job_id generated from job_id_prefix.

* Restores ipynb config.

Co-authored-by: andrewleach <andrewleach@google.com>
  • Loading branch information
AndrewLeach and andrewleach0 committed Jun 1, 2020
1 parent b63dd3f commit 58f1d13
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@
from ._client import MLEngineClient
from .. import common as gcp_common

def create_job(project_id, job, job_id_prefix=None, wait_interval=30):
def create_job(project_id, job, job_id_prefix=None, job_id=None,
wait_interval=30):
"""Creates a MLEngine job.
Args:
project_id: the ID of the parent project of the job.
job: the payload of the job. Must have ``jobId``
and ``trainingInput`` or ``predictionInput`.
job_id_prefix: the prefix of the generated job id.
job_id: the created job_id, takes precedence over generated job
id if set.
wait_interval: optional wait interval between calls
to get job status. Defaults to 30.
Expand All @@ -42,15 +45,16 @@ def create_job(project_id, job, job_id_prefix=None, wait_interval=30):
/tmp/kfp/output/ml_engine/job_id.txt: The ID of the created job.
/tmp/kfp/output/ml_engine/job_dir.txt: The `jobDir` of the training job.
"""
return CreateJobOp(project_id, job, job_id_prefix,
wait_interval).execute_and_wait()
return CreateJobOp(project_id, job, job_id_prefix, job_id, wait_interval
).execute_and_wait()

class CreateJobOp:
def __init__(self, project_id, job, job_id_prefix=None, wait_interval=30):
def __init__(self,project_id, job, job_id_prefix=None, job_id=None,
wait_interval=30):
self._ml = MLEngineClient()
self._project_id = project_id
self._job_id_prefix = job_id_prefix
self._job_id = None
self._job_id = job_id
self._job = job
self._wait_interval = wait_interval

Expand All @@ -61,7 +65,9 @@ def execute_and_wait(self):
return wait_for_job_done(self._ml, self._project_id, self._job_id, self._wait_interval)

def _set_job_id(self, context_id):
if self._job_id_prefix:
if self._job_id:
job_id = self._job_id
elif self._job_id_prefix:
job_id = self._job_id_prefix + context_id[:16]
else:
job_id = 'job_' + context_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def train(project_id, python_module=None, package_uris=None,
region=None, args=None, job_dir=None, python_version=None,
runtime_version=None, master_image_uri=None, worker_image_uri=None,
training_input=None, job_id_prefix=None, wait_interval=30):
training_input=None, job_id_prefix=None, job_id=None, wait_interval=30):
"""Creates a MLEngine training job.
Args:
Expand Down Expand Up @@ -50,6 +50,8 @@ def train(project_id, python_module=None, package_uris=None,
This image must be in Container Registry.
training_input (dict): Input parameters to create a training job.
job_id_prefix (str): the prefix of the generated job id.
job_id (str): the created job_id, takes precedence over generated job
id if set.
wait_interval (int): optional wait interval between calls
to get job status. Defaults to 30.
"""
Expand Down Expand Up @@ -80,4 +82,4 @@ def train(project_id, python_module=None, package_uris=None,
job = {
'trainingInput': training_input
}
return create_job(project_id, job, job_id_prefix, wait_interval)
return create_job(project_id, job, job_id_prefix, job_id, wait_interval)
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,27 @@ def test_create_job_with_job_id_prefix_succeed(self, mock_mlengine_client,
'jobId': 'mock_job_ctx1'
}
)

def test_create_job_with_job_id_succeed(self, mock_mlengine_client,
mock_kfp_context, mock_dump_json, mock_display):
mock_kfp_context().__enter__().context_id.return_value = 'ctx1'
job = {}
returned_job = {
'jobId': 'mock_job',
'state': 'SUCCEEDED'
}
mock_mlengine_client().get_job.return_value = (
returned_job)

result = create_job('mock_project', job, job_id='mock_job')

self.assertEqual(returned_job, result)
mock_mlengine_client().create_job.assert_called_with(
project_id = 'mock_project',
job = {
'jobId': 'mock_job'
}
)

def test_execute_retry_job_success(self, mock_mlengine_client,
mock_kfp_context, mock_dump_json, mock_display):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@
CREATE_JOB_MODULE = 'kfp_component.google.ml_engine._train'

@mock.patch(CREATE_JOB_MODULE + '.create_job')
class TestCreateTraingingJob(unittest.TestCase):
class TestCreateTrainingJob(unittest.TestCase):

def test_train_succeed(self, mock_create_job):
train('proj-1', 'mock.module', ['gs://test/package'],
'region-1', args=['arg-1', 'arg-2'], job_dir='gs://test/job/dir',
training_input={
'runtimeVersion': '1.10',
'pythonVersion': '2.7'
}, job_id_prefix='job-', master_image_uri='tensorflow:latest',
}, job_id_prefix='job-', job_id='job-1',
master_image_uri='tensorflow:latest',
worker_image_uri='debian:latest')

mock_create_job.assert_called_with('proj-1', {
Expand All @@ -48,4 +49,4 @@ def test_train_succeed(self, mock_create_job):
'imageUri': 'debian:latest'
}
}
}, 'job-', 30)
}, 'job-', 'job-1', 30)
3 changes: 3 additions & 0 deletions components/gcp/ml_engine/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Use this component to submit a training job to AI Platform from a Kubeflow pipel
| worker_image_uri | The Docker image to run on the worker replica. This image must be in Container Registry. | Yes | GCRPath |- | None |
| training_input | The input parameters to create a training job. | Yes | Dict | [TrainingInput](https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#TrainingInput) | None |
| job_id_prefix | The prefix of the job ID that is generated. | Yes | String | - | None |
| job_id | The ID of the job to create, takes precedence over generated job id if set. | Yes | String | - | None |
| wait_interval | The number of seconds to wait between API calls to get the status of the job. | Yes | Integer | - | 30 |


Expand Down Expand Up @@ -179,6 +180,7 @@ def pipeline(
worker_image_uri = '',
training_input = '',
job_id_prefix = '',
job_id = '',
wait_interval = '30'):
task = mlengine_train_op(
project_id=project_id,
Expand All @@ -193,6 +195,7 @@ def pipeline(
worker_image_uri=worker_image_uri,
training_input=training_input,
job_id_prefix=job_id_prefix,
job_id=job_id,
wait_interval=wait_interval)
```

Expand Down
7 changes: 7 additions & 0 deletions components/gcp/ml_engine/train/component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ inputs:
description: 'The prefix of the generated job id.'
default: ''
type: String
- name: job_id
description: >-
The ID of the job to create, takes precedence over generated
job id if set.
default: ''
type: String
- name: wait_interval
description: >-
Optional. A time-interval to wait for between calls to get the job status.
Expand Down Expand Up @@ -119,6 +125,7 @@ implementation:
--worker_image_uri, {inputValue: worker_image_uri},
--training_input, {inputValue: training_input},
--job_id_prefix, {inputValue: job_id_prefix},
--job_id, {inputValue: job_id},
--wait_interval, {inputValue: wait_interval},
]
env:
Expand Down
7 changes: 5 additions & 2 deletions components/gcp/ml_engine/train/sample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"| worker_image_uri | The Docker image to run on the worker replica. This image must be in Container Registry. | Yes | GCRPath | | None |\n",
"| training_input | The input parameters to create a training job. | Yes | Dict | [TrainingInput](https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#TrainingInput) | None |\n",
"| job_id_prefix | The prefix of the job ID that is generated. | Yes | String | | None |\n",
"| job_id | The ID of the job to create, takes precedence over generated job id if set. | Yes | String | - | None |\n",
"| wait_interval | The number of seconds to wait between API calls to get the status of the job. | Yes | Integer | | 30 |\n",
"\n",
"\n",
Expand Down Expand Up @@ -238,6 +239,7 @@
" worker_image_uri = '',\n",
" training_input = '',\n",
" job_id_prefix = '',\n",
" job_id = '',\n",
" wait_interval = '30'):\n",
" task = mlengine_train_op(\n",
" project_id=project_id, \n",
Expand All @@ -251,7 +253,8 @@
" master_image_uri=master_image_uri, \n",
" worker_image_uri=worker_image_uri, \n",
" training_input=training_input, \n",
" job_id_prefix=job_id_prefix, \n",
" job_id_prefix=job_id_prefix,\n",
" job_id=job_id,\n",
" wait_interval=wait_interval)"
]
},
Expand Down Expand Up @@ -354,4 +357,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

0 comments on commit 58f1d13

Please sign in to comment.