Skip to content

Commit 8c98f10

Browse files
committed
Make openapi_schema consistently a method
This commit also fixes the implementation to return the OpenAPI schema for the requested version rather than always the latest version.
1 parent 21ac011 commit 8c98f10

File tree

2 files changed

+46
-56
lines changed

2 files changed

+46
-56
lines changed

replicate/use.py

Lines changed: 45 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import hashlib
66
import os
77
import tempfile
8-
from dataclasses import dataclass
98
from functools import cached_property
109
from pathlib import Path
1110
from typing import (
@@ -25,7 +24,6 @@
2524
cast,
2625
overload,
2726
)
28-
from urllib.parse import urlparse
2927

3028
import httpx
3129

@@ -62,36 +60,6 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
6260
return True
6361

6462

65-
def _has_iterator_output_type(openapi_schema: dict) -> bool:
66-
"""
67-
Returns true if the model output type is an iterator (non-concatenate).
68-
"""
69-
output = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {})
70-
return (
71-
output.get("type") == "array" and output.get("x-cog-array-type") == "iterator"
72-
)
73-
74-
75-
def _download_file(url: str) -> Path:
76-
"""
77-
Download a file from URL to a temporary location and return the Path.
78-
"""
79-
parsed_url = urlparse(url)
80-
filename = os.path.basename(parsed_url.path)
81-
82-
if not filename or "." not in filename:
83-
filename = "download"
84-
85-
_, ext = os.path.splitext(filename)
86-
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
87-
with httpx.stream("GET", url) as response:
88-
response.raise_for_status()
89-
for chunk in response.iter_bytes():
90-
temp_file.write(chunk)
91-
92-
return Path(temp_file.name)
93-
94-
9563
def _process_iterator_item(item: Any, openapi_schema: dict) -> Any:
9664
"""
9765
Process a single item from an iterator output based on schema.
@@ -357,7 +325,6 @@ class FunctionRef(Protocol, Generic[Input, Output]):
357325
__call__: Callable[Input, Output]
358326

359327

360-
@dataclass
361328
class Run[O]:
362329
"""
363330
Represents a running prediction with access to the underlying schema.
@@ -416,13 +383,13 @@ def logs(self) -> Optional[str]:
416383
return self._prediction.logs
417384

418385

419-
@dataclass
420386
class Function(Generic[Input, Output]):
421387
"""
422388
A wrapper for a Replicate model that can be called as a function.
423389
"""
424390

425391
_ref: str
392+
_streaming: bool
426393

427394
def __init__(self, ref: str, *, streaming: bool) -> None:
428395
self._ref = ref
@@ -460,7 +427,9 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
460427
)
461428

462429
return Run(
463-
prediction=prediction, schema=self.openapi_schema, streaming=self._streaming
430+
prediction=prediction,
431+
schema=self.openapi_schema(),
432+
streaming=self._streaming,
464433
)
465434

466435
@property
@@ -470,18 +439,26 @@ def default_example(self) -> Optional[dict[str, Any]]:
470439
"""
471440
raise NotImplementedError("This property has not yet been implemented")
472441

473-
@cached_property
474442
def openapi_schema(self) -> dict[str, Any]:
475443
"""
476444
Get the OpenAPI schema for this model version.
477445
"""
478-
latest_version = self._model.latest_version
479-
if latest_version is None:
446+
return self._openapi_schema
447+
448+
@cached_property
449+
def _openapi_schema(self) -> dict[str, Any]:
450+
_, _, model_version = self._parsed_ref
451+
model = self._model
452+
453+
version = (
454+
model.versions.get(model_version) if model_version else model.latest_version
455+
)
456+
if version is None:
480457
msg = f"Model {self._model.owner}/{self._model.name} has no latest version"
481458
raise ValueError(msg)
482459

483-
schema = latest_version.openapi_schema
484-
if cog_version := latest_version.cog_version:
460+
schema = version.openapi_schema
461+
if cog_version := version.cog_version:
485462
schema = make_schema_backwards_compatible(schema, cog_version)
486463
return _dereference_schema(schema)
487464

@@ -524,7 +501,6 @@ def _version(self) -> Version | None:
524501
return version
525502

526503

527-
@dataclass
528504
class AsyncRun[O]:
529505
"""
530506
Represents a running prediction with access to its version (async version).
@@ -583,21 +559,25 @@ async def logs(self) -> Optional[str]:
583559
return self._prediction.logs
584560

585561

586-
@dataclass
587562
class AsyncFunction(Generic[Input, Output]):
588563
"""
589564
An async wrapper for a Replicate model that can be called as a function.
590565
"""
591566

592-
function_ref: str
593-
streaming: bool
567+
_ref: str
568+
_streaming: bool
569+
_openapi_schema: dict[str, Any] | None = None
570+
571+
def __init__(self, ref: str, *, streaming: bool) -> None:
572+
self._ref = ref
573+
self._streaming = streaming
594574

595575
def _client(self) -> Client:
596576
return Client()
597577

598578
@cached_property
599579
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
600-
return ModelVersionIdentifier.parse(self.function_ref)
580+
return ModelVersionIdentifier.parse(self._ref)
601581

602582
async def _model(self) -> Model:
603583
client = self._client()
@@ -662,7 +642,7 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
662642
return AsyncRun(
663643
prediction=prediction,
664644
schema=await self.openapi_schema(),
665-
streaming=self.streaming,
645+
streaming=self._streaming,
666646
)
667647

668648
@property
@@ -676,16 +656,26 @@ async def openapi_schema(self) -> dict[str, Any]:
676656
"""
677657
Get the OpenAPI schema for this model version asynchronously.
678658
"""
679-
model = await self._model()
680-
latest_version = model.latest_version
681-
if latest_version is None:
682-
msg = f"Model {model.owner}/{model.name} has no latest version"
683-
raise ValueError(msg)
659+
if not self._openapi_schema:
660+
_, _, model_version = self._parsed_ref
684661

685-
schema = latest_version.openapi_schema
686-
if cog_version := latest_version.cog_version:
687-
schema = make_schema_backwards_compatible(schema, cog_version)
688-
return _dereference_schema(schema)
662+
model = await self._model()
663+
if model_version:
664+
version = await model.versions.async_get(model_version)
665+
else:
666+
version = model.latest_version
667+
668+
if version is None:
669+
msg = f"Model {model.owner}/{model.name} has no version"
670+
raise ValueError(msg)
671+
672+
schema = version.openapi_schema
673+
if cog_version := version.cog_version:
674+
schema = make_schema_backwards_compatible(schema, cog_version)
675+
676+
self._openapi_schema = _dereference_schema(schema)
677+
678+
return self._openapi_schema
689679

690680

691681
@overload

tests/test_use.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ async def test_use_function_openapi_schema_dereferenced(client_mode):
372372
if client_mode == ClientMode.ASYNC:
373373
schema = await hotdog_detector.openapi_schema()
374374
else:
375-
schema = hotdog_detector.openapi_schema
375+
schema = hotdog_detector.openapi_schema()
376376

377377
assert schema["components"]["schemas"]["Output"] == {
378378
"type": "object",

0 commit comments

Comments
 (0)