Skip to content

Commit 4111e82

Browse files
committed
Actually use PathProxy
1 parent bc5d7d8 commit 4111e82

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

replicate/use.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
154154
# Handle direct string with format=uri
155155
if output_schema.get("type") == "string" and output_schema.get("format") == "uri":
156156
if isinstance(output, str) and output.startswith(("http://", "https://")):
157-
return _download_file(output)
157+
return PathProxy(output)
158158
return output
159159

160160
# Handle array of strings with format=uri
@@ -163,7 +163,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
163163
if items.get("type") == "string" and items.get("format") == "uri":
164164
if isinstance(output, list):
165165
return [
166-
_download_file(url)
166+
PathProxy(url)
167167
if isinstance(url, str) and url.startswith(("http://", "https://"))
168168
else url
169169
for url in output
@@ -187,15 +187,15 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
187187
if isinstance(value, str) and value.startswith(
188188
("http://", "https://")
189189
):
190-
result[prop_name] = _download_file(value)
190+
result[prop_name] = PathProxy(value)
191191

192192
# Array of files property
193193
elif prop_schema.get("type") == "array":
194194
items = prop_schema.get("items", {})
195195
if items.get("type") == "string" and items.get("format") == "uri":
196196
if isinstance(value, list):
197197
result[prop_name] = [
198-
_download_file(url)
198+
PathProxy(url)
199199
if isinstance(url, str)
200200
and url.startswith(("http://", "https://"))
201201
else url

tests/test_use.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import respx
1010

1111
import replicate
12+
from replicate.use import PathProxy
1213

1314

1415
class ClientMode(str, Enum):
@@ -540,6 +541,7 @@ async def test_use_path_output(client_mode):
540541
# Call function with prompt="hello world"
541542
output = hotdog_detector(prompt="hello world")
542543

544+
assert isinstance(output, PathProxy)
543545
assert isinstance(output, Path)
544546
assert output.exists()
545547
assert output.read_bytes() == b"fake image data"
@@ -598,6 +600,7 @@ async def test_use_list_of_paths_output(client_mode):
598600

599601
assert isinstance(output, list)
600602
assert len(output) == 2
603+
assert all(isinstance(path, PathProxy) for path in output)
601604
assert all(isinstance(path, Path) for path in output)
602605
assert all(path.exists() for path in output)
603606
assert output[0].read_bytes() == b"fake image 1 data"
@@ -663,6 +666,7 @@ async def test_use_iterator_of_paths_output(client_mode):
663666
# Convert to list to check contents
664667
output_list = list(output)
665668
assert len(output_list) == 2
669+
assert all(isinstance(path, PathProxy) for path in output_list)
666670
assert all(isinstance(path, Path) for path in output_list)
667671
assert all(path.exists() for path in output_list)
668672
assert output_list[0].read_bytes() == b"fake image 1 data"

0 commit comments

Comments
 (0)