Skip to content

Commit 78919cc

Browse files
committed
Test async predictor with legacy cog
1 parent 4f7f4b6 commit 78919cc

File tree

6 files changed

+28
-21
lines changed

6 files changed

+28
-21
lines changed

internal/server/runner.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,13 @@ func (r *Runner) predict(req PredictionRequest) (chan PredictionResponse, error)
177177
req.CreatedAt = util.NowIso()
178178
}
179179
r.mu.Lock()
180-
if !r.asyncPredict && r.maxConcurrency > 0 && len(r.pending) > r.maxConcurrency {
180+
// blocking `def predict()`, max concurrency is always 1
181+
maxPending := 1
182+
// `async def predict()`, respect concurrency.max from cog.yaml
183+
if r.asyncPredict {
184+
maxPending = r.maxConcurrency
185+
}
186+
if !r.asyncPredict && len(r.pending) >= maxPending {
181187
r.mu.Unlock()
182188
log.Errorw("prediction rejected: Already running a prediction")
183189
return nil, ErrConflict

internal/tests/async_predictor_test.go

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,13 @@ package tests
33
import (
44
"strings"
55
"testing"
6-
"time"
76

87
"github.com/replicate/cog-runtime/internal/server"
98

109
"github.com/stretchr/testify/assert"
1110
)
1211

1312
func TestAsyncPredictorConcurrency(t *testing.T) {
14-
if *legacyCog {
15-
// Compat: legacy Cog rejects concurrent prediction requests
16-
t.SkipNow()
17-
}
1813
ct := NewCogTest(t, "async_sleep")
1914
ct.StartWebhook()
2015
assert.NoError(t, ct.Start())
@@ -23,8 +18,8 @@ func TestAsyncPredictorConcurrency(t *testing.T) {
2318
assert.Equal(t, server.StatusReady.String(), hc.Status)
2419
assert.Equal(t, server.SetupSucceeded, hc.Setup.Status)
2520

26-
barId := ct.AsyncPrediction(map[string]any{"i": 1, "s": "bar"})
27-
bazId := ct.AsyncPrediction(map[string]any{"i": 2, "s": "baz"})
21+
barId := ct.AsyncPredictionWithId("p01", map[string]any{"i": 1, "s": "bar"})
22+
bazId := ct.AsyncPredictionWithId("p02", map[string]any{"i": 2, "s": "baz"})
2823
wr := ct.WaitForWebhookCompletion()
2924
var barR []server.PredictionResponse
3025
var bazR []server.PredictionResponse
@@ -45,6 +40,11 @@ func TestAsyncPredictorConcurrency(t *testing.T) {
4540
}
4641

4742
func TestAsyncPredictorCanceled(t *testing.T) {
43+
if *legacyCog {
44+
// Cancellation bug as of 0.14.1
45+
// https://github.com/replicate/cog/issues/2212
46+
t.SkipNow()
47+
}
4848
ct := NewCogTest(t, "async_sleep")
4949
ct.StartWebhook()
5050
assert.NoError(t, ct.Start())
@@ -55,14 +55,9 @@ func TestAsyncPredictorCanceled(t *testing.T) {
5555

5656
pid := "p01"
5757
ct.AsyncPredictionWithId(pid, map[string]any{"i": 60, "s": "bar"})
58-
if *legacyCog {
59-
// Compat: legacy Cog does not send output webhook
60-
time.Sleep(time.Second)
61-
} else {
62-
ct.WaitForWebhook(func(response server.PredictionResponse) bool {
63-
return strings.Contains(response.Logs, "prediction in progress 1/60\n")
64-
})
65-
}
58+
ct.WaitForWebhook(func(response server.PredictionResponse) bool {
59+
return strings.Contains(response.Logs, "prediction in progress 1/60\n")
60+
})
6661
ct.Cancel(pid)
6762
wr := ct.WaitForWebhookCompletion()
6863
logs := "starting async prediction\nprediction in progress 1/60\nprediction canceled\n"

internal/tests/cog_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,12 @@ func (ct *CogTest) legacyCmd() *exec.Cmd {
185185
tmpDir := ct.t.TempDir()
186186
runnersPath := path.Join(basePath, "python", "tests", "runners")
187187
module := fmt.Sprintf("%s.py", ct.module)
188-
must.Do(os.Symlink(path.Join(runnersPath, "cog.yaml"), path.Join(tmpDir, "cog.yaml")))
188+
yaml := "cog.yaml"
189+
if strings.HasPrefix(ct.module, "async_") {
190+
// cog.yaml with concurrency.max
191+
yaml = "async_cog.yaml"
192+
}
193+
must.Do(os.Symlink(path.Join(runnersPath, yaml), path.Join(tmpDir, "cog.yaml")))
189194
must.Do(os.Symlink(path.Join(runnersPath, module), path.Join(tmpDir, "predict.py")))
190195
pythonBin := path.Join(basePath, ".venv-legacy", "bin", "python3")
191196
ct.serverPort = portFinder.Get()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
predict: "predict.py:Predictor"
2+
concurrency:
3+
max: 2

python/tests/runners/cog.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
build:
2-
python_version: "3.8"
31
predict: "predict.py:Predictor"

script/init.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ cd "$base_dir"
1111
uv sync --all-extras
1212

1313
# venv with legacy Cog
14-
uv venv --python "$python_version" .venv-legacy
14+
# python >= 3.11 for async tests
15+
uv venv --python 3.13 .venv-legacy
1516
export VIRTUAL_ENV="$base_dir/.venv-legacy"
1617
export UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
17-
uv sync --all-extras
1818
uv pip install cog==0.14.1

0 commit comments

Comments
 (0)