Skip to content

Commit b864aa6

Browse files
committed
Rework concurrency
1 parent 033bb1e commit b864aa6

File tree

9 files changed

+368
-140
lines changed

9 files changed

+368
-140
lines changed

internal/server/runner.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ func NewRunner(ipcUrl, uploadUrl string) *Runner {
118118
return newRunner(DefaultRunner, ipcUrl, uploadUrl)
119119
}
120120

121-
func NewProcedureRunner(ipcUrl, uploadUrl, srcURL, srcDir string) *Runner {
121+
func NewProcedureRunner(ipcUrl, uploadUrl, name, srcDir string) *Runner {
122122
// Use srcDir as name
123-
r := newRunner(srcURL, ipcUrl, uploadUrl)
123+
r := newRunner(name, ipcUrl, uploadUrl)
124124
r.cmd.Dir = srcDir
125125
return r
126126
}
@@ -177,6 +177,15 @@ func (r *Runner) SrcDir() string {
177177
return r.cmd.Dir
178178
}
179179

180+
func (r *Runner) Concurrency() Concurrency {
181+
r.mu.Lock()
182+
defer r.mu.Unlock()
183+
return Concurrency{
184+
Max: r.maxConcurrency,
185+
Current: len(r.pending),
186+
}
187+
}
188+
180189
func (r *Runner) Idle() bool {
181190
// IPC from Python runner is the source of truth for Runner.status where
182191
// * Ready: pending predictions < max concurrency
@@ -402,7 +411,6 @@ func (r *Runner) HandleIPC(s IPCStatus) {
402411
r.updateSetupResult()
403412
if _, err := os.Stat(path.Join(r.workingDir, "async_predict")); err == nil {
404413
r.asyncPredict = true
405-
406414
} else if errors.Is(err, os.ErrNotExist) && r.maxConcurrency > 1 {
407415
log.Warnw("max concurrency > 1 for blocking predict, reset to 1", "max_concurrency", r.maxConcurrency)
408416
r.maxConcurrency = 1

internal/server/server.go

Lines changed: 114 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import (
77
"encoding/json"
88
"errors"
99
"fmt"
10+
"os"
1011
"runtime"
12+
"strconv"
1113

1214
"github.com/replicate/go/must"
1315

@@ -63,10 +65,19 @@ func NewHandler(cfg Config, shutdown context.CancelFunc) (*Handler, error) {
6365
// Reset Go server to 1 to make room for Python runners
6466
autoMaxProcs := runtime.GOMAXPROCS(1)
6567
if cfg.UseProcedureMode {
66-
// At least 2 Python runners in procedure mode so that:
67-
// * Server status is READY if available runner slot >= 1, either empty or IDLE
68-
// * The IDLE runner can be evicted for one with a new procedure source URL
69-
h.maxRunners = max(autoMaxProcs, 2)
68+
concurrencyPerCPU := 4
69+
if s, ok := os.LookupEnv("COG_PROCEDURE_CONCURRENCY_PER_CPU"); ok {
70+
if i, err := strconv.Atoi(s); err == nil {
71+
concurrencyPerCPU = i
72+
} else {
73+
log.Errorw("failed to parse COG_PROCEDURE_CONCURRENCY_PER_CPU", "value", s)
74+
}
75+
}
76+
// Set both max runners and max concurrency across all runners to CPU * n,
77+
// regardless what max concurrency each runner has.
78+
// In the worst case scenario where all runners are non-async,
79+
// completion of any runner frees up concurrency.
80+
h.maxRunners = autoMaxProcs * concurrencyPerCPU
7081
log.Infow("running in procedure mode", "max_runners", h.maxRunners)
7182
} else {
7283
h.runners[DefaultRunner] = NewRunner(cfg.IPCUrl, cfg.UploadUrl)
@@ -103,6 +114,17 @@ func (h *Handler) Root(w http.ResponseWriter, r *http.Request) {
103114
}
104115

105116
func (h *Handler) HealthCheck(w http.ResponseWriter, r *http.Request) {
117+
if bs, err := json.Marshal(h.healthCheck()); err != nil {
118+
http.Error(w, err.Error(), http.StatusBadRequest)
119+
} else {
120+
w.WriteHeader(http.StatusOK)
121+
writeBytes(w, bs)
122+
}
123+
}
124+
125+
func (h *Handler) healthCheck() *HealthCheck {
126+
// FIXME: remove ready/busy IPC
127+
// Use Go runner as source of truth for readiness and concurrency
106128
log := logger.Sugar()
107129
var hc HealthCheck
108130
if h.cfg.UseProcedureMode {
@@ -112,10 +134,13 @@ func (h *Handler) HealthCheck(w http.ResponseWriter, r *http.Request) {
112134
CompletedAt: util.FormatTime(h.startedAt),
113135
Status: SetupSucceeded,
114136
},
137+
Concurrency: Concurrency{
138+
// Max runners as max concurrency
139+
Max: h.maxRunners,
140+
},
115141
}
116142
h.mu.Lock()
117143
defer h.mu.Unlock()
118-
hasIdle := false
119144
toRemove := make([]string, 0)
120145
for name, runner := range h.runners {
121146
if runner.status == StatusDefunct || runner.status == StatusSetupFailed {
@@ -128,34 +153,25 @@ func (h *Handler) HealthCheck(w http.ResponseWriter, r *http.Request) {
128153
}()
129154
continue
130155
}
131-
if runner.Idle() {
132-
hasIdle = true
133-
}
156+
// Aggregate current concurrency across workers
157+
hc.Concurrency.Current += runner.Concurrency().Current
134158
}
135-
// In procedure mode, a server is only READY if available runner slot >= 1, either empty or IDLE.
136-
// In the case of a request with a new procedure source URL, the IDLE runner can be evicted.
137-
// Otherwise, we report BUSY even if all runners are READY but not IDLE, e.g. len(pending) > 0.
138159
for _, name := range toRemove {
139160
delete(h.runners, name)
140161
}
141-
if len(h.runners) < h.maxRunners || hasIdle {
162+
if hc.Concurrency.Current < hc.Concurrency.Max {
142163
hc.Status = StatusReady.String()
143164
} else {
144165
hc.Status = StatusBusy.String()
145166
}
146167
} else {
147168
hc = HealthCheck{
148-
Status: h.runners[DefaultRunner].status.String(),
149-
Setup: &h.runners[DefaultRunner].setupResult,
169+
Status: h.runners[DefaultRunner].status.String(),
170+
Setup: &h.runners[DefaultRunner].setupResult,
171+
Concurrency: h.runners[DefaultRunner].Concurrency(),
150172
}
151173
}
152-
153-
if bs, err := json.Marshal(hc); err != nil {
154-
http.Error(w, err.Error(), http.StatusBadRequest)
155-
} else {
156-
w.WriteHeader(http.StatusOK)
157-
writeBytes(w, bs)
158-
}
174+
return &hc
159175
}
160176

161177
func (h *Handler) OpenApi(w http.ResponseWriter, r *http.Request) {
@@ -196,7 +212,7 @@ func (h *Handler) Stop() error {
196212
eg := errgroup.Group{}
197213
for name, runner := range h.runners {
198214
if err = runner.Stop(); err != nil {
199-
log.Errorw("failed to stop runner", "name", name, "err", err)
215+
log.Errorw("failed to stop runner", "name", name, "error", err)
200216
}
201217
eg.Go(func() error {
202218
runner.WaitForStop()
@@ -235,16 +251,22 @@ func (h *Handler) HandleIPC(w http.ResponseWriter, r *http.Request) {
235251
}
236252
}
237253

238-
func (h *Handler) getRunner(srcURL, srcDir string) (*Runner, error) {
254+
func (h *Handler) predictWithRunner(srcURL string, req PredictionRequest) (chan PredictionResponse, error) {
239255
log := logger.Sugar()
240256

241257
// Lock before checking to avoid thrashing runner replacements
242258
h.mu.Lock()
243259
defer h.mu.Unlock()
244260

245-
// Reuse current runner, nothing to do
246-
if runner, ok := h.runners[srcURL]; ok {
247-
return runner, nil
261+
// Look for an existing runner copy for source URL in READY state
262+
// There might be multiple copies if the # pending predictions > max concurrency of a single runner
263+
// For non-async predictors, the same runner might occupy all runner slots
264+
for i := 0; i <= h.maxRunners; i++ {
265+
name := fmt.Sprintf("%02d:%s", i, srcURL)
266+
runner, ok := h.runners[name]
267+
if ok && runner.Concurrency().Current < runner.Concurrency().Max {
268+
return runner.Predict(req)
269+
}
248270
}
249271

250272
// Need to evict one
@@ -253,7 +275,7 @@ func (h *Handler) getRunner(srcURL, srcDir string) (*Runner, error) {
253275
if !runner.Idle() {
254276
continue
255277
}
256-
log.Infow("stopping procedure runner", "src_url", name)
278+
log.Infow("stopping procedure runner", "name", name)
257279
if err := runner.Stop(); err != nil {
258280
log.Errorw("failed to stop runner", "error", err)
259281
} else {
@@ -262,14 +284,37 @@ func (h *Handler) getRunner(srcURL, srcDir string) (*Runner, error) {
262284
}
263285
}
264286
}
287+
// Failed to evict one, this should not happen
265288
if len(h.runners) == h.maxRunners {
289+
log.Errorw("failed to find idle runner to evict", "src_url", srcURL)
290+
return nil, ErrConflict
291+
}
292+
293+
// Find the first available slot for the new runner copy
294+
var name string
295+
var slot int
296+
for i := 0; i <= h.maxRunners; i++ {
297+
n := fmt.Sprintf("%02d:%s", i, srcURL)
298+
if _, ok := h.runners[n]; !ok {
299+
name = n
300+
slot = i
301+
break
302+
}
303+
}
304+
// Max out slots, this should not happen
305+
if name == "" {
306+
log.Errorw("reached max copies of runner", "src_url", srcURL)
266307
return nil, ErrConflict
267308
}
268309

269310
// Start new runner
270-
log.Infow("starting procedure runner", "src_url", srcURL)
271-
r := NewProcedureRunner(h.cfg.IPCUrl, h.cfg.UploadUrl, srcURL, srcDir)
272-
h.runners[srcURL] = r
311+
srcDir, err := util.PrepareProcedureSourceURL(srcURL, slot)
312+
if err != nil {
313+
return nil, err
314+
}
315+
log.Infow("starting procedure runner", "src_url", srcURL, "src_dir", srcDir)
316+
r := NewProcedureRunner(h.cfg.IPCUrl, h.cfg.UploadUrl, name, srcDir)
317+
h.runners[name] = r
273318

274319
if err := r.Start(); err != nil {
275320
return nil, err
@@ -282,12 +327,36 @@ func (h *Handler) getRunner(srcURL, srcDir string) (*Runner, error) {
282327
}
283328
if r.status == StatusSetupFailed {
284329
log.Errorw("procedure runner setup failed", "logs", r.setupResult.Logs)
285-
delete(h.runners, srcURL)
286-
// Include failed runner here so that the caller can extract setup logs and respond with a prediction failure
287-
return r, ErrSetupFailed
330+
delete(h.runners, name)
331+
332+
// Translate setup failure to prediction failure
333+
resp := PredictionResponse{
334+
Input: req.Input,
335+
Id: req.Id,
336+
CreatedAt: r.setupResult.StartedAt,
337+
StartedAt: r.setupResult.StartedAt,
338+
CompletedAt: r.setupResult.CompletedAt,
339+
Logs: r.setupResult.Logs,
340+
Status: PredictionFailed,
341+
Error: ErrSetupFailed.Error(),
342+
}
343+
if req.Webhook == "" {
344+
c := make(chan PredictionResponse, 1)
345+
c <- resp
346+
return c, nil
347+
} else {
348+
// Async prediction, send webhook
349+
go func() {
350+
if err := SendWebhook(req.Webhook, &resp); err != nil {
351+
log.Errorw("failed to send webhook", "url", "error", err)
352+
}
353+
}()
354+
return nil, nil
355+
}
356+
288357
}
289358
if time.Since(start) > 10*time.Second {
290-
delete(h.runners, srcURL)
359+
delete(h.runners, name)
291360
log.Errorw("stopping procedure runner after time out", "elapsed", time.Since(start))
292361
if err := r.Stop(); err != nil {
293362
log.Errorw("failed to stop procedure runner", "error", err)
@@ -296,11 +365,10 @@ func (h *Handler) getRunner(srcURL, srcDir string) (*Runner, error) {
296365
}
297366
time.Sleep(10 * time.Millisecond)
298367
}
299-
return r, nil
368+
return r.Predict(req)
300369
}
301370

302371
func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) {
303-
log := logger.Sugar()
304372
if r.Header.Get("Content-Type") != "application/json" {
305373
http.Error(w, "invalid content type", http.StatusUnsupportedMediaType)
306374
return
@@ -330,8 +398,15 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) {
330398
req.Id = util.PredictionId()
331399
}
332400

333-
var runner *Runner
401+
var c chan PredictionResponse
334402
if h.cfg.UseProcedureMode {
403+
// Although individual runners may have higher concurrency than the global max runners/concurrency
404+
// We still bail early if the global max has been reached
405+
concurrency := h.healthCheck().Concurrency
406+
if concurrency.Current == concurrency.Max {
407+
http.Error(w, ErrConflict.Error(), http.StatusConflict)
408+
return
409+
}
335410
val, ok := req.Context["procedure_source_url"]
336411
if !ok {
337412
http.Error(w, "missing procedure_source_url in context", http.StatusBadRequest)
@@ -350,47 +425,11 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) {
350425
http.Error(w, "empty procedure_source_url or replicate_api_token", http.StatusBadRequest)
351426
return
352427
}
353-
srcDir, err := util.PrepareProcedureSourceURL(procedureSourceUrl)
354-
if err != nil {
355-
http.Error(w, "invalid procedure_source_url", http.StatusBadRequest)
356-
}
357-
if r, err := h.getRunner(procedureSourceUrl, srcDir); err == nil {
358-
runner = r
359-
} else if errors.Is(err, ErrConflict) {
360-
http.Error(w, err.Error(), http.StatusConflict)
361-
return
362-
} else if errors.Is(err, ErrSetupFailed) {
363-
// Translate setup failure to prediction failure
364-
resp := PredictionResponse{
365-
Input: req.Input,
366-
Id: req.Id,
367-
CreatedAt: r.setupResult.StartedAt,
368-
StartedAt: r.setupResult.StartedAt,
369-
CompletedAt: r.setupResult.CompletedAt,
370-
Logs: r.setupResult.Logs,
371-
Status: PredictionFailed,
372-
}
373-
374-
if req.Webhook == "" {
375-
w.WriteHeader(http.StatusOK)
376-
writeResponse(w, resp)
377-
} else {
378-
w.WriteHeader(http.StatusAccepted)
379-
writeResponse(w, PredictionResponse{Id: req.Id, Status: "starting"})
380-
if err := SendWebhook(req.Webhook, &resp); err != nil {
381-
log.Errorw("failed to send webhook", "url", "error", err)
382-
}
383-
}
384-
return
385-
} else {
386-
http.Error(w, err.Error(), http.StatusInternalServerError)
387-
return
388-
}
428+
c, err = h.predictWithRunner(procedureSourceUrl, req)
389429
} else {
390-
runner = h.runners[DefaultRunner]
430+
c, err = h.runners[DefaultRunner].Predict(req)
391431
}
392432

393-
c, err := runner.Predict(req)
394433
if errors.Is(err, ErrConflict) {
395434
http.Error(w, err.Error(), http.StatusConflict)
396435
return

internal/server/types.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,15 @@ const (
7979
WebhookCompleted WebhookEvent = "completed"
8080
)
8181

82+
type Concurrency struct {
83+
Max int `json:"max,omitempty"`
84+
Current int `json:"current,omitempty"`
85+
}
86+
8287
type HealthCheck struct {
83-
Status string `json:"status"`
84-
Setup *SetupResult `json:"setup,omitempty"`
88+
Status string `json:"status"`
89+
Setup *SetupResult `json:"setup,omitempty"`
90+
Concurrency Concurrency `json:"concurrency,omitempty"`
8591
}
8692

8793
type SetupResult struct {

internal/tests/async_prediction_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,19 @@ func TestAsyncPredictionConcurrency(t *testing.T) {
133133
hc := ct.WaitForSetup()
134134
assert.Equal(t, server.StatusReady.String(), hc.Status)
135135
assert.Equal(t, server.SetupSucceeded, hc.Setup.Status)
136+
if !*legacyCog {
137+
// Compat: not implemented in legacy Cog
138+
assert.Equal(t, 1, hc.Concurrency.Max)
139+
assert.Equal(t, 0, hc.Concurrency.Current)
140+
}
136141

137142
ct.AsyncPrediction(map[string]any{"i": 1, "s": "bar"})
143+
if !*legacyCog {
144+
// Compat: not implemented in legacy Cog
145+
hc = ct.HealthCheck()
146+
assert.Equal(t, 1, hc.Concurrency.Max)
147+
assert.Equal(t, 1, hc.Concurrency.Current)
148+
}
138149

139150
// Fail prediction requests when one is in progress
140151
req := server.PredictionRequest{

internal/tests/cog_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func NewCogTest(t *testing.T, module string) *CogTest {
121121
}
122122

123123
func NewCogProcedureTest(t *testing.T) *CogTest {
124-
t.Parallel()
124+
// No parallel procedure test since they use the same temp source directory
125125
return &CogTest{
126126
t: t,
127127
procedure: true,

0 commit comments

Comments
 (0)