Skip to content

Commit

Permalink
fix(sdk.v2): fix a small bug in io_types.is_artifact_annotation() (#…
Browse files Browse the repository at this point in the history
…5699)

* fix is_artifact_annotation

* move tests
  • Loading branch information
chensun committed May 19, 2021
1 parent 9632509 commit b7084f2
Show file tree
Hide file tree
Showing 18 changed files with 55 additions and 253 deletions.
3 changes: 3 additions & 0 deletions sdk/python/kfp/dsl/io_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,9 @@ def is_artifact_annotation(typ) -> bool:
if not hasattr(typ, '__args__') or len(typ.__args__) != 2:
return False

if typ.__args__[1] not in [InputAnnotation, OutputAnnotation]:
return False

return True

def is_input_artifact(typ) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import unittest
import json
import os
from typing import List, Optional, Union

from kfp.dsl import io_types
from kfp.dsl.io_types import Input, InputAnnotation, Output, Model, OutputAnnotation
Expand All @@ -31,34 +32,40 @@ def test_complex_metrics(self):
metrics.log_confusion_matrix_row('dog', [2, 6, 0])
metrics.log_confusion_matrix_cell('cat', 'dog', 3)
metrics.log_confusion_matrix_cell('horses', 'horses', 3)
metrics.metadata["test"] = 1.0
with open(os.path.join(os.path.dirname(__file__),
'test_data', 'expected_io_types_classification_metrics.json')) as json_file:
metrics.metadata['test'] = 1.0
with open(
os.path.join(
os.path.dirname(__file__), 'test_data',
'expected_io_types_classification_metrics.json')) as json_file:
expected_json = json.load(json_file)
self.assertEqual(expected_json, metrics.metadata)

def test_complex_metrics_bulk_loading(self):
metrics = io_types.ClassificationMetrics()
metrics.log_roc_curve(fpr=[85.1, 85.1, 85.1],
tpr=[52.6, 52.6, 52.6],
threshold=[53.6, 53.6, 53.6])
metrics.log_roc_curve(
fpr=[85.1, 85.1, 85.1],
tpr=[52.6, 52.6, 52.6],
threshold=[53.6, 53.6, 53.6])
metrics.log_confusion_matrix(['dog', 'cat', 'horses'],
[[2, 6, 0], [3, 5, 6], [5, 7, 8]])
with open(os.path.join(os.path.dirname(__file__),
'test_data', 'expected_io_types_bulk_load_classification_metrics.json')) as json_file:
with open(
os.path.join(
os.path.dirname(__file__), 'test_data',
'expected_io_types_bulk_load_classification_metrics.json')
) as json_file:
expected_json = json.load(json_file)
self.assertEqual(expected_json, metrics.metadata)


class InputOutputArtifacts(unittest.TestCase):

def test_is_artifact_annotation(self):
self.assertTrue(io_types.is_artifact_annotation(Input[Model]))
self.assertTrue(io_types.is_artifact_annotation(Output[Model]))
self.assertTrue(io_types.is_artifact_annotation(Output['MyArtifact']))

self.assertFalse(io_types.is_artifact_annotation(Model))
self.assertFalse(io_types.is_artifact_annotation(int))
self.assertFalse(io_types.is_artifact_annotation('Dataset'))
self.assertFalse(io_types.is_artifact_annotation(List[str]))
self.assertFalse(io_types.is_artifact_annotation(Optional[str]))

def test_is_input_artifact(self):
self.assertTrue(io_types.is_input_artifact(Input[Model]))
Expand All @@ -83,18 +90,18 @@ def test_get_io_artifact_class(self):
self.assertEqual(io_types.get_io_artifact_class(str), None)

def test_get_io_artifact_annotation(self):
self.assertEqual(io_types.get_io_artifact_annotation(Output[Model]),
OutputAnnotation)
self.assertEqual(io_types.get_io_artifact_annotation(Input[Model]),
InputAnnotation)
self.assertEqual(io_types.get_io_artifact_annotation(Input),
InputAnnotation)
self.assertEqual(io_types.get_io_artifact_annotation(Output),
OutputAnnotation)
self.assertEqual(
io_types.get_io_artifact_annotation(Output[Model]), OutputAnnotation)
self.assertEqual(
io_types.get_io_artifact_annotation(Input[Model]), InputAnnotation)
self.assertEqual(
io_types.get_io_artifact_annotation(Input), InputAnnotation)
self.assertEqual(
io_types.get_io_artifact_annotation(Output), OutputAnnotation)

self.assertEqual(io_types.get_io_artifact_annotation(Model), None)
self.assertEqual(io_types.get_io_artifact_annotation(str), None)


if __name__ == '__main__':
unittest.main()
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,30 @@

from google.protobuf import json_format


class MetricsUtilsTest(unittest.TestCase):

def test_confusion_matrix(self):
conf_matrix = metrics_utils.ConfusionMatrix()
conf_matrix.set_categories(['dog', 'cat', 'horses'])
conf_matrix.log_row('dog', [2, 6, 0])
conf_matrix.log_cell('cat', 'dog', 3)
with open(os.path.join(os.path.dirname(__file__),
'test_data', 'expected_confusion_matrix.json')) as json_file:
with open(
os.path.join(
os.path.dirname(__file__), 'test_data',
'expected_confusion_matrix.json')) as json_file:
expected_json = json.load(json_file)
self.assertEqual(expected_json, conf_matrix.get_metrics())

def test_bulkload_confusion_matrix(self):
conf_matrix = metrics_utils.ConfusionMatrix()
conf_matrix.load_matrix(['dog', 'cat', 'horses'], [
[2, 6, 0], [3, 5,6], [5,7,8]])
conf_matrix.load_matrix(['dog', 'cat', 'horses'],
[[2, 6, 0], [3, 5, 6], [5, 7, 8]])

with open(os.path.join(os.path.dirname(__file__),
'test_data',
'expected_bulk_loaded_confusion_matrix.json')) as json_file:
with open(
os.path.join(
os.path.dirname(__file__), 'test_data',
'expected_bulk_loaded_confusion_matrix.json')) as json_file:
expected_json = json.load(json_file)
self.assertEqual(expected_json, conf_matrix.get_metrics())

Expand All @@ -55,11 +59,12 @@ def test_confidence_metrics(self):
confid_metrics.recall = 24.5
confid_metrics.falsePositiveRate = 98.4
expected_dict = {
'confidenceThreshold': 24.3,
'recall': 24.5,
'falsePositiveRate': 98.4
'confidenceThreshold': 24.3,
'recall': 24.5,
'falsePositiveRate': 98.4
}
self.assertEqual(expected_dict, confid_metrics.get_metrics())


if __name__ == '__main__':
unittest.main()

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

This file was deleted.

This file was deleted.

16 changes: 0 additions & 16 deletions sdk/python/tests/dsl/test_data/expected_complete_artifact.json

This file was deleted.

10 changes: 0 additions & 10 deletions sdk/python/tests/dsl/test_data/expected_dataset_artifact.json

This file was deleted.

12 changes: 0 additions & 12 deletions sdk/python/tests/dsl/test_data/expected_metrics.json

This file was deleted.

17 changes: 0 additions & 17 deletions sdk/python/tests/dsl/test_data/expected_model_artifact.json

This file was deleted.

This file was deleted.

0 comments on commit b7084f2

Please sign in to comment.