Skip to content

Commit

Permalink
chore(sdk): Refactored command-line resolving (#5379)
Browse files Browse the repository at this point in the history
* SDK - Refactored command-line resolving

Moved the execution engine specific code to the component bridge.

* Added placeholder_resolver

This simplifies adding custom placeholder resolving logic.
  • Loading branch information
Ark-kun committed Apr 9, 2021
1 parent 02dbcfa commit ba3a92f
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 257 deletions.
20 changes: 10 additions & 10 deletions sdk/python/kfp/compiler/_data_passing_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from typing import Any, Dict, List, Optional, Set, Tuple

from kfp.components import _components
from kfp.dsl import _component_bridge
from kfp import dsl


Expand Down Expand Up @@ -389,15 +389,15 @@ def deconstruct_single_placeholder(s: str) -> List[str]:
def _replace_output_dir_and_run_id(command_line: str,
output_directory: Optional[str] = None) -> str:
"""Replaces the output directory placeholder."""
if _components.OUTPUT_DIR_PLACEHOLDER in command_line:
if _component_bridge.OUTPUT_DIR_PLACEHOLDER in command_line:
if not output_directory:
raise ValueError('output_directory of a pipeline must be specified '
'when URI placeholder is used.')
command_line = command_line.replace(
_components.OUTPUT_DIR_PLACEHOLDER, output_directory)
if _components.RUN_ID_PLACEHOLDER in command_line:
_component_bridge.OUTPUT_DIR_PLACEHOLDER, output_directory)
if _component_bridge.RUN_ID_PLACEHOLDER in command_line:
command_line = command_line.replace(
_components.RUN_ID_PLACEHOLDER, dsl.RUN_ID_PLACEHOLDER)
_component_bridge.RUN_ID_PLACEHOLDER, dsl.RUN_ID_PLACEHOLDER)
return command_line


Expand Down Expand Up @@ -426,11 +426,11 @@ def _refactor_outputs_if_uri_placeholder(
for artifact_output in container_template['outputs']['artifacts']:
# Check if this is an output associated with URI placeholder based
# on its path.
if _components.OUTPUT_DIR_PLACEHOLDER in artifact_output['path']:
if _component_bridge.OUTPUT_DIR_PLACEHOLDER in artifact_output['path']:
# If so, we'll add a parameter output to output the pod name
parameter_outputs.append(
{
'name': _components.PRODUCER_POD_NAME_PARAMETER.format(
'name': _component_bridge.PRODUCER_POD_NAME_PARAMETER.format(
artifact_output['name']),
'value': '{{pod.name}}'
})
Expand Down Expand Up @@ -473,7 +473,7 @@ def _refactor_inputs_if_uri_placeholder(
for artifact_input in container_template['inputs']['artifacts']:
# Check if this is an input artifact associated with URI placeholder,
# according to its path.
if _components.OUTPUT_DIR_PLACEHOLDER in artifact_input['path']:
if _component_bridge.OUTPUT_DIR_PLACEHOLDER in artifact_input['path']:
# If so, we'll add a parameter input to receive the producer's pod
# name.
# The correct input parameter name should be parsed from the
Expand All @@ -491,7 +491,7 @@ def _refactor_inputs_if_uri_placeholder(
artifact_input['name'])] = input_name

# In the container implementation, the pod name is already connected
# to the input parameter per the implementation in _components.
# to the input parameter per the implementation in _component_bridge.
# The only thing yet to be reconciled is the file name.

def reconcile_filename(
Expand Down Expand Up @@ -620,7 +620,7 @@ def _refactor_dag_template_uri_inputs(
'value': '{{{{tasks.{task_name}.outputs.'
'parameters.{output}}}}}'.format(
task_name=task_name,
output=_components.PRODUCER_POD_NAME_PARAMETER.format(
output=_component_bridge.PRODUCER_POD_NAME_PARAMETER.format(
output_name)
)})
else:
Expand Down
183 changes: 13 additions & 170 deletions sdk/python/kfp/components/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import copy
from collections import OrderedDict
import pathlib
from typing import Any, Callable, List, Mapping, NamedTuple, Optional, Sequence, Union
from typing import Any, Callable, List, Mapping, NamedTuple, Sequence, Union
from ._naming import _sanitize_file_name, _sanitize_python_function_name, generate_unique_name_conversion_table
from ._yaml_utils import load_yaml
from .structures import *
Expand Down Expand Up @@ -193,94 +193,6 @@ def _generate_output_file_name(port_name):
_single_io_file_name
))


# Placeholder to represent the output directory hosting all the generated URIs.
# Its actual value will be specified during pipeline compilation.
# The format of OUTPUT_DIR_PLACEHOLDER is serialized dsl.PipelineParam, to
# ensure being extracted as a pipeline parameter during compilation.
# Note that we cannot direclty import dsl module here due to circular
# dependencies.
OUTPUT_DIR_PLACEHOLDER = '{{pipelineparam:op=;name=pipeline-output-directory}}'
# Placeholder to represent to UID of the current pipeline at runtime.
# Will be replaced by engine-specific placeholder during compilation.
RUN_ID_PLACEHOLDER = '{{kfp.run_uid}}'
# Format of the Argo parameter used to pass the producer's Pod ID to
# the consumer.
PRODUCER_POD_NAME_PARAMETER = '{}-producer-pod-id-'
# Format of the input output port name placeholder.
INPUT_OUTPUT_NAME_PATTERN = '{{{{kfp.input-output-name.{}}}}}'
# Fixed name for per-task output metadata json file.
OUTPUT_METADATA_JSON = '/tmp/outputs/executor_output.json'
# Executor input placeholder.
_EXECUTOR_INPUT_PLACEHOLDER = '{{$}}'


def _generate_output_uri(port_name: str) -> str:
"""Generates a unique URI for an output.
Args:
port_name: The name of the output associated with this URI.
Returns:
The URI assigned to this output, which is unique within the pipeline.
"""
return str(pathlib.PurePosixPath(
OUTPUT_DIR_PLACEHOLDER,
RUN_ID_PLACEHOLDER, '{{pod.name}}', port_name))


def _generate_input_uri(port_name: str) -> str:
"""Generates the URI for an input.
Args:
port_name: The name of the input associated with this URI.
Returns:
The URI assigned to this input, will be consistent with the URI where
the actual content is written after compilation.
"""
return str(pathlib.PurePosixPath(
OUTPUT_DIR_PLACEHOLDER,
RUN_ID_PLACEHOLDER,
'{{{{inputs.parameters.{input}}}}}'.format(
input=PRODUCER_POD_NAME_PARAMETER.format(port_name)),
port_name
))


def _generate_output_metadata_path() -> str:
"""Generates the URI to write the output metadata JSON file."""

return OUTPUT_METADATA_JSON


def _generate_input_metadata_path(port_name: str) -> str:
"""Generates the placeholder for input artifact metadata file."""

# Return a placeholder for path to input artifact metadata, which will be
# rewritten during pipeline compilation.
return str(pathlib.PurePosixPath(
OUTPUT_DIR_PLACEHOLDER,
RUN_ID_PLACEHOLDER,
'{{{{inputs.parameters.{input}}}}}'.format(
input=PRODUCER_POD_NAME_PARAMETER.format(port_name)),
OUTPUT_METADATA_JSON
))


def _generate_input_output_name(port_name: str) -> str:
"""Generates the placeholder for input artifact's output name."""

# Return a placeholder for the output port name of the input artifact, which
# will be rewritten during pipeline compilation.
return INPUT_OUTPUT_NAME_PATTERN.format(port_name)


def _generate_executor_input() -> str:
"""Generates the placeholder for serialized executor input."""
return _EXECUTOR_INPUT_PLACEHOLDER


def _react_to_incompatible_reference_type(
input_type,
argument_type,
Expand Down Expand Up @@ -476,8 +388,6 @@ def component_default_to_func_default(component_default: str, is_optional: bool)
('input_paths', Mapping[str, str]),
('output_paths', Mapping[str, str]),
('inputs_consumed_by_value', Mapping[str, str]),
('input_uris', Mapping[str, str]),
('output_uris', Mapping[str, str]),
],
)

Expand All @@ -488,16 +398,7 @@ def _resolve_command_line_and_paths(
input_path_generator: Callable[[str], str] = _generate_input_file_name,
output_path_generator: Callable[[str], str] = _generate_output_file_name,
argument_serializer: Callable[[str], str] = serialize_value,
input_uri_generator: Callable[[str], str] = _generate_input_uri,
output_uri_generator: Callable[[str], str] = _generate_output_uri,
input_value_generator: Optional[Callable[[str], str]] = None,
input_metadata_path_generator: Callable[
[str], str] = _generate_input_metadata_path,
output_metadata_path_generator: Callable[
[], str] = _generate_output_metadata_path,
input_output_name_generator: Callable[
[str], str] = _generate_input_output_name,
executor_input_generator: Callable[[], str] = _generate_executor_input,
placeholder_resolver: Callable[[Any, ComponentSpec, Mapping[str, str]], str] = None,
) -> _ResolvedCommandLineAndPaths:
"""Resolves the command line argument placeholders. Also produces the maps of the generated inpuit/output paths."""
argument_values = arguments
Expand All @@ -516,27 +417,28 @@ def _resolve_command_line_and_paths(

input_paths = OrderedDict()
inputs_consumed_by_value = {}
input_uris = OrderedDict()
input_metadata_paths = OrderedDict()
output_uris = OrderedDict()

def expand_command_part(arg) -> Union[str, List[str], None]:
if arg is None:
return None
if placeholder_resolver:
resolved_arg = placeholder_resolver(
arg=arg,
component_spec=component_spec,
arguments=arguments,
)
if resolved_arg is not None:
return resolved_arg
if isinstance(arg, (str, int, float, bool)):
return str(arg)
if isinstance(arg, ExecutorInputPlaceholder):
return executor_input_generator()
if isinstance(arg, InputValuePlaceholder):
input_name = arg.input_name
input_spec = inputs_dict[input_name]
input_value = argument_values.get(input_name, None)
if input_value is not None:
if input_value_generator is not None:
inputs_consumed_by_value[input_name] = input_value_generator(input_name)
else:
inputs_consumed_by_value[input_name] = argument_serializer(input_value, input_spec.type)
return inputs_consumed_by_value[input_name]
serialized_argument = argument_serializer(input_value, input_spec.type)
inputs_consumed_by_value[input_name] = serialized_argument
return serialized_argument
else:
if input_spec.optional:
return None
Expand Down Expand Up @@ -569,63 +471,6 @@ def expand_command_part(arg) -> Union[str, List[str], None]:
output_paths[output_name] = output_filename

return output_filename

elif isinstance(arg, InputUriPlaceholder):
input_name = arg.input_name
if input_name in argument_values:
input_uri = input_uri_generator(input_name)
input_uris[input_name] = input_uri
return input_uri
else:
input_spec = inputs_dict[input_name]
if input_spec.optional:
return None
else:
raise ValueError('No value provided for input {}'.format(input_name))

elif isinstance(arg, InputMetadataPlaceholder):
input_name = arg.input_name
if input_name in argument_values:
input_metadata_path = input_metadata_path_generator(input_name)
input_metadata_paths[input_name] = input_metadata_path
return input_metadata_path
else:
input_spec = inputs_dict[input_name]
if input_spec.optional:
return None
else:
raise ValueError(
'No value provided for input {}'.format(input_name))

elif isinstance(arg, InputOutputPortNamePlaceholder):
input_name = arg.input_name
if input_name in argument_values:
return input_output_name_generator(input_name)
else:
input_spec = inputs_dict[input_name]
if input_spec.optional:
return None
else:
raise ValueError(
'No value provided for input {}'.format(input_name))

elif isinstance(arg, OutputUriPlaceholder):
output_name = arg.output_name
output_uri = output_uri_generator(output_name)
if arg.output_name in output_uris:
if output_uris[output_name] != output_uri:
raise ValueError(
'Conflicting output URIs specified for port {}: {} and {}'.format(
output_name, output_uris[output_name], output_uri))
else:
output_uris[output_name] = output_uri

return output_uri

elif isinstance(arg, OutputMetadataPlaceholder):
# TODO: Consider making the output metadata per-artifact.
return output_metadata_path_generator()

elif isinstance(arg, ConcatPlaceholder):
expanded_argument_strings = expand_argument_list(arg.items)
return ''.join(expanded_argument_strings)
Expand Down Expand Up @@ -671,8 +516,6 @@ def expand_argument_list(argument_list):
input_paths=input_paths,
output_paths=output_paths,
inputs_consumed_by_value=inputs_consumed_by_value,
input_uris=input_uris,
output_uris=output_uris,
)


Expand Down
71 changes: 0 additions & 71 deletions sdk/python/kfp/components_tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,42 +719,6 @@ def test_prevent_passing_unserializable_objects_as_argument(self):
component(input_1="value 1", input_2=task1)
with self.assertRaises(TypeError):
component(input_1="value 1", input_2=open)

def test_input_output_uri_resolving(self):
component_text = textwrap.dedent('''\
inputs:
- {name: In1}
outputs:
- {name: Out1}
implementation:
container:
image: busybox
command:
- program
- --in1-uri
- {inputUri: In1}
- --out1-uri
- {outputUri: Out1}
'''
)
op = comp.load_component_from_text(text=component_text)
task = op(in1='foo')
resolved_cmd = _resolve_command_line_and_paths(
component_spec=task.component_ref.spec,
arguments=task.arguments
)

self.assertEqual(
[
'program',
'--in1-uri',
'{{pipelineparam:op=;name=pipeline-output-directory}}/{{kfp.run_uid}}/{{inputs.parameters.In1-producer-pod-id-}}/In1',
'--out1-uri',
'{{pipelineparam:op=;name=pipeline-output-directory}}/{{kfp.run_uid}}/{{pod.name}}/Out1',
],
resolved_cmd.command
)

def test_check_type_validation_of_task_spec_outputs(self):
producer_component_text = '''\
outputs:
Expand Down Expand Up @@ -1079,41 +1043,6 @@ def test_fail_type_compatibility_check_for_types_with_different_schemas(self):
with self.assertRaises(TypeError):
b_task = task_factory_b(in1=a_task.outputs['out1'])

def test_convert_executor_input_and_output_metadata_placeholder(self):
test_component = textwrap.dedent("""\
inputs:
- {name: in1}
outputs:
- {name: out1}
implementation:
container:
image: busybox
command: [echo, {executorInput}, {outputMetadata}]
""")
task_factory = comp.load_component_from_text(test_component)
task = task_factory(in1='foo')
resolved_cmd = _resolve_command_line_and_paths(
component_spec=task.component_ref.spec,
arguments=task.arguments
)
self.assertListEqual(
['echo', '{{$}}', '/tmp/outputs/executor_output.json'],
resolved_cmd.command)

def test_fail_executor_input_with_key(self):
test_component = textwrap.dedent("""\
inputs:
- {name: in1}
outputs:
- {name: out1}
implementation:
container:
image: busybox
command: [echo, {executorInput: a_bad_key}]
""")
with self.assertRaises(TypeError):
_ = comp.load_component_from_text(test_component)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit ba3a92f

Please sign in to comment.