|
1 | 1 | # TODO
|
2 | 2 | # - [ ] Support text streaming
|
3 | 3 | # - [ ] Support file streaming
|
| 4 | +import copy |
4 | 5 | import hashlib
|
5 | 6 | import os
|
6 | 7 | import tempfile
|
@@ -177,6 +178,60 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: # py
|
177 | 178 | return output
|
178 | 179 |
|
179 | 180 |
|
| 181 | +def _dereference_schema(schema: dict[str, Any]) -> dict[str, Any]: |
| 182 | + """ |
| 183 | + Performs basic dereferencing on an OpenAPI schema based on the current schemas generated |
| 184 | + by Replicate. This code assumes that: |
| 185 | +
|
| 186 | + 1) References will always point to a field within #/components/schemas and will error |
| 187 | + if the reference is more deeply nested. |
| 188 | + 2) That the references when used can be discarded. |
| 189 | +
|
| 190 | + Should something more in-depth be required we could consider using the jsonref package. |
| 191 | + """ |
| 192 | + dereferenced = copy.deepcopy(schema) |
| 193 | + schemas = dereferenced.get("components", {}).get("schemas", {}) |
| 194 | + dereferenced_refs = set() |
| 195 | + |
| 196 | + def _resolve_ref(obj: Any) -> Any: |
| 197 | + if isinstance(obj, dict): |
| 198 | + if "$ref" in obj: |
| 199 | + ref_path = obj["$ref"] |
| 200 | + if ref_path.startswith("#/components/schemas/"): |
| 201 | + parts = ref_path.replace("#/components/schemas/", "").split("/", 2) |
| 202 | + |
| 203 | + if len(parts) > 1: |
| 204 | + raise NotImplementedError( |
| 205 | + f"Unexpected nested $ref found in schema: {ref_path}" |
| 206 | + ) |
| 207 | + |
| 208 | + (schema_name,) = parts |
| 209 | + if schema_name in schemas: |
| 210 | + dereferenced_refs.add(schema_name) |
| 211 | + return _resolve_ref(schemas[schema_name]) |
| 212 | + else: |
| 213 | + return obj |
| 214 | + else: |
| 215 | + return obj |
| 216 | + else: |
| 217 | + return {key: _resolve_ref(value) for key, value in obj.items()} |
| 218 | + elif isinstance(obj, list): |
| 219 | + return [_resolve_ref(item) for item in obj] |
| 220 | + else: |
| 221 | + return obj |
| 222 | + |
| 223 | + result = _resolve_ref(dereferenced) |
| 224 | + |
| 225 | + # Filter out any references that have now been referenced. |
| 226 | + result["components"]["schemas"] = { |
| 227 | + k: v |
| 228 | + for k, v in result["components"]["schemas"].items() |
| 229 | + if k not in dereferenced_refs |
| 230 | + } |
| 231 | + |
| 232 | + return result |
| 233 | + |
| 234 | + |
180 | 235 | T = TypeVar("T")
|
181 | 236 |
|
182 | 237 |
|
@@ -428,7 +483,7 @@ def openapi_schema(self) -> dict[str, Any]:
|
428 | 483 | schema = latest_version.openapi_schema
|
429 | 484 | if cog_version := latest_version.cog_version:
|
430 | 485 | schema = make_schema_backwards_compatible(schema, cog_version)
|
431 |
| - return schema |
| 486 | + return _dereference_schema(schema) |
432 | 487 |
|
433 | 488 | def _client(self) -> Client:
|
434 | 489 | return Client()
|
@@ -630,7 +685,7 @@ async def openapi_schema(self) -> dict[str, Any]:
|
630 | 685 | schema = latest_version.openapi_schema
|
631 | 686 | if cog_version := latest_version.cog_version:
|
632 | 687 | schema = make_schema_backwards_compatible(schema, cog_version)
|
633 |
| - return schema |
| 688 | + return _dereference_schema(schema) |
634 | 689 |
|
635 | 690 |
|
636 | 691 | @overload
|
|
0 commit comments