@@ -13,8 +13,6 @@ import (
1313
1414 "github.com/replicate/cog-runtime/internal/util"
1515
16- "github.com/replicate/go/must"
17-
1816 "github.com/replicate/go/logging"
1917)
2018
@@ -31,15 +29,17 @@ type Handler struct {
3129 mu sync.Mutex
3230}
3331
34- func NewHandler (cfg Config , shutdown context.CancelFunc ) * Handler {
32+ func NewHandler (cfg Config , shutdown context.CancelFunc ) ( * Handler , error ) {
3533 h := & Handler {
3634 cfg : cfg ,
3735 shutdown : shutdown ,
3836 startedAt : time .Now (),
3937 }
4038 if ! cfg .UseProcedureMode {
4139 h .runner = NewRunner (cfg .UploadUrl )
42- must .Do (h .runner .Start ())
40+ if err := h .runner .Start (); err != nil {
41+ return nil , err
42+ }
4343 if ! cfg .AwaitExplicitShutdown {
4444 go func () {
4545 // Shut down as soon as runner exists
@@ -48,7 +48,7 @@ func NewHandler(cfg Config, shutdown context.CancelFunc) *Handler {
4848 }()
4949 }
5050 }
51- return h
51+ return h , nil
5252}
5353
5454func (h * Handler ) ExitCode () int {
@@ -87,22 +87,22 @@ func (h *Handler) HealthCheck(w http.ResponseWriter, r *http.Request) {
8787 http .Error (w , err .Error (), http .StatusBadRequest )
8888 } else {
8989 w .WriteHeader (http .StatusOK )
90- must . Get ( w . Write ( bs ) )
90+ writeBytes ( w , bs )
9191 }
9292}
9393
9494func (h * Handler ) OpenApi (w http.ResponseWriter , r * http.Request ) {
9595 if h .cfg .UseProcedureMode {
9696 w .WriteHeader (http .StatusOK )
97- must . Get ( w . Write ( []byte (procedureSchema ) ))
97+ writeBytes ( w , []byte (procedureSchema ))
9898 return
9999 }
100100
101101 if h .runner .schema == "" {
102102 http .Error (w , "unavailable" , http .StatusServiceUnavailable )
103103 } else {
104104 w .WriteHeader (http .StatusOK )
105- must . Get ( w . Write ( []byte (h .runner .schema ) ))
105+ writeBytes ( w , []byte (h .runner .schema ))
106106 }
107107}
108108
@@ -187,7 +187,11 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) {
187187 return
188188 }
189189 var req PredictionRequest
190- if err := json .Unmarshal (must .Get (io .ReadAll (r .Body )), & req ); err != nil {
190+ bs , err := io .ReadAll (r .Body )
191+ if err != nil {
192+ http .Error (w , err .Error (), http .StatusBadRequest )
193+ }
194+ if err := json .Unmarshal (bs , & req ); err != nil {
191195 http .Error (w , err .Error (), http .StatusBadRequest )
192196 return
193197 }
@@ -257,12 +261,28 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) {
257261 if c == nil {
258262 w .WriteHeader (http .StatusAccepted )
259263 resp := PredictionResponse {Id : req .Id , Status : "starting" }
260- must . Get ( w . Write ( must . Get ( json . Marshal ( resp ))) )
264+ writeResponse ( w , resp )
261265 } else {
262266 resp := <- c
263267 w .WriteHeader (http .StatusOK )
264- must .Get (w .Write (must .Get (json .Marshal (resp ))))
268+ writeResponse (w , resp )
269+ }
270+ }
271+
272+ func writeBytes (w http.ResponseWriter , bs []byte ) {
273+ log := logger .Sugar ()
274+ if _ , err := w .Write (bs ); err != nil {
275+ log .Errorw ("failed to write response" , "error" , err )
276+ }
277+ }
278+
279+ func writeResponse (w http.ResponseWriter , resp PredictionResponse ) {
280+ log := logger .Sugar ()
281+ bs , err := json .Marshal (resp )
282+ if err != nil {
283+ log .Errorw ("failed to marshal response" , "error" , err )
265284 }
285+ writeBytes (w , bs )
266286}
267287
268288func (h * Handler ) Cancel (w http.ResponseWriter , r * http.Request ) {
0 commit comments