|
1 | 1 | import responses
|
2 | 2 | from responses import matchers
|
3 | 3 |
|
| 4 | +from replicate.prediction import Prediction |
| 5 | + |
4 | 6 | from .factories import create_client, create_version
|
5 | 7 |
|
6 | 8 |
|
@@ -214,3 +216,63 @@ def test_async_timings():
|
214 | 216 | assert prediction.completed_at == "2022-04-26T20:02:27.648305Z"
|
215 | 217 | assert prediction.output == "hello world"
|
216 | 218 | assert prediction.metrics["predict_time"] == 1.2345
|
| 219 | + |
| 220 | + |
| 221 | +def test_prediction_progress(): |
| 222 | + client = create_client() |
| 223 | + version = create_version(client) |
| 224 | + prediction = Prediction( |
| 225 | + id="ufawqhfynnddngldkgtslldrkq", version=version, status="starting" |
| 226 | + ) |
| 227 | + |
| 228 | + lines = [ |
| 229 | + "Using seed: 12345", |
| 230 | + "0%| | 0/5 [00:00<?, ?it/s]", |
| 231 | + "20%|██ | 1/5 [00:00<00:01, 21.38it/s]", |
| 232 | + "40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]", |
| 233 | + "60%|████▍ | 3/5 [00:01<00:01, 22.46it/s]", |
| 234 | + "80%|████████ | 4/5 [00:01<00:00, 22.86it/s]", |
| 235 | + "100%|██████████| 5/5 [00:02<00:00, 22.26it/s]", |
| 236 | + ] |
| 237 | + logs = "" |
| 238 | + |
| 239 | + for i, line in enumerate(lines): |
| 240 | + logs += "\n" + line |
| 241 | + prediction.logs = logs |
| 242 | + |
| 243 | + progress = prediction.progress |
| 244 | + |
| 245 | + if i == 0: |
| 246 | + prediction.status = "processing" |
| 247 | + assert progress is None |
| 248 | + elif i == 1: |
| 249 | + assert progress is not None |
| 250 | + assert progress.current == 0 |
| 251 | + assert progress.total == 5 |
| 252 | + assert progress.percentage == 0.0 |
| 253 | + elif i == 2: |
| 254 | + assert progress is not None |
| 255 | + assert progress.current == 1 |
| 256 | + assert progress.total == 5 |
| 257 | + assert progress.percentage == 0.2 |
| 258 | + elif i == 3: |
| 259 | + assert progress is not None |
| 260 | + assert progress.current == 2 |
| 261 | + assert progress.total == 5 |
| 262 | + assert progress.percentage == 0.4 |
| 263 | + elif i == 4: |
| 264 | + assert progress is not None |
| 265 | + assert progress.current == 3 |
| 266 | + assert progress.total == 5 |
| 267 | + assert progress.percentage == 0.6 |
| 268 | + elif i == 5: |
| 269 | + assert progress is not None |
| 270 | + assert progress.current == 4 |
| 271 | + assert progress.total == 5 |
| 272 | + assert progress.percentage == 0.8 |
| 273 | + elif i == 6: |
| 274 | + assert progress is not None |
| 275 | + prediction.status = "succeeded" |
| 276 | + assert progress.current == 5 |
| 277 | + assert progress.total == 5 |
| 278 | + assert progress.percentage == 1.0 |
0 commit comments