diff --git a/example_custom_storage_test.go b/example_custom_storage_test.go index ff44424..b6852ac 100644 --- a/example_custom_storage_test.go +++ b/example_custom_storage_test.go @@ -54,7 +54,7 @@ func (c customInMemStorage) Origin() string { func Example_withCustomStorage() { client := &http.Client{} - handler, err := httpcache.NewWithCustomStorageCache(client, NewCustomInMemStorage()) + handler, err := httpcache.NewWithCustomStorageCache(client, true, NewCustomInMemStorage()) if err != nil { log.Fatal(err) } diff --git a/example_inmemory_storage_test.go b/example_inmemory_storage_test.go index 9dfaed3..40b24b4 100644 --- a/example_inmemory_storage_test.go +++ b/example_inmemory_storage_test.go @@ -11,7 +11,7 @@ import ( func Example_inMemoryStorageDefault() { client := &http.Client{} - handler, err := httpcache.NewWithInmemoryCache(client, time.Second*15) + handler, err := httpcache.NewWithInmemoryCache(client, true, time.Second*15) if err != nil { log.Fatal(err) } diff --git a/httpcache.go b/httpcache.go index 434d6b5..bec8b76 100644 --- a/httpcache.go +++ b/httpcache.go @@ -13,22 +13,22 @@ import ( // NewWithCustomStorageCache will initiate the httpcache with your defined cache storage // To use your own cache storage handler, you need to implement the cache.Interactor interface // And pass it to httpcache. -func NewWithCustomStorageCache(client *http.Client, cacheInteractor cache.ICacheInteractor) (cacheHandler *CacheHandler, err error) { - return newClient(client, cacheInteractor) +func NewWithCustomStorageCache(client *http.Client, rfcCompliance bool, cacheInteractor cache.ICacheInteractor) (cacheHandler *CacheHandler, err error) { + return newClient(client, rfcCompliance, cacheInteractor) } -func newClient(client *http.Client, cacheInteractor cache.ICacheInteractor) (cachedHandler *CacheHandler, err error) { +func newClient(client *http.Client, rfcCompliance bool, cacheInteractor cache.ICacheInteractor) (cachedHandler *CacheHandler, err error) { if client.Transport == nil { client.Transport = http.DefaultTransport } - cachedHandler = NewRoundtrip(client.Transport, cacheInteractor) + cachedHandler = NewRoundtrip(client.Transport, cacheInteractor, rfcCompliance) client.Transport = cachedHandler return } // NewWithInmemoryCache will create a complete cache-support of HTTP client with using inmemory cache. // If the duration not set, the cache will use LFU algorithm -func NewWithInmemoryCache(client *http.Client, duration ...time.Duration) (cachedHandler *CacheHandler, err error) { +func NewWithInmemoryCache(client *http.Client, rfcCompliance bool, duration ...time.Duration) (cachedHandler *CacheHandler, err error) { var expiryTime time.Duration if len(duration) > 0 { expiryTime = duration[0] @@ -38,5 +38,5 @@ func NewWithInmemoryCache(client *http.Client, duration ...time.Duration) (cache SetExpiryTime(expiryTime).SetMaxSizeItem(100), ) - return newClient(client, inmem.NewCache(c)) + return newClient(client, rfcCompliance, inmem.NewCache(c)) } diff --git a/roundtriper.go b/roundtriper.go index a88275a..92e3754 100644 --- a/roundtriper.go +++ b/roundtriper.go @@ -27,16 +27,18 @@ const ( type CacheHandler struct { DefaultRoundTripper http.RoundTripper CacheInteractor cache.ICacheInteractor + ComplyRFC bool } // NewRoundtrip will create an implementations of cache http roundtripper -func NewRoundtrip(defaultRoundTripper http.RoundTripper, cacheActor cache.ICacheInteractor) *CacheHandler { +func NewRoundtrip(defaultRoundTripper http.RoundTripper, cacheActor cache.ICacheInteractor, rfcCompliance bool) *CacheHandler { if cacheActor == nil { log.Fatal("cache interactor is nil") } return &CacheHandler{ DefaultRoundTripper: defaultRoundTripper, CacheInteractor: cacheActor, + ComplyRFC: rfcCompliance, } } @@ -91,8 +93,7 @@ func validateTheCacheControl(req *http.Request, resp *http.Response) (validation return validationResult, nil } -// RoundTrip the implementation of http.RoundTripper -func (r *CacheHandler) RoundTrip(req *http.Request) (resp *http.Response, err error) { +func (r *CacheHandler) roundTripRFCCompliance(req *http.Request) (resp *http.Response, err error) { allowCache := allowedFromCache(req.Header) if allowCache { cachedResp, cachedItem, cachedErr := getCachedResponse(r.CacheInteractor, req) @@ -111,18 +112,21 @@ func (r *CacheHandler) RoundTrip(req *http.Request) (resp *http.Response, err er return } - validationResult, err := validateTheCacheControl(req, resp) - if err != nil { - return + validationResult, errValidation := validateTheCacheControl(req, resp) + if errValidation != nil { + log.Printf("Can't validate the response to RFC 7234, plase check. Err: %v\n", errValidation) + return // return directly, not sure can be stored or not } if validationResult.OutErr != nil { - return + log.Printf("Can't validate the response to RFC 7234, plase check. Err: %v\n", validationResult.OutErr) + return // return directly, not sure can be stored or not } // reasons to not to cache if len(validationResult.OutReasons) > 0 { - return + log.Printf("Can't validate the response to RFC 7234, plase check. Err: %v\n", validationResult.OutReasons) + return // return directly, not sure can be stored or not. } err = storeRespToCache(r.CacheInteractor, req, resp) @@ -134,6 +138,40 @@ func (r *CacheHandler) RoundTrip(req *http.Request) (resp *http.Response, err er return } +// RoundTrip the implementation of http.RoundTripper +func (r *CacheHandler) RoundTrip(req *http.Request) (resp *http.Response, err error) { + if r.ComplyRFC { + return r.roundTripRFCCompliance(req) + } + cachedResp, cachedItem, cachedErr := getCachedResponse(r.CacheInteractor, req) + if cachedResp != nil && cachedErr == nil { + buildTheCachedResponseHeader(cachedResp, cachedItem, r.CacheInteractor.Origin()) + return cachedResp, cachedErr + } + // if error when getting from cachce, ignore it, re-try a live version + if cachedErr != nil { + log.Println(cachedErr, "failed to retrieve from cache, trying with a live version") + } + + resp, err = r.DefaultRoundTripper.RoundTrip(req) + if err != nil { + return + } + + err = storeRespToCache(r.CacheInteractor, req, resp) + if err != nil { + log.Printf("Can't store the response to database, plase check. Err: %v\n", err) + err = nil // set err back to nil to make the call still success. + } + return +} + +// RFC7234Compliance used for enable/disable the RFC 7234 compliance +func (r *CacheHandler) RFC7234Compliance(val bool) *CacheHandler { + r.ComplyRFC = val + return r +} + func storeRespToCache(cacheInteractor cache.ICacheInteractor, req *http.Request, resp *http.Response) (err error) { cachedResp := cache.CachedResponse{ RequestMethod: req.Method, diff --git a/roundtripper_test.go b/roundtripper_test.go index 4ece191..25b952a 100644 --- a/roundtripper_test.go +++ b/roundtripper_test.go @@ -20,7 +20,7 @@ func TestSetToCacheRoundtrip(t *testing.T) { mockCacheInteractor.On("Get", mock.AnythingOfType("string")).Once().Return(cachedResponse, errors.New("uknown error")) mockCacheInteractor.On("Set", mock.AnythingOfType("string"), mock.Anything).Once().Return(nil) client := &http.Client{} - client.Transport = httpcache.NewRoundtrip(http.DefaultTransport, mockCacheInteractor) + client.Transport = httpcache.NewRoundtrip(http.DefaultTransport, mockCacheInteractor, true) // HTTP GET 200 jsonResp := []byte(`{"message": "Hello World!"}`) handler := func() (res http.Handler) {