Skip to content

Dereference OpenAPI schema returned by Function.openapi_schema #439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ name: CI

on:
push:
branches: ["main"]
branches: ["main", "beta"]

pull_request:
branches: ["main"]
branches: ["main", "beta"]

jobs:
test:
Expand Down
159 changes: 102 additions & 57 deletions replicate/use.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# TODO
# - [ ] Support text streaming
# - [ ] Support file streaming
import copy
import hashlib
import os
import tempfile
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import (
Expand All @@ -24,7 +24,6 @@
cast,
overload,
)
from urllib.parse import urlparse

import httpx

Expand Down Expand Up @@ -61,36 +60,6 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
return True


def _has_iterator_output_type(openapi_schema: dict) -> bool:
"""
Returns true if the model output type is an iterator (non-concatenate).
"""
output = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {})
return (
output.get("type") == "array" and output.get("x-cog-array-type") == "iterator"
)


def _download_file(url: str) -> Path:
"""
Download a file from URL to a temporary location and return the Path.
"""
parsed_url = urlparse(url)
filename = os.path.basename(parsed_url.path)

if not filename or "." not in filename:
filename = "download"

_, ext = os.path.splitext(filename)
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
with httpx.stream("GET", url) as response:
response.raise_for_status()
for chunk in response.iter_bytes():
temp_file.write(chunk)

return Path(temp_file.name)


def _process_iterator_item(item: Any, openapi_schema: dict) -> Any:
"""
Process a single item from an iterator output based on schema.
Expand Down Expand Up @@ -177,6 +146,60 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: # py
return output


def _dereference_schema(schema: dict[str, Any]) -> dict[str, Any]:
"""
Performs basic dereferencing on an OpenAPI schema based on the current schemas generated
by Replicate. This code assumes that:

1) References will always point to a field within #/components/schemas and will error
if the reference is more deeply nested.
2) That the references when used can be discarded.

Should something more in-depth be required we could consider using the jsonref package.
"""
dereferenced = copy.deepcopy(schema)
schemas = dereferenced.get("components", {}).get("schemas", {})
dereferenced_refs = set()

def _resolve_ref(obj: Any) -> Any:
if isinstance(obj, dict):
if "$ref" in obj:
ref_path = obj["$ref"]
if ref_path.startswith("#/components/schemas/"):
parts = ref_path.replace("#/components/schemas/", "").split("/", 2)

if len(parts) > 1:
raise NotImplementedError(
f"Unexpected nested $ref found in schema: {ref_path}"
)

(schema_name,) = parts
if schema_name in schemas:
dereferenced_refs.add(schema_name)
return _resolve_ref(schemas[schema_name])
else:
return obj
else:
return obj
else:
return {key: _resolve_ref(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [_resolve_ref(item) for item in obj]
else:
return obj

result = _resolve_ref(dereferenced)

# Filter out any references that have now been referenced.
result["components"]["schemas"] = {
k: v
for k, v in result["components"]["schemas"].items()
if k not in dereferenced_refs
}

return result


T = TypeVar("T")


Expand Down Expand Up @@ -302,7 +325,6 @@ class FunctionRef(Protocol, Generic[Input, Output]):
__call__: Callable[Input, Output]


@dataclass
class Run[O]:
"""
Represents a running prediction with access to the underlying schema.
Expand Down Expand Up @@ -361,13 +383,13 @@ def logs(self) -> Optional[str]:
return self._prediction.logs


@dataclass
class Function(Generic[Input, Output]):
"""
A wrapper for a Replicate model that can be called as a function.
"""

_ref: str
_streaming: bool

def __init__(self, ref: str, *, streaming: bool) -> None:
self._ref = ref
Expand Down Expand Up @@ -405,7 +427,9 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
)

return Run(
prediction=prediction, schema=self.openapi_schema, streaming=self._streaming
prediction=prediction,
schema=self.openapi_schema(),
streaming=self._streaming,
)

@property
Expand All @@ -415,20 +439,28 @@ def default_example(self) -> Optional[dict[str, Any]]:
"""
raise NotImplementedError("This property has not yet been implemented")

@cached_property
def openapi_schema(self) -> dict[str, Any]:
"""
Get the OpenAPI schema for this model version.
"""
latest_version = self._model.latest_version
if latest_version is None:
msg = f"Model {self._model.owner}/{self._model.name} has no latest version"
return self._openapi_schema

@cached_property
def _openapi_schema(self) -> dict[str, Any]:
_, _, model_version = self._parsed_ref
model = self._model

version = (
model.versions.get(model_version) if model_version else model.latest_version
)
if version is None:
msg = f"Model {self._model.owner}/{self._model.name} has no version"
raise ValueError(msg)

schema = latest_version.openapi_schema
if cog_version := latest_version.cog_version:
schema = version.openapi_schema
if cog_version := version.cog_version:
schema = make_schema_backwards_compatible(schema, cog_version)
return schema
return _dereference_schema(schema)

def _client(self) -> Client:
return Client()
Expand Down Expand Up @@ -469,7 +501,6 @@ def _version(self) -> Version | None:
return version


@dataclass
class AsyncRun[O]:
"""
Represents a running prediction with access to its version (async version).
Expand Down Expand Up @@ -528,21 +559,25 @@ async def logs(self) -> Optional[str]:
return self._prediction.logs


@dataclass
class AsyncFunction(Generic[Input, Output]):
"""
An async wrapper for a Replicate model that can be called as a function.
"""

function_ref: str
streaming: bool
_ref: str
_streaming: bool
_openapi_schema: dict[str, Any] | None = None

def __init__(self, ref: str, *, streaming: bool) -> None:
self._ref = ref
self._streaming = streaming

def _client(self) -> Client:
return Client()

@cached_property
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
return ModelVersionIdentifier.parse(self.function_ref)
return ModelVersionIdentifier.parse(self._ref)

async def _model(self) -> Model:
client = self._client()
Expand Down Expand Up @@ -607,7 +642,7 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
return AsyncRun(
prediction=prediction,
schema=await self.openapi_schema(),
streaming=self.streaming,
streaming=self._streaming,
)

@property
Expand All @@ -621,16 +656,26 @@ async def openapi_schema(self) -> dict[str, Any]:
"""
Get the OpenAPI schema for this model version asynchronously.
"""
model = await self._model()
latest_version = model.latest_version
if latest_version is None:
msg = f"Model {model.owner}/{model.name} has no latest version"
raise ValueError(msg)
if not self._openapi_schema:
_, _, model_version = self._parsed_ref

schema = latest_version.openapi_schema
if cog_version := latest_version.cog_version:
schema = make_schema_backwards_compatible(schema, cog_version)
return schema
model = await self._model()
if model_version:
version = await model.versions.async_get(model_version)
else:
version = model.latest_version

if version is None:
msg = f"Model {model.owner}/{model.name} has no version"
raise ValueError(msg)

schema = version.openapi_schema
if cog_version := version.cog_version:
schema = make_schema_backwards_compatible(schema, cog_version)

self._openapi_schema = _dereference_schema(schema)

return self._openapi_schema


@overload
Expand Down
55 changes: 55 additions & 0 deletions tests/test_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,61 @@ async def test_use_function_create_method(client_mode):
assert run._prediction.input == {"prompt": "hello world"}


@pytest.mark.asyncio
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
@respx.mock
async def test_use_function_openapi_schema_dereferenced(client_mode):
mock_model_endpoints(
versions=[
create_mock_version(
{
"openapi_schema": {
"components": {
"schemas": {
"Output": {"$ref": "#/components/schemas/ModelOutput"},
"ModelOutput": {
"type": "object",
"properties": {
"text": {"type": "string"},
"image": {
"type": "string",
"format": "uri",
},
"count": {"type": "integer"},
},
},
}
}
}
}
)
]
)

hotdog_detector = replicate.use(
"acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC
)

if client_mode == ClientMode.ASYNC:
schema = await hotdog_detector.openapi_schema()
else:
schema = hotdog_detector.openapi_schema()

assert schema["components"]["schemas"]["Output"] == {
"type": "object",
"properties": {
"text": {"type": "string"},
"image": {
"type": "string",
"format": "uri",
},
"count": {"type": "integer"},
},
}

assert "ModelOutput" not in schema["components"]["schemas"]


@pytest.mark.asyncio
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
@respx.mock
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.