Skip to content

Commit 7c066ac

Browse files
committed
Add more error handling
1 parent 02a3b4e commit 7c066ac

File tree

4 files changed

+40
-18
lines changed

4 files changed

+40
-18
lines changed

cmd/cog/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func serverCommand() *ff.Command {
8686
UploadUrl: cfg.UploadUrl,
8787
}
8888
ctx, cancel := context.WithCancel(ctx)
89-
h := server.NewHandler(serverCfg, cancel)
89+
h := must.Get(server.NewHandler(serverCfg, cancel))
9090
s := server.NewServer(addr, h, cfg.UseProcedureMode)
9191
go func() {
9292
<-ctx.Done()

internal/server/http.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import (
55
"net/http"
66
"os"
77
"strconv"
8-
9-
"github.com/replicate/go/must"
108
)
119

1210
var (
@@ -39,7 +37,10 @@ func NewServer(addr string, handler *Handler, useProcedureMode bool) *http.Serve
3937
if _, ok := os.LookupEnv("TEST_COG"); ok {
4038
serveMux.HandleFunc("/_pid", func(w http.ResponseWriter, r *http.Request) {
4139
w.WriteHeader(http.StatusOK)
42-
must.Get(w.Write([]byte(strconv.Itoa(os.Getpid()))))
40+
if _, err := w.Write([]byte(strconv.Itoa(os.Getpid()))); err != nil {
41+
log := logger.Sugar()
42+
log.Errorw("failed to write response", "error", err)
43+
}
4344
})
4445
}
4546

internal/server/path.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ import (
1414

1515
"github.com/getkin/kin-openapi/openapi3"
1616

17-
"github.com/replicate/go/must"
18-
1917
"github.com/gabriel-vasile/mimetype"
2018
)
2119

@@ -208,7 +206,10 @@ func outputToUpload(uploadUrl string, predictionId string) func(s string, paths
208206
*paths = append(*paths, p)
209207
filename := path.Base(p)
210208
uUpload := fmt.Sprintf("%s%s", uploadUrl, filename)
211-
req := must.Get(http.NewRequest(http.MethodPut, uUpload, bytes.NewReader(bs)))
209+
req, err := http.NewRequest(http.MethodPut, uUpload, bytes.NewReader(bs))
210+
if err != nil {
211+
return "", err
212+
}
212213
req.Header.Set("X-Prediction-ID", predictionId)
213214
resp, err := http.DefaultClient.Do(req)
214215
if err != nil {

internal/server/server.go

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ import (
1313

1414
"github.com/replicate/cog-runtime/internal/util"
1515

16-
"github.com/replicate/go/must"
17-
1816
"github.com/replicate/go/logging"
1917
)
2018

@@ -31,15 +29,17 @@ type Handler struct {
3129
mu sync.Mutex
3230
}
3331

34-
func NewHandler(cfg Config, shutdown context.CancelFunc) *Handler {
32+
func NewHandler(cfg Config, shutdown context.CancelFunc) (*Handler, error) {
3533
h := &Handler{
3634
cfg: cfg,
3735
shutdown: shutdown,
3836
startedAt: time.Now(),
3937
}
4038
if !cfg.UseProcedureMode {
4139
h.runner = NewRunner(cfg.UploadUrl)
42-
must.Do(h.runner.Start())
40+
if err := h.runner.Start(); err != nil {
41+
return nil, err
42+
}
4343
if !cfg.AwaitExplicitShutdown {
4444
go func() {
4545
// Shut down as soon as runner exists
@@ -48,7 +48,7 @@ func NewHandler(cfg Config, shutdown context.CancelFunc) *Handler {
4848
}()
4949
}
5050
}
51-
return h
51+
return h, nil
5252
}
5353

5454
func (h *Handler) ExitCode() int {
@@ -87,22 +87,22 @@ func (h *Handler) HealthCheck(w http.ResponseWriter, r *http.Request) {
8787
http.Error(w, err.Error(), http.StatusBadRequest)
8888
} else {
8989
w.WriteHeader(http.StatusOK)
90-
must.Get(w.Write(bs))
90+
writeBytes(w, bs)
9191
}
9292
}
9393

9494
func (h *Handler) OpenApi(w http.ResponseWriter, r *http.Request) {
9595
if h.cfg.UseProcedureMode {
9696
w.WriteHeader(http.StatusOK)
97-
must.Get(w.Write([]byte(procedureSchema)))
97+
writeBytes(w, []byte(procedureSchema))
9898
return
9999
}
100100

101101
if h.runner.schema == "" {
102102
http.Error(w, "unavailable", http.StatusServiceUnavailable)
103103
} else {
104104
w.WriteHeader(http.StatusOK)
105-
must.Get(w.Write([]byte(h.runner.schema)))
105+
writeBytes(w, []byte(h.runner.schema))
106106
}
107107
}
108108

@@ -187,7 +187,11 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) {
187187
return
188188
}
189189
var req PredictionRequest
190-
if err := json.Unmarshal(must.Get(io.ReadAll(r.Body)), &req); err != nil {
190+
bs, err := io.ReadAll(r.Body)
191+
if err != nil {
192+
http.Error(w, err.Error(), http.StatusBadRequest)
193+
}
194+
if err := json.Unmarshal(bs, &req); err != nil {
191195
http.Error(w, err.Error(), http.StatusBadRequest)
192196
return
193197
}
@@ -257,12 +261,28 @@ func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) {
257261
if c == nil {
258262
w.WriteHeader(http.StatusAccepted)
259263
resp := PredictionResponse{Id: req.Id, Status: "starting"}
260-
must.Get(w.Write(must.Get(json.Marshal(resp))))
264+
writeResponse(w, resp)
261265
} else {
262266
resp := <-c
263267
w.WriteHeader(http.StatusOK)
264-
must.Get(w.Write(must.Get(json.Marshal(resp))))
268+
writeResponse(w, resp)
269+
}
270+
}
271+
272+
func writeBytes(w http.ResponseWriter, bs []byte) {
273+
log := logger.Sugar()
274+
if _, err := w.Write(bs); err != nil {
275+
log.Errorw("failed to write response", "error", err)
276+
}
277+
}
278+
279+
func writeResponse(w http.ResponseWriter, resp PredictionResponse) {
280+
log := logger.Sugar()
281+
bs, err := json.Marshal(resp)
282+
if err != nil {
283+
log.Errorw("failed to marshal response", "error", err)
265284
}
285+
writeBytes(w, bs)
266286
}
267287

268288
func (h *Handler) Cancel(w http.ResponseWriter, r *http.Request) {

0 commit comments

Comments
 (0)