Skip to content

Commit 20a37d1

Browse files
committed
Add get_url_path() helper to get underlying URL for a PathProxy object
1 parent 8d85629 commit 20a37d1

File tree

2 files changed

+99
-21
lines changed

2 files changed

+99
-21
lines changed

replicate/use.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
11
# TODO
2-
# - [x] Support downloading files and conversion into Path when schema is URL
3-
# - [x] Support list outputs
4-
# - [x] Support iterator outputs
5-
# - [x] Support helpers for working with ContatenateIterator
6-
# - [ ] Support reusing output URL when passing to new method
7-
# - [ ] Support lazy downloading of files into Path
82
# - [ ] Support text streaming
93
# - [ ] Support file streaming
104
# - [ ] Support asyncio variant
@@ -28,6 +22,9 @@
2822
from replicate.version import Version
2923

3024

25+
__all__ = ["use", "get_path_url"]
26+
27+
3128
def _in_module_scope() -> bool:
3229
"""
3330
Returns True when called from top level module scope.
@@ -41,9 +38,6 @@ def _in_module_scope() -> bool:
4138
return False
4239

4340

44-
__all__ = ["use"]
45-
46-
4741
def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
4842
"""
4943
Returns true if the model output type is ConcatenateIterator or
@@ -218,29 +212,41 @@ def ensure_path() -> Path:
218212
path = _download_file(target)
219213
return path
220214

221-
object.__setattr__(self, "__target__", target)
222-
object.__setattr__(self, "__path__", ensure_path)
215+
object.__setattr__(self, "__replicate_target__", target)
216+
object.__setattr__(self, "__replicate_path__", ensure_path)
223217

224218
def __getattribute__(self, name) -> Any:
225-
if name in ("__path__", "__target__"):
219+
if name in ("__replicate_path__", "__replicate_target__"):
226220
return object.__getattribute__(self, name)
227221

228222
# TODO: We should cover other common properties on Path...
229223
if name == "__class__":
230224
return Path
231225

232-
return getattr(object.__getattribute__(self, "__path__")(), name)
226+
return getattr(object.__getattribute__(self, "__replicate_path__")(), name)
233227

234228
def __setattr__(self, name, value) -> None:
235-
if name in ("__path__", "__target__"):
229+
if name in ("__replicate_path__", "__replicate_target__"):
236230
raise ValueError()
237231

238-
object.__setattr__(object.__getattribute__(self, "__path__")(), name, value)
232+
object.__setattr__(
233+
object.__getattribute__(self, "__replicate_path__")(), name, value
234+
)
239235

240236
def __delattr__(self, name) -> None:
241-
if name in ("__path__", "__target__"):
237+
if name in ("__replicate_path__", "__replicate_target__"):
242238
raise ValueError()
243-
delattr(object.__getattribute__(self, "__path__")(), name)
239+
delattr(object.__getattribute__(self, "__replicate_path__")(), name)
240+
241+
242+
def get_path_url(path: Any) -> str | None:
243+
"""
244+
Return the remote URL (if any) for a Path output from a model.
245+
"""
246+
try:
247+
return object.__getattribute__(path, "__replicate_target__")
248+
except AttributeError:
249+
return None
244250

245251

246252
@dataclass
@@ -252,7 +258,7 @@ class Run:
252258
prediction: Prediction
253259
schema: dict
254260

255-
def wait(self) -> Union[Any, Iterator[Any]]:
261+
def output(self) -> Union[Any, Iterator[Any]]:
256262
"""
257263
Wait for the prediction to complete and return its output.
258264
"""
@@ -330,7 +336,7 @@ def _version(self) -> Version | None:
330336

331337
def __call__(self, **inputs: Dict[str, Any]) -> Any:
332338
run = self.create(**inputs)
333-
return run.wait()
339+
return run.output()
334340

335341
def create(self, **inputs: Dict[str, Any]) -> Run:
336342
"""
@@ -341,8 +347,8 @@ def create(self, **inputs: Dict[str, Any]) -> Run:
341347
for key, value in inputs.items():
342348
if isinstance(value, OutputIterator) and value.is_concatenate:
343349
processed_inputs[key] = str(value)
344-
elif isinstance(value, PathProxy):
345-
processed_inputs[key] = object.__getattribute__(value, "__target__")
350+
elif url := get_path_url(value):
351+
processed_inputs[key] = url
346352
else:
347353
processed_inputs[key] = value
348354

tests/test_use.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,78 @@ async def test_use_iterator_of_paths_output(use_async_client):
608608
assert output_list[1].read_bytes() == b"fake image 2 data"
609609

610610

611+
def test_get_path_url_with_pathproxy():
612+
"""Test get_path_url returns the URL for PathProxy instances."""
613+
from replicate.use import get_path_url, PathProxy
614+
615+
url = "https://example.com/test.jpg"
616+
path_proxy = PathProxy(url)
617+
618+
result = get_path_url(path_proxy)
619+
assert result == url
620+
621+
622+
def test_get_path_url_with_regular_path():
623+
"""Test get_path_url returns None for regular Path instances."""
624+
from replicate.use import get_path_url
625+
626+
regular_path = Path("/tmp/test.txt")
627+
628+
result = get_path_url(regular_path)
629+
assert result is None
630+
631+
632+
def test_get_path_url_with_object_without_target():
633+
"""Test get_path_url returns None for objects without __replicate_target__."""
634+
from replicate.use import get_path_url
635+
636+
# Test with a string
637+
result = get_path_url("not a path")
638+
assert result is None
639+
640+
# Test with a dict
641+
result = get_path_url({"key": "value"})
642+
assert result is None
643+
644+
# Test with None
645+
result = get_path_url(None)
646+
assert result is None
647+
648+
649+
def test_get_path_url_with_object_with_target():
650+
"""Test get_path_url returns URL for any object with __replicate_target__."""
651+
from replicate.use import get_path_url
652+
653+
class MockObjectWithTarget:
654+
def __init__(self, target):
655+
object.__setattr__(self, "__replicate_target__", target)
656+
657+
url = "https://example.com/mock.png"
658+
mock_obj = MockObjectWithTarget(url)
659+
660+
result = get_path_url(mock_obj)
661+
assert result == url
662+
663+
664+
def test_get_path_url_with_empty_target():
665+
"""Test get_path_url with empty/falsy target values."""
666+
from replicate.use import get_path_url
667+
668+
class MockObjectWithEmptyTarget:
669+
def __init__(self, target):
670+
object.__setattr__(self, "__replicate_target__", target)
671+
672+
# Test with empty string
673+
mock_obj = MockObjectWithEmptyTarget("")
674+
result = get_path_url(mock_obj)
675+
assert result == ""
676+
677+
# Test with None
678+
mock_obj = MockObjectWithEmptyTarget(None)
679+
result = get_path_url(mock_obj)
680+
assert result is None
681+
682+
611683
@pytest.mark.asyncio
612684
@pytest.mark.parametrize("use_async_client", [False])
613685
@respx.mock

0 commit comments

Comments
 (0)