Skip to content

Commit

Permalink
chore(sdk.v2): Migrate to the new IR with subdag support (kubeflow#4981)
Browse files Browse the repository at this point in the history
* migrate to new ir

* address review comments

* fix bugs

* fix pipeline parameters and tests

* fix components import

* fix typo
  • Loading branch information
chensun committed Feb 3, 2021
1 parent 9e0b9aa commit 91c5a93
Show file tree
Hide file tree
Showing 21 changed files with 1,658 additions and 1,003 deletions.
113 changes: 65 additions & 48 deletions sdk/python/kfp/v2/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@

import kfp
from kfp.compiler._k8s_helper import sanitize_k8s_name
from kfp.components import _python_op
from kfp.v2 import dsl
from kfp.v2.compiler import compiler_utils
from kfp.components import _python_op
from kfp.v2.dsl import component_spec as dsl_component_spec
from kfp.v2.dsl import dsl_utils
from kfp.v2.dsl import importer_node
from kfp.v2.dsl import type_utils
from kfp.pipeline_spec import pipeline_spec_pb2
Expand Down Expand Up @@ -75,68 +77,95 @@ def _create_pipeline_spec(
"""
compiler_utils.validate_pipeline_name(pipeline.name)

pipeline_spec = pipeline_spec_pb2.PipelineSpec(
runtime_parameters=compiler_utils.build_runtime_parameter_spec(args))
pipeline_spec = pipeline_spec_pb2.PipelineSpec()

pipeline_spec.pipeline_info.name = pipeline.name
pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__)
pipeline_spec.schema_version = 'v2alpha1'
# Schema version 2.0.0 is required for kfp-pipeline-spec>0.1.3.1
pipeline_spec.schema_version = '2.0.0'

pipeline_spec.root.CopyFrom(
dsl_component_spec.build_root_spec_from_pipeline_params(args))

deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
importer_tasks = []

for op in pipeline.ops.values():
component_spec = op._metadata
task = pipeline_spec.tasks.add()
task.CopyFrom(op.task_spec)
deployment_config.executors[task.executor_label].container.CopyFrom(
task_name = op.task_spec.task_info.name
component_name = op.task_spec.component_ref.name
executor_label = op.component_spec.executor_label

pipeline_spec.root.dag.tasks[task_name].CopyFrom(op.task_spec)
pipeline_spec.components[component_name].CopyFrom(op.component_spec)
deployment_config.executors[executor_label].container.CopyFrom(
op.container_spec)

task = pipeline_spec.root.dag.tasks[task_name]
# A task may have explicit depdency on other tasks even though they may
# not have inputs/outputs dependency. e.g.: op2.after(op1)
if op.dependent_names:
op.dependent_names = [
dsl_utils.sanitize_task_name(name) for name in op.dependent_names
]
task.dependent_tasks.extend(op.dependent_names)

# Check if need to insert importer node
for input_name in task.inputs.artifacts:
if not task.inputs.artifacts[input_name].producer_task:
if not task.inputs.artifacts[
input_name].task_output_artifact.producer_task:
type_schema = type_utils.get_input_artifact_type_schema(
input_name, component_spec.inputs)

importer_task = importer_node.build_importer_task_spec(
dependent_task=task,
input_name, op._metadata.inputs)

importer_name = importer_node.generate_importer_base_name(
dependent_task_name=task_name, input_name=input_name)
importer_task_spec = importer_node.build_importer_task_spec(
importer_name)
importer_comp_spec = importer_node.build_importer_component_spec(
importer_base_name=importer_name,
input_name=input_name,
input_type_schema=type_schema)
importer_tasks.append(importer_task)
importer_task_name = importer_task_spec.task_info.name
importer_comp_name = importer_task_spec.component_ref.name
importer_exec_label = importer_comp_spec.executor_label
pipeline_spec.root.dag.tasks[importer_task_name].CopyFrom(
importer_task_spec)
pipeline_spec.components[importer_comp_name].CopyFrom(
importer_comp_spec)

task.inputs.artifacts[
input_name].producer_task = importer_task.task_info.name
input_name].task_output_artifact.producer_task = (
importer_task_name)
task.inputs.artifacts[
input_name].output_artifact_key = importer_node.OUTPUT_KEY
input_name].task_output_artifact.output_artifact_key = (
importer_node.OUTPUT_KEY)

# Retrieve the pre-built importer spec
importer_spec = op.importer_spec[input_name]
deployment_config.executors[
importer_task.executor_label].importer.CopyFrom(importer_spec)
importer_spec = op.importer_specs[input_name]
deployment_config.executors[importer_exec_label].importer.CopyFrom(
importer_spec)

pipeline_spec.deployment_config.Pack(deployment_config)
pipeline_spec.tasks.extend(importer_tasks)
pipeline_spec.deployment_spec.update(
json_format.MessageToDict(deployment_config))

return pipeline_spec

def _create_pipeline(
self,
pipeline_func: Callable[..., Any],
output_directory: str,
pipeline_name: Optional[str] = None,
) -> pipeline_spec_pb2.PipelineSpec:
pipeline_parameters_override: Optional[Mapping[str, Any]] = None,
) -> pipeline_spec_pb2.PipelineJob:
"""Creates a pipeline instance and constructs the pipeline spec from it.
Args:
pipeline_func: Pipeline function with @dsl.pipeline decorator.
pipeline_name: The name of the pipeline. Optional.
output_directory: The root of the pipeline outputs.
pipeline_parameters_override: The mapping from parameter names to values.
Optional.
Returns:
The IR representation (pipeline spec) of the pipeline.
A PipelineJob proto representing the compiled pipeline.
"""

# Create the arg list with no default values and call pipeline function.
Expand Down Expand Up @@ -174,26 +203,14 @@ def _create_pipeline(
dsl_pipeline,
)

return pipeline_spec

def _create_pipeline_job(
self,
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
pipeline_root: str,
pipeline_parameters: Optional[Mapping[str, Any]] = None,
) -> pipeline_spec_pb2.PipelineJob:
"""Creates the pipeline job spec object.
Args:
pipeline_spec: The pipeline spec object.
pipeline_root: The root of the pipeline outputs.
pipeline_parameters: The mapping from parameter names to values. Optional.
Returns:
A PipelineJob proto representing the compiled pipeline.
"""
pipeline_parameters = {
arg.name: arg.value for arg in args_list_with_defaults
}
# Update pipeline parameters override if there were any.
pipeline_parameters.update(pipeline_parameters_override or {})
runtime_config = compiler_utils.build_runtime_config_spec(
pipeline_root=pipeline_root, pipeline_parameters=pipeline_parameters)
output_directory=output_directory,
pipeline_parameters=pipeline_parameters)
pipeline_job = pipeline_spec_pb2.PipelineJob(runtime_config=runtime_config)
pipeline_job.pipeline_spec.update(json_format.MessageToDict(pipeline_spec))

Expand All @@ -220,11 +237,11 @@ def compile(self,
type_check_old_value = kfp.TYPE_CHECK
try:
kfp.TYPE_CHECK = type_check
pipeline = self._create_pipeline(pipeline_func, pipeline_name)
pipeline_job = self._create_pipeline_job(
pipeline_spec=pipeline,
pipeline_root=pipeline_root,
pipeline_parameters=pipeline_parameters)
pipeline_job = self._create_pipeline(
pipeline_func=pipeline_func,
output_directory=pipeline_root,
pipeline_name=pipeline_name,
pipeline_parameters_override=pipeline_parameters)
self._write_pipeline(pipeline_job, output_path)
finally:
kfp.TYPE_CHECK = type_check_old_value
Expand Down
54 changes: 8 additions & 46 deletions sdk/python/kfp/v2/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,62 +20,22 @@
from kfp.pipeline_spec import pipeline_spec_pb2


def build_runtime_parameter_spec(
pipeline_params: List[dsl.PipelineParam]
) -> Mapping[str, pipeline_spec_pb2.PipelineSpec.RuntimeParameter]:
"""Converts pipeine parameters to runtime parameters mapping.
Args:
pipeline_params: The list of pipeline parameters.
Returns:
A map of pipeline parameter name to its runtime parameter message.
"""

def to_message(param: dsl.PipelineParam):
result = pipeline_spec_pb2.PipelineSpec.RuntimeParameter()
if param.param_type == 'Integer' or (param.param_type is None and
isinstance(param.value, int)):

result.type = pipeline_spec_pb2.PrimitiveType.INT
if param.value is not None:
result.default_value.int_value = int(param.value)
elif param.param_type == 'Float' or (param.param_type is None and
isinstance(param.value, float)):
result.type = pipeline_spec_pb2.PrimitiveType.DOUBLE
if param.value is not None:
result.default_value.double_value = float(param.value)
elif param.param_type == 'String' or param.param_type is None:
result.type = pipeline_spec_pb2.PrimitiveType.STRING
if param.value is not None:
result.default_value.string_value = str(param.value)
else:
raise TypeError('Unsupported type "{}" for argument "{}".'.format(
param.param_type, param.name))
return result

return {param.name: to_message(param) for param in pipeline_params}


def build_runtime_config_spec(
pipeline_root: str,
output_directory: str,
pipeline_parameters: Optional[Mapping[str, Any]] = None,
) -> pipeline_spec_pb2.PipelineJob.RuntimeConfig:
"""Converts pipeine parameters to runtime parameters mapping.
Args:
pipeline_root: The root of pipeline outputs.
output_directory: The root of pipeline outputs.
pipeline_parameters: The mapping from parameter names to values. Optional.
Returns:
A pipeline job RuntimeConfig object.
"""

def _get_value(
value: Optional[Union[int, float,
str]]) -> Optional[pipeline_spec_pb2.Value]:
if value is None:
return None
def _get_value(value: Union[int, float, str]) -> pipeline_spec_pb2.Value:
assert value is not None, 'None values should be filterd out.'

result = pipeline_spec_pb2.Value()
if isinstance(value, int):
Expand All @@ -91,8 +51,10 @@ def _get_value(

parameter_values = pipeline_parameters or {}
return pipeline_spec_pb2.PipelineJob.RuntimeConfig(
gcs_output_directory=pipeline_root,
parameters={k: _get_value(v) for k, v in parameter_values.items()})
gcs_output_directory=output_directory,
parameters={
k: _get_value(v) for k, v in parameter_values.items() if v is not None
})


def validate_pipeline_name(name: str) -> None:
Expand Down
53 changes: 2 additions & 51 deletions sdk/python/kfp/v2/compiler/compiler_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for kfp.v2.compiler.compiler_utils."""

import unittest

Expand All @@ -22,56 +23,6 @@

class CompilerUtilsTest(unittest.TestCase):

def test_build_runtime_parameter_spec(self):
pipeline_params = [
dsl.PipelineParam(name='input1', param_type='Integer', value=99),
dsl.PipelineParam(name='input2', param_type='String', value='hello'),
dsl.PipelineParam(name='input3', param_type='Float', value=3.1415926),
dsl.PipelineParam(name='input4', param_type=None, value=None),
]
expected_dict = {
'runtimeParameters': {
'input1': {
'type': 'INT',
'defaultValue': {
'intValue': '99'
}
},
'input2': {
'type': 'STRING',
'defaultValue': {
'stringValue': 'hello'
}
},
'input3': {
'type': 'DOUBLE',
'defaultValue': {
'doubleValue': '3.1415926'
}
},
'input4': {
'type': 'STRING'
}
}
}
expected_spec = pipeline_spec_pb2.PipelineSpec()
json_format.ParseDict(expected_dict, expected_spec)

pipeline_spec = pipeline_spec_pb2.PipelineSpec(
runtime_parameters=compiler_utils.build_runtime_parameter_spec(
pipeline_params))
self.maxDiff = None
self.assertEqual(expected_spec, pipeline_spec)

def test_build_runtime_parameter_spec_with_unsupported_type_should_fail(self):
pipeline_params = [
dsl.PipelineParam(name='input1', param_type='Dict'),
]

with self.assertRaisesRegexp(
TypeError, 'Unsupported type "Dict" for argument "input1"'):
compiler_utils.build_runtime_parameter_spec(pipeline_params)

def test_build_runtime_config_spec(self):
expected_dict = {
'gcsOutputDirectory': 'gs://path',
Expand All @@ -85,7 +36,7 @@ def test_build_runtime_config_spec(self):
json_format.ParseDict(expected_dict, expected_spec)

runtime_config = compiler_utils.build_runtime_config_spec(
'gs://path', {'input1': 'test'})
'gs://path', {'input1': 'test', 'input2': None})
self.assertEqual(expected_spec, runtime_config)

def test_validate_pipeline_name(self):
Expand Down
4 changes: 0 additions & 4 deletions sdk/python/kfp/v2/compiler_cli_tests/compiler_cli_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,9 @@ def _test_compile_py_to_json(self, file_base_name, additional_arguments = []):
golden = json.load(f)
# Correct the sdkVersion
golden['pipelineSpec']['sdkVersion'] = 'kfp-{}'.format(kfp.__version__)
# Need to sort the list items before comparison
golden['pipelineSpec']['tasks'].sort(key=lambda x: x['executorLabel'])

with open(os.path.join(test_data_dir, target_json), 'r') as f:
compiled = json.load(f)
# Need to sort the list items before comparison
compiled['pipelineSpec']['tasks'].sort(key=lambda x: x['executorLabel'])

self.maxDiff = None
self.assertEqual(golden, compiled)
Expand Down
Loading

0 comments on commit 91c5a93

Please sign in to comment.