11import typing
2- from typing import Generic , Callable , Any , TypeVar
2+ from typing import Generic , Callable , Any , TypeVar , Annotated
33
44import pydantic
55from pydantic import BaseModel
66
7+ from langdiff .parser .decoder import get_decoder
8+
79T = TypeVar ("T" )
810
911Field = pydantic .Field
1012
1113
14+ class PydanticType :
15+ """A hint that specifies the Pydantic type to use when converting to Pydantic models.
16+
17+ This is used with typing.Annotated to provide custom type hints for Pydantic model derivation.
18+
19+ Example:
20+ class Item(Object):
21+ field: Annotated[String, PydanticType(UUID)]
22+
23+ When Item.to_pydantic() is called, the generated field will have type UUID instead of str.
24+ """
25+
26+ def __init__ (self , pydantic_type : Any ):
27+ """Initialize with the desired Pydantic type.
28+
29+ Args:
30+ pydantic_type: The type to use in the generated Pydantic model
31+ """
32+ self .pydantic_type = pydantic_type
33+
34+
1235class StreamingValue (Generic [T ]):
1336 """A generic base class for a value that is streamed incrementally.
1437
@@ -65,12 +88,17 @@ def __init__(self):
6588 for key , type_hint in type (self ).__annotations__ .items ():
6689 self ._keys .append (key )
6790
68- # handle StreamingList[T], CompleteValue[T]
69- if hasattr (type_hint , "__origin__" ):
70- item_cls = typing .get_args (type_hint )[0 ]
71- setattr (self , key , type_hint .__origin__ (item_cls ))
91+ # Extract base type from Annotated[T, PydanticType(...), ...]
92+ base_type = type_hint
93+ if typing .get_origin (type_hint ) is Annotated :
94+ base_type = typing .get_args (type_hint )[0 ]
95+
96+ # handle List[T], Atom[T]
97+ if hasattr (base_type , "__origin__" ):
98+ item_cls = typing .get_args (base_type )[0 ]
99+ setattr (self , key , base_type .__origin__ (item_cls ))
72100 else :
73- setattr (self , key , type_hint ())
101+ setattr (self , key , base_type ())
74102
75103 def on_update (self , func : Callable [[dict ], Any ]):
76104 """Register a callback that is called whenever the object is updated."""
@@ -121,7 +149,7 @@ def to_pydantic(cls) -> type[BaseModel]:
121149 model = getattr (cls , "_pydantic_model" , None )
122150 if model is not None : # use cached model if available
123151 return model
124- fields = {}
152+ fields : dict [ str , Any ] = {}
125153 for name , type_hint in cls .__annotations__ .items ():
126154 type_hint = unwrap_raw_type (type_hint )
127155 field = getattr (cls , name , None )
@@ -130,15 +158,15 @@ def to_pydantic(cls) -> type[BaseModel]:
130158 else :
131159 fields [name ] = type_hint
132160 model = pydantic .create_model (cls .__name__ , ** fields , __doc__ = cls .__doc__ )
133- cls . _pydantic_model = model
161+ setattr ( cls , " _pydantic_model" , model )
134162 return model
135163
136164
137165class List (Generic [T ], StreamingValue [list ]):
138166 """Represents a JSON array that is streamed.
139167
140168 This class can handle a list of items that are themselves `StreamingValue`s
141- (like `StreamingObject ` or `StreamingString `) or complete values. It provides
169+ (like `langdiff.Object ` or `langdiff.String `) or complete values. It provides
142170 an `on_append` callback that is fired when a new item is added to the list.
143171 """
144172
@@ -154,9 +182,7 @@ def __init__(self, item_cls: type[T]):
154182 self ._value = []
155183 self ._item_cls = item_cls
156184 self ._item_streaming = issubclass (item_cls , StreamingValue )
157- self ._decode = (
158- item_cls .model_validate if issubclass (item_cls , BaseModel ) else None
159- )
185+ self ._decode = get_decoder (item_cls ) if not self ._item_streaming else None
160186 self ._streaming_values = []
161187 self ._on_append_funcs = []
162188
@@ -270,7 +296,7 @@ def update(self, value: str | None):
270296 else :
271297 if value is None or not value .startswith (self ._value ):
272298 raise ValueError (
273- "StreamingString can only be updated with a continuation of the current value."
299+ "langdiff.String can only be updated with a continuation of the current value."
274300 )
275301 if len (value ) == len (self ._value ):
276302 return
@@ -290,18 +316,16 @@ class Atom(Generic[T], StreamingValue[T]):
290316
291317 This is useful for types like numbers, booleans, or even entire objects/lists
292318 that are not streamed part-by-part but are present completely once available.
293- The `on_complete` callback is triggered when the parent `StreamingObject ` or
294- `StreamingList ` determines that this value is complete.
319+ The `on_complete` callback is triggered when the parent `langdiff.Object ` or
320+ `langdiff.List ` determines that this value is complete.
295321 """
296322
297323 _value : T | None
298324
299325 def __init__ (self , item_cls : type [T ]):
300326 super ().__init__ ()
301327 self ._value = None
302- self ._decode = (
303- item_cls .model_validate if issubclass (item_cls , BaseModel ) else None
304- )
328+ self ._decode = get_decoder (item_cls )
305329
306330 def update (self , value : T ):
307331 self ._trigger_start ()
@@ -320,23 +344,53 @@ def value(self) -> T | None:
320344 return self ._value
321345
322346
323- def unwrap_raw_type (type_hint : Any ) -> type :
347+ def _extract_pydantic_hint (type_hint : Any ) -> type | None :
348+ """Extract PydanticType from Annotated type if present."""
349+ if typing .get_origin (type_hint ) is Annotated :
350+ args = typing .get_args (type_hint )
351+ if len (args ) >= 2 :
352+ # Look for PydanticType in the metadata
353+ for metadata in args [1 :]:
354+ if isinstance (metadata , PydanticType ):
355+ return metadata .pydantic_type
356+ return None
357+
358+
359+ def unwrap_raw_type (type_hint : Any ):
324360 # Possible types:
361+ # - Annotated[T, PydanticType(U)] => U (custom Pydantic type)
325362 # - Atom[T] => T
326363 # - List[T] => list[unwrap(T)]
327364 # - String => str
328- # - T extends StreamableModel => T.to_pydantic()
365+ # - T extends Object => T.to_pydantic()
366+
367+ # First check for PydanticType in Annotated types
368+ pydantic_hint = _extract_pydantic_hint (type_hint )
369+ if pydantic_hint is not None :
370+ return pydantic_hint
371+
372+ # Handle Annotated[T, ...] by extracting the base type
373+ if typing .get_origin (type_hint ) is Annotated :
374+ type_hint = typing .get_args (type_hint )[0 ]
375+
329376 if hasattr (type_hint , "__origin__" ):
330377 origin = type_hint .__origin__
331378 if origin is Atom :
332379 return typing .get_args (type_hint )[0 ]
333380 elif origin is List :
334381 item_type = typing .get_args (type_hint )[0 ]
335- return list [unwrap_raw_type (item_type )]
382+ return list [unwrap_raw_type (item_type )] # type: ignore[misc]
336383 elif type_hint is String :
337384 return str
338385 elif issubclass (type_hint , Object ):
339386 return type_hint .to_pydantic ()
387+ elif issubclass (type_hint , StreamingValue ):
388+ to_pydantic = getattr (type_hint , "to_pydantic" , None )
389+ if to_pydantic is None or not callable (to_pydantic ):
390+ raise ValueError (
391+ f"Custom StreamingValue type { type_hint } must implement to_pydantic() method."
392+ )
393+ return to_pydantic ()
340394 elif (
341395 type_hint is str
342396 or type_hint is int
@@ -346,5 +400,5 @@ def unwrap_raw_type(type_hint: Any) -> type:
346400 ):
347401 return type_hint
348402 raise ValueError (
349- f"Unsupported type hint: { type_hint } . Expected Atom, List, String, or StreamableModel subclass."
403+ f"Unsupported type hint: { type_hint } . Expected LangDiff Atom, List, String, or Object subclass."
350404 )
0 commit comments