Skip to content

Handle union and literal typing correctly in annotations #478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions csp/impl/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def __new__(cls, name, bases, dct):
# Lists need to be normalized too as potentially we need to add a boolean flag to use FastList
if v == FastList:
raise TypeError(f"{v} annotation is not supported without args")
if CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v):
if (
CspTypingUtils.is_generic_container(v)
or CspTypingUtils.is_union_type(v)
or CspTypingUtils.is_literal_type(v)
):
actual_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(v)
if CspTypingUtils.is_generic_container(actual_type):
raise TypeError(f"{v} annotation is not supported as a struct field [{actual_type}]")
Expand Down Expand Up @@ -191,7 +195,8 @@ def _obj_from_python(cls, json, obj_type):
if CspTypingUtils.is_generic_container(obj_type):
if CspTypingUtils.get_origin(obj_type) in (typing.List, typing.Set, typing.Tuple, FastList):
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
(expected_item_type,) = obj_type.__args__
# We only take the first item, so like for a Tuple, we would ignore arguments after
expected_item_type = obj_type.__args__[0]
return_type = list if isinstance(return_type, list) else return_type
return return_type(cls._obj_from_python(v, expected_item_type) for v in json)
elif CspTypingUtils.get_origin(obj_type) is typing.Dict:
Expand All @@ -206,6 +211,13 @@ def _obj_from_python(cls, json, obj_type):
return json
else:
raise NotImplementedError(f"Can not deserialize {obj_type} from json")
elif CspTypingUtils.is_union_type(obj_type):
return json ## no checks, just let it through
elif CspTypingUtils.is_literal_type(obj_type):
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
if isinstance(json, return_type):
return json
raise ValueError(f"Expected type {return_type} received {json.__class__}")
elif issubclass(obj_type, Struct):
if not isinstance(json, dict):
raise TypeError("Representation of struct as json is expected to be of dict type")
Expand Down
24 changes: 12 additions & 12 deletions csp/impl/types/container_type_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,21 @@ def normalized_type_to_actual_python_type(cls, typ, level=0):
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1), True]
if origin is typing.List and level == 0:
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1)]
if origin is typing.Literal:
# Import here to prevent circular import
from csp.impl.types.instantiation_type_resolver import UpcastRegistry

args = typing.get_args(typ)
typ = type(args[0])
for arg in args[1:]:
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
if typ:
return typ
else:
return object
return cls._NORMALIZED_TYPE_MAPPING.get(CspTypingUtils.get_origin(typ), typ)
elif CspTypingUtils.is_union_type(typ):
return object
elif CspTypingUtils.is_literal_type(typ):
# Import here to prevent circular import
from csp.impl.types.instantiation_type_resolver import UpcastRegistry

args = typing.get_args(typ)
typ = type(args[0])
for arg in args[1:]:
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
if typ:
return typ
else:
return object
else:
return typ

Expand Down
4 changes: 3 additions & 1 deletion csp/impl/types/pydantic_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import types
import typing
from typing import Any, ForwardRef, Generic, Optional, Type, TypeVar, Union, get_args, get_origin
from typing import Any, ForwardRef, Generic, Literal, Optional, Type, TypeVar, Union, get_args, get_origin

from pydantic import GetCoreSchemaHandler, ValidationInfo, ValidatorFunctionWrapHandler
from pydantic_core import CoreSchema, core_schema
Expand Down Expand Up @@ -184,6 +184,8 @@ def adjust_annotations(
return TsType[
adjust_annotations(args[0], top_level=False, in_ts=True, make_optional=False, forced_tvars=forced_tvars)
]
if origin is Literal: # for literals, we stop converting
return Optional[annotation] if make_optional else annotation
else:
try:
if origin is CspTypeVar or origin is CspTypeVarType:
Expand Down
18 changes: 8 additions & 10 deletions csp/impl/types/type_annotation_normalizer_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def visit_arg(self, node):
return node

def visit_Subscript(self, node):
# We choose to avoid parsing here
# to maintain current behavior of allowing empty lists in our types
return node

def visit_List(self, node):
Expand Down Expand Up @@ -98,17 +100,13 @@ def visit_Call(self, node):
return node

def visit_Constant(self, node):
if not self._cur_arg:
return node

if self._cur_arg:
return ast.Call(
func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()),
args=[node],
keywords=[],
)
else:
if not self._cur_arg or not isinstance(node.value, str):
return node
return ast.Call(
func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()),
args=[node],
keywords=[],
)

def visit_Str(self, node):
return self.visit_Constant(node)
6 changes: 5 additions & 1 deletion csp/impl/types/typing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class CspTypingUtils39:

@classmethod
def is_generic_container(cls, typ):
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ not in (typing.Union, typing.Literal)

@classmethod
def is_type_spec(cls, val):
Expand Down Expand Up @@ -56,6 +56,10 @@ def is_numpy_nd_array_type(cls, typ):
def is_union_type(cls, typ):
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Union

@classmethod
def is_literal_type(cls, typ):
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Literal

@classmethod
def is_forward_ref(cls, typ):
return isinstance(typ, typing.ForwardRef)
Expand Down
124 changes: 124 additions & 0 deletions csp/tests/impl/test_struct.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import json
import pickle
import sys
import typing
import unittest
from datetime import date, datetime, time, timedelta
Expand Down Expand Up @@ -3903,6 +3904,129 @@ class DataPoint(csp.Struct):
self.assertNotIn("_last_updated", json_data)
self.assertNotIn("_source", json_data["data"])

def test_literal_types_validation(self):
"""Test that Literal type annotations correctly validate input values in CSP Structs"""

# Define a simple class with various Literal types
class StructWithLiterals(csp.Struct):
# String literals
color: Literal["red", "green", "blue"]
# Integer literals
size: Literal[1, 2, 3]
# Mixed type literals
status: Literal["on", "off", 0, 1, True, False]
# Optional literal with default
mode: Optional[Literal["fast", "slow"]] = "fast"

# Test valid assignments
s1 = StructWithLiterals(color="red", size=2, status="on")
self.assertEqual(s1.color, "red")
self.assertEqual(s1.size, 2)
self.assertEqual(s1.status, "on")
self.assertEqual(s1.mode, "fast") # Default value

s2 = StructWithLiterals.from_dict(dict(color="blue", size=1, status=True, mode="slow"))
s2_dump = s2.to_json()
s2_looped = TypeAdapter(StructWithLiterals).validate_json(s2_dump)
self.assertEqual(s2, s2_looped)
s2_dict = s2.to_dict()
s2_looped_dict = s2.from_dict(s2_dict)
self.assertEqual(s2_looped_dict, s2)

# Invalid color, but from_dict still accepts
StructWithLiterals.from_dict(dict(color="yellow", size=1, status="on"))

# Invalid size but from_dict still accepts
StructWithLiterals.from_dict(dict(color="red", size=4, status="on"))

# Invalid status but from_dict still accepts
StructWithLiterals.from_dict(dict(color="red", size=1, status="standby"))

# Invalid mode but from_dict still accepts
StructWithLiterals.from_dict(dict(color="red", size=1, mode=12))

# Invalid size and since the literals are all the same type
# If we give an incorrect type, we catch the error
with self.assertRaises(ValueError) as exc_info:
StructWithLiterals.from_dict(dict(color="red", size="adasd", mode=12))
self.assertIn("Expected type <class 'int'> received <class 'str'>", str(exc_info.exception))

# Test valid values
result = TypeAdapter(StructWithLiterals).validate_python({"color": "green", "size": 3, "status": 0})
self.assertEqual(result.color, "green")
self.assertEqual(result.size, 3)
self.assertEqual(result.status, 0)

# Test invalid color with Pydantic validation
with self.assertRaises(ValidationError) as exc_info:
TypeAdapter(StructWithLiterals).validate_python({"color": "yellow", "size": 1, "status": "on"})
self.assertIn("1 validation error for", str(exc_info.exception))
self.assertIn("color", str(exc_info.exception))

# Test invalid size with Pydantic validation
with self.assertRaises(ValidationError) as exc_info:
TypeAdapter(StructWithLiterals).validate_python({"color": "red", "size": 4, "status": "on"})
self.assertIn("1 validation error for", str(exc_info.exception))
self.assertIn("size", str(exc_info.exception))

# Test invalid status with Pydantic validation
with self.assertRaises(ValidationError) as exc_info:
TypeAdapter(StructWithLiterals).validate_python({"color": "red", "size": 1, "status": "standby"})
self.assertIn("1 validation error for", str(exc_info.exception))
self.assertIn("status", str(exc_info.exception))

# Test invalid mode with Pydantic validation
with self.assertRaises(ValidationError) as exc_info:
TypeAdapter(StructWithLiterals).validate_python(
{"color": "red", "size": 1, "status": "on", "mode": "medium"}
)
self.assertIn("1 validation error for", str(exc_info.exception))
self.assertIn("mode", str(exc_info.exception))

def test_pipe_operator_types(self):
"""Test using the pipe operator for union types in Python 3.10+"""
if sys.version_info >= (3, 10): # Only run on Python 3.10+
# Define a class using various pipe operator combinations
class PipeTypesConfig(csp.Struct):
# Basic primitive types with pipe
id_field: str | int
# Pipe with None (similar to Optional)
description: str | None = None
# Multiple types with pipe
value: str | int | float | bool
# Container with pipe
tags: List[str] | Dict[str, str] | None = None
# Pipe with literal for comparison
status: Literal["active", "inactive"] | None = "active"

# Test all valid types
valid_cases = [
{"id_field": "string_id", "value": "string_value"},
{"id_field": 42, "value": 123},
{"id_field": "mixed", "value": 3.14},
{"id_field": 999, "value": True},
{"id_field": "with_desc", "value": 1, "description": "Description"},
{"id_field": "with_dict", "value": 1, "tags": None},
]

for case in valid_cases:
result = PipeTypesConfig.from_dict(case)
# use the other route to get back the result
result_to_dict_loop = TypeAdapter(PipeTypesConfig).validate_python(result.to_dict())
self.assertEqual(result, result_to_dict_loop)

# Test invalid values
invalid_cases = [
{"id_field": 3.14, "value": 1}, # Float for id_field
{"id_field": None, "value": 1}, # None for required id_field
{"id_field": "test", "value": {}}, # Dict for value
{"id_field": "test", "value": None}, # None for required value
{"id_field": "test", "value": 1, "status": "unknown"}, # Invalid literal
]
for case in invalid_cases:
with self.assertRaises(ValidationError):
TypeAdapter(PipeTypesConfig).validate_python(case)


if __name__ == "__main__":
unittest.main()
11 changes: 10 additions & 1 deletion csp/tests/impl/types/test_pydantic_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
from inspect import isclass
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, get_args, get_origin
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union, get_args, get_origin
from unittest import TestCase

import csp
Expand Down Expand Up @@ -160,3 +160,12 @@ def test_force_tvars(self):
self.assertAnnotationsEqual(
adjust_annotations(CspTypeVarType[T], forced_tvars={"T": float}), Union[Type[float], Type[int]]
)

def test_literal(self):
self.assertAnnotationsEqual(adjust_annotations(Literal["a", "b"]), Literal["a", "b"])
self.assertAnnotationsEqual(
adjust_annotations(Literal["a", "b"], make_optional=True), Optional[Literal["a", "b"]]
)
self.assertAnnotationsEqual(adjust_annotations(Literal[123, "a"]), Literal[123, "a"])
self.assertAnnotationsEqual(adjust_annotations(Literal[123, None]), Literal[123, None])
self.assertAnnotationsEqual(adjust_annotations(ts[Literal[123, None]]), ts[Literal[123, None]])
Loading
Loading