Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK - Components - Added type to graph input references #2451

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/python/kfp/components/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def resolve_argument(argument):
if isinstance(argument, (str, int, float, bool)):
return argument
elif isinstance(argument, GraphInputArgument):
return graph_input_arguments[argument.input_name]
return graph_input_arguments[argument.graph_input.input_name]
elif isinstance(argument, TaskOutputArgument):
upstream_task_output_ref = argument.task_output
upstream_task_outputs = outputs_of_tasks[upstream_task_output_ref.task_id]
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/kfp/components/_python_to_graph_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Callable

from . import _components
from ._structures import TaskSpec, ComponentSpec, OutputSpec, GraphInputArgument, TaskOutputArgument, GraphImplementation, GraphSpec
from ._structures import TaskSpec, ComponentSpec, OutputSpec, GraphInputReference, TaskOutputArgument, GraphImplementation, GraphSpec
from ._naming import _make_name_unique_by_adding_index
from ._python_op import _extract_component_interface

Expand Down Expand Up @@ -90,7 +90,7 @@ def task_construction_handler(task: TaskSpec):

# Preparing the pipeline_func arguments
# TODO: The key should be original parameter name if different
pipeline_func_args = {input.name: GraphInputArgument(input_name=input.name) for input in input_specs}
pipeline_func_args = {input.name: GraphInputReference(input_name=input.name).as_argument() for input in input_specs}

try:
#Setting the handler to fix and catch the tasks.
Expand Down
31 changes: 28 additions & 3 deletions sdk/python/kfp/components/_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

'ComponentReference',

'GraphInputReference',
'GraphInputArgument',
'TaskOutputReference',
'TaskOutputArgument',
Expand Down Expand Up @@ -306,7 +307,7 @@ def verify_arg(arg):
for task in graph.tasks.values():
if task.arguments is not None:
for argument in task.arguments.values():
if isinstance(argument, GraphInputArgument) and argument.input_name not in self._inputs_dict:
if isinstance(argument, GraphInputArgument) and argument.graph_input.input_name not in self._inputs_dict:
raise TypeError('Argument "{}" references non-existing input.'.format(argument))

def save(self, file_path: str):
Expand Down Expand Up @@ -334,14 +335,38 @@ def _post_init(self) -> None:
raise TypeError('Need at least one argument.')


class GraphInputReference(ModelBase):
'''References the input of the graph (the scope is a single graph).'''
_serialized_names = {
'input_name': 'inputName',
}

def __init__(self,
input_name: str,
type: Optional[TypeSpecType] = None, # Can be used to override the reference data type
Ark-kun marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(locals())

def as_argument(self) -> 'GraphInputArgument':
return GraphInputArgument(graph_input=self)

def with_type(self, type_spec: TypeSpecType) -> 'GraphInputReference':
return GraphInputReference(
input_name=self.input_name,
type=type_spec,
)

def without_type(self) -> 'GraphInputReference':
return self.with_type(None)

class GraphInputArgument(ModelBase):
'''Represents the component argument value that comes from the graph component input.'''
_serialized_names = {
'input_name': 'graphInput',
'graph_input': 'graphInput',
}

def __init__(self,
input_name: str,
graph_input: GraphInputReference,
):
super().__init__(locals())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ implementation:
url: https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/create_dataset_for_tables/component.yaml
arguments:
gcp_project_id:
graphInput: gcp_project_id
graphInput:
inputName: gcp_project_id
gcp_region:
graphInput: gcp_region
graphInput:
inputName: gcp_region
display_name:
graphInput: dataset_display_name
graphInput:
inputName: dataset_display_name
Automl import data from bigquery:
componentRef:
url: https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/import_data_from_bigquery/component.yaml
Expand All @@ -56,7 +59,8 @@ implementation:
taskId: Automl create dataset for tables
type: String
input_uri:
graphInput: dataset_bq_input_uri
graphInput:
inputName: dataset_bq_input_uri
Automl split dataset table column names:
componentRef:
url: https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/split_dataset_table_column_names/component.yaml
Expand All @@ -67,18 +71,22 @@ implementation:
taskId: Automl import data from bigquery
type: String
target_column_name:
graphInput: target_column_name
graphInput:
inputName: target_column_name
table_index: '0'
Automl create model for tables:
componentRef:
url: https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/create_model_for_tables/component.yaml
arguments:
gcp_project_id:
graphInput: gcp_project_id
graphInput:
inputName: gcp_project_id
gcp_region:
graphInput: gcp_region
graphInput:
inputName: gcp_region
display_name:
graphInput: model_display_name
graphInput:
inputName: model_display_name
dataset_id:
taskOutput:
outputName: dataset_path
Expand All @@ -96,7 +104,8 @@ implementation:
type: JsonArray
optimization_objective: MAXIMIZE_AU_PRC
train_budget_milli_node_hours:
graphInput: train_budget_milli_node_hours
graphInput:
inputName: train_budget_milli_node_hours
Automl prediction service batch predict:
componentRef:
url: https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/prediction_service_batch_predict/component.yaml
Expand All @@ -107,9 +116,11 @@ implementation:
taskId: Automl create model for tables
type: String
gcs_output_uri_prefix:
graphInput: batch_predict_gcs_output_uri_prefix
graphInput:
inputName: batch_predict_gcs_output_uri_prefix
bq_input_uri:
graphInput: batch_predict_bq_input_uri
graphInput:
inputName: batch_predict_bq_input_uri
outputValues:
model_path:
taskOutput:
Expand Down
18 changes: 9 additions & 9 deletions sdk/python/tests/components/test_graph_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@


import kfp.components as comp
from kfp.components._structures import ComponentReference, ComponentSpec, ContainerSpec, GraphInputArgument, GraphSpec, InputSpec, InputValuePlaceholder, GraphImplementation, OutputPathPlaceholder, OutputSpec, TaskOutputArgument, TaskSpec
from kfp.components._structures import ComponentReference, ComponentSpec, ContainerSpec, GraphInputReference, GraphSpec, InputSpec, InputValuePlaceholder, GraphImplementation, OutputPathPlaceholder, OutputSpec, TaskOutputArgument, TaskSpec

from kfp.components._yaml_utils import load_yaml

class GraphComponentTestCase(unittest.TestCase):
def test_handle_constructing_graph_component(self):
task1 = TaskSpec(component_ref=ComponentReference(name='comp 1'), arguments={'in1 1': 11})
task2 = TaskSpec(component_ref=ComponentReference(name='comp 2'), arguments={'in2 1': 21, 'in2 2': TaskOutputArgument.construct(task_id='task 1', output_name='out1 1')})
task3 = TaskSpec(component_ref=ComponentReference(name='comp 3'), arguments={'in3 1': TaskOutputArgument.construct(task_id='task 2', output_name='out2 1'), 'in3 2': GraphInputArgument(input_name='graph in 1')})
task3 = TaskSpec(component_ref=ComponentReference(name='comp 3'), arguments={'in3 1': TaskOutputArgument.construct(task_id='task 2', output_name='out2 1'), 'in3 2': GraphInputReference(input_name='graph in 1').as_argument()})

graph_component1 = ComponentSpec(
inputs=[
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_handle_parsing_graph_component(self):
componentRef: {name: Comp 3}
arguments:
in3 1: {taskOutput: {taskId: task 2, outputName: out2 1}}
in3 2: {graphInput: graph in 1}
in3 2: {graphInput: {inputName: graph in 1}}
outputValues:
graph out 1: {taskOutput: {taskId: task 3, outputName: out3 1}}
graph out 2: {taskOutput: {taskId: task 1, outputName: out1 2}}
Expand Down Expand Up @@ -231,11 +231,11 @@ def test_load_graph_component(self):
command: [sh, -c, 'cat "$0" "$1" > $2', {inputValue: in3_1}, {inputValue: in3_2}, {outputPath: out3_1}]
arguments:
in3_1: {taskOutput: {taskId: task 2, outputName: out2_1}}
in3_2: {graphInput: graph in 1}
in3_2: {graphInput: {inputName: graph in 1}}
outputValues:
graph out 1: {taskOutput: {taskId: task 3, outputName: out3_1}}
graph out 2: {taskOutput: {taskId: task 1, outputName: out1_2}}
graph out 3: {graphInput: graph in 2}
graph out 3: {graphInput: {inputName: graph in 2}}
graph out 4: '42'
'''
op = comp.load_component_from_text(component_text)
Expand Down Expand Up @@ -311,17 +311,17 @@ def test_load_nested_graph_components(self):
image: busybox
command: [sh, -c, 'cat "$0" "$1" > $2', {inputValue: in3_1}, {inputValue: in3_2}, {outputPath: out3_1}]
arguments:
in3_1: {graphInput: in3_1}
in3_2: {graphInput: in3_1}
in3_1: {graphInput: {inputName: in3_1}}
in3_2: {graphInput: {inputName: in3_1}}
outputValues:
out3_1: {taskOutput: {taskId: graph subtask, outputName: out3_1}}
arguments:
in3_1: {taskOutput: {taskId: task 2, outputName: out2_1}}
in3_2: {graphInput: graph in 1}
in3_2: {graphInput: {inputName: graph in 1}}
outputValues:
graph out 1: {taskOutput: {taskId: task 3, outputName: out3_1}}
graph out 2: {taskOutput: {taskId: task 1, outputName: out1_2}}
graph out 3: {graphInput: graph in 2}
graph out 3: {graphInput: {inputName: graph in 2}}
graph out 4: '42'
'''
op = comp.load_component_from_text(component_text)
Expand Down