1
+ from enum import Enum
1
2
from typing import (
2
3
TYPE_CHECKING ,
3
4
Any ,
4
5
AsyncIterator ,
5
6
Dict ,
6
7
Iterator ,
7
8
List ,
8
- Literal ,
9
9
Optional ,
10
10
)
11
11
12
12
from typing_extensions import Unpack
13
13
14
- from replicate .identifier import ModelVersionIdentifier
15
14
from replicate .exceptions import ReplicateError
15
+ from replicate .identifier import ModelVersionIdentifier
16
16
17
17
try :
18
18
from pydantic import v1 as pydantic # type: ignore
@@ -32,10 +32,20 @@ class ServerSentEvent(pydantic.BaseModel):
32
32
A server-sent event.
33
33
"""
34
34
35
- event : Literal ["message" , "output" , "logs" , "error" , "done" ] = "message"
36
- data : str = ""
37
- id : str = ""
38
- retry : Optional [int ] = None
35
+ class EventType (Enum ):
36
+ """
37
+ A server-sent event type.
38
+ """
39
+
40
+ OUTPUT = "output"
41
+ LOGS = "logs"
42
+ ERROR = "error"
43
+ DONE = "done"
44
+
45
+ event : EventType
46
+ data : str
47
+ id : str
48
+ retry : Optional [int ]
39
49
40
50
def __str__ (self ) -> str :
41
51
if self .event == "output" :
@@ -45,6 +55,10 @@ def __str__(self) -> str:
45
55
46
56
47
57
class EventSource :
58
+ """
59
+ A server-sent event source.
60
+ """
61
+
48
62
response : "httpx.Response"
49
63
50
64
def __init__ (self , response : "httpx.Response" ) -> None :
@@ -57,27 +71,36 @@ def __init__(self, response: "httpx.Response") -> None:
57
71
)
58
72
59
73
class Decoder :
60
- event : Optional [str ] = None
74
+ """
75
+ A decoder for server-sent events.
76
+ """
77
+
78
+ event : Optional ["ServerSentEvent.EventType" ] = None
61
79
data : List [str ] = []
62
80
last_event_id : Optional [str ] = None
63
81
retry : Optional [int ] = None
64
82
65
83
def decode (self , line : str ) -> Optional [ServerSentEvent ]:
84
+ """
85
+ Decode a line and return a server-sent event if applicable.
86
+ """
87
+
66
88
if not line :
67
- if not any ([self .event , self .data , self .last_event_id , self .retry ]):
89
+ if (
90
+ not any ([self .event , self .data , self .last_event_id , self .retry ])
91
+ or self .event is None
92
+ or self .last_event_id is None
93
+ ):
68
94
return None
69
95
70
- try :
71
- sse = ServerSentEvent (
72
- event = self .event ,
73
- data = "\n " .join (self .data ),
74
- id = self .last_event_id ,
75
- retry = self .retry ,
76
- )
77
- except pydantic .ValidationError :
78
- return None
96
+ sse = ServerSentEvent (
97
+ event = self .event ,
98
+ data = "\n " .join (self .data ),
99
+ id = self .last_event_id ,
100
+ retry = self .retry ,
101
+ )
79
102
80
- self .event = ""
103
+ self .event = None
81
104
self .data = []
82
105
self .retry = None
83
106
@@ -91,7 +114,8 @@ def decode(self, line: str) -> Optional[ServerSentEvent]:
91
114
92
115
match fieldname :
93
116
case "event" :
94
- self .event = value
117
+ if event := ServerSentEvent .EventType (value ):
118
+ self .event = event
95
119
case "data" :
96
120
self .data .append (value )
97
121
case "id" :
@@ -155,7 +179,7 @@ def stream(
155
179
)
156
180
157
181
url = prediction .urls and prediction .urls .get ("stream" , None )
158
- if not url :
182
+ if not url or not isinstance ( url , str ) :
159
183
raise ReplicateError ("Model does not support streaming" )
160
184
161
185
headers = {}
@@ -185,7 +209,7 @@ async def async_stream(
185
209
)
186
210
187
211
url = prediction .urls and prediction .urls .get ("stream" , None )
188
- if not url :
212
+ if not url or not isinstance ( url , str ) :
189
213
raise ReplicateError ("Model does not support streaming" )
190
214
191
215
headers = {}
0 commit comments