Skip to content

Commit 21ac011

Browse files
committed
Dereference the OpenAPI schema returned by openapi_schema
1 parent 0510f24 commit 21ac011

File tree

2 files changed

+112
-2
lines changed

2 files changed

+112
-2
lines changed

replicate/use.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# TODO
22
# - [ ] Support text streaming
33
# - [ ] Support file streaming
4+
import copy
45
import hashlib
56
import os
67
import tempfile
@@ -177,6 +178,60 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: # py
177178
return output
178179

179180

181+
def _dereference_schema(schema: dict[str, Any]) -> dict[str, Any]:
182+
"""
183+
Performs basic dereferencing on an OpenAPI schema based on the current schemas generated
184+
by Replicate. This code assumes that:
185+
186+
1) References will always point to a field within #/components/schemas and will error
187+
if the reference is more deeply nested.
188+
2) That the references when used can be discarded.
189+
190+
Should something more in-depth be required we could consider using the jsonref package.
191+
"""
192+
dereferenced = copy.deepcopy(schema)
193+
schemas = dereferenced.get("components", {}).get("schemas", {})
194+
dereferenced_refs = set()
195+
196+
def _resolve_ref(obj: Any) -> Any:
197+
if isinstance(obj, dict):
198+
if "$ref" in obj:
199+
ref_path = obj["$ref"]
200+
if ref_path.startswith("#/components/schemas/"):
201+
parts = ref_path.replace("#/components/schemas/", "").split("/", 2)
202+
203+
if len(parts) > 1:
204+
raise NotImplementedError(
205+
f"Unexpected nested $ref found in schema: {ref_path}"
206+
)
207+
208+
(schema_name,) = parts
209+
if schema_name in schemas:
210+
dereferenced_refs.add(schema_name)
211+
return _resolve_ref(schemas[schema_name])
212+
else:
213+
return obj
214+
else:
215+
return obj
216+
else:
217+
return {key: _resolve_ref(value) for key, value in obj.items()}
218+
elif isinstance(obj, list):
219+
return [_resolve_ref(item) for item in obj]
220+
else:
221+
return obj
222+
223+
result = _resolve_ref(dereferenced)
224+
225+
# Filter out any references that have now been referenced.
226+
result["components"]["schemas"] = {
227+
k: v
228+
for k, v in result["components"]["schemas"].items()
229+
if k not in dereferenced_refs
230+
}
231+
232+
return result
233+
234+
180235
T = TypeVar("T")
181236

182237

@@ -428,7 +483,7 @@ def openapi_schema(self) -> dict[str, Any]:
428483
schema = latest_version.openapi_schema
429484
if cog_version := latest_version.cog_version:
430485
schema = make_schema_backwards_compatible(schema, cog_version)
431-
return schema
486+
return _dereference_schema(schema)
432487

433488
def _client(self) -> Client:
434489
return Client()
@@ -630,7 +685,7 @@ async def openapi_schema(self) -> dict[str, Any]:
630685
schema = latest_version.openapi_schema
631686
if cog_version := latest_version.cog_version:
632687
schema = make_schema_backwards_compatible(schema, cog_version)
633-
return schema
688+
return _dereference_schema(schema)
634689

635690

636691
@overload

tests/test_use.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,61 @@ async def test_use_function_create_method(client_mode):
334334
assert run._prediction.input == {"prompt": "hello world"}
335335

336336

337+
@pytest.mark.asyncio
338+
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
339+
@respx.mock
340+
async def test_use_function_openapi_schema_dereferenced(client_mode):
341+
mock_model_endpoints(
342+
versions=[
343+
create_mock_version(
344+
{
345+
"openapi_schema": {
346+
"components": {
347+
"schemas": {
348+
"Output": {"$ref": "#/components/schemas/ModelOutput"},
349+
"ModelOutput": {
350+
"type": "object",
351+
"properties": {
352+
"text": {"type": "string"},
353+
"image": {
354+
"type": "string",
355+
"format": "uri",
356+
},
357+
"count": {"type": "integer"},
358+
},
359+
},
360+
}
361+
}
362+
}
363+
}
364+
)
365+
]
366+
)
367+
368+
hotdog_detector = replicate.use(
369+
"acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC
370+
)
371+
372+
if client_mode == ClientMode.ASYNC:
373+
schema = await hotdog_detector.openapi_schema()
374+
else:
375+
schema = hotdog_detector.openapi_schema
376+
377+
assert schema["components"]["schemas"]["Output"] == {
378+
"type": "object",
379+
"properties": {
380+
"text": {"type": "string"},
381+
"image": {
382+
"type": "string",
383+
"format": "uri",
384+
},
385+
"count": {"type": "integer"},
386+
},
387+
}
388+
389+
assert "ModelOutput" not in schema["components"]["schemas"]
390+
391+
337392
@pytest.mark.asyncio
338393
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
339394
@respx.mock

0 commit comments

Comments
 (0)