Skip to content

Commit e7358c4

Browse files
committed
Add progress property to Prediction
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 3b9cd7a commit e7358c4

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

replicate/prediction.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import re
12
import time
3+
from dataclasses import dataclass
24
from typing import Any, Dict, Iterator, List, Optional
35

46
from replicate.base_model import BaseModel
@@ -56,6 +58,37 @@ class Prediction(BaseModel):
5658
- `cancel`: A URL to cancel the prediction.
5759
"""
5860

61+
@dataclass
62+
class Progress:
63+
percentage: float
64+
"""The percentage of the prediction that has completed."""
65+
66+
current: int
67+
"""The number of items that have been processed."""
68+
69+
total: int
70+
"""The total number of items to process."""
71+
72+
@property
73+
def progress(self) -> Optional[Progress]:
74+
if self.logs is None or self.logs == "":
75+
return None
76+
77+
pattern = (
78+
r"^\s*(?P<percentage>\d+)%\s*\|.+?\|\s*(?P<current>\d+)\/(?P<total>\d+)"
79+
)
80+
re_compiled = re.compile(pattern)
81+
82+
lines = self.logs.split("\n")
83+
for i in reversed(range(len(lines))):
84+
line = lines[i].strip()
85+
if re_compiled.match(line):
86+
matches = re_compiled.findall(line)
87+
if len(matches) == 1:
88+
percentage, current, total = map(int, matches[0])
89+
return Prediction.Progress(percentage / 100.0, current, total)
90+
return None
91+
5992
def wait(self) -> None:
6093
"""
6194
Wait for prediction to finish.

tests/test_prediction.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import responses
22
from responses import matchers
33

4+
from replicate.prediction import Prediction
5+
46
from .factories import create_client, create_version
57

68

@@ -214,3 +216,63 @@ def test_async_timings():
214216
assert prediction.completed_at == "2022-04-26T20:02:27.648305Z"
215217
assert prediction.output == "hello world"
216218
assert prediction.metrics["predict_time"] == 1.2345
219+
220+
221+
def test_prediction_progress():
222+
client = create_client()
223+
version = create_version(client)
224+
prediction = Prediction(
225+
id="ufawqhfynnddngldkgtslldrkq", version=version, status="starting"
226+
)
227+
228+
lines = [
229+
"Using seed: 12345",
230+
"0%| | 0/5 [00:00<?, ?it/s]",
231+
"20%|██ | 1/5 [00:00<00:01, 21.38it/s]",
232+
"40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]",
233+
"60%|████▍ | 3/5 [00:01<00:01, 22.46it/s]",
234+
"80%|████████ | 4/5 [00:01<00:00, 22.86it/s]",
235+
"100%|██████████| 5/5 [00:02<00:00, 22.26it/s]",
236+
]
237+
logs = ""
238+
239+
for i, line in enumerate(lines):
240+
logs += "\n" + line
241+
prediction.logs = logs
242+
243+
progress = prediction.progress
244+
245+
if i == 0:
246+
prediction.status = "processing"
247+
assert progress is None
248+
elif i == 1:
249+
assert progress is not None
250+
assert progress.current == 0
251+
assert progress.total == 5
252+
assert progress.percentage == 0.0
253+
elif i == 2:
254+
assert progress is not None
255+
assert progress.current == 1
256+
assert progress.total == 5
257+
assert progress.percentage == 0.2
258+
elif i == 3:
259+
assert progress is not None
260+
assert progress.current == 2
261+
assert progress.total == 5
262+
assert progress.percentage == 0.4
263+
elif i == 4:
264+
assert progress is not None
265+
assert progress.current == 3
266+
assert progress.total == 5
267+
assert progress.percentage == 0.6
268+
elif i == 5:
269+
assert progress is not None
270+
assert progress.current == 4
271+
assert progress.total == 5
272+
assert progress.percentage == 0.8
273+
elif i == 6:
274+
assert progress is not None
275+
prediction.status = "succeeded"
276+
assert progress.current == 5
277+
assert progress.total == 5
278+
assert progress.percentage == 1.0

0 commit comments

Comments
 (0)