@@ -99,6 +99,88 @@ class StopStreaming(Exception):
9999 """Raised internally when processing of a streamed response should be stopped."""
100100
101101
102+ class Stream (Generic [ResponseT ]):
103+ response : httpx .Response
104+
105+ def __init__ (
106+ self ,
107+ * ,
108+ cast_to : type [ResponseT ],
109+ response : httpx .Response ,
110+ client : SyncAPIClient ,
111+ ) -> None :
112+ self .response = response
113+ self ._cast_to = cast_to
114+ self ._client = client
115+ self ._iterator = self .__iter ()
116+
117+ def __next__ (self ) -> ResponseT :
118+ return self ._iterator .__next__ ()
119+
120+ def __iter__ (self ) -> Iterator [ResponseT ]:
121+ for item in self ._iterator :
122+ yield item
123+
124+ def __iter (self ) -> Iterator [ResponseT ]:
125+ cast_to = self ._cast_to
126+ response = self .response
127+ process_line = self ._client ._process_stream_line
128+ process_data = self ._client ._process_response_data
129+
130+ for raw_line in response .iter_lines ():
131+ if not raw_line or raw_line == "\n " :
132+ continue
133+
134+ try :
135+ line = process_line (raw_line )
136+ except StopStreaming :
137+ # we are done!
138+ break
139+
140+ yield process_data (data = json .loads (line ), cast_to = cast_to , response = response )
141+
142+
143+ class AsyncStream (Generic [ResponseT ]):
144+ response : httpx .Response
145+
146+ def __init__ (
147+ self ,
148+ * ,
149+ cast_to : type [ResponseT ],
150+ response : httpx .Response ,
151+ client : AsyncAPIClient ,
152+ ) -> None :
153+ self .response = response
154+ self ._cast_to = cast_to
155+ self ._client = client
156+ self ._iterator = self .__iter ()
157+
158+ async def __anext__ (self ) -> ResponseT :
159+ return await self ._iterator .__anext__ ()
160+
161+ async def __aiter__ (self ) -> AsyncIterator [ResponseT ]:
162+ async for item in self ._iterator :
163+ yield item
164+
165+ async def __iter (self ) -> AsyncIterator [ResponseT ]:
166+ cast_to = self ._cast_to
167+ response = self .response
168+ process_line = self ._client ._process_stream_line
169+ process_data = self ._client ._process_response_data
170+
171+ async for raw_line in response .aiter_lines ():
172+ if not raw_line or raw_line == "\n " :
173+ continue
174+
175+ try :
176+ line = process_line (raw_line )
177+ except StopStreaming :
178+ # we are done!
179+ break
180+
181+ yield process_data (data = json .loads (line ), cast_to = cast_to , response = response )
182+
183+
102184class PageInfo :
103185 """Stores the necesary information to build the request to retrieve the next page.
104186
@@ -526,7 +608,6 @@ def _process_response_data(
526608
527609 return cast (ResponseT , construct_type (type_ = cast_to , value = data ))
528610
529- # TODO: make the constants in here configurable
530611 def _process_stream_line (self , contents : str ) -> str :
531612 """Pre-process an indiviudal line from a streaming response"""
532613 if contents == "data: [DONE]\n " :
@@ -690,7 +771,7 @@ def request(
690771 remaining_retries : Optional [int ] = None ,
691772 * ,
692773 stream : Literal [True ],
693- ) -> Iterator [ResponseT ]:
774+ ) -> Stream [ResponseT ]:
694775 ...
695776
696777 @overload
@@ -712,7 +793,7 @@ def request(
712793 remaining_retries : Optional [int ] = None ,
713794 * ,
714795 stream : bool = False ,
715- ) -> ResponseT | Iterator [ResponseT ]:
796+ ) -> ResponseT | Stream [ResponseT ]:
716797 ...
717798
718799 def request (
@@ -722,7 +803,7 @@ def request(
722803 remaining_retries : Optional [int ] = None ,
723804 * ,
724805 stream : bool = False ,
725- ) -> ResponseT | Iterator [ResponseT ]:
806+ ) -> ResponseT | Stream [ResponseT ]:
726807 return self ._request (
727808 cast_to = cast_to ,
728809 options = options ,
@@ -737,7 +818,7 @@ def _request(
737818 options : FinalRequestOptions ,
738819 remaining_retries : int | None ,
739820 stream : bool ,
740- ) -> ResponseT | Iterator [ResponseT ]:
821+ ) -> ResponseT | Stream [ResponseT ]:
741822 retries = self ._remaining_retries (remaining_retries , options )
742823 request = self ._build_request (options )
743824
@@ -762,7 +843,7 @@ def _request(
762843 raise APIConnectionError (request = request ) from err
763844
764845 if stream :
765- return self . _process_stream_response (cast_to = cast_to , response = response )
846+ return Stream (cast_to = cast_to , response = response , client = self )
766847
767848 try :
768849 rsp = self ._process_response (cast_to = cast_to , options = options , response = response )
@@ -779,7 +860,7 @@ def _retry_request(
779860 response_headers : Optional [httpx .Headers ] = None ,
780861 * ,
781862 stream : bool ,
782- ) -> ResponseT | Iterator [ResponseT ]:
863+ ) -> ResponseT | Stream [ResponseT ]:
783864 remaining = remaining_retries - 1
784865 timeout = self ._calculate_retry_timeout (remaining , options , response_headers )
785866
@@ -794,24 +875,6 @@ def _retry_request(
794875 stream = stream ,
795876 )
796877
797- def _process_stream_response (
798- self ,
799- * ,
800- cast_to : Type [ResponseT ],
801- response : httpx .Response ,
802- ) -> Iterator [ResponseT ]:
803- for raw_line in response .iter_lines ():
804- if not raw_line or raw_line == "\n " :
805- continue
806-
807- try :
808- line = self ._process_stream_line (raw_line )
809- except StopStreaming :
810- # we are done!
811- break
812-
813- yield self ._process_response_data (data = json .loads (line ), cast_to = cast_to , response = response )
814-
815878 def _request_api_list (
816879 self ,
817880 model : Type [ModelT ],
@@ -861,7 +924,7 @@ def post(
861924 options : RequestOptions = {},
862925 files : RequestFiles | None = None ,
863926 stream : Literal [True ],
864- ) -> Iterator [ResponseT ]:
927+ ) -> Stream [ResponseT ]:
865928 ...
866929
867930 @overload
@@ -874,7 +937,7 @@ def post(
874937 options : RequestOptions = {},
875938 files : RequestFiles | None = None ,
876939 stream : bool ,
877- ) -> ResponseT | Iterator [ResponseT ]:
940+ ) -> ResponseT | Stream [ResponseT ]:
878941 ...
879942
880943 def post (
@@ -886,7 +949,7 @@ def post(
886949 options : RequestOptions = {},
887950 files : RequestFiles | None = None ,
888951 stream : bool = False ,
889- ) -> ResponseT | Iterator [ResponseT ]:
952+ ) -> ResponseT | Stream [ResponseT ]:
890953 opts = FinalRequestOptions .construct (method = "post" , url = path , json_data = body , files = files , ** options )
891954 return cast (ResponseT , self .request (cast_to , opts , stream = stream ))
892955
@@ -993,7 +1056,7 @@ async def request(
9931056 * ,
9941057 stream : Literal [True ],
9951058 remaining_retries : Optional [int ] = None ,
996- ) -> AsyncIterator [ResponseT ]:
1059+ ) -> AsyncStream [ResponseT ]:
9971060 ...
9981061
9991062 @overload
@@ -1004,7 +1067,7 @@ async def request(
10041067 * ,
10051068 stream : bool ,
10061069 remaining_retries : Optional [int ] = None ,
1007- ) -> ResponseT | AsyncIterator [ResponseT ]:
1070+ ) -> ResponseT | AsyncStream [ResponseT ]:
10081071 ...
10091072
10101073 async def request (
@@ -1014,7 +1077,7 @@ async def request(
10141077 * ,
10151078 stream : bool = False ,
10161079 remaining_retries : Optional [int ] = None ,
1017- ) -> ResponseT | AsyncIterator [ResponseT ]:
1080+ ) -> ResponseT | AsyncStream [ResponseT ]:
10181081 return await self ._request (
10191082 cast_to = cast_to ,
10201083 options = options ,
@@ -1029,7 +1092,7 @@ async def _request(
10291092 * ,
10301093 stream : bool ,
10311094 remaining_retries : int | None ,
1032- ) -> ResponseT | AsyncIterator [ResponseT ]:
1095+ ) -> ResponseT | AsyncStream [ResponseT ]:
10331096 retries = self ._remaining_retries (remaining_retries , options )
10341097 request = self ._build_request (options )
10351098
@@ -1064,7 +1127,7 @@ async def _request(
10641127 raise APIConnectionError (request = request ) from err
10651128
10661129 if stream :
1067- return self . _process_stream_response (cast_to = cast_to , response = response )
1130+ return AsyncStream (cast_to = cast_to , response = response , client = self )
10681131
10691132 try :
10701133 rsp = self ._process_response (cast_to = cast_to , options = options , response = response )
@@ -1081,7 +1144,7 @@ async def _retry_request(
10811144 response_headers : Optional [httpx .Headers ] = None ,
10821145 * ,
10831146 stream : bool ,
1084- ) -> ResponseT | AsyncIterator [ResponseT ]:
1147+ ) -> ResponseT | AsyncStream [ResponseT ]:
10851148 remaining = remaining_retries - 1
10861149 timeout = self ._calculate_retry_timeout (remaining , options , response_headers )
10871150
@@ -1094,24 +1157,6 @@ async def _retry_request(
10941157 stream = stream ,
10951158 )
10961159
1097- async def _process_stream_response (
1098- self ,
1099- * ,
1100- cast_to : Type [ResponseT ],
1101- response : httpx .Response ,
1102- ) -> AsyncIterator [ResponseT ]:
1103- async for raw_line in response .aiter_lines ():
1104- if not raw_line or raw_line == "\n " :
1105- continue
1106-
1107- try :
1108- line = self ._process_stream_line (raw_line )
1109- except StopStreaming :
1110- # we are done!
1111- break
1112-
1113- yield self ._process_response_data (data = json .loads (line ), cast_to = cast_to , response = response )
1114-
11151160 def _request_api_list (
11161161 self ,
11171162 model : Type [ModelT ],
@@ -1153,7 +1198,7 @@ async def post(
11531198 files : RequestFiles | None = None ,
11541199 options : RequestOptions = {},
11551200 stream : Literal [True ],
1156- ) -> AsyncIterator [ResponseT ]:
1201+ ) -> AsyncStream [ResponseT ]:
11571202 ...
11581203
11591204 @overload
@@ -1166,7 +1211,7 @@ async def post(
11661211 files : RequestFiles | None = None ,
11671212 options : RequestOptions = {},
11681213 stream : bool ,
1169- ) -> ResponseT | AsyncIterator [ResponseT ]:
1214+ ) -> ResponseT | AsyncStream [ResponseT ]:
11701215 ...
11711216
11721217 async def post (
@@ -1178,7 +1223,7 @@ async def post(
11781223 files : RequestFiles | None = None ,
11791224 options : RequestOptions = {},
11801225 stream : bool = False ,
1181- ) -> ResponseT | AsyncIterator [ResponseT ]:
1226+ ) -> ResponseT | AsyncStream [ResponseT ]:
11821227 opts = FinalRequestOptions .construct (method = "post" , url = path , json_data = body , files = files , ** options )
11831228 return await self .request (cast_to , opts , stream = stream )
11841229
0 commit comments