diff --git a/plugin.go b/plugin.go index 34179fb..3c73fdb 100644 --- a/plugin.go +++ b/plugin.go @@ -82,33 +82,42 @@ func (c *Cache) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(cachedResponse.StatusCode) _, _ = rw.Write(cachedResponse.Body) return + } else { + log.Printf("Failed to serialize response for caching: %s", err.Error()) + _ = respClient.Delete(req.Context(), cacheKey) } - log.Printf("Failed to serialize response for caching: %s", err.Error()) - _ = respClient.Delete(req.Context(), cacheKey) + } // Cache miss - record the response - recorder := &responseRecorder{rw: rw} + recorder := &responseRecorder{ + rw: rw, + header: rw.Header().Clone(), // Initialize with the original headers. + } c.next.ServeHTTP(recorder, req) // Serialize the response data cachedResponse := CachedResponse{ StatusCode: recorder.status, - Headers: recorder.Header().Clone(), // Convert http.Header to a map for serialization + Headers: recorder.Header(), // Convert http.Header to a map for serialization Body: recorder.body.Bytes(), } var buffer bytes.Buffer enc := gob.NewEncoder(&buffer) if err := enc.Encode(cachedResponse); err != nil { log.Printf("Failed to serialize response for caching: %s", err) + http.Error(rw, "Internal Server Error", http.StatusInternalServerError) + return + } else { + // Store the serialized response in Redis + if err := respClient.SetWithTTL(req.Context(), cacheKey, buffer.String(), c.cacheExpiry); err != nil { + log.Printf("Failed to cache response in Redis: %s", err.Error()) + } } - // Store the serialized response in Redis as a string with an expiration time - if err := respClient.SetWithTTL(req.Context(), cacheKey, buffer.String(), c.cacheExpiry); err != nil { - log.Println("Failed to cache response in Redis:", err) + if _, err := rw.Write(recorder.body.Bytes()); err != nil { + log.Printf("Failed to write response body: %s", err) + return } - - // Write the original response - rw.WriteHeader(recorder.status) - _, _ = rw.Write(recorder.body.Bytes()) + return } diff --git a/recorder.go b/recorder.go index fb80993..f684bcc 100644 --- a/recorder.go +++ b/recorder.go @@ -9,18 +9,18 @@ type responseRecorder struct { rw http.ResponseWriter status int body bytes.Buffer + header http.Header } func (r *responseRecorder) Header() http.Header { - return r.rw.Header() + return r.header } func (r *responseRecorder) Write(b []byte) (int, error) { - r.body.Write(b) - return r.rw.Write(b) + return r.body.Write(b) // Just buffer the body, don't write to rw + } func (r *responseRecorder) WriteHeader(statusCode int) { r.status = statusCode - r.rw.WriteHeader(statusCode) }