1
1
# TODO
2
2
# - [ ] Support text streaming
3
3
# - [ ] Support file streaming
4
+ import copy
4
5
import hashlib
5
6
import os
6
7
import tempfile
7
- from dataclasses import dataclass
8
8
from functools import cached_property
9
9
from pathlib import Path
10
10
from typing import (
24
24
cast ,
25
25
overload ,
26
26
)
27
- from urllib .parse import urlparse
28
27
29
28
import httpx
30
29
@@ -61,36 +60,6 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
61
60
return True
62
61
63
62
64
- def _has_iterator_output_type (openapi_schema : dict ) -> bool :
65
- """
66
- Returns true if the model output type is an iterator (non-concatenate).
67
- """
68
- output = openapi_schema .get ("components" , {}).get ("schemas" , {}).get ("Output" , {})
69
- return (
70
- output .get ("type" ) == "array" and output .get ("x-cog-array-type" ) == "iterator"
71
- )
72
-
73
-
74
- def _download_file (url : str ) -> Path :
75
- """
76
- Download a file from URL to a temporary location and return the Path.
77
- """
78
- parsed_url = urlparse (url )
79
- filename = os .path .basename (parsed_url .path )
80
-
81
- if not filename or "." not in filename :
82
- filename = "download"
83
-
84
- _ , ext = os .path .splitext (filename )
85
- with tempfile .NamedTemporaryFile (delete = False , suffix = ext ) as temp_file :
86
- with httpx .stream ("GET" , url ) as response :
87
- response .raise_for_status ()
88
- for chunk in response .iter_bytes ():
89
- temp_file .write (chunk )
90
-
91
- return Path (temp_file .name )
92
-
93
-
94
63
def _process_iterator_item (item : Any , openapi_schema : dict ) -> Any :
95
64
"""
96
65
Process a single item from an iterator output based on schema.
@@ -177,6 +146,60 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: # py
177
146
return output
178
147
179
148
149
+ def _dereference_schema (schema : dict [str , Any ]) -> dict [str , Any ]:
150
+ """
151
+ Performs basic dereferencing on an OpenAPI schema based on the current schemas generated
152
+ by Replicate. This code assumes that:
153
+
154
+ 1) References will always point to a field within #/components/schemas and will error
155
+ if the reference is more deeply nested.
156
+ 2) That the references when used can be discarded.
157
+
158
+ Should something more in-depth be required we could consider using the jsonref package.
159
+ """
160
+ dereferenced = copy .deepcopy (schema )
161
+ schemas = dereferenced .get ("components" , {}).get ("schemas" , {})
162
+ dereferenced_refs = set ()
163
+
164
+ def _resolve_ref (obj : Any ) -> Any :
165
+ if isinstance (obj , dict ):
166
+ if "$ref" in obj :
167
+ ref_path = obj ["$ref" ]
168
+ if ref_path .startswith ("#/components/schemas/" ):
169
+ parts = ref_path .replace ("#/components/schemas/" , "" ).split ("/" , 2 )
170
+
171
+ if len (parts ) > 1 :
172
+ raise NotImplementedError (
173
+ f"Unexpected nested $ref found in schema: { ref_path } "
174
+ )
175
+
176
+ (schema_name ,) = parts
177
+ if schema_name in schemas :
178
+ dereferenced_refs .add (schema_name )
179
+ return _resolve_ref (schemas [schema_name ])
180
+ else :
181
+ return obj
182
+ else :
183
+ return obj
184
+ else :
185
+ return {key : _resolve_ref (value ) for key , value in obj .items ()}
186
+ elif isinstance (obj , list ):
187
+ return [_resolve_ref (item ) for item in obj ]
188
+ else :
189
+ return obj
190
+
191
+ result = _resolve_ref (dereferenced )
192
+
193
+ # Filter out any references that have now been referenced.
194
+ result ["components" ]["schemas" ] = {
195
+ k : v
196
+ for k , v in result ["components" ]["schemas" ].items ()
197
+ if k not in dereferenced_refs
198
+ }
199
+
200
+ return result
201
+
202
+
180
203
T = TypeVar ("T" )
181
204
182
205
@@ -302,7 +325,6 @@ class FunctionRef(Protocol, Generic[Input, Output]):
302
325
__call__ : Callable [Input , Output ]
303
326
304
327
305
- @dataclass
306
328
class Run [O ]:
307
329
"""
308
330
Represents a running prediction with access to the underlying schema.
@@ -361,13 +383,13 @@ def logs(self) -> Optional[str]:
361
383
return self ._prediction .logs
362
384
363
385
364
- @dataclass
365
386
class Function (Generic [Input , Output ]):
366
387
"""
367
388
A wrapper for a Replicate model that can be called as a function.
368
389
"""
369
390
370
391
_ref : str
392
+ _streaming : bool
371
393
372
394
def __init__ (self , ref : str , * , streaming : bool ) -> None :
373
395
self ._ref = ref
@@ -405,7 +427,9 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
405
427
)
406
428
407
429
return Run (
408
- prediction = prediction , schema = self .openapi_schema , streaming = self ._streaming
430
+ prediction = prediction ,
431
+ schema = self .openapi_schema (),
432
+ streaming = self ._streaming ,
409
433
)
410
434
411
435
@property
@@ -415,20 +439,28 @@ def default_example(self) -> Optional[dict[str, Any]]:
415
439
"""
416
440
raise NotImplementedError ("This property has not yet been implemented" )
417
441
418
- @cached_property
419
442
def openapi_schema (self ) -> dict [str , Any ]:
420
443
"""
421
444
Get the OpenAPI schema for this model version.
422
445
"""
423
- latest_version = self ._model .latest_version
424
- if latest_version is None :
425
- msg = f"Model { self ._model .owner } /{ self ._model .name } has no latest version"
446
+ return self ._openapi_schema
447
+
448
+ @cached_property
449
+ def _openapi_schema (self ) -> dict [str , Any ]:
450
+ _ , _ , model_version = self ._parsed_ref
451
+ model = self ._model
452
+
453
+ version = (
454
+ model .versions .get (model_version ) if model_version else model .latest_version
455
+ )
456
+ if version is None :
457
+ msg = f"Model { self ._model .owner } /{ self ._model .name } has no version"
426
458
raise ValueError (msg )
427
459
428
- schema = latest_version .openapi_schema
429
- if cog_version := latest_version .cog_version :
460
+ schema = version .openapi_schema
461
+ if cog_version := version .cog_version :
430
462
schema = make_schema_backwards_compatible (schema , cog_version )
431
- return schema
463
+ return _dereference_schema ( schema )
432
464
433
465
def _client (self ) -> Client :
434
466
return Client ()
@@ -469,7 +501,6 @@ def _version(self) -> Version | None:
469
501
return version
470
502
471
503
472
- @dataclass
473
504
class AsyncRun [O ]:
474
505
"""
475
506
Represents a running prediction with access to its version (async version).
@@ -528,21 +559,25 @@ async def logs(self) -> Optional[str]:
528
559
return self ._prediction .logs
529
560
530
561
531
- @dataclass
532
562
class AsyncFunction (Generic [Input , Output ]):
533
563
"""
534
564
An async wrapper for a Replicate model that can be called as a function.
535
565
"""
536
566
537
- function_ref : str
538
- streaming : bool
567
+ _ref : str
568
+ _streaming : bool
569
+ _openapi_schema : dict [str , Any ] | None = None
570
+
571
+ def __init__ (self , ref : str , * , streaming : bool ) -> None :
572
+ self ._ref = ref
573
+ self ._streaming = streaming
539
574
540
575
def _client (self ) -> Client :
541
576
return Client ()
542
577
543
578
@cached_property
544
579
def _parsed_ref (self ) -> Tuple [str , str , Optional [str ]]:
545
- return ModelVersionIdentifier .parse (self .function_ref )
580
+ return ModelVersionIdentifier .parse (self ._ref )
546
581
547
582
async def _model (self ) -> Model :
548
583
client = self ._client ()
@@ -607,7 +642,7 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
607
642
return AsyncRun (
608
643
prediction = prediction ,
609
644
schema = await self .openapi_schema (),
610
- streaming = self .streaming ,
645
+ streaming = self ._streaming ,
611
646
)
612
647
613
648
@property
@@ -621,16 +656,26 @@ async def openapi_schema(self) -> dict[str, Any]:
621
656
"""
622
657
Get the OpenAPI schema for this model version asynchronously.
623
658
"""
624
- model = await self ._model ()
625
- latest_version = model .latest_version
626
- if latest_version is None :
627
- msg = f"Model { model .owner } /{ model .name } has no latest version"
628
- raise ValueError (msg )
659
+ if not self ._openapi_schema :
660
+ _ , _ , model_version = self ._parsed_ref
629
661
630
- schema = latest_version .openapi_schema
631
- if cog_version := latest_version .cog_version :
632
- schema = make_schema_backwards_compatible (schema , cog_version )
633
- return schema
662
+ model = await self ._model ()
663
+ if model_version :
664
+ version = await model .versions .async_get (model_version )
665
+ else :
666
+ version = model .latest_version
667
+
668
+ if version is None :
669
+ msg = f"Model { model .owner } /{ model .name } has no version"
670
+ raise ValueError (msg )
671
+
672
+ schema = version .openapi_schema
673
+ if cog_version := version .cog_version :
674
+ schema = make_schema_backwards_compatible (schema , cog_version )
675
+
676
+ self ._openapi_schema = _dereference_schema (schema )
677
+
678
+ return self ._openapi_schema
634
679
635
680
636
681
@overload
0 commit comments