1
1
import asyncio
2
2
import io
3
- import threading
4
3
import traceback
5
4
from asyncio import Task
6
5
from datetime import datetime , timezone
7
- from multiprocessing .pool import AsyncResult , ThreadPool
8
6
from typing import Any , Callable , Dict , Optional , Tuple
9
7
10
8
import requests
@@ -41,17 +39,14 @@ def __init__(
41
39
self ,
42
40
* ,
43
41
predictor_ref : str ,
44
- shutdown_event : Optional [threading .Event ],
42
+ shutdown_event : Optional [asyncio .Event ],
45
43
upload_url : Optional [str ] = None ,
46
44
) -> None :
47
- self ._thread = None
48
- self ._threadpool = ThreadPool (processes = 1 )
49
-
50
45
self ._response : Optional [schema .PredictionResponse ] = None
51
46
self ._result : Optional [Task ] = None
52
47
53
48
self ._worker = Worker (predictor_ref = predictor_ref )
54
- self ._should_cancel = threading .Event ()
49
+ self ._should_cancel = asyncio .Event ()
55
50
56
51
self ._shutdown_event = shutdown_event
57
52
self ._upload_url = upload_url
@@ -135,8 +130,7 @@ def is_busy(self) -> bool:
135
130
136
131
def shutdown (self ) -> None :
137
132
self ._worker .terminate ()
138
- self ._threadpool .terminate ()
139
- self ._threadpool .join ()
133
+ # TODO: cancel setup or predict task
140
134
141
135
def cancel (self , prediction_id : Optional [str ] = None ) -> None :
142
136
if not self .is_busy ():
@@ -316,7 +310,7 @@ async def predict(
316
310
worker : Worker ,
317
311
request : schema .PredictionRequest ,
318
312
event_handler : PredictionEventHandler ,
319
- should_cancel : threading .Event ,
313
+ should_cancel : asyncio .Event ,
320
314
) -> schema .PredictionResponse :
321
315
# Set up logger context within prediction thread.
322
316
structlog .contextvars .clear_contextvars ()
@@ -341,7 +335,7 @@ async def _predict(
341
335
worker : Worker ,
342
336
request : schema .PredictionRequest ,
343
337
event_handler : PredictionEventHandler ,
344
- should_cancel : threading .Event ,
338
+ should_cancel : asyncio .Event ,
345
339
) -> schema .PredictionResponse :
346
340
initial_prediction = request .dict ()
347
341
0 commit comments