Skip to content

Commit de717a0

Browse files
authored
Fix iterator support for replicate.run() (#383)
Prior to 1.0.0 `replicate.run()` would return an iterator for cog models that output a type of `Iterator[Any]`. This would poll the `predictions.get` endpoint for the in progress prediction and yield any new output. When implementing the new file interface we introduced two bugs: 1. The iterator didn't convert URLs returned by the model into `FileOutput` types making it inconsistent with the non-iterator interface. This is controlled by the `use_file_outputs` argument. 2. The iterator was returned without checking if we are using the new blocking API introduced by default and controlled by the `wait` argument. This commit fixes these two issues, consistently applying the `transform_output` function to the output of the iterator as well as returning the polling iterator (`prediciton.output_iterator`) if the blocking API has not successfully returned a completed prediction. The tests have been updated to exercise both of these code paths.
1 parent 23bd903 commit de717a0

File tree

2 files changed

+606
-228
lines changed

2 files changed

+606
-228
lines changed

replicate/run.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from replicate.exceptions import ModelError
1616
from replicate.helpers import transform_output
1717
from replicate.model import Model
18-
from replicate.prediction import Prediction
1918
from replicate.schema import make_schema_backwards_compatible
2019
from replicate.version import Version, Versions
2120

@@ -59,15 +58,36 @@ def run(
5958
if not version and (owner and name and version_id):
6059
version = Versions(client, model=(owner, name)).get(version_id)
6160

62-
if version and (iterator := _make_output_iterator(version, prediction)):
63-
return iterator
61+
# Currently the "Prefer: wait" interface will return a prediction with a status
62+
# of "processing" rather than a terminal state because it returns before the
63+
# prediction has been fully processed. If request exceeds the wait time, even if
64+
# it is actually processing, the prediction will be in a "starting" state.
65+
#
66+
# We should fix this in the blocking API itself. Predictions that are done should
67+
# be in a terminal state and predictions that are processing should be in state
68+
# "processing".
69+
in_terminal_state = is_blocking and prediction.status != "starting"
70+
if not in_terminal_state:
71+
# Return a "polling" iterator if the model has an output iterator array type.
72+
if version and _has_output_iterator_array_type(version):
73+
return (
74+
transform_output(chunk, client)
75+
for chunk in prediction.output_iterator()
76+
)
6477

65-
if not (is_blocking and prediction.status != "starting"):
6678
prediction.wait()
6779

6880
if prediction.status == "failed":
6981
raise ModelError(prediction)
7082

83+
# Return an iterator for the completed prediction when needed.
84+
if (
85+
version
86+
and _has_output_iterator_array_type(version)
87+
and prediction.output is not None
88+
):
89+
return (transform_output(chunk, client) for chunk in prediction.output)
90+
7191
if use_file_output:
7292
return transform_output(prediction.output, client)
7393

@@ -108,15 +128,39 @@ async def async_run(
108128
if not version and (owner and name and version_id):
109129
version = await Versions(client, model=(owner, name)).async_get(version_id)
110130

111-
if version and (iterator := _make_async_output_iterator(version, prediction)):
112-
return iterator
131+
# Currently the "Prefer: wait" interface will return a prediction with a status
132+
# of "processing" rather than a terminal state because it returns before the
133+
# prediction has been fully processed. If request exceeds the wait time, even if
134+
# it is actually processing, the prediction will be in a "starting" state.
135+
#
136+
# We should fix this in the blocking API itself. Predictions that are done should
137+
# be in a terminal state and predictions that are processing should be in state
138+
# "processing".
139+
in_terminal_state = is_blocking and prediction.status != "starting"
140+
if not in_terminal_state:
141+
# Return a "polling" iterator if the model has an output iterator array type.
142+
if version and _has_output_iterator_array_type(version):
143+
return (
144+
transform_output(chunk, client)
145+
async for chunk in prediction.async_output_iterator()
146+
)
113147

114-
if not (is_blocking and prediction.status != "starting"):
115148
await prediction.async_wait()
116149

117150
if prediction.status == "failed":
118151
raise ModelError(prediction)
119152

153+
# Return an iterator for completed output if the model has an output iterator array type.
154+
if (
155+
version
156+
and _has_output_iterator_array_type(version)
157+
and prediction.output is not None
158+
):
159+
return (
160+
transform_output(chunk, client)
161+
async for chunk in _make_async_iterator(prediction.output)
162+
)
163+
120164
if use_file_output:
121165
return transform_output(prediction.output, client)
122166

@@ -133,22 +177,9 @@ def _has_output_iterator_array_type(version: Version) -> bool:
133177
)
134178

135179

136-
def _make_output_iterator(
137-
version: Version, prediction: Prediction
138-
) -> Optional[Iterator[Any]]:
139-
if _has_output_iterator_array_type(version):
140-
return prediction.output_iterator()
141-
142-
return None
143-
144-
145-
def _make_async_output_iterator(
146-
version: Version, prediction: Prediction
147-
) -> Optional[AsyncIterator[Any]]:
148-
if _has_output_iterator_array_type(version):
149-
return prediction.async_output_iterator()
150-
151-
return None
180+
async def _make_async_iterator(list: list) -> AsyncIterator:
181+
for item in list:
182+
yield item
152183

153184

154185
__all__: List = []

0 commit comments

Comments
 (0)