77 "encoding/json"
88 "errors"
99 "fmt"
10+ "os"
1011 "runtime"
12+ "strconv"
1113
1214 "github.com/replicate/go/must"
1315
@@ -63,10 +65,19 @@ func NewHandler(cfg Config, shutdown context.CancelFunc) (*Handler, error) {
6365 // Reset Go server to 1 to make room for Python runners
6466 autoMaxProcs := runtime .GOMAXPROCS (1 )
6567 if cfg .UseProcedureMode {
66- // At least 2 Python runners in procedure mode so that:
67- // * Server status is READY if available runner slot >= 1, either empty or IDLE
68- // * The IDLE runner can be evicted for one with a new procedure source URL
69- h .maxRunners = max (autoMaxProcs , 2 )
68+ concurrencyPerCPU := 4
69+ if s , ok := os .LookupEnv ("COG_PROCEDURE_CONCURRENCY_PER_CPU" ); ok {
70+ if i , err := strconv .Atoi (s ); err == nil {
71+ concurrencyPerCPU = i
72+ } else {
73+ log .Errorw ("failed to parse COG_PROCEDURE_CONCURRENCY_PER_CPU" , "value" , s )
74+ }
75+ }
76+ // Set both max runners and max concurrency across all runners to CPU * n,
77+ // regardless what max concurrency each runner has.
78+ // In the worst case scenario where all runners are non-async,
79+ // completion of any runner frees up concurrency.
80+ h .maxRunners = autoMaxProcs * concurrencyPerCPU
7081 log .Infow ("running in procedure mode" , "max_runners" , h .maxRunners )
7182 } else {
7283 h .runners [DefaultRunner ] = NewRunner (cfg .IPCUrl , cfg .UploadUrl )
@@ -103,6 +114,17 @@ func (h *Handler) Root(w http.ResponseWriter, r *http.Request) {
103114}
104115
105116func (h * Handler ) HealthCheck (w http.ResponseWriter , r * http.Request ) {
117+ if bs , err := json .Marshal (h .healthCheck ()); err != nil {
118+ http .Error (w , err .Error (), http .StatusBadRequest )
119+ } else {
120+ w .WriteHeader (http .StatusOK )
121+ writeBytes (w , bs )
122+ }
123+ }
124+
125+ func (h * Handler ) healthCheck () * HealthCheck {
126+ // FIXME: remove ready/busy IPC
127+ // Use Go runner as source of truth for readiness and concurrency
106128 log := logger .Sugar ()
107129 var hc HealthCheck
108130 if h .cfg .UseProcedureMode {
@@ -112,10 +134,13 @@ func (h *Handler) HealthCheck(w http.ResponseWriter, r *http.Request) {
112134 CompletedAt : util .FormatTime (h .startedAt ),
113135 Status : SetupSucceeded ,
114136 },
137+ Concurrency : Concurrency {
138+ // Max runners as max concurrency
139+ Max : h .maxRunners ,
140+ },
115141 }
116142 h .mu .Lock ()
117143 defer h .mu .Unlock ()
118- hasIdle := false
119144 toRemove := make ([]string , 0 )
120145 for name , runner := range h .runners {
121146 if runner .status == StatusDefunct || runner .status == StatusSetupFailed {
@@ -128,34 +153,25 @@ func (h *Handler) HealthCheck(w http.ResponseWriter, r *http.Request) {
128153 }()
129154 continue
130155 }
131- if runner .Idle () {
132- hasIdle = true
133- }
156+ // Aggregate current concurrency across workers
157+ hc .Concurrency .Current += runner .Concurrency ().Current
134158 }
135- // In procedure mode, a server is only READY if available runner slot >= 1, either empty or IDLE.
136- // In the case of a request with a new procedure source URL, the IDLE runner can be evicted.
137- // Otherwise, we report BUSY even if all runners are READY but not IDLE, e.g. len(pending) > 0.
138159 for _ , name := range toRemove {
139160 delete (h .runners , name )
140161 }
141- if len ( h . runners ) < h . maxRunners || hasIdle {
162+ if hc . Concurrency . Current < hc . Concurrency . Max {
142163 hc .Status = StatusReady .String ()
143164 } else {
144165 hc .Status = StatusBusy .String ()
145166 }
146167 } else {
147168 hc = HealthCheck {
148- Status : h .runners [DefaultRunner ].status .String (),
149- Setup : & h .runners [DefaultRunner ].setupResult ,
169+ Status : h .runners [DefaultRunner ].status .String (),
170+ Setup : & h .runners [DefaultRunner ].setupResult ,
171+ Concurrency : h .runners [DefaultRunner ].Concurrency (),
150172 }
151173 }
152-
153- if bs , err := json .Marshal (hc ); err != nil {
154- http .Error (w , err .Error (), http .StatusBadRequest )
155- } else {
156- w .WriteHeader (http .StatusOK )
157- writeBytes (w , bs )
158- }
174+ return & hc
159175}
160176
161177func (h * Handler ) OpenApi (w http.ResponseWriter , r * http.Request ) {
@@ -196,7 +212,7 @@ func (h *Handler) Stop() error {
196212 eg := errgroup.Group {}
197213 for name , runner := range h .runners {
198214 if err = runner .Stop (); err != nil {
199- log .Errorw ("failed to stop runner" , "name" , name , "err " , err )
215+ log .Errorw ("failed to stop runner" , "name" , name , "error " , err )
200216 }
201217 eg .Go (func () error {
202218 runner .WaitForStop ()
@@ -235,16 +251,22 @@ func (h *Handler) HandleIPC(w http.ResponseWriter, r *http.Request) {
235251 }
236252}
237253
238- func (h * Handler ) getRunner (srcURL , srcDir string ) (* Runner , error ) {
254+ func (h * Handler ) predictWithRunner (srcURL string , req PredictionRequest ) (chan PredictionResponse , error ) {
239255 log := logger .Sugar ()
240256
241257 // Lock before checking to avoid thrashing runner replacements
242258 h .mu .Lock ()
243259 defer h .mu .Unlock ()
244260
245- // Reuse current runner, nothing to do
246- if runner , ok := h .runners [srcURL ]; ok {
247- return runner , nil
261+ // Look for an existing runner copy for source URL in READY state
262+ // There might be multiple copies if the # pending predictions > max concurrency of a single runner
263+ // For non-async predictors, the same runner might occupy all runner slots
264+ for i := 0 ; i <= h .maxRunners ; i ++ {
265+ name := fmt .Sprintf ("%02d:%s" , i , srcURL )
266+ runner , ok := h .runners [name ]
267+ if ok && runner .Concurrency ().Current < runner .Concurrency ().Max {
268+ return runner .Predict (req )
269+ }
248270 }
249271
250272 // Need to evict one
@@ -253,7 +275,7 @@ func (h *Handler) getRunner(srcURL, srcDir string) (*Runner, error) {
253275 if ! runner .Idle () {
254276 continue
255277 }
256- log .Infow ("stopping procedure runner" , "src_url " , name )
278+ log .Infow ("stopping procedure runner" , "name " , name )
257279 if err := runner .Stop (); err != nil {
258280 log .Errorw ("failed to stop runner" , "error" , err )
259281 } else {
@@ -262,14 +284,37 @@ func (h *Handler) getRunner(srcURL, srcDir string) (*Runner, error) {
262284 }
263285 }
264286 }
287+ // Failed to evict one, this should not happen
265288 if len (h .runners ) == h .maxRunners {
289+ log .Errorw ("failed to find idle runner to evict" , "src_url" , srcURL )
290+ return nil , ErrConflict
291+ }
292+
293+ // Find the first available slot for the new runner copy
294+ var name string
295+ var slot int
296+ for i := 0 ; i <= h .maxRunners ; i ++ {
297+ n := fmt .Sprintf ("%02d:%s" , i , srcURL )
298+ if _ , ok := h .runners [n ]; ! ok {
299+ name = n
300+ slot = i
301+ break
302+ }
303+ }
304+ // Max out slots, this should not happen
305+ if name == "" {
306+ log .Errorw ("reached max copies of runner" , "src_url" , srcURL )
266307 return nil , ErrConflict
267308 }
268309
269310 // Start new runner
270- log .Infow ("starting procedure runner" , "src_url" , srcURL )
271- r := NewProcedureRunner (h .cfg .IPCUrl , h .cfg .UploadUrl , srcURL , srcDir )
272- h .runners [srcURL ] = r
311+ srcDir , err := util .PrepareProcedureSourceURL (srcURL , slot )
312+ if err != nil {
313+ return nil , err
314+ }
315+ log .Infow ("starting procedure runner" , "src_url" , srcURL , "src_dir" , srcDir )
316+ r := NewProcedureRunner (h .cfg .IPCUrl , h .cfg .UploadUrl , name , srcDir )
317+ h .runners [name ] = r
273318
274319 if err := r .Start (); err != nil {
275320 return nil , err
@@ -282,12 +327,36 @@ func (h *Handler) getRunner(srcURL, srcDir string) (*Runner, error) {
282327 }
283328 if r .status == StatusSetupFailed {
284329 log .Errorw ("procedure runner setup failed" , "logs" , r .setupResult .Logs )
285- delete (h .runners , srcURL )
286- // Include failed runner here so that the caller can extract setup logs and respond with a prediction failure
287- return r , ErrSetupFailed
330+ delete (h .runners , name )
331+
332+ // Translate setup failure to prediction failure
333+ resp := PredictionResponse {
334+ Input : req .Input ,
335+ Id : req .Id ,
336+ CreatedAt : r .setupResult .StartedAt ,
337+ StartedAt : r .setupResult .StartedAt ,
338+ CompletedAt : r .setupResult .CompletedAt ,
339+ Logs : r .setupResult .Logs ,
340+ Status : PredictionFailed ,
341+ Error : ErrSetupFailed .Error (),
342+ }
343+ if req .Webhook == "" {
344+ c := make (chan PredictionResponse , 1 )
345+ c <- resp
346+ return c , nil
347+ } else {
348+ // Async prediction, send webhook
349+ go func () {
350+ if err := SendWebhook (req .Webhook , & resp ); err != nil {
351+ log .Errorw ("failed to send webhook" , "url" , "error" , err )
352+ }
353+ }()
354+ return nil , nil
355+ }
356+
288357 }
289358 if time .Since (start ) > 10 * time .Second {
290- delete (h .runners , srcURL )
359+ delete (h .runners , name )
291360 log .Errorw ("stopping procedure runner after time out" , "elapsed" , time .Since (start ))
292361 if err := r .Stop (); err != nil {
293362 log .Errorw ("failed to stop procedure runner" , "error" , err )
@@ -296,11 +365,10 @@ func (h *Handler) getRunner(srcURL, srcDir string) (*Runner, error) {
296365 }
297366 time .Sleep (10 * time .Millisecond )
298367 }
299- return r , nil
368+ return r . Predict ( req )
300369}
301370
302371func (h * Handler ) Predict (w http.ResponseWriter , r * http.Request ) {
303- log := logger .Sugar ()
304372 if r .Header .Get ("Content-Type" ) != "application/json" {
305373 http .Error (w , "invalid content type" , http .StatusUnsupportedMediaType )
306374 return
@@ -330,8 +398,15 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) {
330398 req .Id = util .PredictionId ()
331399 }
332400
333- var runner * Runner
401+ var c chan PredictionResponse
334402 if h .cfg .UseProcedureMode {
403+ // Although individual runners may have higher concurrency than the global max runners/concurrency
404+ // We still bail early if the global max has been reached
405+ concurrency := h .healthCheck ().Concurrency
406+ if concurrency .Current == concurrency .Max {
407+ http .Error (w , ErrConflict .Error (), http .StatusConflict )
408+ return
409+ }
335410 val , ok := req .Context ["procedure_source_url" ]
336411 if ! ok {
337412 http .Error (w , "missing procedure_source_url in context" , http .StatusBadRequest )
@@ -350,47 +425,11 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) {
350425 http .Error (w , "empty procedure_source_url or replicate_api_token" , http .StatusBadRequest )
351426 return
352427 }
353- srcDir , err := util .PrepareProcedureSourceURL (procedureSourceUrl )
354- if err != nil {
355- http .Error (w , "invalid procedure_source_url" , http .StatusBadRequest )
356- }
357- if r , err := h .getRunner (procedureSourceUrl , srcDir ); err == nil {
358- runner = r
359- } else if errors .Is (err , ErrConflict ) {
360- http .Error (w , err .Error (), http .StatusConflict )
361- return
362- } else if errors .Is (err , ErrSetupFailed ) {
363- // Translate setup failure to prediction failure
364- resp := PredictionResponse {
365- Input : req .Input ,
366- Id : req .Id ,
367- CreatedAt : r .setupResult .StartedAt ,
368- StartedAt : r .setupResult .StartedAt ,
369- CompletedAt : r .setupResult .CompletedAt ,
370- Logs : r .setupResult .Logs ,
371- Status : PredictionFailed ,
372- }
373-
374- if req .Webhook == "" {
375- w .WriteHeader (http .StatusOK )
376- writeResponse (w , resp )
377- } else {
378- w .WriteHeader (http .StatusAccepted )
379- writeResponse (w , PredictionResponse {Id : req .Id , Status : "starting" })
380- if err := SendWebhook (req .Webhook , & resp ); err != nil {
381- log .Errorw ("failed to send webhook" , "url" , "error" , err )
382- }
383- }
384- return
385- } else {
386- http .Error (w , err .Error (), http .StatusInternalServerError )
387- return
388- }
428+ c , err = h .predictWithRunner (procedureSourceUrl , req )
389429 } else {
390- runner = h .runners [DefaultRunner ]
430+ c , err = h .runners [DefaultRunner ]. Predict ( req )
391431 }
392432
393- c , err := runner .Predict (req )
394433 if errors .Is (err , ErrConflict ) {
395434 http .Error (w , err .Error (), http .StatusConflict )
396435 return
0 commit comments