Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(sdk.v2): Migrate to the new IR with subdag support #4981

Merged
merged 8 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 65 additions & 48 deletions sdk/python/kfp/v2/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@

import kfp
from kfp.compiler._k8s_helper import sanitize_k8s_name
from kfp.components import _python_op
from kfp.v2 import dsl
from kfp.v2.compiler import compiler_utils
from kfp.components import _python_op
from kfp.v2.dsl import component_spec as dsl_component_spec
from kfp.v2.dsl import dsl_utils
from kfp.v2.dsl import importer_node
from kfp.v2.dsl import type_utils
from kfp.pipeline_spec import pipeline_spec_pb2
Expand Down Expand Up @@ -75,68 +77,95 @@ def _create_pipeline_spec(
"""
compiler_utils.validate_pipeline_name(pipeline.name)

pipeline_spec = pipeline_spec_pb2.PipelineSpec(
runtime_parameters=compiler_utils.build_runtime_parameter_spec(args))
pipeline_spec = pipeline_spec_pb2.PipelineSpec()

pipeline_spec.pipeline_info.name = pipeline.name
pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__)
pipeline_spec.schema_version = 'v2alpha1'
# Schema version 2.0.0 is required for kfp-pipeline-spec>0.1.3.1
pipeline_spec.schema_version = '2.0.0'

pipeline_spec.root.CopyFrom(
dsl_component_spec.build_root_spec_from_pipeline_params(args))

deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
importer_tasks = []

for op in pipeline.ops.values():
component_spec = op._metadata
task = pipeline_spec.tasks.add()
task.CopyFrom(op.task_spec)
deployment_config.executors[task.executor_label].container.CopyFrom(
task_name = op.task_spec.task_info.name
component_name = op.task_spec.component_ref.name
executor_label = op.component_spec.executor_label

pipeline_spec.root.dag.tasks[task_name].CopyFrom(op.task_spec)
pipeline_spec.components[component_name].CopyFrom(op.component_spec)
deployment_config.executors[executor_label].container.CopyFrom(
op.container_spec)

task = pipeline_spec.root.dag.tasks[task_name]
# A task may have explicit depdency on other tasks even though they may
# not have inputs/outputs dependency. e.g.: op2.after(op1)
if op.dependent_names:
op.dependent_names = [
dsl_utils.sanitize_task_name(name) for name in op.dependent_names
]
task.dependent_tasks.extend(op.dependent_names)

# Check if need to insert importer node
for input_name in task.inputs.artifacts:
if not task.inputs.artifacts[input_name].producer_task:
if not task.inputs.artifacts[
input_name].task_output_artifact.producer_task:
type_schema = type_utils.get_input_artifact_type_schema(
input_name, component_spec.inputs)

importer_task = importer_node.build_importer_task_spec(
dependent_task=task,
input_name, op._metadata.inputs)

importer_name = importer_node.generate_importer_base_name(
dependent_task_name=task_name, input_name=input_name)
importer_task_spec = importer_node.build_importer_task_spec(
importer_name)
importer_comp_spec = importer_node.build_importer_component_spec(
importer_base_name=importer_name,
input_name=input_name,
input_type_schema=type_schema)
importer_tasks.append(importer_task)
importer_task_name = importer_task_spec.task_info.name
importer_comp_name = importer_task_spec.component_ref.name
importer_exec_label = importer_comp_spec.executor_label
pipeline_spec.root.dag.tasks[importer_task_name].CopyFrom(
importer_task_spec)
pipeline_spec.components[importer_comp_name].CopyFrom(
importer_comp_spec)

task.inputs.artifacts[
input_name].producer_task = importer_task.task_info.name
input_name].task_output_artifact.producer_task = (
importer_task_name)
task.inputs.artifacts[
input_name].output_artifact_key = importer_node.OUTPUT_KEY
input_name].task_output_artifact.output_artifact_key = (
importer_node.OUTPUT_KEY)

# Retrieve the pre-built importer spec
importer_spec = op.importer_spec[input_name]
deployment_config.executors[
importer_task.executor_label].importer.CopyFrom(importer_spec)
importer_spec = op.importer_specs[input_name]
deployment_config.executors[importer_exec_label].importer.CopyFrom(
importer_spec)

pipeline_spec.deployment_config.Pack(deployment_config)
pipeline_spec.tasks.extend(importer_tasks)
pipeline_spec.deployment_spec.update(
json_format.MessageToDict(deployment_config))

return pipeline_spec

def _create_pipeline(
self,
pipeline_func: Callable[..., Any],
output_directory: str,
pipeline_name: Optional[str] = None,
) -> pipeline_spec_pb2.PipelineSpec:
pipeline_parameters_override: Optional[Mapping[str, Any]] = None,
) -> pipeline_spec_pb2.PipelineJob:
"""Creates a pipeline instance and constructs the pipeline spec from it.

Args:
pipeline_func: Pipeline function with @dsl.pipeline decorator.
pipeline_name: The name of the pipeline. Optional.
output_directory: The root of the pipeline outputs.
pipeline_parameters_override: The mapping from parameter names to values.
Optional.

Returns:
The IR representation (pipeline spec) of the pipeline.
A PipelineJob proto representing the compiled pipeline.
"""

# Create the arg list with no default values and call pipeline function.
Expand Down Expand Up @@ -174,26 +203,14 @@ def _create_pipeline(
dsl_pipeline,
)

return pipeline_spec

def _create_pipeline_job(
self,
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
pipeline_root: str,
pipeline_parameters: Optional[Mapping[str, Any]] = None,
) -> pipeline_spec_pb2.PipelineJob:
"""Creates the pipeline job spec object.

Args:
pipeline_spec: The pipeline spec object.
pipeline_root: The root of the pipeline outputs.
pipeline_parameters: The mapping from parameter names to values. Optional.

Returns:
A PipelineJob proto representing the compiled pipeline.
"""
pipeline_parameters = {
arg.name: arg.value for arg in args_list_with_defaults
}
# Update pipeline parameters override if there were any.
pipeline_parameters.update(pipeline_parameters_override or {})
runtime_config = compiler_utils.build_runtime_config_spec(
pipeline_root=pipeline_root, pipeline_parameters=pipeline_parameters)
output_directory=output_directory,
pipeline_parameters=pipeline_parameters)
pipeline_job = pipeline_spec_pb2.PipelineJob(runtime_config=runtime_config)
pipeline_job.pipeline_spec.update(json_format.MessageToDict(pipeline_spec))

Expand All @@ -220,11 +237,11 @@ def compile(self,
type_check_old_value = kfp.TYPE_CHECK
try:
kfp.TYPE_CHECK = type_check
pipeline = self._create_pipeline(pipeline_func, pipeline_name)
pipeline_job = self._create_pipeline_job(
pipeline_spec=pipeline,
pipeline_root=pipeline_root,
pipeline_parameters=pipeline_parameters)
pipeline_job = self._create_pipeline(
pipeline_func=pipeline_func,
output_directory=pipeline_root,
pipeline_name=pipeline_name,
pipeline_parameters_override=pipeline_parameters)
self._write_pipeline(pipeline_job, output_path)
finally:
kfp.TYPE_CHECK = type_check_old_value
Expand Down
54 changes: 8 additions & 46 deletions sdk/python/kfp/v2/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,62 +20,22 @@
from kfp.pipeline_spec import pipeline_spec_pb2


def build_runtime_parameter_spec(
pipeline_params: List[dsl.PipelineParam]
) -> Mapping[str, pipeline_spec_pb2.PipelineSpec.RuntimeParameter]:
"""Converts pipeine parameters to runtime parameters mapping.

Args:
pipeline_params: The list of pipeline parameters.

Returns:
A map of pipeline parameter name to its runtime parameter message.
"""

def to_message(param: dsl.PipelineParam):
result = pipeline_spec_pb2.PipelineSpec.RuntimeParameter()
if param.param_type == 'Integer' or (param.param_type is None and
isinstance(param.value, int)):

result.type = pipeline_spec_pb2.PrimitiveType.INT
if param.value is not None:
result.default_value.int_value = int(param.value)
elif param.param_type == 'Float' or (param.param_type is None and
isinstance(param.value, float)):
result.type = pipeline_spec_pb2.PrimitiveType.DOUBLE
if param.value is not None:
result.default_value.double_value = float(param.value)
elif param.param_type == 'String' or param.param_type is None:
result.type = pipeline_spec_pb2.PrimitiveType.STRING
if param.value is not None:
result.default_value.string_value = str(param.value)
else:
raise TypeError('Unsupported type "{}" for argument "{}".'.format(
param.param_type, param.name))
return result

return {param.name: to_message(param) for param in pipeline_params}


def build_runtime_config_spec(
pipeline_root: str,
output_directory: str,
pipeline_parameters: Optional[Mapping[str, Any]] = None,
) -> pipeline_spec_pb2.PipelineJob.RuntimeConfig:
"""Converts pipeine parameters to runtime parameters mapping.

Args:
pipeline_root: The root of pipeline outputs.
output_directory: The root of pipeline outputs.
pipeline_parameters: The mapping from parameter names to values. Optional.

Returns:
A pipeline job RuntimeConfig object.
"""

def _get_value(
value: Optional[Union[int, float,
str]]) -> Optional[pipeline_spec_pb2.Value]:
if value is None:
return None
def _get_value(value: Union[int, float, str]) -> pipeline_spec_pb2.Value:
assert value is not None, 'None values should be filterd out.'

result = pipeline_spec_pb2.Value()
if isinstance(value, int):
Expand All @@ -91,8 +51,10 @@ def _get_value(

parameter_values = pipeline_parameters or {}
return pipeline_spec_pb2.PipelineJob.RuntimeConfig(
gcs_output_directory=pipeline_root,
parameters={k: _get_value(v) for k, v in parameter_values.items()})
gcs_output_directory=output_directory,
parameters={
k: _get_value(v) for k, v in parameter_values.items() if v is not None
})


def validate_pipeline_name(name: str) -> None:
Expand Down
53 changes: 2 additions & 51 deletions sdk/python/kfp/v2/compiler/compiler_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for kfp.v2.compiler.compiler_utils."""

import unittest

Expand All @@ -22,56 +23,6 @@

class CompilerUtilsTest(unittest.TestCase):

def test_build_runtime_parameter_spec(self):
pipeline_params = [
dsl.PipelineParam(name='input1', param_type='Integer', value=99),
dsl.PipelineParam(name='input2', param_type='String', value='hello'),
dsl.PipelineParam(name='input3', param_type='Float', value=3.1415926),
dsl.PipelineParam(name='input4', param_type=None, value=None),
]
expected_dict = {
'runtimeParameters': {
'input1': {
'type': 'INT',
'defaultValue': {
'intValue': '99'
}
},
'input2': {
'type': 'STRING',
'defaultValue': {
'stringValue': 'hello'
}
},
'input3': {
'type': 'DOUBLE',
'defaultValue': {
'doubleValue': '3.1415926'
}
},
'input4': {
'type': 'STRING'
}
}
}
expected_spec = pipeline_spec_pb2.PipelineSpec()
json_format.ParseDict(expected_dict, expected_spec)

pipeline_spec = pipeline_spec_pb2.PipelineSpec(
runtime_parameters=compiler_utils.build_runtime_parameter_spec(
pipeline_params))
self.maxDiff = None
self.assertEqual(expected_spec, pipeline_spec)

def test_build_runtime_parameter_spec_with_unsupported_type_should_fail(self):
pipeline_params = [
dsl.PipelineParam(name='input1', param_type='Dict'),
]

with self.assertRaisesRegexp(
TypeError, 'Unsupported type "Dict" for argument "input1"'):
compiler_utils.build_runtime_parameter_spec(pipeline_params)

def test_build_runtime_config_spec(self):
expected_dict = {
'gcsOutputDirectory': 'gs://path',
Expand All @@ -85,7 +36,7 @@ def test_build_runtime_config_spec(self):
json_format.ParseDict(expected_dict, expected_spec)

runtime_config = compiler_utils.build_runtime_config_spec(
'gs://path', {'input1': 'test'})
'gs://path', {'input1': 'test', 'input2': None})
self.assertEqual(expected_spec, runtime_config)

def test_validate_pipeline_name(self):
Expand Down
4 changes: 0 additions & 4 deletions sdk/python/kfp/v2/compiler_cli_tests/compiler_cli_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,9 @@ def _test_compile_py_to_json(self, file_base_name, additional_arguments = []):
golden = json.load(f)
# Correct the sdkVersion
golden['pipelineSpec']['sdkVersion'] = 'kfp-{}'.format(kfp.__version__)
# Need to sort the list items before comparison
golden['pipelineSpec']['tasks'].sort(key=lambda x: x['executorLabel'])

with open(os.path.join(test_data_dir, target_json), 'r') as f:
compiled = json.load(f)
# Need to sort the list items before comparison
compiled['pipelineSpec']['tasks'].sort(key=lambda x: x['executorLabel'])

self.maxDiff = None
self.assertEqual(golden, compiled)
Expand Down
Loading