Skip to content

Commit

Permalink
update refreshing mechanism (#3981)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiaxiao Zheng authored and Bobgy committed Jul 2, 2020
1 parent 87ae181 commit efa5f1e
Showing 1 changed file with 42 additions and 11 deletions.
53 changes: 42 additions & 11 deletions sdk/python/kfp/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import string
import random
import time
import logging
import json
Expand All @@ -24,7 +22,7 @@
import warnings
import yaml
import zipfile
from datetime import datetime
import datetime
from typing import Mapping, Callable

import kfp
Expand All @@ -35,6 +33,12 @@

from kfp._auth import get_auth_token, get_gcp_access_token

# TTL of the access token associated with the client. This is needed because
# `gcloud auth print-access-token` generates a token with TTL=1 hour, after
# which the authentication expires. This TTL is needed for kfp.Client()
# initialized with host=<inverse proxy endpoint>.
# Set to 55 mins to provide some safe margin.
_GCP_ACCESS_TOKEN_TIMEOUT = datetime.timedelta(minutes=55)


def _add_generated_apis(target_struct, api_module, api_client):
Expand Down Expand Up @@ -108,6 +112,9 @@ def __init__(self, host=None, client_id=None, namespace='kubeflow', other_client
host = host or os.environ.get(KF_PIPELINES_ENDPOINT_ENV)
self._uihost = os.environ.get(KF_PIPELINES_UI_ENDPOINT_ENV, host)
config = self._load_config(host, client_id, namespace, other_client_id, other_client_secret, existing_token)
# Save the loaded API client configuration, as a reference if update is
# needed.
self._existing_config = config
api_client = kfp_server_api.api_client.ApiClient(config)
_add_generated_apis(self, kfp_server_api, api_client)
self._job_api = kfp_server_api.api.job_service_api.JobServiceApi(api_client)
Expand Down Expand Up @@ -150,10 +157,13 @@ def _load_config(self, host, client_id, namespace, other_client_id, other_client
#
if existing_token:
token = existing_token
self._is_refresh_token = False
elif client_id:
token = get_auth_token(client_id, other_client_id, other_client_secret)
self._is_refresh_token = True
elif self._is_inverse_proxy_host(host):
token = get_gcp_access_token()
self._is_refresh_token = False

if token:
config.api_key['authorization'] = token
Expand Down Expand Up @@ -226,6 +236,14 @@ def _load_context_setting_or_default(self):
self._context_setting = {
'namespace': '',
}

def _refresh_api_client_token(self):
"""Refreshes the existing token associated with the kfp_api_client."""
if getattr(self, '_is_refresh_token'):
return

new_token = get_gcp_access_token()
self._existing_config.api_key['authorization'] = new_token

def set_user_namespace(self, namespace):
"""Set user namespace into local context setting file.
Expand Down Expand Up @@ -531,7 +549,7 @@ def create_run_from_pipeline_func(self, pipeline_func: Callable, arguments: Mapp
'''
#TODO: Check arguments against the pipeline function
pipeline_name = pipeline_func.__name__
run_name = run_name or pipeline_name + ' ' + datetime.now().strftime('%Y-%m-%d %H-%M-%S')
run_name = run_name or pipeline_name + ' ' + datetime.datetime.now().strftime('%Y-%m-%d %H-%M-%S')
with tempfile.TemporaryDirectory() as tmpdir:
pipeline_package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(pipeline_func, pipeline_package_path, pipeline_conf=pipeline_conf)
Expand All @@ -558,7 +576,7 @@ def __init__(self, client, run_info):
self.run_id = run_info.id

def wait_for_run_completion(self, timeout=None):
timeout = timeout or datetime.datetime.max - datetime.datetime.min
timeout = timeout or datetime.timedelta.max
return self._client.wait_for_run_completion(self.run_id, timeout)

def __repr__(self):
Expand All @@ -572,7 +590,9 @@ def __repr__(self):
import warnings
warnings.warn('Changing experiment name from "{}" to "{}".'.format(experiment_name, overridden_experiment_name))
experiment_name = overridden_experiment_name or 'Default'
run_name = run_name or pipeline_name + ' ' + datetime.now().strftime('%Y-%m-%d %H-%M-%S')
run_name = run_name or (pipeline_name + ' ' +
datetime.datetime.now().strftime(
'%Y-%m-%d %H-%M-%S'))
experiment = self.create_experiment(name=experiment_name, namespace=namespace)
run_info = self.run_pipeline(experiment.id, run_name, pipeline_file, arguments)
return RunPipelineResult(self, run_info)
Expand Down Expand Up @@ -639,19 +659,30 @@ def get_run(self, run_id):
return self._run_api.get_run(run_id=run_id)

def wait_for_run_completion(self, run_id, timeout):
"""Wait for a run to complete.
"""Waits for a run to complete.
Args:
run_id: run id, returned from run_pipeline.
timeout: timeout in seconds.
Returns:
A run detail object: Most important fields are run and pipeline_runtime
A run detail object: Most important fields are run and pipeline_runtime.
Raises:
TimeoutError: if the pipeline run failed to finish before the specified
timeout.
"""
status = 'Running:'
start_time = datetime.now()
while status is None or status.lower() not in ['succeeded', 'failed', 'skipped', 'error']:
start_time = datetime.datetime.now()
last_token_refresh_time = datetime.datetime.now()
while (status is None or
status.lower() not in ['succeeded', 'failed', 'skipped', 'error']):
# Refreshes the access token before it hits the TTL.
if (datetime.datetime.now() - last_token_refresh_time
> _GCP_ACCESS_TOKEN_TIMEOUT):
self._refresh_api_client_token()
last_token_refresh_time = datetime.datetime.now()

get_run_response = self._run_api.get_run(run_id=run_id)
status = get_run_response.run.status
elapsed_time = (datetime.now() - start_time).seconds
elapsed_time = (datetime.datetime.now() - start_time).seconds
logging.info('Waiting for the job to complete...')
if elapsed_time > timeout:
raise TimeoutError('Run timeout')
Expand Down

0 comments on commit efa5f1e

Please sign in to comment.