Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support replacable arguments in command as well (besides arguments) in container op. #623

Merged
merged 5 commits into from
Jan 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions samples/basic/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def __init__(self, msg):
super(PrintOp, self).__init__(
name='Print',
image='alpine:3.6',
command=['echo'],
arguments=[msg]
command=['echo', msg],
)


Expand Down
37 changes: 22 additions & 15 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,22 @@ def _build_conventional_artifact(self, name):
},
}

def _process_args(self, raw_args, argument_inputs):
if not raw_args:
return []

processed_args = list(map(str, raw_args))
for i, _ in enumerate(processed_args):
for param in argument_inputs:
full_name = self._pipelineparam_full_name(param)
processed_args[i] = re.sub(
str(param), '{{inputs.parameters.%s}}' % full_name, str(processed_args[i]))

return processed_args

def _op_to_template(self, op):
"""Generate template given an operator inherited from dsl.ContainerOp."""

processed_args = None
if op.arguments:
processed_args = list(map(str, op.arguments))
for i, _ in enumerate(processed_args):
if op.argument_inputs:
for param in op.argument_inputs:
full_name = self._pipelineparam_full_name(param)
processed_args[i] = re.sub(str(param), '{{inputs.parameters.%s}}' % full_name,
processed_args[i])

input_parameters = []
for param in op.inputs:
one_parameter = {'name': self._pipelineparam_full_name(param)}
Expand All @@ -110,8 +114,12 @@ def _op_to_template(self, op):
'image': op.image,
}
}
if processed_args:
template['container']['args'] = processed_args
processed_arguments = self._process_args(op.arguments, op.argument_inputs)
processed_command = self._process_args(op.command, op.argument_inputs)
if processed_arguments:
template['container']['args'] = processed_arguments
if processed_command:
template['container']['command'] = processed_command
if input_parameters:
template['inputs'] = {'parameters': input_parameters}

Expand All @@ -129,8 +137,7 @@ def _op_to_template(self, op):
output_artifacts.append(self._build_conventional_artifact('mlpipeline-ui-metadata'))
output_artifacts.append(self._build_conventional_artifact('mlpipeline-metrics'))
template['outputs']['artifacts'] = output_artifacts
if op.command:
template['container']['command'] = op.command


# Set resources.
if op.resource_limits or op.resource_requests:
Expand Down Expand Up @@ -536,4 +543,4 @@ def compile(self, pipeline_func, package_path):
with closing(BytesIO(yaml_text.encode())) as yaml_file:
tarinfo = tarfile.TarInfo('pipeline.yaml')
tarinfo.size = len(yaml_file.getvalue())
tar.addfile(tarinfo, fileobj=yaml_file)
tar.addfile(tarinfo, fileobj=yaml_file)
3 changes: 0 additions & 3 deletions sdk/python/kfp/components/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def _create_task_factory_from_component_spec(component_spec:ComponentSpec, compo
container_spec = component_spec.implementation.container
container_image = container_spec.image

file_inputs={}
file_outputs_from_def = OrderedDict()
if container_spec.file_outputs != None:
for param, path in container_spec.file_outputs.items():
Expand Down Expand Up @@ -267,7 +266,6 @@ def expand_command_part(arg): #input values with original names
input_key = input_name_to_kubernetes[port_name]
input_value = pythonic_input_argument_values[input_name_to_pythonic[port_name]]
if input_value is not None:
file_inputs[input_key] = {'local_path': input_filename, 'data_source': input_value}
return input_filename
else:
input_spec = inputs_dict[port_name]
Expand Down Expand Up @@ -350,7 +348,6 @@ def expand_argument_list(argument_list):
container_image=container_image,
command=expanded_command,
arguments=expanded_args,
file_inputs=file_inputs,
file_outputs=file_outputs_to_pass,
)

Expand Down
3 changes: 1 addition & 2 deletions sdk/python/kfp/components/_dsl_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

_dummy_pipeline=None

def _create_task_object(name:str, container_image:str, command=None, arguments=None, file_inputs=None, file_outputs=None):
def _create_task_object(name:str, container_image:str, command=None, arguments=None, file_outputs=None):
from .. import dsl
global _dummy_pipeline
need_dummy = dsl.Pipeline._default_pipeline is None
Expand All @@ -28,7 +28,6 @@ def _create_task_object(name:str, container_image:str, command=None, arguments=N
image=container_image,
command=command,
arguments=arguments,
file_inputs=file_inputs,
file_outputs=file_outputs,
)

Expand Down
15 changes: 3 additions & 12 deletions sdk/python/kfp/dsl/_container_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class ContainerOp(object):
"""Represents an op implemented by a docker container image."""

def __init__(self, name: str, image: str, command: str=None, arguments: str=None,
file_inputs : Dict[_pipeline_param.PipelineParam, str]=None,
file_outputs : Dict[str, str]=None, is_exit_handler=False):
"""Create a new instance of ContainerOp.

Expand All @@ -35,9 +34,6 @@ def __init__(self, name: str, image: str, command: str=None, arguments: str=None
arguments: the arguments of the command. The command can include "%s" and supply
a PipelineParam as the string replacement. For example, ('echo %s' % input_param).
At container run time the argument will be 'echo param_value'.
file_inputs: Maps PipelineParams to local file paths. At pipeline run time,
the value of a PipelineParam is saved to its corresponding local file. It is
not implemented yet.
file_outputs: Maps output labels to local file paths. At pipeline run time,
the value of a PipelineParam is saved to its corresponding local file. It's
one way for outside world to receive outputs of the container.
Expand All @@ -63,24 +59,19 @@ def __init__(self, name: str, image: str, command: str=None, arguments: str=None
self.pod_labels = {}

matches = []
if arguments:
for arg in arguments:
match = re.findall(r'{{pipelineparam:op=([\w-]*);name=([\w-]+);value=(.*?)}}', str(arg))
matches += match
for arg in (command or []) + (arguments or []):
match = re.findall(r'{{pipelineparam:op=([\w-]*);name=([\w-]+);value=(.*?)}}', str(arg))
matches += match

self.argument_inputs = [_pipeline_param.PipelineParam(x[1], x[0], x[2])
for x in list(set(matches))]
self.file_inputs = file_inputs
self.file_outputs = file_outputs
self.dependent_op_names = []

self.inputs = []
if self.argument_inputs:
self.inputs += self.argument_inputs

if file_inputs:
self.inputs += list(file_inputs.keys())

self.outputs = {}
if file_outputs:
self.outputs = {name: _pipeline_param.PipelineParam(name, op_name=self.name)
Expand Down
8 changes: 4 additions & 4 deletions sdk/python/tests/compiler/testdata/coin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def __init__(self, name):

class PrintOp(dsl.ContainerOp):

def __init__(self, name):
def __init__(self, name, msg):
super(PrintOp, self).__init__(
name=name,
image='alpine:3.6',
command=['echo', '"it was tail"'])
command=['echo', msg])


@dsl.pipeline(
Expand All @@ -48,7 +48,7 @@ def flipcoin():
flip2 = FlipCoinOp('flip-again')

with dsl.Condition(flip2.output=='tails'):
PrintOp('print1')
PrintOp('print1', flip2.output)

with dsl.Condition(flip.output=='tails'):
PrintOp('print2')
PrintOp('print2', flip2.output)
21 changes: 19 additions & 2 deletions sdk/python/tests/compiler/testdata/coin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ spec:
parameters:
- name: flip-output
name: condition-1
outputs:
parameters:
- name: flip-again-output
valueFrom:
parameter: '{{tasks.flip-again.outputs.parameters.flip-again-output}}'
- dag:
tasks:
- arguments:
Expand All @@ -63,12 +68,15 @@ spec:
tasks:
- arguments:
parameters:
- name: flip-again-output
value: '{{inputs.parameters.flip-again-output}}'
- name: flip-output
value: '{{inputs.parameters.flip-output}}'
name: print2
template: print2
inputs:
parameters:
- name: flip-again-output
- name: flip-output
name: condition-3
- container:
Expand Down Expand Up @@ -166,9 +174,12 @@ spec:
when: '{{tasks.flip.outputs.parameters.flip-output}} == heads'
- arguments:
parameters:
- name: flip-again-output
value: '{{tasks.condition-1.outputs.parameters.flip-again-output}}'
- name: flip-output
value: '{{tasks.flip.outputs.parameters.flip-output}}'
dependencies:
- condition-1
- flip
name: condition-3
template: condition-3
Expand All @@ -179,8 +190,11 @@ spec:
- container:
command:
- echo
- '"it was tail"'
- '{{inputs.parameters.flip-again-output}}'
image: alpine:3.6
inputs:
parameters:
- name: flip-again-output
name: print1
outputs:
artifacts:
Expand Down Expand Up @@ -213,8 +227,11 @@ spec:
- container:
command:
- echo
- '"it was tail"'
- '{{inputs.parameters.flip-again-output}}'
image: alpine:3.6
inputs:
parameters:
- name: flip-again-output
name: print2
outputs:
artifacts:
Expand Down