Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK/DSL/Compiler - Reverted fix of dsl.Condition until the UI is ready. #94

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 80 additions & 50 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,17 +289,14 @@ def _get_dependencies(self, pipeline, root_group, op_groups):
dependencies[downstream_groups[0]].add(upstream_groups[0])
return dependencies

def _resolve_value_or_reference(self, value_or_reference, inputs):
if isinstance(value_or_reference, dsl.PipelineParam):
parameter_name = self._param_full_name(value_or_reference)
task_names = [task_name for param_name, task_name in inputs if param_name == parameter_name]
if task_names:
task_name = task_names[0]
return '{{tasks.%s.outputs.parameters.%s}}' % (task_name, parameter_name)
else:
return '{{inputs.parameters.%s}}' % parameter_name
else:
return str(value_or_reference)
def _create_condition(self, condition):
left = ('{{inputs.parameters.%s}}' % self._param_full_name(condition.operand1)
if isinstance(condition.operand1, dsl.PipelineParam)
else str(condition.operand1))
right = ('{{inputs.parameters.%s}}' % self._param_full_name(condition.operand2)
if isinstance(condition.operand2, dsl.PipelineParam)
else str(condition.operand2))
return ('%s == %s' % (left, right))

def _group_to_template(self, group, inputs, outputs, dependencies):
"""Generate template given an OpsGroup.
Expand Down Expand Up @@ -329,56 +326,89 @@ def _group_to_template(self, group, inputs, outputs, dependencies):
template_outputs.sort(key=lambda x: x['name'])
template['outputs'] = {'parameters': template_outputs}

# Generate tasks section.
tasks = []
for sub_group in group.groups + group.ops:
task = {
'name': sub_group.name,
'template': sub_group.name,
if group.type == 'condition':
# This is a workaround for the fact that argo does not support conditions in DAG mode.
# Basically, we insert an extra group that contains only the original group. The extra group
# operates in "step" mode where condition is supported.
only_child = group.groups[0]
step = {
'name': only_child.name,
'template': only_child.name,
}

if isinstance(sub_group, dsl.OpsGroup) and sub_group.type == 'condition':
subgroup_inputs = inputs.get(sub_group.name, [])
condition = sub_group.condition
condition_operation = '=='
operand1_value = self._resolve_value_or_reference(condition.operand1, subgroup_inputs)
operand2_value = self._resolve_value_or_reference(condition.operand2, subgroup_inputs)
task['when'] = '{} {} {}'.format(operand1_value, condition_operation, operand2_value)

# Generate dependencies section for this task.
if dependencies.get(sub_group.name, None):
group_dependencies = list(dependencies[sub_group.name])
group_dependencies.sort()
task['dependencies'] = group_dependencies

# Generate arguments section for this task.
if inputs.get(sub_group.name, None):
if inputs.get(only_child.name, None):
arguments = []
for param_name, dependent_name in inputs[sub_group.name]:
if dependent_name:
# The value comes from an upstream sibling.
arguments.append({
'name': param_name,
'value': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
})
else:
# The value comes from its parent.
arguments.append({
for param_name, dependent_name in inputs[only_child.name]:
arguments.append({
'name': param_name,
'value': '{{inputs.parameters.%s}}' % param_name
})
})
arguments.sort(key=lambda x: x['name'])
task['arguments'] = {'parameters': arguments}
tasks.append(task)
tasks.sort(key=lambda x: x['name'])
template['dag'] = {'tasks': tasks}
step['arguments'] = {'parameters': arguments}
step['when'] = self._create_condition(group.condition)
template['steps'] = [[step]]
else:
# Generate tasks section.
tasks = []
for sub_group in group.groups + group.ops:
task = {
'name': sub_group.name,
'template': sub_group.name,
}
# Generate dependencies section for this task.
if dependencies.get(sub_group.name, None):
group_dependencies = list(dependencies[sub_group.name])
group_dependencies.sort()
task['dependencies'] = group_dependencies

# Generate arguments section for this task.
if inputs.get(sub_group.name, None):
arguments = []
for param_name, dependent_name in inputs[sub_group.name]:
if dependent_name:
# The value comes from an upstream sibling.
arguments.append({
'name': param_name,
'value': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
})
else:
# The value comes from its parent.
arguments.append({
'name': param_name,
'value': '{{inputs.parameters.%s}}' % param_name
})
arguments.sort(key=lambda x: x['name'])
task['arguments'] = {'parameters': arguments}
tasks.append(task)
tasks.sort(key=lambda x: x['name'])
template['dag'] = {'tasks': tasks}
return template

def _create_new_groups(self, root_group):
"""Create a copy of the input group, and insert extra groups for conditions."""

new_group = copy.deepcopy(root_group)

def _insert_group_for_condition_helper(group):
for i, g in enumerate(group.groups):
if g.type == 'condition':
child_condition_group = dsl.OpsGroup('condition-child', g.name + '-child')
child_condition_group.ops = g.ops
child_condition_group.groups = g.groups
g.groups = [child_condition_group]
g.ops = list()
_insert_group_for_condition_helper(child_condition_group)
else:
_insert_group_for_condition_helper(g)

_insert_group_for_condition_helper(new_group)
return new_group

def _create_templates(self, pipeline):
"""Create all groups and ops templates in the pipeline."""

new_root_group = pipeline.groups[0]
# This is needed only because Argo does not support condition in DAG mode.
# Revisit when https://github.com/argoproj/argo/issues/921 is fixed.
new_root_group = self._create_new_groups(pipeline.groups[0])

op_groups = self._get_groups_for_ops(new_root_group)
inputs, outputs = self._get_inputs_outputs(pipeline, new_root_group, op_groups)
Expand Down
48 changes: 42 additions & 6 deletions sdk/python/tests/compiler/testdata/coin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ spec:
entrypoint: pipeline-flip-coin
serviceAccountName: pipeline-runner
templates:
- inputs:
parameters:
- name: flip-output
name: condition-1
steps:
- - arguments:
parameters:
- name: flip-output
value: '{{inputs.parameters.flip-output}}'
name: condition-1-child
template: condition-1-child
when: '{{inputs.parameters.flip-output}} == heads'
- dag:
tasks:
- arguments:
Expand All @@ -33,7 +45,6 @@ spec:
- flip-again
name: condition-2
template: condition-2
when: '{{tasks.flip-again.outputs.parameters.flip-again-output}} == tails'
- arguments:
parameters:
- name: flip-output
Expand All @@ -43,7 +54,22 @@ spec:
inputs:
parameters:
- name: flip-output
name: condition-1
name: condition-1-child
- inputs:
parameters:
- name: flip-again-output
- name: flip-output
name: condition-2
steps:
- - arguments:
parameters:
- name: flip-again-output
value: '{{inputs.parameters.flip-again-output}}'
- name: flip-output
value: '{{inputs.parameters.flip-output}}'
name: condition-2-child
template: condition-2-child
when: '{{inputs.parameters.flip-again-output}} == tails'
- dag:
tasks:
- arguments:
Expand All @@ -58,7 +84,19 @@ spec:
parameters:
- name: flip-again-output
- name: flip-output
name: condition-2
name: condition-2-child
- inputs:
parameters:
- name: flip-output
name: condition-3
steps:
- - arguments:
parameters:
- name: flip-output
value: '{{inputs.parameters.flip-output}}'
name: condition-3-child
template: condition-3-child
when: '{{inputs.parameters.flip-output}} == tails'
- dag:
tasks:
- arguments:
Expand All @@ -70,7 +108,7 @@ spec:
inputs:
parameters:
- name: flip-output
name: condition-3
name: condition-3-child
- container:
args:
- python -c "import random; result = 'heads' if random.randint(0,1) == 0 else
Expand Down Expand Up @@ -163,7 +201,6 @@ spec:
- flip
name: condition-1
template: condition-1
when: '{{tasks.flip.outputs.parameters.flip-output}} == heads'
- arguments:
parameters:
- name: flip-output
Expand All @@ -172,7 +209,6 @@ spec:
- flip
name: condition-3
template: condition-3
when: '{{tasks.flip.outputs.parameters.flip-output}} == tails'
- name: flip
template: flip
name: pipeline-flip-coin
Expand Down