Skip to content

Commit 36df860

Browse files
authored
Dereference OpenAPI schema returned by Function.openapi_schema (#439)
This PR updates the `openapi_schema` method in `Function` and `AsyncFunction` to return a dereferenced API schema. This fixes a bug with transformation for models that output objects with named interfaces. This PR also introduces a breaking change in that `Function.openapi_schema()` is now a method rather than a property to match `AsyncFunction.openapi_schema()` the latter needs to be a method because the lookup of the version is both lazy and async. Lastly, the PR fixes an issue with the implementation where the latest OpenAPI schema was always returned. Now we return the version specified in the ref (or the latest if unspecified).
1 parent 0510f24 commit 36df860

File tree

4 files changed

+160
-60
lines changed

4 files changed

+160
-60
lines changed

.github/workflows/ci.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ name: CI
22

33
on:
44
push:
5-
branches: ["main"]
5+
branches: ["main", "beta"]
66

77
pull_request:
8-
branches: ["main"]
8+
branches: ["main", "beta"]
99

1010
jobs:
1111
test:

replicate/use.py

Lines changed: 102 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# TODO
22
# - [ ] Support text streaming
33
# - [ ] Support file streaming
4+
import copy
45
import hashlib
56
import os
67
import tempfile
7-
from dataclasses import dataclass
88
from functools import cached_property
99
from pathlib import Path
1010
from typing import (
@@ -24,7 +24,6 @@
2424
cast,
2525
overload,
2626
)
27-
from urllib.parse import urlparse
2827

2928
import httpx
3029

@@ -61,36 +60,6 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
6160
return True
6261

6362

64-
def _has_iterator_output_type(openapi_schema: dict) -> bool:
65-
"""
66-
Returns true if the model output type is an iterator (non-concatenate).
67-
"""
68-
output = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {})
69-
return (
70-
output.get("type") == "array" and output.get("x-cog-array-type") == "iterator"
71-
)
72-
73-
74-
def _download_file(url: str) -> Path:
75-
"""
76-
Download a file from URL to a temporary location and return the Path.
77-
"""
78-
parsed_url = urlparse(url)
79-
filename = os.path.basename(parsed_url.path)
80-
81-
if not filename or "." not in filename:
82-
filename = "download"
83-
84-
_, ext = os.path.splitext(filename)
85-
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
86-
with httpx.stream("GET", url) as response:
87-
response.raise_for_status()
88-
for chunk in response.iter_bytes():
89-
temp_file.write(chunk)
90-
91-
return Path(temp_file.name)
92-
93-
9463
def _process_iterator_item(item: Any, openapi_schema: dict) -> Any:
9564
"""
9665
Process a single item from an iterator output based on schema.
@@ -177,6 +146,60 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: # py
177146
return output
178147

179148

149+
def _dereference_schema(schema: dict[str, Any]) -> dict[str, Any]:
150+
"""
151+
Performs basic dereferencing on an OpenAPI schema based on the current schemas generated
152+
by Replicate. This code assumes that:
153+
154+
1) References will always point to a field within #/components/schemas and will error
155+
if the reference is more deeply nested.
156+
2) That the references when used can be discarded.
157+
158+
Should something more in-depth be required we could consider using the jsonref package.
159+
"""
160+
dereferenced = copy.deepcopy(schema)
161+
schemas = dereferenced.get("components", {}).get("schemas", {})
162+
dereferenced_refs = set()
163+
164+
def _resolve_ref(obj: Any) -> Any:
165+
if isinstance(obj, dict):
166+
if "$ref" in obj:
167+
ref_path = obj["$ref"]
168+
if ref_path.startswith("#/components/schemas/"):
169+
parts = ref_path.replace("#/components/schemas/", "").split("/", 2)
170+
171+
if len(parts) > 1:
172+
raise NotImplementedError(
173+
f"Unexpected nested $ref found in schema: {ref_path}"
174+
)
175+
176+
(schema_name,) = parts
177+
if schema_name in schemas:
178+
dereferenced_refs.add(schema_name)
179+
return _resolve_ref(schemas[schema_name])
180+
else:
181+
return obj
182+
else:
183+
return obj
184+
else:
185+
return {key: _resolve_ref(value) for key, value in obj.items()}
186+
elif isinstance(obj, list):
187+
return [_resolve_ref(item) for item in obj]
188+
else:
189+
return obj
190+
191+
result = _resolve_ref(dereferenced)
192+
193+
# Filter out any references that have now been referenced.
194+
result["components"]["schemas"] = {
195+
k: v
196+
for k, v in result["components"]["schemas"].items()
197+
if k not in dereferenced_refs
198+
}
199+
200+
return result
201+
202+
180203
T = TypeVar("T")
181204

182205

@@ -302,7 +325,6 @@ class FunctionRef(Protocol, Generic[Input, Output]):
302325
__call__: Callable[Input, Output]
303326

304327

305-
@dataclass
306328
class Run[O]:
307329
"""
308330
Represents a running prediction with access to the underlying schema.
@@ -361,13 +383,13 @@ def logs(self) -> Optional[str]:
361383
return self._prediction.logs
362384

363385

364-
@dataclass
365386
class Function(Generic[Input, Output]):
366387
"""
367388
A wrapper for a Replicate model that can be called as a function.
368389
"""
369390

370391
_ref: str
392+
_streaming: bool
371393

372394
def __init__(self, ref: str, *, streaming: bool) -> None:
373395
self._ref = ref
@@ -405,7 +427,9 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
405427
)
406428

407429
return Run(
408-
prediction=prediction, schema=self.openapi_schema, streaming=self._streaming
430+
prediction=prediction,
431+
schema=self.openapi_schema(),
432+
streaming=self._streaming,
409433
)
410434

411435
@property
@@ -415,20 +439,28 @@ def default_example(self) -> Optional[dict[str, Any]]:
415439
"""
416440
raise NotImplementedError("This property has not yet been implemented")
417441

418-
@cached_property
419442
def openapi_schema(self) -> dict[str, Any]:
420443
"""
421444
Get the OpenAPI schema for this model version.
422445
"""
423-
latest_version = self._model.latest_version
424-
if latest_version is None:
425-
msg = f"Model {self._model.owner}/{self._model.name} has no latest version"
446+
return self._openapi_schema
447+
448+
@cached_property
449+
def _openapi_schema(self) -> dict[str, Any]:
450+
_, _, model_version = self._parsed_ref
451+
model = self._model
452+
453+
version = (
454+
model.versions.get(model_version) if model_version else model.latest_version
455+
)
456+
if version is None:
457+
msg = f"Model {self._model.owner}/{self._model.name} has no version"
426458
raise ValueError(msg)
427459

428-
schema = latest_version.openapi_schema
429-
if cog_version := latest_version.cog_version:
460+
schema = version.openapi_schema
461+
if cog_version := version.cog_version:
430462
schema = make_schema_backwards_compatible(schema, cog_version)
431-
return schema
463+
return _dereference_schema(schema)
432464

433465
def _client(self) -> Client:
434466
return Client()
@@ -469,7 +501,6 @@ def _version(self) -> Version | None:
469501
return version
470502

471503

472-
@dataclass
473504
class AsyncRun[O]:
474505
"""
475506
Represents a running prediction with access to its version (async version).
@@ -528,21 +559,25 @@ async def logs(self) -> Optional[str]:
528559
return self._prediction.logs
529560

530561

531-
@dataclass
532562
class AsyncFunction(Generic[Input, Output]):
533563
"""
534564
An async wrapper for a Replicate model that can be called as a function.
535565
"""
536566

537-
function_ref: str
538-
streaming: bool
567+
_ref: str
568+
_streaming: bool
569+
_openapi_schema: dict[str, Any] | None = None
570+
571+
def __init__(self, ref: str, *, streaming: bool) -> None:
572+
self._ref = ref
573+
self._streaming = streaming
539574

540575
def _client(self) -> Client:
541576
return Client()
542577

543578
@cached_property
544579
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
545-
return ModelVersionIdentifier.parse(self.function_ref)
580+
return ModelVersionIdentifier.parse(self._ref)
546581

547582
async def _model(self) -> Model:
548583
client = self._client()
@@ -607,7 +642,7 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
607642
return AsyncRun(
608643
prediction=prediction,
609644
schema=await self.openapi_schema(),
610-
streaming=self.streaming,
645+
streaming=self._streaming,
611646
)
612647

613648
@property
@@ -621,16 +656,26 @@ async def openapi_schema(self) -> dict[str, Any]:
621656
"""
622657
Get the OpenAPI schema for this model version asynchronously.
623658
"""
624-
model = await self._model()
625-
latest_version = model.latest_version
626-
if latest_version is None:
627-
msg = f"Model {model.owner}/{model.name} has no latest version"
628-
raise ValueError(msg)
659+
if not self._openapi_schema:
660+
_, _, model_version = self._parsed_ref
629661

630-
schema = latest_version.openapi_schema
631-
if cog_version := latest_version.cog_version:
632-
schema = make_schema_backwards_compatible(schema, cog_version)
633-
return schema
662+
model = await self._model()
663+
if model_version:
664+
version = await model.versions.async_get(model_version)
665+
else:
666+
version = model.latest_version
667+
668+
if version is None:
669+
msg = f"Model {model.owner}/{model.name} has no version"
670+
raise ValueError(msg)
671+
672+
schema = version.openapi_schema
673+
if cog_version := version.cog_version:
674+
schema = make_schema_backwards_compatible(schema, cog_version)
675+
676+
self._openapi_schema = _dereference_schema(schema)
677+
678+
return self._openapi_schema
634679

635680

636681
@overload

tests/test_use.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,61 @@ async def test_use_function_create_method(client_mode):
334334
assert run._prediction.input == {"prompt": "hello world"}
335335

336336

337+
@pytest.mark.asyncio
338+
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
339+
@respx.mock
340+
async def test_use_function_openapi_schema_dereferenced(client_mode):
341+
mock_model_endpoints(
342+
versions=[
343+
create_mock_version(
344+
{
345+
"openapi_schema": {
346+
"components": {
347+
"schemas": {
348+
"Output": {"$ref": "#/components/schemas/ModelOutput"},
349+
"ModelOutput": {
350+
"type": "object",
351+
"properties": {
352+
"text": {"type": "string"},
353+
"image": {
354+
"type": "string",
355+
"format": "uri",
356+
},
357+
"count": {"type": "integer"},
358+
},
359+
},
360+
}
361+
}
362+
}
363+
}
364+
)
365+
]
366+
)
367+
368+
hotdog_detector = replicate.use(
369+
"acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC
370+
)
371+
372+
if client_mode == ClientMode.ASYNC:
373+
schema = await hotdog_detector.openapi_schema()
374+
else:
375+
schema = hotdog_detector.openapi_schema()
376+
377+
assert schema["components"]["schemas"]["Output"] == {
378+
"type": "object",
379+
"properties": {
380+
"text": {"type": "string"},
381+
"image": {
382+
"type": "string",
383+
"format": "uri",
384+
},
385+
"count": {"type": "integer"},
386+
},
387+
}
388+
389+
assert "ModelOutput" not in schema["components"]["schemas"]
390+
391+
337392
@pytest.mark.asyncio
338393
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
339394
@respx.mock

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)