Skip to content

Commit 3fa37dc

Browse files
committed
Add progress property to Prediction
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 3b9cd7a commit 3fa37dc

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

replicate/prediction.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import re
12
import time
3+
from dataclasses import dataclass
24
from typing import Any, Dict, Iterator, List, Optional
35

46
from replicate.base_model import BaseModel
@@ -56,6 +58,36 @@ class Prediction(BaseModel):
5658
- `cancel`: A URL to cancel the prediction.
5759
"""
5860

61+
@dataclass
62+
class Progress:
63+
percentage: float
64+
"""The percentage of the prediction that has completed."""
65+
66+
current: int
67+
"""The number of items that have been processed."""
68+
69+
total: int
70+
"""The total number of items to process."""
71+
72+
@property
73+
def progress(self) -> Optional[Progress]:
74+
if self.logs is None or self.logs == "":
75+
return None
76+
77+
pattern = r"^\s*(?P<percentage>\d+)%\s*\|.+?\|\s*(?P<current>\d+)\/(?P<total>\d+)"
78+
re_compiled = re.compile(pattern)
79+
80+
lines = self.logs.split("\n")
81+
for i in reversed(range(len(lines))):
82+
line = lines[i].strip()
83+
if re_compiled.match(line):
84+
matches = re_compiled.findall(line)
85+
if len(matches) == 1:
86+
percentage, current, total = map(int, matches[0])
87+
return Prediction.Progress(percentage / 100.0, current, total)
88+
return None
89+
90+
5991
def wait(self) -> None:
6092
"""
6193
Wait for prediction to finish.

tests/test_prediction.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import responses
22
from responses import matchers
33

4+
from replicate.prediction import Prediction
5+
46
from .factories import create_client, create_version
57

68

@@ -214,3 +216,64 @@ def test_async_timings():
214216
assert prediction.completed_at == "2022-04-26T20:02:27.648305Z"
215217
assert prediction.output == "hello world"
216218
assert prediction.metrics["predict_time"] == 1.2345
219+
220+
def test_prediction_progress():
221+
client = create_client()
222+
version = create_version(client)
223+
prediction = Prediction(
224+
id="ufawqhfynnddngldkgtslldrkq",
225+
version=version,
226+
status="starting"
227+
)
228+
229+
lines = [
230+
"Using seed: 12345",
231+
"0%| | 0/5 [00:00<?, ?it/s]",
232+
"20%|██ | 1/5 [00:00<00:01, 21.38it/s]",
233+
"40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]",
234+
"60%|████▍ | 3/5 [00:01<00:01, 22.46it/s]",
235+
"80%|████████ | 4/5 [00:01<00:00, 22.86it/s]",
236+
"100%|██████████| 5/5 [00:02<00:00, 22.26it/s]",
237+
]
238+
logs = ""
239+
240+
for i, line in enumerate(lines):
241+
logs += "\n" + line
242+
prediction.logs = logs
243+
244+
progress = prediction.progress
245+
246+
if i == 0:
247+
prediction.status = "processing"
248+
assert progress is None
249+
elif i == 1:
250+
assert progress is not None
251+
assert progress.current == 0
252+
assert progress.total == 5
253+
assert progress.percentage == 0.0
254+
elif i == 2:
255+
assert progress is not None
256+
assert progress.current == 1
257+
assert progress.total == 5
258+
assert progress.percentage == 0.2
259+
elif i == 3:
260+
assert progress is not None
261+
assert progress.current == 2
262+
assert progress.total == 5
263+
assert progress.percentage == 0.4
264+
elif i == 4:
265+
assert progress is not None
266+
assert progress.current == 3
267+
assert progress.total == 5
268+
assert progress.percentage == 0.6
269+
elif i == 5:
270+
assert progress is not None
271+
assert progress.current == 4
272+
assert progress.total == 5
273+
assert progress.percentage == 0.8
274+
elif i == 6:
275+
assert progress is not None
276+
prediction.status = "succeeded"
277+
assert progress.current == 5
278+
assert progress.total == 5
279+
assert progress.percentage == 1.0

0 commit comments

Comments
 (0)