Skip to content

Commit 8f2ec99

Browse files
committed
create event loop before predictor setup (#1366)
* conditionally create the event loop if predictor is async, and add a path for hypothetical async setup * don't use async for predict loop if predict is not async * add test cases for shared loop across setup and predict + asyncio.run in setup (reverts commit b533c6b) Signed-off-by: technillogue <technillogue@gmail.com>
1 parent 3c30e8f commit 8f2ec99

File tree

10 files changed

+182
-80
lines changed

10 files changed

+182
-80
lines changed

python/cog/command/ast_openapi_schema.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,15 @@ def find(obj: ast.AST, name: str) -> ast.AST:
309309
"""Find a particular named node in a tree"""
310310
return next(node for node in ast.walk(obj) if getattr(node, "name", "") == name)
311311

312+
312313
if typing.TYPE_CHECKING:
313-
AstVal: "typing.TypeAlias" = "int | float | complex | str | list[AstVal] | bytes | None"
314+
AstVal: "typing.TypeAlias" = (
315+
"int | float | complex | str | list[AstVal] | bytes | None"
316+
)
314317
AstValNoBytes: "typing.TypeAlias" = "int | float | str | list[AstValNoBytes]"
315-
JSONObject: "typing.TypeAlias" = "int | float | str | list[JSONObject] | JSONDict | None"
318+
JSONObject: "typing.TypeAlias" = (
319+
"int | float | str | list[JSONObject] | JSONDict | None"
320+
)
316321
JSONDict: "typing.TypeAlias" = "dict[str, JSONObject]"
317322

318323

@@ -327,6 +332,7 @@ def to_serializable(val: "AstVal") -> "JSONObject":
327332
else:
328333
return val
329334

335+
330336
def get_value(node: ast.AST) -> "AstVal":
331337
"""Return the value of constant or list of constants"""
332338
if isinstance(node, ast.Constant):
@@ -339,7 +345,7 @@ def get_value(node: ast.AST) -> "AstVal":
339345
if isinstance(node, (ast.List, ast.Tuple)):
340346
return [get_value(e) for e in node.elts]
341347
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
342-
return -typing.cast(typing.Union[int, float, complex], get_value(node.operand))
348+
return -typing.cast(typing.Union[int, float, complex], get_value(node.operand))
343349
raise ValueError("Unexpected node type", type(node))
344350

345351

@@ -372,6 +378,7 @@ def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | types.EllipsisT
372378
defaults = [...] * (len(args) - len(predict.args.defaults)) + predict.args.defaults
373379
return list(zip(args, defaults))
374380

381+
375382
def parse_assignment(assignment: ast.AST) -> "None | tuple[str, JSONObject]":
376383
"""Parse an assignment into an OpenAPI object property"""
377384
if isinstance(assignment, ast.AnnAssign):
@@ -403,7 +410,9 @@ def parse_class(classdef: ast.AST) -> "JSONDict":
403410
"""Parse a class definition into an OpenAPI object"""
404411
assert isinstance(classdef, ast.ClassDef)
405412
properties = {
406-
assignment[0]: assignment[1] for assignment in map(parse_assignment, classdef.body) if assignment
413+
assignment[0]: assignment[1]
414+
for assignment in map(parse_assignment, classdef.body)
415+
if assignment
407416
}
408417
return {
409418
"title": classdef.name,
@@ -428,15 +437,17 @@ def resolve_name(node: ast.expr) -> str:
428437
return node.id
429438
if isinstance(node, ast.Index):
430439
# deprecated, but needed for py3.8
431-
return resolve_name(node.value) # type: ignore
440+
return resolve_name(node.value) # type: ignore
432441
if isinstance(node, ast.Attribute):
433442
return node.attr
434443
if isinstance(node, ast.Subscript):
435444
return resolve_name(node.value)
436445
raise ValueError("Unexpected node type", type(node), ast.unparse(node))
437446

438447

439-
def parse_return_annotation(tree: ast.AST, fn: str = "predict") -> "tuple[JSONDict, JSONDict]":
448+
def parse_return_annotation(
449+
tree: ast.AST, fn: str = "predict"
450+
) -> "tuple[JSONDict, JSONDict]":
440451
predict = find(tree, fn)
441452
if not isinstance(predict, (ast.FunctionDef, ast.AsyncFunctionDef)):
442453
raise ValueError("Could not find predict function")
@@ -550,7 +561,7 @@ def extract_info(code: str) -> "JSONDict":
550561
**return_schema,
551562
}
552563
# trust me, typechecker, I know BASE_SCHEMA
553-
x: "JSONDict" = schema["components"]["schemas"] # type: ignore
564+
x: "JSONDict" = schema["components"]["schemas"] # type: ignore
554565
x.update(components)
555566
return schema
556567

python/cog/predictor.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import enum
22
import importlib.util
33
import inspect
4-
import io
54
import os.path
65
import sys
76
import types
87
from abc import ABC, abstractmethod
9-
from collections.abc import Iterator, AsyncIterator
8+
from collections.abc import AsyncIterator, Iterator
109
from pathlib import Path
1110
from typing import (
1211
Any,
12+
Awaitable,
1313
Callable,
1414
Dict,
1515
List,
@@ -48,7 +48,9 @@
4848

4949

5050
class BasePredictor(ABC):
51-
def setup(self, weights: Optional[Union[CogFile, CogPath]] = None) -> None:
51+
def setup(
52+
self, weights: Optional[Union[CogFile, CogPath]] = None
53+
) -> Optional[Awaitable[None]]:
5254
"""
5355
An optional method to prepare the model so multiple predictions run efficiently.
5456
"""
@@ -63,15 +65,25 @@ def predict(self, **kwargs: Any) -> Any:
6365

6466

6567
def run_setup(predictor: BasePredictor) -> None:
66-
weights_type = get_weights_type(predictor.setup)
67-
68-
# No weights need to be passed, so just run setup() without any arguments.
69-
if weights_type is None:
68+
weights = get_weights_argument(predictor)
69+
if weights:
70+
predictor.setup(weights=weights)
71+
else:
7072
predictor.setup()
71-
return
7273

73-
weights: Union[io.IOBase, Path, None]
7474

75+
async def run_setup_async(predictor: BasePredictor) -> None:
76+
weights = get_weights_argument(predictor)
77+
maybe_coro = predictor.setup(weights=weights) if weights else predictor.setup()
78+
if maybe_coro:
79+
return await maybe_coro
80+
81+
82+
def get_weights_argument(predictor: BasePredictor) -> Union[CogFile, CogPath, None]:
83+
# by the time we get here we assume predictor has a setup method
84+
weights_type = get_weights_type(predictor.setup)
85+
if weights_type is None:
86+
return None
7587
weights_url = os.environ.get("COG_WEIGHTS")
7688
weights_path = "weights"
7789

@@ -81,30 +93,27 @@ def run_setup(predictor: BasePredictor) -> None:
8193
# TODO: CogFile/CogPath should have subclasses for each of the subtypes
8294
if weights_url:
8395
if weights_type == CogFile:
84-
weights = cast(CogFile, CogFile.validate(weights_url))
85-
elif weights_type == CogPath:
96+
return cast(CogFile, CogFile.validate(weights_url))
97+
if weights_type == CogPath:
8698
# TODO: So this can be a url. evil!
87-
weights = cast(CogPath, CogPath.validate(weights_url))
88-
else:
89-
raise ValueError(
90-
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
91-
)
92-
elif os.path.exists(weights_path):
99+
return cast(CogPath, CogPath.validate(weights_url))
100+
raise ValueError(
101+
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
102+
)
103+
if os.path.exists(weights_path):
93104
if weights_type == CogFile:
94-
weights = cast(CogFile, open(weights_path, "rb"))
95-
elif weights_type == CogPath:
96-
weights = CogPath(weights_path)
97-
else:
98-
raise ValueError(
99-
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
100-
)
101-
else:
102-
weights = None
103-
104-
predictor.setup(weights=weights)
105+
return cast(CogFile, open(weights_path, "rb"))
106+
if weights_type == CogPath:
107+
return CogPath(weights_path)
108+
raise ValueError(
109+
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
110+
)
111+
return None
105112

106113

107-
def get_weights_type(setup_function: Callable[[Any], None]) -> Optional[Any]:
114+
def get_weights_type(
115+
setup_function: Callable[[Any], Optional[Awaitable[None]]]
116+
) -> Optional[Any]:
108117
signature = inspect.signature(setup_function)
109118
if "weights" not in signature.parameters:
110119
return None

python/cog/server/http.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,10 @@ async def healthcheck() -> Any:
247247
response_model=PredictionResponse,
248248
response_model_exclude_unset=True,
249249
)
250-
async def predict(request: PredictionRequest = Body(default=None), prefer: Union[str, None] = Header(default=None)) -> Any: # type: ignore
250+
async def predict(
251+
request: PredictionRequest = Body(default=None),
252+
prefer: Union[str, None] = Header(default=None),
253+
) -> Any: # type: ignore
251254
"""
252255
Run a single prediction on the model
253256
"""

python/cog/server/webhook.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def _get_version() -> str:
3737

3838

3939
def webhook_caller_filtered(
40-
webhook: str, webhook_events_filter: Set[WebhookEvent],
40+
webhook: str,
41+
webhook_events_filter: Set[WebhookEvent],
4142
) -> Callable[[Any, WebhookEvent], None]:
4243
upstream_caller = webhook_caller(webhook)
4344

0 commit comments

Comments
 (0)