1
1
# TODO
2
2
# - [ ] Support text streaming
3
3
# - [ ] Support file streaming
4
- # - [ ] Support asyncio variant
5
4
import hashlib
6
5
import inspect
7
6
import os
12
11
from pathlib import Path
13
12
from typing import (
14
13
Any ,
14
+ AsyncIterator ,
15
15
Callable ,
16
16
Generic ,
17
17
Iterator ,
18
+ Literal ,
18
19
Optional ,
19
20
ParamSpec ,
20
21
Protocol ,
21
22
Tuple ,
22
23
TypeVar ,
24
+ Union ,
23
25
cast ,
24
26
overload ,
25
27
)
@@ -211,27 +213,61 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
211
213
class OutputIterator :
212
214
"""
213
215
An iterator wrapper that handles both regular iteration and string conversion.
216
+ Supports both sync and async iteration patterns.
214
217
"""
215
218
216
- def __init__ (self , iterator_factory , schema : dict , * , is_concatenate : bool ) -> None :
219
+ def __init__ (
220
+ self ,
221
+ iterator_factory : Callable [[], Iterator [Any ]],
222
+ async_iterator_factory : Callable [[], AsyncIterator [Any ]],
223
+ schema : dict ,
224
+ * ,
225
+ is_concatenate : bool
226
+ ) -> None :
217
227
self .iterator_factory = iterator_factory
228
+ self .async_iterator_factory = async_iterator_factory
218
229
self .schema = schema
219
230
self .is_concatenate = is_concatenate
220
231
221
232
def __iter__ (self ) -> Iterator [Any ]:
222
- """Iterate over output items."""
233
+ """Iterate over output items synchronously ."""
223
234
for chunk in self .iterator_factory ():
224
235
if self .is_concatenate :
225
236
yield str (chunk )
226
237
else :
227
238
yield _process_iterator_item (chunk , self .schema )
228
239
240
+ async def __aiter__ (self ) -> AsyncIterator [Any ]:
241
+ """Iterate over output items asynchronously."""
242
+ async for chunk in self .async_iterator_factory ():
243
+ if self .is_concatenate :
244
+ yield str (chunk )
245
+ else :
246
+ yield _process_iterator_item (chunk , self .schema )
247
+
229
248
def __str__ (self ) -> str :
230
249
"""Convert to string by joining segments with empty string."""
231
250
if self .is_concatenate :
232
251
return "" .join ([str (segment ) for segment in self .iterator_factory ()])
233
252
else :
234
- return str (self .iterator_factory ())
253
+ return str (list (self .iterator_factory ()))
254
+
255
+ def __await__ (self ):
256
+ """Make OutputIterator awaitable, returning appropriate result based on concatenate mode."""
257
+ async def _collect_result ():
258
+ if self .is_concatenate :
259
+ # For concatenate iterators, return the joined string
260
+ segments = []
261
+ async for segment in self :
262
+ segments .append (segment )
263
+ return "" .join (segments )
264
+ else :
265
+ # For regular iterators, return the list of items
266
+ items = []
267
+ async for item in self :
268
+ items .append (item )
269
+ return items
270
+ return _collect_result ().__await__ ()
235
271
236
272
237
273
class URLPath (os .PathLike ):
@@ -319,6 +355,7 @@ def output(self) -> O:
319
355
O ,
320
356
OutputIterator (
321
357
lambda : self .prediction .output_iterator (),
358
+ lambda : self .prediction .async_output_iterator (),
322
359
self .schema ,
323
360
is_concatenate = is_concatenate ,
324
361
),
@@ -435,21 +472,186 @@ def openapi_schema(self) -> dict[str, Any]:
435
472
return schema
436
473
437
474
475
+ @dataclass
476
+ class AsyncRun [O ]:
477
+ """
478
+ Represents a running prediction with access to its version (async version).
479
+ """
480
+
481
+ prediction : Prediction
482
+ schema : dict
483
+
484
+ async def output (self ) -> O :
485
+ """
486
+ Wait for the prediction to complete and return its output asynchronously.
487
+ """
488
+ await self .prediction .async_wait ()
489
+
490
+ if self .prediction .status == "failed" :
491
+ raise ModelError (self .prediction )
492
+
493
+ # Return an OutputIterator for iterator output types (including concatenate iterators)
494
+ if _has_iterator_output_type (self .schema ):
495
+ is_concatenate = _has_concatenate_iterator_output_type (self .schema )
496
+ return cast (
497
+ O ,
498
+ OutputIterator (
499
+ lambda : self .prediction .output_iterator (),
500
+ lambda : self .prediction .async_output_iterator (),
501
+ self .schema ,
502
+ is_concatenate = is_concatenate ,
503
+ ),
504
+ )
505
+
506
+ # Process output for file downloads based on schema
507
+ return _process_output_with_schema (self .prediction .output , self .schema )
508
+
509
+ async def logs (self ) -> Optional [str ]:
510
+ """
511
+ Fetch and return the logs from the prediction asynchronously.
512
+ """
513
+ await self .prediction .async_reload ()
514
+
515
+ return self .prediction .logs
516
+
517
+
518
+ @dataclass
519
+ class AsyncFunction (Generic [Input , Output ]):
520
+ """
521
+ An async wrapper for a Replicate model that can be called as a function.
522
+ """
523
+
524
+ function_ref : str
525
+
526
+ def _client (self ) -> Client :
527
+ return Client ()
528
+
529
+ @cached_property
530
+ def _parsed_ref (self ) -> Tuple [str , str , Optional [str ]]:
531
+ return ModelVersionIdentifier .parse (self .function_ref )
532
+
533
+ async def _model (self ) -> Model :
534
+ client = self ._client ()
535
+ model_owner , model_name , _ = self ._parsed_ref
536
+ return await client .models .async_get (f"{ model_owner } /{ model_name } " )
537
+
538
+ async def _version (self ) -> Version | None :
539
+ _ , _ , model_version = self ._parsed_ref
540
+ model = await self ._model ()
541
+ try :
542
+ versions = await model .versions .async_list ()
543
+ if len (versions ) == 0 :
544
+ # if we got an empty list when getting model versions, this
545
+ # model is possibly a procedure instead and should be called via
546
+ # the versionless API
547
+ return None
548
+ except ReplicateError as e :
549
+ if e .status == 404 :
550
+ # if we get a 404 when getting model versions, this is an official
551
+ # model and doesn't have addressable versions (despite what
552
+ # latest_version might tell us)
553
+ return None
554
+ raise
555
+
556
+ if model_version :
557
+ version = await model .versions .async_get (model_version )
558
+ else :
559
+ version = model .latest_version
560
+
561
+ return version
562
+
563
+ async def __call__ (self , * args : Input .args , ** inputs : Input .kwargs ) -> Output :
564
+ run = await self .create (* args , ** inputs )
565
+ return await run .output ()
566
+
567
+ async def create (self , * _ : Input .args , ** inputs : Input .kwargs ) -> AsyncRun [Output ]:
568
+ """
569
+ Start a prediction with the specified inputs asynchronously.
570
+ """
571
+ # Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs
572
+ processed_inputs = {}
573
+ for key , value in inputs .items ():
574
+ if isinstance (value , OutputIterator ) and value .is_concatenate :
575
+ processed_inputs [key ] = str (value )
576
+ elif url := get_path_url (value ):
577
+ processed_inputs [key ] = url
578
+ else :
579
+ processed_inputs [key ] = value
580
+
581
+ version = await self ._version ()
582
+
583
+ if version :
584
+ prediction = await self ._client ().predictions .async_create (
585
+ version = version , input = processed_inputs
586
+ )
587
+ else :
588
+ model = await self ._model ()
589
+ prediction = await self ._client ().models .predictions .async_create (
590
+ model = model , input = processed_inputs
591
+ )
592
+
593
+ return AsyncRun (prediction , await self .openapi_schema ())
594
+
595
+ @property
596
+ def default_example (self ) -> Optional [dict [str , Any ]]:
597
+ """
598
+ Get the default example for this model.
599
+ """
600
+ raise NotImplementedError ("This property has not yet been implemented" )
601
+
602
+ async def openapi_schema (self ) -> dict [str , Any ]:
603
+ """
604
+ Get the OpenAPI schema for this model version asynchronously.
605
+ """
606
+ model = await self ._model ()
607
+ latest_version = model .latest_version
608
+ if latest_version is None :
609
+ msg = f"Model { model .owner } /{ model .name } has no latest version"
610
+ raise ValueError (msg )
611
+
612
+ schema = latest_version .openapi_schema
613
+ if cog_version := latest_version .cog_version :
614
+ schema = make_schema_backwards_compatible (schema , cog_version )
615
+ return schema
616
+
617
+
438
618
@overload
439
619
def use (ref : FunctionRef [Input , Output ]) -> Function [Input , Output ]: ...
440
620
441
621
442
622
@overload
443
623
def use (
444
- ref : str , * , hint : Callable [Input , Output ] | None = None
624
+ ref : FunctionRef [Input , Output ], * , use_async : Literal [False ]
625
+ ) -> Function [Input , Output ]: ...
626
+
627
+
628
+ @overload
629
+ def use (
630
+ ref : FunctionRef [Input , Output ], * , use_async : Literal [True ]
631
+ ) -> AsyncFunction [Input , Output ]: ...
632
+
633
+
634
+ @overload
635
+ def use (
636
+ ref : str , * , hint : Callable [Input , Output ] | None = None , use_async : Literal [True ]
637
+ ) -> AsyncFunction [Input , Output ]: ...
638
+
639
+
640
+ @overload
641
+ def use (
642
+ ref : str ,
643
+ * ,
644
+ hint : Callable [Input , Output ] | None = None ,
645
+ use_async : Literal [False ] = False ,
445
646
) -> Function [Input , Output ]: ...
446
647
447
648
448
649
def use (
449
650
ref : str | FunctionRef [Input , Output ],
450
651
* ,
451
652
hint : Callable [Input , Output ] | None = None ,
452
- ) -> Function [Input , Output ]:
653
+ use_async : bool = False ,
654
+ ) -> Function [Input , Output ] | AsyncFunction [Input , Output ]:
453
655
"""
454
656
Use a Replicate model as a function.
455
657
@@ -469,4 +671,29 @@ def use(
469
671
except AttributeError :
470
672
pass
471
673
674
+ if use_async :
675
+ return AsyncFunction (function_ref = str (ref ))
676
+
472
677
return Function (str (ref ))
678
+
679
+
680
+ # class Model:
681
+ # name = "foo"
682
+
683
+ # def __call__(self) -> str: ...
684
+
685
+
686
+ # def model() -> int: ...
687
+
688
+
689
+ # flux = use("")
690
+ # flux_sync = use("", use_async=False)
691
+ # flux_async = use("", use_async=True)
692
+
693
+ # flux = use("", hint=model)
694
+ # flux_sync = use("", hint=model, use_async=False)
695
+ # flux_async = use("", hint=model, use_async=True)
696
+
697
+ # flux = use(Model())
698
+ # flux_sync = use(Model(), use_async=False)
699
+ # flux_async = use(Model(), use_async=True)
0 commit comments