Skip to content

Add stream and async_stream methods to client #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,34 @@ Some models, like [methexis-inc/img2prompt](https://replicate.com/methexis-inc/i
> print(results)
> ```

## Run a model and stream its output

Replicate’s API supports server-sent event streams (SSEs) for language models.
Use the `stream` method to consume tokens as they're produced by the model.

```python
import replicate

# https://replicate.com/meta/llama-2-70b-chat
model_version = "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"

tokens = []
for event in replicate.stream(
model_version,
input={
"prompt": "Please write a haiku about llamas.",
},
):
print(event)
tokens.append(str(event))

print("".join(tokens))
```

For more information, see
["Streaming output"](https://replicate.com/docs/streaming) in Replicate's docs.


## Run a model in the background

You can start a model and run it in the background:
Expand Down
3 changes: 3 additions & 0 deletions replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
run = default_client.run
async_run = default_client.async_run

stream = default_client.stream
async_stream = default_client.async_stream

paginate = _paginate
async_paginate = _async_paginate

Expand Down
30 changes: 30 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import time
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterable,
Iterator,
Expand All @@ -24,8 +26,12 @@
from replicate.model import Models
from replicate.prediction import Predictions
from replicate.run import async_run, run
from replicate.stream import async_stream, stream
from replicate.training import Trainings

if TYPE_CHECKING:
from replicate.stream import ServerSentEvent


class Client:
"""A Replicate API client library"""
Expand Down Expand Up @@ -152,6 +158,30 @@ async def async_run(

return await async_run(self, ref, input, **params)

def stream(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Iterator["ServerSentEvent"]:
"""
Stream a model's output.
"""

return stream(self, ref, input, **params)

async def async_stream(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> AsyncIterator["ServerSentEvent"]:
"""
Stream a model's output asynchronously.
"""

return async_stream(self, ref, input, **params)


# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155
class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport):
Expand Down
26 changes: 26 additions & 0 deletions replicate/identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import re
from typing import NamedTuple


class ModelVersionIdentifier(NamedTuple):
"""
A reference to a model version in the format owner/name:version.
"""

owner: str
name: str
version: str

@classmethod
def parse(cls, ref: str) -> "ModelVersionIdentifier":
"""
Split a reference in the format owner/name:version into its components.
"""

match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", ref)
if not match:
raise ValueError(
f"Invalid reference to model version: {ref}. Expected format: owner/name:version"
)

return cls(match.group("owner"), match.group("name"), match.group("version"))
26 changes: 4 additions & 22 deletions replicate/run.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import re
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union

from typing_extensions import Unpack

from replicate.exceptions import ModelError, ReplicateError
from replicate.exceptions import ModelError
from replicate.identifier import ModelVersionIdentifier
from replicate.schema import make_schema_backwards_compatible
from replicate.version import Versions

Expand All @@ -23,16 +23,7 @@ def run(
Run a model and wait for its output.
"""

# Split ref into owner, name, version in format owner/name:version
match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", ref)
if not match:
raise ReplicateError(
f"Invalid reference to model version: {ref}. Expected format: owner/name:version"
)

owner = match.group("owner")
name = match.group("name")
version_id = match.group("version")
owner, name, version_id = ModelVersionIdentifier.parse(ref)

prediction = client.predictions.create(
version=version_id, input=input or {}, **params
Expand Down Expand Up @@ -70,16 +61,7 @@ async def async_run(
Run a model and wait for its output asynchronously.
"""

# Split ref into owner, name, version in format owner/name:version
match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", ref)
if not match:
raise ReplicateError(
f"Invalid reference to model version: {ref}. Expected format: owner/name:version"
)

owner = match.group("owner")
name = match.group("name")
version_id = match.group("version")
owner, name, version_id = ModelVersionIdentifier.parse(ref)

prediction = await client.predictions.async_create(
version=version_id, input=input or {}, **params
Expand Down
216 changes: 216 additions & 0 deletions replicate/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
)

from typing_extensions import Unpack

from replicate.exceptions import ReplicateError
from replicate.identifier import ModelVersionIdentifier

try:
from pydantic import v1 as pydantic # type: ignore
except ImportError:
import pydantic # type: ignore


if TYPE_CHECKING:
import httpx

from replicate.client import Client
from replicate.prediction import Predictions


class ServerSentEvent(pydantic.BaseModel):
"""
A server-sent event.
"""

class EventType(Enum):
"""
A server-sent event type.
"""

OUTPUT = "output"
LOGS = "logs"
ERROR = "error"
DONE = "done"

event: EventType
data: str
id: str
retry: Optional[int]

def __str__(self) -> str:
if self.event == "output":
return self.data

return ""


class EventSource:
"""
A server-sent event source.
"""

response: "httpx.Response"

def __init__(self, response: "httpx.Response") -> None:
self.response = response
content_type, _, _ = response.headers["content-type"].partition(";")
if content_type != "text/event-stream":
raise ValueError(
"Expected response Content-Type to be 'text/event-stream', "
f"got {content_type!r}"
)

class Decoder:
"""
A decoder for server-sent events.
"""

event: Optional["ServerSentEvent.EventType"] = None
data: List[str] = []
last_event_id: Optional[str] = None
retry: Optional[int] = None

def decode(self, line: str) -> Optional[ServerSentEvent]:
"""
Decode a line and return a server-sent event if applicable.
"""

if not line:
if (
not any([self.event, self.data, self.last_event_id, self.retry])
or self.event is None
or self.last_event_id is None
):
return None

sse = ServerSentEvent(
event=self.event,
data="\n".join(self.data),
id=self.last_event_id,
retry=self.retry,
)

self.event = None
self.data = []
self.retry = None

return sse

if line.startswith(":"):
return None

fieldname, _, value = line.partition(":")
value = value.lstrip()

if fieldname == "event":
if event := ServerSentEvent.EventType(value):
self.event = event
elif fieldname == "data":
self.data.append(value)
elif fieldname == "id":
if "\0" not in value:
self.last_event_id = value
elif fieldname == "retry":
try:
self.retry = int(value)
except (TypeError, ValueError):
pass

return None

def __iter__(self) -> Iterator[ServerSentEvent]:
decoder = EventSource.Decoder()
for line in self.response.iter_lines():
line = line.rstrip("\n")
sse = decoder.decode(line)
if sse is not None:
if sse.event == "done":
return
elif sse.event == "error":
raise RuntimeError(sse.data)
else:
yield sse

async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
decoder = EventSource.Decoder()
async for line in self.response.aiter_lines():
line = line.rstrip("\n")
sse = decoder.decode(line)
if sse is not None:
if sse.event == "done":
return
elif sse.event == "error":
raise RuntimeError(sse.data)
else:
yield sse


def stream(
client: "Client",
ref: str,
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Iterator[ServerSentEvent]:
"""
Run a model and stream its output.
"""

params = params or {}
params["stream"] = True

_, _, version_id = ModelVersionIdentifier.parse(ref)
prediction = client.predictions.create(
version=version_id, input=input or {}, **params
)

url = prediction.urls and prediction.urls.get("stream", None)
if not url or not isinstance(url, str):
raise ReplicateError("Model does not support streaming")

headers = {}
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"

with client._client.stream("GET", url, headers=headers) as response:
yield from EventSource(response)


async def async_stream(
client: "Client",
ref: str,
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> AsyncIterator[ServerSentEvent]:
"""
Run a model and stream its output asynchronously.
"""

params = params or {}
params["stream"] = True

_, _, version_id = ModelVersionIdentifier.parse(ref)
prediction = await client.predictions.async_create(
version=version_id, input=input or {}, **params
)

url = prediction.urls and prediction.urls.get("stream", None)
if not url or not isinstance(url, str):
raise ReplicateError("Model does not support streaming")

headers = {}
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"

async with client._async_client.stream("GET", url, headers=headers) as response:
async for event in EventSource(response):
yield event
Loading