Skip to content

Commit

Permalink
SDK - Refactoring - Split the K8sHelper class
Browse files Browse the repository at this point in the history
One part was only used by container builder and provided higher-level API over K8s Client.
Another was used by the compiler and did not use the kubernetes library.
  • Loading branch information
Ark-kun committed Oct 9, 2019
1 parent 6a8d105 commit 1c2a539
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 158 deletions.
4 changes: 2 additions & 2 deletions sdk/python/kfp/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import kfp_server_api

from kfp.compiler import compiler
from kfp.compiler import _k8s_helper
from kfp.compiler._k8s_helper import sanitize_k8s_name

from kfp._auth import get_auth_token, get_gcp_access_token

Expand Down Expand Up @@ -298,7 +298,7 @@ def run_pipeline(self, experiment_id, job_name, pipeline_package_path=None, para
if pipeline_package_path:
pipeline_obj = self._extract_pipeline_yaml(pipeline_package_path)
pipeline_json_string = json.dumps(pipeline_obj)
api_params = [kfp_server_api.ApiParameter(name=_k8s_helper.K8sHelper.sanitize_k8s_name(k), value=str(v))
api_params = [kfp_server_api.ApiParameter(name=sanitize_k8s_name(k), value=str(v))
for k,v in params.items()]
key = kfp_server_api.models.ApiResourceKey(id=experiment_id,
type=kfp_server_api.models.ApiResourceType.EXPERIMENT)
Expand Down
133 changes: 6 additions & 127 deletions sdk/python/kfp/compiler/_k8s_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,140 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
from kubernetes import client as k8s_client
from kubernetes import config
import time
import logging
import re

from .. import dsl


class K8sHelper(object):
""" Kubernetes Helper """

def __init__(self):
if not self._configure_k8s():
raise Exception('K8sHelper __init__ failure')

def _configure_k8s(self):
try:
config.load_incluster_config()
logging.info('Initialized with in-cluster config.')
except:
logging.info('Cannot find in-cluster config, trying the local kubernetes config. ')
try:
config.load_kube_config()
logging.info('Found local kubernetes config. Initialized with kube_config.')
except:
raise RuntimeError('Forgot to run the gcloud command? Check out the link: \
https://cloud.google.com/kubernetes-engine/docs/how-to/cluster-access-for-kubectl for more information')
self._api_client = k8s_client.ApiClient()
self._corev1 = k8s_client.CoreV1Api(self._api_client)
return True

def _create_k8s_job(self, yaml_spec):
""" _create_k8s_job creates a kubernetes job based on the yaml spec """
pod = k8s_client.V1Pod(metadata=k8s_client.V1ObjectMeta(generate_name=yaml_spec['metadata']['generateName'],
annotations=yaml_spec['metadata']['annotations']))
container = k8s_client.V1Container(name = yaml_spec['spec']['containers'][0]['name'],
image = yaml_spec['spec']['containers'][0]['image'],
args = yaml_spec['spec']['containers'][0]['args'],
volume_mounts = [k8s_client.V1VolumeMount(
name=yaml_spec['spec']['containers'][0]['volumeMounts'][0]['name'],
mount_path=yaml_spec['spec']['containers'][0]['volumeMounts'][0]['mountPath'],
)],
env = [k8s_client.V1EnvVar(
name=yaml_spec['spec']['containers'][0]['env'][0]['name'],
value=yaml_spec['spec']['containers'][0]['env'][0]['value'],
)])
pod.spec = k8s_client.V1PodSpec(restart_policy=yaml_spec['spec']['restartPolicy'],
containers = [container],
service_account_name=yaml_spec['spec']['serviceAccountName'],
volumes=[k8s_client.V1Volume(
name=yaml_spec['spec']['volumes'][0]['name'],
secret=k8s_client.V1SecretVolumeSource(
secret_name=yaml_spec['spec']['volumes'][0]['secret']['secretName'],
)
)])
try:
api_response = self._corev1.create_namespaced_pod(yaml_spec['metadata']['namespace'], pod)
return api_response.metadata.name, True
except k8s_client.rest.ApiException as e:
logging.exception("Exception when calling CoreV1Api->create_namespaced_pod: {}\n".format(str(e)))
return '', False

def _wait_for_k8s_job(self, pod_name, yaml_spec, timeout):
""" _wait_for_k8s_job waits for the job to complete """
status = 'running'
start_time = datetime.now()
while status in ['pending', 'running']:
# Pod pending values: https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1PodStatus.md
try:
api_response = self._corev1.read_namespaced_pod(pod_name, yaml_spec['metadata']['namespace'])
status = api_response.status.phase.lower()
time.sleep(5)
elapsed_time = (datetime.now() - start_time).seconds
logging.info('{} seconds: waiting for job to complete'.format(elapsed_time))
if elapsed_time > timeout:
logging.info('Kubernetes job timeout')
return False
except k8s_client.rest.ApiException as e:
logging.exception('Exception when calling CoreV1Api->read_namespaced_pod: {}\n'.format(str(e)))
return False
return status == 'succeeded'

def _delete_k8s_job(self, pod_name, yaml_spec):
""" _delete_k8s_job deletes a pod """
try:
api_response = self._corev1.delete_namespaced_pod(pod_name, yaml_spec['metadata']['namespace'], body=k8s_client.V1DeleteOptions())
except k8s_client.rest.ApiException as e:
logging.exception('Exception when calling CoreV1Api->delete_namespaced_pod: {}\n'.format(str(e)))

def _read_pod_log(self, pod_name, yaml_spec):
try:
api_response = self._corev1.read_namespaced_pod_log(pod_name, yaml_spec['metadata']['namespace'])
except k8s_client.rest.ApiException as e:
logging.exception('Exception when calling CoreV1Api->read_namespaced_pod_log: {}\n'.format(str(e)))
return False
return api_response

def _read_pod_status(self, pod_name, namespace):
try:
# Using read_namespaced_pod due to the following error: "pods \"kaniko-p2phh\" is forbidden: User \"system:serviceaccount:kubeflow:jupyter-notebook\" cannot get pods/status in the namespace \"kubeflow\""
#api_response = self._corev1.read_namespaced_pod_status(pod_name, namespace)
api_response = self._corev1.read_namespaced_pod(pod_name, namespace)
except k8s_client.rest.ApiException as e:
logging.exception('Exception when calling CoreV1Api->read_namespaced_pod_status: {}\n'.format(str(e)))
return False
return api_response

def run_job(self, yaml_spec, timeout=600):
""" run_job runs a kubernetes job and clean up afterwards """
pod_name, succ = self._create_k8s_job(yaml_spec)
namespace = yaml_spec['metadata']['namespace']
if not succ:
raise RuntimeError('Kubernetes job creation failed.')
# timeout in seconds
succ = self._wait_for_k8s_job(pod_name, yaml_spec, timeout)
if not succ:
logging.info('Kubernetes job failed.')
print(self._read_pod_log(pod_name, yaml_spec))
raise RuntimeError('Kubernetes job failed.')
status_obj = self._read_pod_status(pod_name, namespace)
self._delete_k8s_job(pod_name, yaml_spec)
return status_obj

@staticmethod
def sanitize_k8s_name(name):
def sanitize_k8s_name(name):
"""From _make_kubernetes_name
sanitize_k8s_name cleans and converts the names in the workflow.
"""
return re.sub('-+', '-', re.sub('[^-0-9a-z]+', '-', name.lower())).lstrip('-').rstrip('-')

@staticmethod
def convert_k8s_obj_to_json(k8s_obj):

def convert_k8s_obj_to_json(k8s_obj):
"""
Builds a JSON K8s object.
Expand All @@ -170,10 +49,10 @@ def convert_k8s_obj_to_json(k8s_obj):
elif isinstance(k8s_obj, PRIMITIVE_TYPES):
return k8s_obj
elif isinstance(k8s_obj, list):
return [K8sHelper.convert_k8s_obj_to_json(sub_obj)
return [convert_k8s_obj_to_json(sub_obj)
for sub_obj in k8s_obj]
elif isinstance(k8s_obj, tuple):
return tuple(K8sHelper.convert_k8s_obj_to_json(sub_obj)
return tuple(convert_k8s_obj_to_json(sub_obj)
for sub_obj in k8s_obj)
elif isinstance(k8s_obj, (datetime, date)):
return k8s_obj.isoformat()
Expand All @@ -196,5 +75,5 @@ def convert_k8s_obj_to_json(k8s_obj):
for attr, _ in iteritems(attr_types)
if getattr(k8s_obj, attr) is not None}

return {key: K8sHelper.convert_k8s_obj_to_json(val)
return {key: convert_k8s_obj_to_json(val)
for key, val in iteritems(obj_dict)}
18 changes: 9 additions & 9 deletions sdk/python/kfp/compiler/_op_to_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import OrderedDict
from typing import Union, List, Any, Callable, TypeVar, Dict

from ._k8s_helper import K8sHelper
from ._k8s_helper import convert_k8s_obj_to_json
from .. import dsl
from ..dsl._container_op import BaseOp
from ..dsl._artifact_location import ArtifactLocation
Expand Down Expand Up @@ -73,15 +73,15 @@ def _process_obj(obj: Any, map_to_tmpl_var: dict):
for key in obj.swagger_types.keys():
setattr(obj, key, _process_obj(getattr(obj, key), map_to_tmpl_var))
# return json representation of the k8s obj
return K8sHelper.convert_k8s_obj_to_json(obj)
return convert_k8s_obj_to_json(obj)

# k8s objects (generated from openapi)
if hasattr(obj, 'openapi_types') and isinstance(obj.openapi_types, dict):
# process everything inside recursively
for key in obj.openapi_types.keys():
setattr(obj, key, _process_obj(getattr(obj, key), map_to_tmpl_var))
# return json representation of the k8s obj
return K8sHelper.convert_k8s_obj_to_json(obj)
return convert_k8s_obj_to_json(obj)

# do nothing
return obj
Expand Down Expand Up @@ -194,7 +194,7 @@ def _op_to_template(op: BaseOp):
output_artifact_paths.update(sorted(((param.full_name, processed_op.file_outputs[param.name]) for param in processed_op.outputs.values()), key=lambda x: x[0]))

output_artifacts = [
K8sHelper.convert_k8s_obj_to_json(
convert_k8s_obj_to_json(
ArtifactLocation.create_artifact_for_s3(
op.artifact_location,
name=name,
Expand All @@ -206,7 +206,7 @@ def _op_to_template(op: BaseOp):
# workflow template
template = {
'name': processed_op.name,
'container': K8sHelper.convert_k8s_obj_to_json(
'container': convert_k8s_obj_to_json(
processed_op.container
)
}
Expand All @@ -216,12 +216,12 @@ def _op_to_template(op: BaseOp):

# workflow template
processed_op.resource["manifest"] = yaml.dump(
K8sHelper.convert_k8s_obj_to_json(processed_op.k8s_resource),
convert_k8s_obj_to_json(processed_op.k8s_resource),
default_flow_style=False
)
template = {
'name': processed_op.name,
'resource': K8sHelper.convert_k8s_obj_to_json(
'resource': convert_k8s_obj_to_json(
processed_op.resource
)
}
Expand Down Expand Up @@ -252,7 +252,7 @@ def _op_to_template(op: BaseOp):

# affinity
if processed_op.affinity:
template['affinity'] = K8sHelper.convert_k8s_obj_to_json(processed_op.affinity)
template['affinity'] = convert_k8s_obj_to_json(processed_op.affinity)

# metadata
if processed_op.pod_annotations or processed_op.pod_labels:
Expand All @@ -279,7 +279,7 @@ def _op_to_template(op: BaseOp):

# volumes
if processed_op.volumes:
template['volumes'] = [K8sHelper.convert_k8s_obj_to_json(volume) for volume in processed_op.volumes]
template['volumes'] = [convert_k8s_obj_to_json(volume) for volume in processed_op.volumes]
template['volumes'].sort(key=lambda x: x['name'])

# Display name
Expand Down
26 changes: 13 additions & 13 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from kfp.dsl import _for_loop

from .. import dsl
from ._k8s_helper import K8sHelper
from ._k8s_helper import convert_k8s_obj_to_json, sanitize_k8s_name
from ._op_to_template import _op_to_template
from ._default_transformers import add_pod_env

Expand Down Expand Up @@ -639,7 +639,7 @@ def _create_pipeline_workflow(self, args, pipeline, op_transformers=None, pipeli
if len(pipeline_conf.image_pull_secrets) > 0:
image_pull_secrets = []
for image_pull_secret in pipeline_conf.image_pull_secrets:
image_pull_secrets.append(K8sHelper.convert_k8s_obj_to_json(image_pull_secret))
image_pull_secrets.append(convert_k8s_obj_to_json(image_pull_secret))
workflow['spec']['imagePullSecrets'] = image_pull_secrets

if pipeline_conf.timeout:
Expand Down Expand Up @@ -684,26 +684,26 @@ def _sanitize_and_inject_artifact(self, pipeline: dsl.Pipeline, pipeline_conf=No
if artifact_location and not op.artifact_location:
op.artifact_location = artifact_location

sanitized_name = K8sHelper.sanitize_k8s_name(op.name)
sanitized_name = sanitize_k8s_name(op.name)
op.name = sanitized_name
for param in op.outputs.values():
param.name = K8sHelper.sanitize_k8s_name(param.name)
param.name = sanitize_k8s_name(param.name)
if param.op_name:
param.op_name = K8sHelper.sanitize_k8s_name(param.op_name)
param.op_name = sanitize_k8s_name(param.op_name)
if op.output is not None:
op.output.name = K8sHelper.sanitize_k8s_name(op.output.name)
op.output.op_name = K8sHelper.sanitize_k8s_name(op.output.op_name)
op.output.name = sanitize_k8s_name(op.output.name)
op.output.op_name = sanitize_k8s_name(op.output.op_name)
if op.dependent_names:
op.dependent_names = [K8sHelper.sanitize_k8s_name(name) for name in op.dependent_names]
op.dependent_names = [sanitize_k8s_name(name) for name in op.dependent_names]
if isinstance(op, dsl.ContainerOp) and op.file_outputs is not None:
sanitized_file_outputs = {}
for key in op.file_outputs.keys():
sanitized_file_outputs[K8sHelper.sanitize_k8s_name(key)] = op.file_outputs[key]
sanitized_file_outputs[sanitize_k8s_name(key)] = op.file_outputs[key]
op.file_outputs = sanitized_file_outputs
elif isinstance(op, dsl.ResourceOp) and op.attribute_outputs is not None:
sanitized_attribute_outputs = {}
for key in op.attribute_outputs.keys():
sanitized_attribute_outputs[K8sHelper.sanitize_k8s_name(key)] = \
sanitized_attribute_outputs[sanitize_k8s_name(key)] = \
op.attribute_outputs[key]
op.attribute_outputs = sanitized_attribute_outputs
sanitized_ops[sanitized_name] = op
Expand All @@ -725,7 +725,7 @@ def _create_workflow(self,
pipeline_meta = _extract_pipeline_metadata(pipeline_func)
pipeline_meta.name = pipeline_name or pipeline_meta.name
pipeline_meta.description = pipeline_description or pipeline_meta.description
pipeline_name = K8sHelper.sanitize_k8s_name(pipeline_meta.name)
pipeline_name = sanitize_k8s_name(pipeline_meta.name)

# Need to first clear the default value of dsl.PipelineParams. Otherwise, it
# will be resolved immediately in place when being to each component.
Expand All @@ -746,7 +746,7 @@ def _create_workflow(self,
if arg_name == input.name:
arg_type = input.type
break
args_list.append(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name), param_type=arg_type))
args_list.append(dsl.PipelineParam(sanitize_k8s_name(arg_name), param_type=arg_type))

with dsl.Pipeline(pipeline_name) as dsl_pipeline:
pipeline_func(*args_list)
Expand All @@ -759,7 +759,7 @@ def _create_workflow(self,
# Fill in the default values.
args_list_with_defaults = []
if pipeline_meta.inputs:
args_list_with_defaults = [dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name))
args_list_with_defaults = [dsl.PipelineParam(sanitize_k8s_name(arg_name))
for arg_name in argspec.args]
if argspec.defaults:
for arg, default in zip(reversed(args_list_with_defaults), reversed(argspec.defaults)):
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/kfp/containers/_container_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ def build(self, local_dir, docker_filename : str = 'Dockerfile', target_image=No
docker_filename=docker_filename,
target_image=target_image)
logging.info('Start a kaniko job for build.')
from ..compiler._k8s_helper import K8sHelper
k8s_helper = K8sHelper()
from ._k8s_job_helper import K8sJobHelper
k8s_helper = K8sJobHelper()
result_pod_obj = k8s_helper.run_job(kaniko_spec, timeout)
logging.info('Kaniko job complete.')

Expand Down
Loading

0 comments on commit 1c2a539

Please sign in to comment.