Skip to content

Commit

Permalink
SDK - Lightweight - Fixed custom types in multi-output case (kubeflow…
Browse files Browse the repository at this point in the history
…#1875)

The type was mistakenly serialized as `_ForwardRef('CustomType')`.
The input parameter types and single-output types were not affected.
  • Loading branch information
Ark-kun authored Aug 21, 2019
1 parent 2e7f2d4 commit 203307d
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
6 changes: 4 additions & 2 deletions sdk/python/kfp/components/_python_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import inspect
from pathlib import Path
import typing
from typing import TypeVar, Generic, List

T = TypeVar('T')
Expand Down Expand Up @@ -146,8 +147,9 @@ def annotation_to_type_struct(annotation):
return None
if isinstance(annotation, type):
return str(annotation.__name__)
else:
return str(annotation)
if hasattr(annotation, '__forward_arg__'): # Handling typing.ForwardRef('Type_name') (the name was _ForwardRef in python 3.5-3.6)
return str(annotation.__forward_arg__) # It can only be string
return str(annotation)

for parameter in parameters:
type_struct = annotation_to_type_struct(parameter.annotation)
Expand Down
75 changes: 75 additions & 0 deletions sdk/python/tests/components/test_python_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,81 @@ def add_multiply_two_numbers(a: float, b: float) -> NamedTuple('DummyName', [('s

self.helper_test_2_in_2_out_component_using_local_call(func, op, output_names=['sum', 'product'])

def test_extract_component_interface(self):
from typing import NamedTuple
def my_func(
required_param,
int_param: int = 42,
float_param : float = 3.14,
str_param : str = 'string',
bool_param : bool = True,
none_param = None,
custom_type_param: 'Custom type' = None,
) -> NamedTuple('DummyName', [
#('required_param',), # All typing.NamedTuple fields must have types
('int_param', int),
('float_param', float),
('str_param', str),
('bool_param', bool),
#('custom_type_param', 'Custom type'), #SyntaxError: Forward reference must be an expression -- got 'Custom type'
('custom_type_param', 'CustomType'),
]
):
'''Function docstring'''
pass

component_spec = comp._python_op._extract_component_interface(my_func)

from kfp.components._structures import InputSpec, OutputSpec
self.assertEqual(
component_spec.inputs,
[
InputSpec(name='required_param'),
InputSpec(name='int_param', type='int', default='42', optional=True),
InputSpec(name='float_param', type='float', default='3.14', optional=True),
InputSpec(name='str_param', type='str', default='string', optional=True),
InputSpec(name='bool_param', type='bool', default='True', optional=True),
InputSpec(name='none_param', optional=True), # No default='None'
InputSpec(name='custom_type_param', type='Custom type', optional=True),
]
)
self.assertEqual(
component_spec.outputs,
[
OutputSpec(name='int_param', type='int'),
OutputSpec(name='float_param', type='float'),
OutputSpec(name='str_param', type='str'),
OutputSpec(name='bool_param', type='bool'),
#OutputSpec(name='custom_type_param', type='Custom type', default='None'),
OutputSpec(name='custom_type_param', type='CustomType'),
]
)

self.maxDiff = None
self.assertDictEqual(
component_spec.to_dict(),
{
'name': 'My func',
'description': 'Function docstring\n',
'inputs': [
{'name': 'required_param'},
{'name': 'int_param', 'type': 'int', 'default': '42', 'optional': True},
{'name': 'float_param', 'type': 'float', 'default': '3.14', 'optional': True},
{'name': 'str_param', 'type': 'str', 'default': 'string', 'optional': True},
{'name': 'bool_param', 'type': 'bool', 'default': 'True', 'optional': True},
{'name': 'none_param', 'optional': True}, # No default='None'
{'name': 'custom_type_param', 'type': 'Custom type', 'optional': True},
],
'outputs': [
{'name': 'int_param', 'type': 'int'},
{'name': 'float_param', 'type': 'float'},
{'name': 'str_param', 'type': 'str'},
{'name': 'bool_param', 'type': 'bool'},
{'name': 'custom_type_param', 'type': 'CustomType'},
]
}
)

@unittest.skip #TODO: #Simplified multi-output syntax is not implemented yet
def test_func_to_container_op_multiple_named_typed_outputs_using_list_syntax(self):
def add_multiply_two_numbers(a: float, b: float) -> [('sum', float), ('product', float)]:
Expand Down

0 comments on commit 203307d

Please sign in to comment.