Skip to content

actually allow concurrent predictions and refactor runner #1500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0df9b82
run CI for this branch the same way as for main
technillogue Jan 16, 2024
513e837
Revert "Revert PR "async runner" (#1352)"
technillogue Jan 16, 2024
db88489
Revert "Revert PR "create event loop before predictor setup" (#1366)"
technillogue Jan 16, 2024
3444169
lints
technillogue Jan 16, 2024
73a6de9
minimal async worker (#1410)
technillogue Jan 22, 2024
27b288f
format
technillogue Jan 22, 2024
45afedc
select utility for racing awaitables
technillogue Dec 7, 2023
54b34e3
start mux
technillogue Nov 29, 2023
0fbadea
tag events with id, read pipe in a task, get events from mux
technillogue Nov 29, 2023
3239230
use async pipe for async child loop
technillogue Nov 29, 2023
58afbfb
_shutting_down vs _terminating
technillogue Dec 2, 2023
a70b027
select shutdown event
technillogue Dec 7, 2023
74a3dac
keep reading events during shutdown, but call terminate after the las…
technillogue Dec 4, 2023
0ce8bd3
imagine a world with context managers for our highly stateful resources
technillogue Dec 7, 2023
597816e
emit heartbeats from mux.read
technillogue Dec 1, 2023
83b325b
don't use _wait. instead, setup reads event from the mux too
technillogue Dec 1, 2023
c1256ea
worker semaphore and prediction ctx
technillogue Jan 23, 2024
18640dc
where _wait used to raise a fatal error, have _read_events set an err…
technillogue Jan 23, 2024
4469766
fix event loop errors for <3.9
technillogue Dec 5, 2023
7399614
keep track of predictions in flight explicitly and use that to route …
technillogue Dec 5, 2023
de29ee7
don't wait for executor shutdown
technillogue Dec 22, 2023
d4e5c6f
progress: check for cancelation in task done_handler
technillogue Jan 18, 2024
13fec6c
let mux check if child is alive and set mux shutdown after leaving re…
technillogue Jan 19, 2024
2f687b7
close pipe when exiting
technillogue Jan 19, 2024
2f193bf
note a different way to do eager state checks with a decorator
technillogue Jan 23, 2024
4015925
predict requires IDLE or PROCESSING
technillogue Jan 23, 2024
e47868c
idk, try adding a BUSY state distinct from PROCESSING when we no long…
technillogue Jan 23, 2024
b441ea0
move resetting events to setup() instead of _read_events()
technillogue Jan 25, 2024
e939c0b
add concurrency to config
technillogue Jan 25, 2024
eaceff7
this basically works!
technillogue Jan 26, 2024
2c28885
more descriptive names for predict functions
technillogue Jan 26, 2024
d44e467
maybe pass through prediction id and try to make cancelation do both?
technillogue Jan 27, 2024
74da750
don't cancel from signal handler if a loop is running. expose worker …
technillogue Jan 27, 2024
b416c97
move handle_event_stream to PredictionEventHandler
technillogue Jan 27, 2024
07e86b8
make setup and canceling work
technillogue Jan 27, 2024
39ea156
drop some checks around cancelation
technillogue Jan 27, 2024
b145c5d
try out eager_predict_state_change
technillogue Jan 29, 2024
0b1ad74
try to make idempotent endpoint return the same result and fix tests …
technillogue Jan 29, 2024
8480e51
fix idempotent tests
technillogue Jan 29, 2024
9fb111e
fix remaining errors?
technillogue Jan 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@ on:
push:
branches:
- main
- async
tags:
- "**"
pull_request:
branches:
- main
- async
merge_group:
branches:
- main
- async
types:
- checks_requested
jobs:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ optional-dependencies = { "dev" = [
"pillow",
"pyright==1.1.347",
"pytest",
"pytest-asyncio",
"pytest-httpserver",
"pytest-rerunfailures",
"pytest-xdist",
Expand Down
29 changes: 20 additions & 9 deletions python/cog/command/ast_openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,15 @@ def find(obj: ast.AST, name: str) -> ast.AST:
"""Find a particular named node in a tree"""
return next(node for node in ast.walk(obj) if getattr(node, "name", "") == name)


if typing.TYPE_CHECKING:
AstVal: "typing.TypeAlias" = "int | float | complex | str | list[AstVal] | bytes | None"
AstVal: "typing.TypeAlias" = (
"int | float | complex | str | list[AstVal] | bytes | None"
)
AstValNoBytes: "typing.TypeAlias" = "int | float | str | list[AstValNoBytes]"
JSONObject: "typing.TypeAlias" = "int | float | str | list[JSONObject] | JSONDict | None"
JSONObject: "typing.TypeAlias" = (
"int | float | str | list[JSONObject] | JSONDict | None"
)
JSONDict: "typing.TypeAlias" = "dict[str, JSONObject]"


Expand All @@ -327,6 +332,7 @@ def to_serializable(val: "AstVal") -> "JSONObject":
else:
return val


def get_value(node: ast.AST) -> "AstVal":
"""Return the value of constant or list of constants"""
if isinstance(node, ast.Constant):
Expand All @@ -339,7 +345,7 @@ def get_value(node: ast.AST) -> "AstVal":
if isinstance(node, (ast.List, ast.Tuple)):
return [get_value(e) for e in node.elts]
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
return -typing.cast(typing.Union[int, float, complex], get_value(node.operand))
return -typing.cast(typing.Union[int, float, complex], get_value(node.operand))
raise ValueError("Unexpected node type", type(node))


Expand All @@ -366,12 +372,13 @@ def get_call_name(call: ast.Call) -> str:
def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | types.EllipsisType]]":
"""Parse argument, default pairs from a file with a predict function"""
predict = find(tree, "predict")
assert isinstance(predict, ast.FunctionDef)
assert isinstance(predict, (ast.FunctionDef, ast.AsyncFunctionDef))
args = predict.args.args # [-len(defaults) :]
# use Ellipsis instead of None here to distinguish a default of None
defaults = [...] * (len(args) - len(predict.args.defaults)) + predict.args.defaults
return list(zip(args, defaults))


def parse_assignment(assignment: ast.AST) -> "None | tuple[str, JSONObject]":
"""Parse an assignment into an OpenAPI object property"""
if isinstance(assignment, ast.AnnAssign):
Expand Down Expand Up @@ -403,7 +410,9 @@ def parse_class(classdef: ast.AST) -> "JSONDict":
"""Parse a class definition into an OpenAPI object"""
assert isinstance(classdef, ast.ClassDef)
properties = {
assignment[0]: assignment[1] for assignment in map(parse_assignment, classdef.body) if assignment
assignment[0]: assignment[1]
for assignment in map(parse_assignment, classdef.body)
if assignment
}
return {
"title": classdef.name,
Expand All @@ -428,17 +437,19 @@ def resolve_name(node: ast.expr) -> str:
return node.id
if isinstance(node, ast.Index):
# deprecated, but needed for py3.8
return resolve_name(node.value) # type: ignore
return resolve_name(node.value) # type: ignore
if isinstance(node, ast.Attribute):
return node.attr
if isinstance(node, ast.Subscript):
return resolve_name(node.value)
raise ValueError("Unexpected node type", type(node), ast.unparse(node))


def parse_return_annotation(tree: ast.AST, fn: str = "predict") -> "tuple[JSONDict, JSONDict]":
def parse_return_annotation(
tree: ast.AST, fn: str = "predict"
) -> "tuple[JSONDict, JSONDict]":
predict = find(tree, fn)
if not isinstance(predict, ast.FunctionDef):
if not isinstance(predict, (ast.FunctionDef, ast.AsyncFunctionDef)):
raise ValueError("Could not find predict function")
annotation = predict.returns
if not annotation:
Expand Down Expand Up @@ -550,7 +561,7 @@ def extract_info(code: str) -> "JSONDict":
**return_schema,
}
# trust me, typechecker, I know BASE_SCHEMA
x: "JSONDict" = schema["components"]["schemas"] # type: ignore
x: "JSONDict" = schema["components"]["schemas"] # type: ignore
x.update(components)
return schema

Expand Down
69 changes: 39 additions & 30 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import enum
import importlib.util
import inspect
import io
import os.path
import sys
import types
from abc import ABC, abstractmethod
from collections.abc import Iterator
from collections.abc import AsyncIterator, Iterator
from pathlib import Path
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Expand Down Expand Up @@ -48,7 +48,9 @@


class BasePredictor(ABC):
def setup(self, weights: Optional[Union[CogFile, CogPath]] = None) -> None:
def setup(
self, weights: Optional[Union[CogFile, CogPath]] = None
) -> Optional[Awaitable[None]]:
"""
An optional method to prepare the model so multiple predictions run efficiently.
"""
Expand All @@ -63,15 +65,25 @@ def predict(self, **kwargs: Any) -> Any:


def run_setup(predictor: BasePredictor) -> None:
weights_type = get_weights_type(predictor.setup)

# No weights need to be passed, so just run setup() without any arguments.
if weights_type is None:
weights = get_weights_argument(predictor)
if weights:
predictor.setup(weights=weights)
else:
predictor.setup()
return

weights: Union[io.IOBase, Path, None]

async def run_setup_async(predictor: BasePredictor) -> None:
weights = get_weights_argument(predictor)
maybe_coro = predictor.setup(weights=weights) if weights else predictor.setup()
if maybe_coro:
return await maybe_coro


def get_weights_argument(predictor: BasePredictor) -> Union[CogFile, CogPath, None]:
# by the time we get here we assume predictor has a setup method
weights_type = get_weights_type(predictor.setup)
if weights_type is None:
return None
weights_url = os.environ.get("COG_WEIGHTS")
weights_path = "weights"

Expand All @@ -81,30 +93,27 @@ def run_setup(predictor: BasePredictor) -> None:
# TODO: CogFile/CogPath should have subclasses for each of the subtypes
if weights_url:
if weights_type == CogFile:
weights = cast(CogFile, CogFile.validate(weights_url))
elif weights_type == CogPath:
return cast(CogFile, CogFile.validate(weights_url))
if weights_type == CogPath:
# TODO: So this can be a url. evil!
weights = cast(CogPath, CogPath.validate(weights_url))
else:
raise ValueError(
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
)
elif os.path.exists(weights_path):
return cast(CogPath, CogPath.validate(weights_url))
raise ValueError(
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
)
if os.path.exists(weights_path):
if weights_type == CogFile:
weights = cast(CogFile, open(weights_path, "rb"))
elif weights_type == CogPath:
weights = CogPath(weights_path)
else:
raise ValueError(
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
)
else:
weights = None

predictor.setup(weights=weights)
return cast(CogFile, open(weights_path, "rb"))
if weights_type == CogPath:
return CogPath(weights_path)
raise ValueError(
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
)
return None


def get_weights_type(setup_function: Callable[[Any], None]) -> Optional[Any]:
def get_weights_type(
setup_function: Callable[[Any], Optional[Awaitable[None]]]
) -> Optional[Any]:
signature = inspect.signature(setup_function)
if "weights" not in signature.parameters:
return None
Expand Down Expand Up @@ -341,7 +350,7 @@ def predict(
OutputType = signature.return_annotation

# The type that goes in the response is a list of the yielded type
if get_origin(OutputType) is Iterator:
if get_origin(OutputType) in {Iterator, AsyncIterator}:
# Annotated allows us to attach Field annotations to the list, which we use to mark that this is an iterator
# https://pydantic-docs.helpmanual.io/usage/schema/#typingannotated-fields
field = Field(**{"x-cog-array-type": "iterator"}) # type: ignore
Expand Down
10 changes: 9 additions & 1 deletion python/cog/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import secrets
import typing as t
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -36,7 +37,12 @@ class PredictionBaseModel(pydantic.BaseModel, extra=pydantic.Extra.allow):


class PredictionRequest(PredictionBaseModel):
id: t.Optional[str]
# there's a problem here where the idempotent endpoint is supposed to
# let you pass id in the route and omit it from the input
# however this fills in the default
# maybe it should be allowed to be optional without the factory initially
# and be filled in later
id: str = pydantic.Field(default_factory=lambda: secrets.token_hex(4))
created_at: t.Optional[datetime]

# TODO: deprecate this
Expand Down Expand Up @@ -85,8 +91,10 @@ def with_types(cls, input_type: t.Type[t.Any], output_type: t.Type[t.Any]) -> t.
output=(output_type, None),
)


class TrainingRequest(PredictionRequest):
pass


class TrainingResponse(PredictionResponse):
pass
23 changes: 22 additions & 1 deletion python/cog/server/eventtypes.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
from typing import Any, Dict
import secrets
from typing import Any, Dict, Union

from attrs import define, field, validators

from .. import schema, types


# From worker parent process
#
@define
class PredictionInput:
payload: Dict[str, Any]
id: str = field(factory=lambda: secrets.token_hex(4))

@classmethod
def from_request(cls, request: schema.PredictionRequest) -> "PredictionInput":
assert request.id, "PredictionRequest must have an id"
payload = request.dict()["input"]
for k, v in payload.items():
if isinstance(v, types.URLPath):
payload[k] = v.convert()
return cls(payload=payload, id=request.id)


@define
class Cancel:
id: str


@define
Expand Down Expand Up @@ -43,3 +61,6 @@ class Done:
@define
class Heartbeat:
pass


PublicEventType = Union[Done, Heartbeat, Log, PredictionOutput, PredictionOutputType]
Loading