diff --git a/flytekit/models/filters.py b/flytekit/models/filters.py index 2b0cb04d88..5d7bb55104 100644 --- a/flytekit/models/filters.py +++ b/flytekit/models/filters.py @@ -118,6 +118,8 @@ def __init__(self, key, values): :param Text key: The name of the field to compare against :param list[Text] values: A list of textual values to compare. """ + if not isinstance(values, list): + raise TypeError(f"values must be a list. but got {type(values)}") super(SetFilter, self).__init__(key, ";".join(values)) @classmethod diff --git a/tests/flytekit/unit/models/test_filters.py b/tests/flytekit/unit/models/test_filters.py index 7f4f9c9b86..d995eeb805 100644 --- a/tests/flytekit/unit/models/test_filters.py +++ b/tests/flytekit/unit/models/test_filters.py @@ -1,5 +1,5 @@ from flytekit.models import filters - +import pytest def test_eq_filter(): assert filters.Equal("key", "value").to_flyte_idl() == "eq(key,value)" @@ -28,6 +28,10 @@ def test_lte_filter(): def test_value_in_filter(): assert filters.ValueIn("key", ["1", "2", "3"]).to_flyte_idl() == "value_in(key,1;2;3)" +def test_invalid_value_in_filter(): + with pytest.raises(TypeError, match=r"values must be a list. but got .*"): + filters.ValueIn("key", "1") + def test_contains_filter(): assert filters.Contains("key", ["1", "2", "3"]).to_flyte_idl() == "contains(key,1;2;3)"