Skip to content

Commit ff833c7

Browse files
committed
delete remaining runner thread code :)
Signed-off-by: technillogue <technillogue@gmail.com>
1 parent ba06ec2 commit ff833c7

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

python/cog/server/runner.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import asyncio
22
import io
3-
import threading
43
import traceback
54
from asyncio import Task
65
from datetime import datetime, timezone
7-
from multiprocessing.pool import AsyncResult, ThreadPool
86
from typing import Any, Callable, Dict, Optional, Tuple
97

108
import requests
@@ -41,17 +39,14 @@ def __init__(
4139
self,
4240
*,
4341
predictor_ref: str,
44-
shutdown_event: Optional[threading.Event],
42+
shutdown_event: Optional[asyncio.Event],
4543
upload_url: Optional[str] = None,
4644
) -> None:
47-
self._thread = None
48-
self._threadpool = ThreadPool(processes=1)
49-
5045
self._response: Optional[schema.PredictionResponse] = None
5146
self._result: Optional[Task] = None
5247

5348
self._worker = Worker(predictor_ref=predictor_ref)
54-
self._should_cancel = threading.Event()
49+
self._should_cancel = asyncio.Event()
5550

5651
self._shutdown_event = shutdown_event
5752
self._upload_url = upload_url
@@ -135,8 +130,7 @@ def is_busy(self) -> bool:
135130

136131
def shutdown(self) -> None:
137132
self._worker.terminate()
138-
self._threadpool.terminate()
139-
self._threadpool.join()
133+
# TODO: cancel setup or predict task
140134

141135
def cancel(self, prediction_id: Optional[str] = None) -> None:
142136
if not self.is_busy():
@@ -316,7 +310,7 @@ async def predict(
316310
worker: Worker,
317311
request: schema.PredictionRequest,
318312
event_handler: PredictionEventHandler,
319-
should_cancel: threading.Event,
313+
should_cancel: asyncio.Event,
320314
) -> schema.PredictionResponse:
321315
# Set up logger context within prediction thread.
322316
structlog.contextvars.clear_contextvars()
@@ -341,7 +335,7 @@ async def _predict(
341335
worker: Worker,
342336
request: schema.PredictionRequest,
343337
event_handler: PredictionEventHandler,
344-
should_cancel: threading.Event,
338+
should_cancel: asyncio.Event,
345339
) -> schema.PredictionResponse:
346340
initial_prediction = request.dict()
347341

0 commit comments

Comments
 (0)