Skip to content

Commit f03807a

Browse files
refactor(internal): move to Stream and AsyncStream classes for streaming
refactor(internal): move to `Stream` and `AsyncStream` classes for streaming
1 parent a57e32a commit f03807a

File tree

3 files changed

+108
-62
lines changed

3 files changed

+108
-62
lines changed

src/increase/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
from ._client import (
77
ENVIRONMENTS,
88
Client,
9+
Stream,
910
Timeout,
1011
Increase,
1112
Transport,
1213
AsyncClient,
14+
AsyncStream,
1315
ProxiesTypes,
1416
AsyncIncrease,
1517
RequestOptions,
@@ -79,6 +81,8 @@
7981
"RequestOptions",
8082
"Client",
8183
"AsyncClient",
84+
"Stream",
85+
"AsyncStream",
8286
"Increase",
8387
"AsyncIncrease",
8488
"ENVIRONMENTS",

src/increase/_base_client.py

Lines changed: 100 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,88 @@ class StopStreaming(Exception):
9999
"""Raised internally when processing of a streamed response should be stopped."""
100100

101101

102+
class Stream(Generic[ResponseT]):
103+
response: httpx.Response
104+
105+
def __init__(
106+
self,
107+
*,
108+
cast_to: type[ResponseT],
109+
response: httpx.Response,
110+
client: SyncAPIClient,
111+
) -> None:
112+
self.response = response
113+
self._cast_to = cast_to
114+
self._client = client
115+
self._iterator = self.__iter()
116+
117+
def __next__(self) -> ResponseT:
118+
return self._iterator.__next__()
119+
120+
def __iter__(self) -> Iterator[ResponseT]:
121+
for item in self._iterator:
122+
yield item
123+
124+
def __iter(self) -> Iterator[ResponseT]:
125+
cast_to = self._cast_to
126+
response = self.response
127+
process_line = self._client._process_stream_line
128+
process_data = self._client._process_response_data
129+
130+
for raw_line in response.iter_lines():
131+
if not raw_line or raw_line == "\n":
132+
continue
133+
134+
try:
135+
line = process_line(raw_line)
136+
except StopStreaming:
137+
# we are done!
138+
break
139+
140+
yield process_data(data=json.loads(line), cast_to=cast_to, response=response)
141+
142+
143+
class AsyncStream(Generic[ResponseT]):
144+
response: httpx.Response
145+
146+
def __init__(
147+
self,
148+
*,
149+
cast_to: type[ResponseT],
150+
response: httpx.Response,
151+
client: AsyncAPIClient,
152+
) -> None:
153+
self.response = response
154+
self._cast_to = cast_to
155+
self._client = client
156+
self._iterator = self.__iter()
157+
158+
async def __anext__(self) -> ResponseT:
159+
return await self._iterator.__anext__()
160+
161+
async def __aiter__(self) -> AsyncIterator[ResponseT]:
162+
async for item in self._iterator:
163+
yield item
164+
165+
async def __iter(self) -> AsyncIterator[ResponseT]:
166+
cast_to = self._cast_to
167+
response = self.response
168+
process_line = self._client._process_stream_line
169+
process_data = self._client._process_response_data
170+
171+
async for raw_line in response.aiter_lines():
172+
if not raw_line or raw_line == "\n":
173+
continue
174+
175+
try:
176+
line = process_line(raw_line)
177+
except StopStreaming:
178+
# we are done!
179+
break
180+
181+
yield process_data(data=json.loads(line), cast_to=cast_to, response=response)
182+
183+
102184
class PageInfo:
103185
"""Stores the necesary information to build the request to retrieve the next page.
104186
@@ -526,7 +608,6 @@ def _process_response_data(
526608

527609
return cast(ResponseT, construct_type(type_=cast_to, value=data))
528610

529-
# TODO: make the constants in here configurable
530611
def _process_stream_line(self, contents: str) -> str:
531612
"""Pre-process an indiviudal line from a streaming response"""
532613
if contents == "data: [DONE]\n":
@@ -690,7 +771,7 @@ def request(
690771
remaining_retries: Optional[int] = None,
691772
*,
692773
stream: Literal[True],
693-
) -> Iterator[ResponseT]:
774+
) -> Stream[ResponseT]:
694775
...
695776

696777
@overload
@@ -712,7 +793,7 @@ def request(
712793
remaining_retries: Optional[int] = None,
713794
*,
714795
stream: bool = False,
715-
) -> ResponseT | Iterator[ResponseT]:
796+
) -> ResponseT | Stream[ResponseT]:
716797
...
717798

718799
def request(
@@ -722,7 +803,7 @@ def request(
722803
remaining_retries: Optional[int] = None,
723804
*,
724805
stream: bool = False,
725-
) -> ResponseT | Iterator[ResponseT]:
806+
) -> ResponseT | Stream[ResponseT]:
726807
return self._request(
727808
cast_to=cast_to,
728809
options=options,
@@ -737,7 +818,7 @@ def _request(
737818
options: FinalRequestOptions,
738819
remaining_retries: int | None,
739820
stream: bool,
740-
) -> ResponseT | Iterator[ResponseT]:
821+
) -> ResponseT | Stream[ResponseT]:
741822
retries = self._remaining_retries(remaining_retries, options)
742823
request = self._build_request(options)
743824

@@ -762,7 +843,7 @@ def _request(
762843
raise APIConnectionError(request=request) from err
763844

764845
if stream:
765-
return self._process_stream_response(cast_to=cast_to, response=response)
846+
return Stream(cast_to=cast_to, response=response, client=self)
766847

767848
try:
768849
rsp = self._process_response(cast_to=cast_to, options=options, response=response)
@@ -779,7 +860,7 @@ def _retry_request(
779860
response_headers: Optional[httpx.Headers] = None,
780861
*,
781862
stream: bool,
782-
) -> ResponseT | Iterator[ResponseT]:
863+
) -> ResponseT | Stream[ResponseT]:
783864
remaining = remaining_retries - 1
784865
timeout = self._calculate_retry_timeout(remaining, options, response_headers)
785866

@@ -794,24 +875,6 @@ def _retry_request(
794875
stream=stream,
795876
)
796877

797-
def _process_stream_response(
798-
self,
799-
*,
800-
cast_to: Type[ResponseT],
801-
response: httpx.Response,
802-
) -> Iterator[ResponseT]:
803-
for raw_line in response.iter_lines():
804-
if not raw_line or raw_line == "\n":
805-
continue
806-
807-
try:
808-
line = self._process_stream_line(raw_line)
809-
except StopStreaming:
810-
# we are done!
811-
break
812-
813-
yield self._process_response_data(data=json.loads(line), cast_to=cast_to, response=response)
814-
815878
def _request_api_list(
816879
self,
817880
model: Type[ModelT],
@@ -861,7 +924,7 @@ def post(
861924
options: RequestOptions = {},
862925
files: RequestFiles | None = None,
863926
stream: Literal[True],
864-
) -> Iterator[ResponseT]:
927+
) -> Stream[ResponseT]:
865928
...
866929

867930
@overload
@@ -874,7 +937,7 @@ def post(
874937
options: RequestOptions = {},
875938
files: RequestFiles | None = None,
876939
stream: bool,
877-
) -> ResponseT | Iterator[ResponseT]:
940+
) -> ResponseT | Stream[ResponseT]:
878941
...
879942

880943
def post(
@@ -886,7 +949,7 @@ def post(
886949
options: RequestOptions = {},
887950
files: RequestFiles | None = None,
888951
stream: bool = False,
889-
) -> ResponseT | Iterator[ResponseT]:
952+
) -> ResponseT | Stream[ResponseT]:
890953
opts = FinalRequestOptions.construct(method="post", url=path, json_data=body, files=files, **options)
891954
return cast(ResponseT, self.request(cast_to, opts, stream=stream))
892955

@@ -993,7 +1056,7 @@ async def request(
9931056
*,
9941057
stream: Literal[True],
9951058
remaining_retries: Optional[int] = None,
996-
) -> AsyncIterator[ResponseT]:
1059+
) -> AsyncStream[ResponseT]:
9971060
...
9981061

9991062
@overload
@@ -1004,7 +1067,7 @@ async def request(
10041067
*,
10051068
stream: bool,
10061069
remaining_retries: Optional[int] = None,
1007-
) -> ResponseT | AsyncIterator[ResponseT]:
1070+
) -> ResponseT | AsyncStream[ResponseT]:
10081071
...
10091072

10101073
async def request(
@@ -1014,7 +1077,7 @@ async def request(
10141077
*,
10151078
stream: bool = False,
10161079
remaining_retries: Optional[int] = None,
1017-
) -> ResponseT | AsyncIterator[ResponseT]:
1080+
) -> ResponseT | AsyncStream[ResponseT]:
10181081
return await self._request(
10191082
cast_to=cast_to,
10201083
options=options,
@@ -1029,7 +1092,7 @@ async def _request(
10291092
*,
10301093
stream: bool,
10311094
remaining_retries: int | None,
1032-
) -> ResponseT | AsyncIterator[ResponseT]:
1095+
) -> ResponseT | AsyncStream[ResponseT]:
10331096
retries = self._remaining_retries(remaining_retries, options)
10341097
request = self._build_request(options)
10351098

@@ -1064,7 +1127,7 @@ async def _request(
10641127
raise APIConnectionError(request=request) from err
10651128

10661129
if stream:
1067-
return self._process_stream_response(cast_to=cast_to, response=response)
1130+
return AsyncStream(cast_to=cast_to, response=response, client=self)
10681131

10691132
try:
10701133
rsp = self._process_response(cast_to=cast_to, options=options, response=response)
@@ -1081,7 +1144,7 @@ async def _retry_request(
10811144
response_headers: Optional[httpx.Headers] = None,
10821145
*,
10831146
stream: bool,
1084-
) -> ResponseT | AsyncIterator[ResponseT]:
1147+
) -> ResponseT | AsyncStream[ResponseT]:
10851148
remaining = remaining_retries - 1
10861149
timeout = self._calculate_retry_timeout(remaining, options, response_headers)
10871150

@@ -1094,24 +1157,6 @@ async def _retry_request(
10941157
stream=stream,
10951158
)
10961159

1097-
async def _process_stream_response(
1098-
self,
1099-
*,
1100-
cast_to: Type[ResponseT],
1101-
response: httpx.Response,
1102-
) -> AsyncIterator[ResponseT]:
1103-
async for raw_line in response.aiter_lines():
1104-
if not raw_line or raw_line == "\n":
1105-
continue
1106-
1107-
try:
1108-
line = self._process_stream_line(raw_line)
1109-
except StopStreaming:
1110-
# we are done!
1111-
break
1112-
1113-
yield self._process_response_data(data=json.loads(line), cast_to=cast_to, response=response)
1114-
11151160
def _request_api_list(
11161161
self,
11171162
model: Type[ModelT],
@@ -1153,7 +1198,7 @@ async def post(
11531198
files: RequestFiles | None = None,
11541199
options: RequestOptions = {},
11551200
stream: Literal[True],
1156-
) -> AsyncIterator[ResponseT]:
1201+
) -> AsyncStream[ResponseT]:
11571202
...
11581203

11591204
@overload
@@ -1166,7 +1211,7 @@ async def post(
11661211
files: RequestFiles | None = None,
11671212
options: RequestOptions = {},
11681213
stream: bool,
1169-
) -> ResponseT | AsyncIterator[ResponseT]:
1214+
) -> ResponseT | AsyncStream[ResponseT]:
11701215
...
11711216

11721217
async def post(
@@ -1178,7 +1223,7 @@ async def post(
11781223
files: RequestFiles | None = None,
11791224
options: RequestOptions = {},
11801225
stream: bool = False,
1181-
) -> ResponseT | AsyncIterator[ResponseT]:
1226+
) -> ResponseT | AsyncStream[ResponseT]:
11821227
opts = FinalRequestOptions.construct(method="post", url=path, json_data=body, files=files, **options)
11831228
return await self.request(cast_to, opts, stream=stream)
11841229

src/increase/_client.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,10 @@
2121
)
2222
from ._utils import is_mapping
2323
from ._version import __version__
24-
from ._base_client import (
25-
DEFAULT_LIMITS,
26-
DEFAULT_TIMEOUT,
27-
DEFAULT_MAX_RETRIES,
28-
SyncAPIClient,
29-
AsyncAPIClient,
30-
)
24+
from ._base_client import DEFAULT_LIMITS, DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
25+
from ._base_client import Stream as Stream
26+
from ._base_client import AsyncStream as AsyncStream
27+
from ._base_client import SyncAPIClient, AsyncAPIClient
3128

3229
__all__ = [
3330
"ENVIRONMENTS",

0 commit comments

Comments
 (0)