5
5
import hashlib
6
6
import os
7
7
import tempfile
8
- from dataclasses import dataclass
9
8
from functools import cached_property
10
9
from pathlib import Path
11
10
from typing import (
25
24
cast ,
26
25
overload ,
27
26
)
28
- from urllib .parse import urlparse
29
27
30
28
import httpx
31
29
@@ -62,36 +60,6 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
62
60
return True
63
61
64
62
65
- def _has_iterator_output_type (openapi_schema : dict ) -> bool :
66
- """
67
- Returns true if the model output type is an iterator (non-concatenate).
68
- """
69
- output = openapi_schema .get ("components" , {}).get ("schemas" , {}).get ("Output" , {})
70
- return (
71
- output .get ("type" ) == "array" and output .get ("x-cog-array-type" ) == "iterator"
72
- )
73
-
74
-
75
- def _download_file (url : str ) -> Path :
76
- """
77
- Download a file from URL to a temporary location and return the Path.
78
- """
79
- parsed_url = urlparse (url )
80
- filename = os .path .basename (parsed_url .path )
81
-
82
- if not filename or "." not in filename :
83
- filename = "download"
84
-
85
- _ , ext = os .path .splitext (filename )
86
- with tempfile .NamedTemporaryFile (delete = False , suffix = ext ) as temp_file :
87
- with httpx .stream ("GET" , url ) as response :
88
- response .raise_for_status ()
89
- for chunk in response .iter_bytes ():
90
- temp_file .write (chunk )
91
-
92
- return Path (temp_file .name )
93
-
94
-
95
63
def _process_iterator_item (item : Any , openapi_schema : dict ) -> Any :
96
64
"""
97
65
Process a single item from an iterator output based on schema.
@@ -357,7 +325,6 @@ class FunctionRef(Protocol, Generic[Input, Output]):
357
325
__call__ : Callable [Input , Output ]
358
326
359
327
360
- @dataclass
361
328
class Run [O ]:
362
329
"""
363
330
Represents a running prediction with access to the underlying schema.
@@ -416,13 +383,13 @@ def logs(self) -> Optional[str]:
416
383
return self ._prediction .logs
417
384
418
385
419
- @dataclass
420
386
class Function (Generic [Input , Output ]):
421
387
"""
422
388
A wrapper for a Replicate model that can be called as a function.
423
389
"""
424
390
425
391
_ref : str
392
+ _streaming : bool
426
393
427
394
def __init__ (self , ref : str , * , streaming : bool ) -> None :
428
395
self ._ref = ref
@@ -460,7 +427,9 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
460
427
)
461
428
462
429
return Run (
463
- prediction = prediction , schema = self .openapi_schema , streaming = self ._streaming
430
+ prediction = prediction ,
431
+ schema = self .openapi_schema (),
432
+ streaming = self ._streaming ,
464
433
)
465
434
466
435
@property
@@ -470,18 +439,26 @@ def default_example(self) -> Optional[dict[str, Any]]:
470
439
"""
471
440
raise NotImplementedError ("This property has not yet been implemented" )
472
441
473
- @cached_property
474
442
def openapi_schema (self ) -> dict [str , Any ]:
475
443
"""
476
444
Get the OpenAPI schema for this model version.
477
445
"""
478
- latest_version = self ._model .latest_version
479
- if latest_version is None :
446
+ return self ._openapi_schema
447
+
448
+ @cached_property
449
+ def _openapi_schema (self ) -> dict [str , Any ]:
450
+ _ , _ , model_version = self ._parsed_ref
451
+ model = self ._model
452
+
453
+ version = (
454
+ model .versions .get (model_version ) if model_version else model .latest_version
455
+ )
456
+ if version is None :
480
457
msg = f"Model { self ._model .owner } /{ self ._model .name } has no latest version"
481
458
raise ValueError (msg )
482
459
483
- schema = latest_version .openapi_schema
484
- if cog_version := latest_version .cog_version :
460
+ schema = version .openapi_schema
461
+ if cog_version := version .cog_version :
485
462
schema = make_schema_backwards_compatible (schema , cog_version )
486
463
return _dereference_schema (schema )
487
464
@@ -524,7 +501,6 @@ def _version(self) -> Version | None:
524
501
return version
525
502
526
503
527
- @dataclass
528
504
class AsyncRun [O ]:
529
505
"""
530
506
Represents a running prediction with access to its version (async version).
@@ -583,21 +559,25 @@ async def logs(self) -> Optional[str]:
583
559
return self ._prediction .logs
584
560
585
561
586
- @dataclass
587
562
class AsyncFunction (Generic [Input , Output ]):
588
563
"""
589
564
An async wrapper for a Replicate model that can be called as a function.
590
565
"""
591
566
592
- function_ref : str
593
- streaming : bool
567
+ _ref : str
568
+ _streaming : bool
569
+ _openapi_schema : dict [str , Any ] | None = None
570
+
571
+ def __init__ (self , ref : str , * , streaming : bool ) -> None :
572
+ self ._ref = ref
573
+ self ._streaming = streaming
594
574
595
575
def _client (self ) -> Client :
596
576
return Client ()
597
577
598
578
@cached_property
599
579
def _parsed_ref (self ) -> Tuple [str , str , Optional [str ]]:
600
- return ModelVersionIdentifier .parse (self .function_ref )
580
+ return ModelVersionIdentifier .parse (self ._ref )
601
581
602
582
async def _model (self ) -> Model :
603
583
client = self ._client ()
@@ -662,7 +642,7 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
662
642
return AsyncRun (
663
643
prediction = prediction ,
664
644
schema = await self .openapi_schema (),
665
- streaming = self .streaming ,
645
+ streaming = self ._streaming ,
666
646
)
667
647
668
648
@property
@@ -676,16 +656,26 @@ async def openapi_schema(self) -> dict[str, Any]:
676
656
"""
677
657
Get the OpenAPI schema for this model version asynchronously.
678
658
"""
679
- model = await self ._model ()
680
- latest_version = model .latest_version
681
- if latest_version is None :
682
- msg = f"Model { model .owner } /{ model .name } has no latest version"
683
- raise ValueError (msg )
659
+ if not self ._openapi_schema :
660
+ _ , _ , model_version = self ._parsed_ref
684
661
685
- schema = latest_version .openapi_schema
686
- if cog_version := latest_version .cog_version :
687
- schema = make_schema_backwards_compatible (schema , cog_version )
688
- return _dereference_schema (schema )
662
+ model = await self ._model ()
663
+ if model_version :
664
+ version = await model .versions .async_get (model_version )
665
+ else :
666
+ version = model .latest_version
667
+
668
+ if version is None :
669
+ msg = f"Model { model .owner } /{ model .name } has no version"
670
+ raise ValueError (msg )
671
+
672
+ schema = version .openapi_schema
673
+ if cog_version := version .cog_version :
674
+ schema = make_schema_backwards_compatible (schema , cog_version )
675
+
676
+ self ._openapi_schema = _dereference_schema (schema )
677
+
678
+ return self ._openapi_schema
689
679
690
680
691
681
@overload
0 commit comments