Skip to content

Commit 976485f

Browse files
technilloguemattt
authored andcommitted
replace requests with httpx and factor out clients (#1574, #1707, #1714)
* input downloads, output uploads, and webhooks are now handled by ClientManager, which persists for the lifetime of runner, allowing us to reuse connections, which may significantly help with large uploads. * although I was originally going to drop output_file_prefix, it's not actually hard to maintain. the behavior is changed now and objects are uploaded as soon as they're outputted rather than after the prediction is completed. * there's an ugly hack with uploading an empty body to get the redirect instead of making api time out from trying to upload an 140GB file. that can be fixed by implemented an MPU endpoint and/or a "fetch upload url" endpoint. * the behavior of the non-indempotent endpoint is changed; the id is now randomly generated if it's not provided in the body. this isn't strictly required for this change alone, but is hard to carve out. * the behavior of Path is changed significantly. see https://www.notion.so/replicate/Cog-Setup-Path-Problem-2fc41d40bcaf47579ccd8b2f4c71ee24 Co-authored-by: Mattt <mattt@replicate.com> * format * stick a %s on line 190 clients.py (#1707) * local upload server can be called cluster.local in addition to .internal (#1714) Signed-off-by: technillogue <technillogue@gmail.com>
1 parent 7624116 commit 976485f

26 files changed

+784
-555
lines changed

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ authors = [{ name = "Replicate", email = "team@replicate.com" }]
1010
license.file = "LICENSE"
1111
urls."Source" = "https://github.com/replicate/cog"
1212

13-
requires-python = ">=3.7"
13+
requires-python = ">=3.8"
1414
dependencies = [
1515
# intentionally loose. perhaps these should be vendored to not collide with user code?
1616
"attrs>=20.1,<24",
1717
"fastapi>=0.75.2,<0.99.0",
18+
# we may not need http2
19+
"httpx[http2]>=0.21.0,<1",
1820
"pydantic>=1.9,<2",
1921
"PyYAML",
2022
"requests>=2,<3",
@@ -27,9 +29,9 @@ dependencies = [
2729
optional-dependencies = { "dev" = [
2830
"black",
2931
"build",
30-
"httpx",
3132
'hypothesis<6.80.0; python_version < "3.8"',
3233
'hypothesis; python_version >= "3.8"',
34+
"respx",
3335
'numpy<1.22.0; python_version < "3.8"',
3436
'numpy; python_version >= "3.8"',
3537
"pillow",

python/cog/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from pydantic import BaseModel
22

33
from .predictor import BasePredictor
4-
from .types import AsyncConcatenateIterator, ConcatenateIterator, File, Input, Path, Secret
4+
from .types import (
5+
AsyncConcatenateIterator,
6+
ConcatenateIterator,
7+
File,
8+
Input,
9+
Path,
10+
Secret,
11+
)
512

613
try:
714
from ._version import __version__

python/cog/json.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
import io
21
from datetime import datetime
32
from enum import Enum
43
from types import GeneratorType
5-
from typing import Any, Callable
4+
from typing import Any
65

76
from pydantic import BaseModel
87

9-
from .types import Path
10-
118

129
def make_encodeable(obj: Any) -> Any:
1310
"""
@@ -39,24 +36,3 @@ def make_encodeable(obj: Any) -> Any:
3936
if isinstance(obj, np.ndarray):
4037
return obj.tolist()
4138
return obj
42-
43-
44-
def upload_files(obj: Any, upload_file: Callable[[io.IOBase], str]) -> Any:
45-
"""
46-
Iterates through an object from make_encodeable and uploads any files.
47-
48-
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
49-
"""
50-
# skip four isinstance checks for fast text models
51-
if type(obj) == str: # noqa: E721
52-
return obj
53-
if isinstance(obj, dict):
54-
return {key: upload_files(value, upload_file) for key, value in obj.items()}
55-
if isinstance(obj, list):
56-
return [upload_files(value, upload_file) for value in obj]
57-
if isinstance(obj, Path):
58-
with obj.open("rb") as f:
59-
return upload_file(f)
60-
if isinstance(obj, io.IOBase):
61-
return upload_file(obj)
62-
return obj

python/cog/logging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,5 @@ def setup_logging(*, log_level: int = logging.NOTSET) -> None:
8686

8787
# Reconfigure log levels for some overly chatty libraries
8888
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
89+
# FIXME: no more urllib3(?)
8990
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)

python/cog/predictor.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,24 @@
2222
)
2323
from unittest.mock import patch
2424

25-
import structlog
26-
27-
import cog.code_xforms as code_xforms
28-
2925
try:
3026
from typing import get_args, get_origin
3127
except ImportError: # Python < 3.8
3228
from typing_compat import get_args, get_origin # type: ignore
3329

30+
import structlog
3431
import yaml
3532
from pydantic import BaseModel, Field, create_model
3633
from pydantic.fields import FieldInfo
3734

3835
# Added in Python 3.9. Can be from typing if we drop support for <3.9
3936
from typing_extensions import Annotated
4037

38+
from . import code_xforms
4139
from .errors import ConfigDoesNotExist, PredictorNotSet
42-
from .types import (
43-
CogConfig,
44-
Input,
45-
URLPath,
46-
)
47-
from .types import (
48-
File as CogFile,
49-
)
50-
from .types import (
51-
Path as CogPath,
52-
)
40+
from .types import CogConfig, Input, URLTempFile
41+
from .types import File as CogFile
42+
from .types import Path as CogPath
5343
from .types import Secret as CogSecret
5444

5545
log = structlog.get_logger("cog.server.predictor")
@@ -97,27 +87,33 @@ async def run_setup_async(predictor: BasePredictor) -> None:
9787
return await maybe_coro
9888

9989

100-
def get_weights_argument(predictor: BasePredictor) -> Union[CogFile, CogPath, None]:
90+
def get_weights_argument(
91+
predictor: BasePredictor,
92+
) -> Union[CogFile, CogPath, str, None]:
10193
# by the time we get here we assume predictor has a setup method
10294
weights_type = get_weights_type(predictor.setup)
10395
if weights_type is None:
10496
return None
10597
weights_url = os.environ.get("COG_WEIGHTS")
106-
weights_path = "weights"
98+
weights_path = "weights" # this is the source of a bug isn't it?
10799

108100
# TODO: Cog{File,Path}.validate(...) methods accept either "real"
109101
# paths/files or URLs to those things. In future we can probably tidy this
110102
# up a little bit.
111103
# TODO: CogFile/CogPath should have subclasses for each of the subtypes
104+
105+
# this is a breaking change
106+
# previously, CogPath wouldn't be converted in setup(); now it is
107+
# essentially everyone needs to switch from Path to str (or a new URL type)
112108
if weights_url:
113109
if weights_type == CogFile:
114110
return cast(CogFile, CogFile.validate(weights_url))
115111
if weights_type == CogPath:
116112
# TODO: So this can be a url. evil!
117-
weights = cast(CogPath, CogPath.validate(weights_url))
113+
return cast(CogPath, CogPath.validate(weights_url))
118114
# allow people to download weights themselves
119115
elif weights_type == str: # noqa: E721
120-
weights = weights_url
116+
return weights_url
121117
else:
122118
raise ValueError(
123119
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File, Path and str are supported"
@@ -128,13 +124,13 @@ def get_weights_argument(predictor: BasePredictor) -> Union[CogFile, CogPath, No
128124
if weights_type == CogPath:
129125
return CogPath(weights_path)
130126
raise ValueError(
131-
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
127+
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File, Path and str are supported"
132128
)
133129
return None
134130

135131

136132
def get_weights_type(
137-
setup_function: Callable[[Any], Optional[Awaitable[None]]]
133+
setup_function: Callable[[Any], Optional[Awaitable[None]]],
138134
) -> Optional[Any]:
139135
signature = inspect.signature(setup_function)
140136
if "weights" not in signature.parameters:
@@ -276,12 +272,19 @@ def cleanup(self) -> None:
276272
Cleanup any temporary files created by the input.
277273
"""
278274
for _, value in self:
279-
# Handle URLPath objects specially for cleanup.
275+
# Handle URLTempFile objects specially for cleanup.
280276
# Also handle pathlib.Path objects, which cog.Path is a subclass of.
281277
# A pathlib.Path object shouldn't make its way here,
282278
# but both have an unlink() method, so we may as well be safe.
283-
if isinstance(value, (URLPath, Path)):
284-
value.unlink(missing_ok=True)
279+
if isinstance(value, (URLTempFile, Path)):
280+
try:
281+
value.unlink(missing_ok=True)
282+
except FileNotFoundError:
283+
pass
284+
285+
# if we had a separate method to traverse the input and apply some function to each value
286+
# we could have cleanup/get_tempfile/convert functions that operate on a single value
287+
# and do it that way. convert is supposed to mutate though, so it's tricky
285288

286289

287290
def validate_input_type(type: Type[Any], name: str) -> None:

python/cog/schema.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import importlib.util
22
import os
33
import os.path
4+
import secrets
45
import sys
56
import typing as t
67
from datetime import datetime
@@ -43,7 +44,14 @@ class PredictionBaseModel(pydantic.BaseModel, extra=pydantic.Extra.allow):
4344

4445

4546
class PredictionRequest(PredictionBaseModel):
46-
id: t.Optional[str]
47+
# there's a problem here where the idempotent endpoint is supposed to
48+
# let you pass id in the route and omit it from the input
49+
# however this fills in the default
50+
# maybe it should be allowed to be optional without the factory initially
51+
# and be filled in later
52+
#
53+
# actually, this changes the public api so we should really do this differently
54+
id: str = pydantic.Field(default_factory=lambda: secrets.token_hex(4))
4755
created_at: t.Optional[datetime]
4856

4957
# TODO: deprecate this

0 commit comments

Comments
 (0)