Skip to content

Commit 8b8bbac

Browse files
committed
don't use async for predict loop if predict is not async
Signed-off-by: technillogue <technillogue@gmail.com>
1 parent dcf5ac4 commit 8b8bbac

File tree

2 files changed

+78
-34
lines changed

2 files changed

+78
-34
lines changed

python/cog/predictor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
class BasePredictor(ABC):
5252
def setup(
5353
self, weights: Optional[Union[CogFile, CogPath]] = None
54-
) -> Union[Awaitable[None], None]:
54+
) -> Optional[Awaitable[None]]:
5555
"""
5656
An optional method to prepare the model so multiple predictions run efficiently.
5757
"""
@@ -68,8 +68,9 @@ def predict(self, **kwargs: Any) -> Any:
6868
def run_setup(predictor: BasePredictor) -> None:
6969
weights = get_weights_argument(predictor)
7070
if weights:
71-
return predictor.setup()
72-
return predictor.setup(weights=weights)
71+
predictor.setup(weights=weights)
72+
else:
73+
predictor.setup()
7374

7475

7576
async def run_setup_async(predictor: BasePredictor) -> None:
@@ -80,6 +81,7 @@ async def run_setup_async(predictor: BasePredictor) -> None:
8081

8182

8283
def get_weights_argument(predictor: BasePredictor) -> Union[io.IOBase, CogPath, None]:
84+
# by the time we get here we assume predictor has a setup method
8385
weights_type = get_weights_type(predictor.setup)
8486
if weights_type is None:
8587
return None
@@ -110,7 +112,7 @@ def get_weights_argument(predictor: BasePredictor) -> Union[io.IOBase, CogPath,
110112
return None
111113

112114

113-
def get_weights_type(predictor: BasePredictor) -> Optional[Any]:
115+
def get_weights_type(setup_function: Callable) -> Optional[Any]:
114116
signature = inspect.signature(setup_function)
115117
if "weights" not in signature.parameters:
116118
return None

python/cog/server/worker.py

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
import inspect
34
import multiprocessing
45
import os
@@ -8,7 +9,7 @@
89
import types
910
from enum import Enum, auto, unique
1011
from multiprocessing.connection import Connection
11-
from typing import Any, Dict, Iterable, Optional, TextIO, Union
12+
from typing import Any, Dict, Iterator, Optional, TextIO, Union
1213

1314
from ..json import make_encodeable
1415
from ..predictor import (
@@ -58,7 +59,7 @@ def __init__(self, predictor_ref: str, tee_output: bool = True) -> None:
5859
self._child = _ChildWorker(predictor_ref, child_events, tee_output)
5960
self._terminating = False
6061

61-
def setup(self) -> Iterable[_PublicEventType]:
62+
def setup(self) -> Iterator[_PublicEventType]:
6263
self._assert_state(WorkerState.NEW)
6364
self._state = WorkerState.STARTING
6465
self._child.start()
@@ -67,7 +68,7 @@ def setup(self) -> Iterable[_PublicEventType]:
6768

6869
def predict(
6970
self, payload: Dict[str, Any], poll: Optional[float] = None
70-
) -> Iterable[_PublicEventType]:
71+
) -> Iterator[_PublicEventType]:
7172
self._assert_state(WorkerState.READY)
7273
self._state = WorkerState.PROCESSING
7374
self._allow_cancel = True
@@ -108,7 +109,7 @@ def _assert_state(self, state: WorkerState) -> None:
108109

109110
def _wait(
110111
self, poll: Optional[float] = None, raise_on_error: Optional[str] = None
111-
) -> Iterable[_PublicEventType]:
112+
) -> Iterator[_PublicEventType]:
112113
done = None
113114

114115
if poll:
@@ -178,15 +179,21 @@ def run(self) -> None:
178179
[ws_stdout, ws_stderr], self._stream_write_hook
179180
)
180181
self._stream_redirector.start()
181-
182182
self._setup()
183-
asyncio.run(self._loop())
183+
self._loop()
184184
self._stream_redirector.shutdown()
185185

186186
def _setup(self) -> None:
187-
done = Done()
188-
try:
187+
with self._handle_setup_error():
188+
# we need to load the predictor to know if setup is async
189189
self._predictor = load_predictor_from_ref(self._predictor_ref)
190+
# if users want to access the same event loop from setup and predict,
191+
# both have to be async. if setup isn't async, it doesn't matter if we
192+
# create the event loop here or after setup
193+
#
194+
# otherwise, if setup is sync and the user does new_event_loop to use a ClientSession,
195+
# then tries to use the same session from async predict, they would get an error.
196+
# that's significant if connections are open and would need to be discarded
190197
if is_async_predictor(self._predictor):
191198
self.loop = get_loop()
192199
# Could be a function or a class
@@ -195,6 +202,12 @@ def _setup(self) -> None:
195202
self.loop.run_until_complete(run_setup_async(self._predictor))
196203
else:
197204
run_setup(self._predictor)
205+
206+
@contextlib.contextmanager
207+
def _handle_setup_error(self) -> Iterator[None]:
208+
done = Done()
209+
try:
210+
yield
198211
except Exception as e:
199212
traceback.print_exc()
200213
done.error = True
@@ -210,50 +223,76 @@ def _setup(self) -> None:
210223
self._stream_redirector.drain()
211224
self._events.send(done)
212225

213-
async def _loop(self) -> None:
226+
def _loop_sync(self) -> None:
214227
while True:
215228
ev = self._events.recv()
216229
if isinstance(ev, Shutdown):
217230
break
218231
if isinstance(ev, PredictionInput):
219-
await self._predict(ev.payload)
232+
self._predict_sync(ev.payload)
220233
else:
221234
print(f"Got unexpected event: {ev}", file=sys.stderr)
222235

223-
async def _predict(self, payload: Dict[str, Any]) -> None:
236+
async def _loop_async(self) -> None:
237+
while True:
238+
ev = self._events.recv()
239+
if isinstance(ev, Shutdown):
240+
break
241+
if isinstance(ev, PredictionInput):
242+
await self._predict_async(ev.payload)
243+
else:
244+
print(f"Got unexpected event: {ev}", file=sys.stderr)
245+
246+
def _loop(self) -> None:
247+
if is_async(get_predict(self._predictor)):
248+
self.loop.run_until_complete(self._loop_async())
249+
else:
250+
self._loop_sync()
251+
252+
@contextlib.contextmanager
253+
def _handle_predict_error(self) -> Iterator[None]:
224254
assert self._predictor
225255
done = Done()
226256
self._cancelable = True
227257
try:
258+
yield
259+
except CancelationException:
260+
done.canceled = True
261+
except Exception as e:
262+
traceback.print_exc()
263+
done.error = True
264+
done.error_detail = str(e)
265+
finally:
266+
self._cancelable = False
267+
self._stream_redirector.drain()
268+
self._events.send(done)
269+
270+
async def _predict_async(self, payload: Dict[str, Any]) -> None:
271+
with self._handle_predict_error():
228272
predict = get_predict(self._predictor)
229273
result = predict(**payload)
230-
231274
if result:
232275
if inspect.isasyncgen(result):
233276
self._events.send(PredictionOutputType(multi=True))
234277
async for r in result:
235278
self._events.send(PredictionOutput(payload=make_encodeable(r)))
236-
elif inspect.isgenerator(result):
237-
self._events.send(PredictionOutputType(multi=True))
238-
for r in result:
239-
self._events.send(PredictionOutput(payload=make_encodeable(r)))
240279
elif inspect.isawaitable(result):
241280
result = await result
242281
self._events.send(PredictionOutputType(multi=False))
243282
self._events.send(PredictionOutput(payload=make_encodeable(result)))
283+
284+
def _predict_sync(self, payload: Dict[str, Any]) -> None:
285+
with self._handle_predict_error():
286+
predict = get_predict(self._predictor)
287+
result = predict(**payload)
288+
if result:
289+
if inspect.isgenerator(result):
290+
self._events.send(PredictionOutputType(multi=True))
291+
for r in result:
292+
self._events.send(PredictionOutput(payload=make_encodeable(r)))
244293
else:
245294
self._events.send(PredictionOutputType(multi=False))
246295
self._events.send(PredictionOutput(payload=make_encodeable(result)))
247-
except CancelationException:
248-
done.canceled = True
249-
except Exception as e:
250-
traceback.print_exc()
251-
done.error = True
252-
done.error_detail = str(e)
253-
finally:
254-
self._cancelable = False
255-
self._stream_redirector.drain()
256-
self._events.send(done)
257296

258297
def _signal_handler(self, signum: int, frame: Optional[types.FrameType]) -> None:
259298
if signum == signal.SIGUSR1 and self._cancelable:
@@ -270,13 +309,16 @@ def _stream_write_hook(
270309

271310
def get_loop() -> asyncio.AbstractEventLoop:
272311
try:
312+
# just in case something else created an event loop already
273313
return asyncio.get_running_loop()
274314
except RuntimeError:
275315
return asyncio.new_event_loop()
276316

277317

318+
def is_async(fn: Any) -> bool:
319+
return inspect.iscoroutinefunction(fn) or inspect.isasyncgenfunction(fn)
320+
321+
278322
def is_async_predictor(predictor: BasePredictor) -> bool:
279-
predict = get_predict(predictor)
280-
if inspect.iscoroutinefunction(predict) or inspect.isasyncgenfunction(predict):
281-
return True
282-
return inspect.iscoroutinefunction(getattr(predictor, "setup", None))
323+
setup = getattr(predictor, "setup", None)
324+
return is_async(setup) or is_async(get_predict(predictor))

0 commit comments

Comments
 (0)