Skip to content

Commit c982d53

Browse files
committed
Add asyncio support to use() function
This introduces a new `use_async` keyword to `use()` that will return an `AsyncFunction` instead of a `Function` instance that provides an asyncio compatible interface. The `OutputIterator` has also been updated to implement the `AsyncIterator` interface as well as be awaitable itself.
1 parent 2df34ed commit c982d53

File tree

2 files changed

+559
-63
lines changed

2 files changed

+559
-63
lines changed

replicate/use.py

Lines changed: 233 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# TODO
22
# - [ ] Support text streaming
33
# - [ ] Support file streaming
4-
# - [ ] Support asyncio variant
54
import hashlib
65
import inspect
76
import os
@@ -12,14 +11,17 @@
1211
from pathlib import Path
1312
from typing import (
1413
Any,
14+
AsyncIterator,
1515
Callable,
1616
Generic,
1717
Iterator,
18+
Literal,
1819
Optional,
1920
ParamSpec,
2021
Protocol,
2122
Tuple,
2223
TypeVar,
24+
Union,
2325
cast,
2426
overload,
2527
)
@@ -211,27 +213,61 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
211213
class OutputIterator:
212214
"""
213215
An iterator wrapper that handles both regular iteration and string conversion.
216+
Supports both sync and async iteration patterns.
214217
"""
215218

216-
def __init__(self, iterator_factory, schema: dict, *, is_concatenate: bool) -> None:
219+
def __init__(
220+
self,
221+
iterator_factory: Callable[[], Iterator[Any]],
222+
async_iterator_factory: Callable[[], AsyncIterator[Any]],
223+
schema: dict,
224+
*,
225+
is_concatenate: bool
226+
) -> None:
217227
self.iterator_factory = iterator_factory
228+
self.async_iterator_factory = async_iterator_factory
218229
self.schema = schema
219230
self.is_concatenate = is_concatenate
220231

221232
def __iter__(self) -> Iterator[Any]:
222-
"""Iterate over output items."""
233+
"""Iterate over output items synchronously."""
223234
for chunk in self.iterator_factory():
224235
if self.is_concatenate:
225236
yield str(chunk)
226237
else:
227238
yield _process_iterator_item(chunk, self.schema)
228239

240+
async def __aiter__(self) -> AsyncIterator[Any]:
241+
"""Iterate over output items asynchronously."""
242+
async for chunk in self.async_iterator_factory():
243+
if self.is_concatenate:
244+
yield str(chunk)
245+
else:
246+
yield _process_iterator_item(chunk, self.schema)
247+
229248
def __str__(self) -> str:
230249
"""Convert to string by joining segments with empty string."""
231250
if self.is_concatenate:
232251
return "".join([str(segment) for segment in self.iterator_factory()])
233252
else:
234-
return str(self.iterator_factory())
253+
return str(list(self.iterator_factory()))
254+
255+
def __await__(self):
256+
"""Make OutputIterator awaitable, returning appropriate result based on concatenate mode."""
257+
async def _collect_result():
258+
if self.is_concatenate:
259+
# For concatenate iterators, return the joined string
260+
segments = []
261+
async for segment in self:
262+
segments.append(segment)
263+
return "".join(segments)
264+
else:
265+
# For regular iterators, return the list of items
266+
items = []
267+
async for item in self:
268+
items.append(item)
269+
return items
270+
return _collect_result().__await__()
235271

236272

237273
class URLPath(os.PathLike):
@@ -319,6 +355,7 @@ def output(self) -> O:
319355
O,
320356
OutputIterator(
321357
lambda: self.prediction.output_iterator(),
358+
lambda: self.prediction.async_output_iterator(),
322359
self.schema,
323360
is_concatenate=is_concatenate,
324361
),
@@ -435,21 +472,186 @@ def openapi_schema(self) -> dict[str, Any]:
435472
return schema
436473

437474

475+
@dataclass
476+
class AsyncRun[O]:
477+
"""
478+
Represents a running prediction with access to its version (async version).
479+
"""
480+
481+
prediction: Prediction
482+
schema: dict
483+
484+
async def output(self) -> O:
485+
"""
486+
Wait for the prediction to complete and return its output asynchronously.
487+
"""
488+
await self.prediction.async_wait()
489+
490+
if self.prediction.status == "failed":
491+
raise ModelError(self.prediction)
492+
493+
# Return an OutputIterator for iterator output types (including concatenate iterators)
494+
if _has_iterator_output_type(self.schema):
495+
is_concatenate = _has_concatenate_iterator_output_type(self.schema)
496+
return cast(
497+
O,
498+
OutputIterator(
499+
lambda: self.prediction.output_iterator(),
500+
lambda: self.prediction.async_output_iterator(),
501+
self.schema,
502+
is_concatenate=is_concatenate,
503+
),
504+
)
505+
506+
# Process output for file downloads based on schema
507+
return _process_output_with_schema(self.prediction.output, self.schema)
508+
509+
async def logs(self) -> Optional[str]:
510+
"""
511+
Fetch and return the logs from the prediction asynchronously.
512+
"""
513+
await self.prediction.async_reload()
514+
515+
return self.prediction.logs
516+
517+
518+
@dataclass
519+
class AsyncFunction(Generic[Input, Output]):
520+
"""
521+
An async wrapper for a Replicate model that can be called as a function.
522+
"""
523+
524+
function_ref: str
525+
526+
def _client(self) -> Client:
527+
return Client()
528+
529+
@cached_property
530+
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
531+
return ModelVersionIdentifier.parse(self.function_ref)
532+
533+
async def _model(self) -> Model:
534+
client = self._client()
535+
model_owner, model_name, _ = self._parsed_ref
536+
return await client.models.async_get(f"{model_owner}/{model_name}")
537+
538+
async def _version(self) -> Version | None:
539+
_, _, model_version = self._parsed_ref
540+
model = await self._model()
541+
try:
542+
versions = await model.versions.async_list()
543+
if len(versions) == 0:
544+
# if we got an empty list when getting model versions, this
545+
# model is possibly a procedure instead and should be called via
546+
# the versionless API
547+
return None
548+
except ReplicateError as e:
549+
if e.status == 404:
550+
# if we get a 404 when getting model versions, this is an official
551+
# model and doesn't have addressable versions (despite what
552+
# latest_version might tell us)
553+
return None
554+
raise
555+
556+
if model_version:
557+
version = await model.versions.async_get(model_version)
558+
else:
559+
version = model.latest_version
560+
561+
return version
562+
563+
async def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output:
564+
run = await self.create(*args, **inputs)
565+
return await run.output()
566+
567+
async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Output]:
568+
"""
569+
Start a prediction with the specified inputs asynchronously.
570+
"""
571+
# Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs
572+
processed_inputs = {}
573+
for key, value in inputs.items():
574+
if isinstance(value, OutputIterator) and value.is_concatenate:
575+
processed_inputs[key] = str(value)
576+
elif url := get_path_url(value):
577+
processed_inputs[key] = url
578+
else:
579+
processed_inputs[key] = value
580+
581+
version = await self._version()
582+
583+
if version:
584+
prediction = await self._client().predictions.async_create(
585+
version=version, input=processed_inputs
586+
)
587+
else:
588+
model = await self._model()
589+
prediction = await self._client().models.predictions.async_create(
590+
model=model, input=processed_inputs
591+
)
592+
593+
return AsyncRun(prediction, await self.openapi_schema())
594+
595+
@property
596+
def default_example(self) -> Optional[dict[str, Any]]:
597+
"""
598+
Get the default example for this model.
599+
"""
600+
raise NotImplementedError("This property has not yet been implemented")
601+
602+
async def openapi_schema(self) -> dict[str, Any]:
603+
"""
604+
Get the OpenAPI schema for this model version asynchronously.
605+
"""
606+
model = await self._model()
607+
latest_version = model.latest_version
608+
if latest_version is None:
609+
msg = f"Model {model.owner}/{model.name} has no latest version"
610+
raise ValueError(msg)
611+
612+
schema = latest_version.openapi_schema
613+
if cog_version := latest_version.cog_version:
614+
schema = make_schema_backwards_compatible(schema, cog_version)
615+
return schema
616+
617+
438618
@overload
439619
def use(ref: FunctionRef[Input, Output]) -> Function[Input, Output]: ...
440620

441621

442622
@overload
443623
def use(
444-
ref: str, *, hint: Callable[Input, Output] | None = None
624+
ref: FunctionRef[Input, Output], *, use_async: Literal[False]
625+
) -> Function[Input, Output]: ...
626+
627+
628+
@overload
629+
def use(
630+
ref: FunctionRef[Input, Output], *, use_async: Literal[True]
631+
) -> AsyncFunction[Input, Output]: ...
632+
633+
634+
@overload
635+
def use(
636+
ref: str, *, hint: Callable[Input, Output] | None = None, use_async: Literal[True]
637+
) -> AsyncFunction[Input, Output]: ...
638+
639+
640+
@overload
641+
def use(
642+
ref: str,
643+
*,
644+
hint: Callable[Input, Output] | None = None,
645+
use_async: Literal[False] = False,
445646
) -> Function[Input, Output]: ...
446647

447648

448649
def use(
449650
ref: str | FunctionRef[Input, Output],
450651
*,
451652
hint: Callable[Input, Output] | None = None,
452-
) -> Function[Input, Output]:
653+
use_async: bool = False,
654+
) -> Function[Input, Output] | AsyncFunction[Input, Output]:
453655
"""
454656
Use a Replicate model as a function.
455657
@@ -469,4 +671,29 @@ def use(
469671
except AttributeError:
470672
pass
471673

674+
if use_async:
675+
return AsyncFunction(function_ref=str(ref))
676+
472677
return Function(str(ref))
678+
679+
680+
# class Model:
681+
# name = "foo"
682+
683+
# def __call__(self) -> str: ...
684+
685+
686+
# def model() -> int: ...
687+
688+
689+
# flux = use("")
690+
# flux_sync = use("", use_async=False)
691+
# flux_async = use("", use_async=True)
692+
693+
# flux = use("", hint=model)
694+
# flux_sync = use("", hint=model, use_async=False)
695+
# flux_async = use("", hint=model, use_async=True)
696+
697+
# flux = use(Model())
698+
# flux_sync = use(Model(), use_async=False)
699+
# flux_async = use(Model(), use_async=True)

0 commit comments

Comments
 (0)