Skip to content

Commit cd5383d

Browse files
committed
Fix linting errors in replicate/stream.py
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 2ead118 commit cd5383d

File tree

1 file changed

+45
-21
lines changed

1 file changed

+45
-21
lines changed

replicate/stream.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1+
from enum import Enum
12
from typing import (
23
TYPE_CHECKING,
34
Any,
45
AsyncIterator,
56
Dict,
67
Iterator,
78
List,
8-
Literal,
99
Optional,
1010
)
1111

1212
from typing_extensions import Unpack
1313

14-
from replicate.identifier import ModelVersionIdentifier
1514
from replicate.exceptions import ReplicateError
15+
from replicate.identifier import ModelVersionIdentifier
1616

1717
try:
1818
from pydantic import v1 as pydantic # type: ignore
@@ -32,10 +32,20 @@ class ServerSentEvent(pydantic.BaseModel):
3232
A server-sent event.
3333
"""
3434

35-
event: Literal["message", "output", "logs", "error", "done"] = "message"
36-
data: str = ""
37-
id: str = ""
38-
retry: Optional[int] = None
35+
class EventType(Enum):
36+
"""
37+
A server-sent event type.
38+
"""
39+
40+
OUTPUT = "output"
41+
LOGS = "logs"
42+
ERROR = "error"
43+
DONE = "done"
44+
45+
event: EventType
46+
data: str
47+
id: str
48+
retry: Optional[int]
3949

4050
def __str__(self) -> str:
4151
if self.event == "output":
@@ -45,6 +55,10 @@ def __str__(self) -> str:
4555

4656

4757
class EventSource:
58+
"""
59+
A server-sent event source.
60+
"""
61+
4862
response: "httpx.Response"
4963

5064
def __init__(self, response: "httpx.Response") -> None:
@@ -57,27 +71,36 @@ def __init__(self, response: "httpx.Response") -> None:
5771
)
5872

5973
class Decoder:
60-
event: Optional[str] = None
74+
"""
75+
A decoder for server-sent events.
76+
"""
77+
78+
event: Optional["ServerSentEvent.EventType"] = None
6179
data: List[str] = []
6280
last_event_id: Optional[str] = None
6381
retry: Optional[int] = None
6482

6583
def decode(self, line: str) -> Optional[ServerSentEvent]:
84+
"""
85+
Decode a line and return a server-sent event if applicable.
86+
"""
87+
6688
if not line:
67-
if not any([self.event, self.data, self.last_event_id, self.retry]):
89+
if (
90+
not any([self.event, self.data, self.last_event_id, self.retry])
91+
or self.event is None
92+
or self.last_event_id is None
93+
):
6894
return None
6995

70-
try:
71-
sse = ServerSentEvent(
72-
event=self.event,
73-
data="\n".join(self.data),
74-
id=self.last_event_id,
75-
retry=self.retry,
76-
)
77-
except pydantic.ValidationError:
78-
return None
96+
sse = ServerSentEvent(
97+
event=self.event,
98+
data="\n".join(self.data),
99+
id=self.last_event_id,
100+
retry=self.retry,
101+
)
79102

80-
self.event = ""
103+
self.event = None
81104
self.data = []
82105
self.retry = None
83106

@@ -91,7 +114,8 @@ def decode(self, line: str) -> Optional[ServerSentEvent]:
91114

92115
match fieldname:
93116
case "event":
94-
self.event = value
117+
if event := ServerSentEvent.EventType(value):
118+
self.event = event
95119
case "data":
96120
self.data.append(value)
97121
case "id":
@@ -155,7 +179,7 @@ def stream(
155179
)
156180

157181
url = prediction.urls and prediction.urls.get("stream", None)
158-
if not url:
182+
if not url or not isinstance(url, str):
159183
raise ReplicateError("Model does not support streaming")
160184

161185
headers = {}
@@ -185,7 +209,7 @@ async def async_stream(
185209
)
186210

187211
url = prediction.urls and prediction.urls.get("stream", None)
188-
if not url:
212+
if not url or not isinstance(url, str):
189213
raise ReplicateError("Model does not support streaming")
190214

191215
headers = {}

0 commit comments

Comments
 (0)