Skip to content

Commit

Permalink
feat(router): consider public and private directives for cache control (
Browse files Browse the repository at this point in the history
  • Loading branch information
df-wg authored Oct 28, 2024
1 parent 174c11b commit f0da638
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 70 deletions.
28 changes: 24 additions & 4 deletions router-tests/header_propagation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,19 @@ func TestHeaderPropagation(t *testing.T) {
cc := res.Response.Header.Get("Cache-Control")
require.Equal(t, "max-age=60", cc) // Most restrictive wins
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)

// Verify that it doesn't set Expires header
val, present := res.Response.Header["Expires"]
require.False(t, present)
require.Equal(t, []string(nil), val)
})
})

// Local test: Cache control rules are applied per subgraph (employees and hobbies)
t.Run("only enable cache control for subgraphs", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
Subgraphs: cacheOptions("max-age=120", "max-age=60"),
Subgraphs: cacheOptions("max-age=120", "max-age=60, private"),
CacheControlPolicy: config.CacheControlPolicy{
Subgraphs: []config.SubgraphCacheControlRule{
{Name: "employees"},
Expand All @@ -336,7 +341,7 @@ func TestHeaderPropagation(t *testing.T) {
Query: queryEmployeeWithHobby,
})
cc := res.Response.Header.Get("Cache-Control")
require.Equal(t, "max-age=60", cc) // Most restrictive wins
require.Equal(t, "max-age=60, private", cc) // Most restrictive wins
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})
Expand All @@ -350,13 +355,13 @@ func TestHeaderPropagation(t *testing.T) {
{Name: "employees"},
},
},
Subgraphs: cacheOptions("max-age=120", "max-age=60"),
Subgraphs: cacheOptions("max-age=120, public", "max-age=60"),
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: queryEmployeeWithHobby,
})
cc := res.Response.Header.Get("Cache-Control")
require.Equal(t, "max-age=120", cc) // Only employee subgraph is considered
require.Equal(t, "max-age=120, public", cc) // Only employee subgraph is considered
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})
Expand Down Expand Up @@ -424,6 +429,21 @@ func TestHeaderPropagation(t *testing.T) {
})
})

t.Run("selects shortest max-age and private vs private", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
CacheControlPolicy: config.CacheControlPolicy{Enabled: true},
Subgraphs: cacheOptions("max-age=600, private", "max-age=300, public"),
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: queryEmployeeWithHobby,
})
cc := res.Response.Header.Get("Cache-Control")
require.Equal(t, "max-age=300, private", cc) // Shorter max-age wins
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})

// Test case for Expires header: earliest expiration wins
t.Run("earliest Expires wins", func(t *testing.T) {
t.Parallel()
Expand Down
162 changes: 96 additions & 66 deletions router/core/header_rule_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package core
import (
"context"
"fmt"
cachedirective "github.com/pquerna/cachecontrol/cacheobject"
nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1"
"github.com/wundergraph/cosmo/router/pkg/config"
"github.com/wundergraph/cosmo/router/pkg/otel"
rtrace "github.com/wundergraph/cosmo/router/pkg/trace"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
Expand All @@ -12,12 +15,9 @@ import (
"net/http"
"regexp"
"slices"
"strings"
"sync"
"time"

cachedirective "github.com/pquerna/cachecontrol/cacheobject"
nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1"
"github.com/wundergraph/cosmo/router/pkg/config"
)

var (
Expand Down Expand Up @@ -50,6 +50,9 @@ var (
"Sec-Websocket-Protocol",
"Sec-Websocket-Version",
}
cacheControlKey = "Cache-Control"
expiresKey = "Expires"
noCache = "no-cache"
)

type responseHeaderPropagationKey struct{}
Expand All @@ -58,6 +61,7 @@ type responseHeaderPropagation struct {
header http.Header
m *sync.Mutex
previousCacheControl *cachedirective.Object
setCacheControl bool
}

func WithResponseHeaderPropagation(ctx *resolve.Context) *resolve.Context {
Expand Down Expand Up @@ -283,6 +287,10 @@ func (h *HeaderPropagation) OnOriginResponse(resp *http.Response, ctx RequestCon
func (h *HeaderPropagation) applyResponseRule(propagation *responseHeaderPropagation, res *http.Response, rule *config.ResponseHeaderRule) {
if rule.Operation == config.HeaderRuleOperationSet {
propagation.header.Set(rule.Name, rule.Value)
if rule.Name == cacheControlKey {
// Handle the case where the cache control header is set explicitly
propagation.setCacheControl = true
}
return
}

Expand Down Expand Up @@ -438,7 +446,10 @@ func (h *HeaderPropagation) applyRequestRule(ctx RequestContext, request *http.R
}

func (h *HeaderPropagation) applyResponseRuleMostRestrictiveCacheControl(res *http.Response, propagation *responseHeaderPropagation, rule *config.ResponseHeaderRule) {
cacheControlKey := "Cache-Control"
if propagation.setCacheControl {
// Handle the case where the cache control header is set explicitly using the set propagation rule
return
}

ctx := res.Request.Context()
tracer := rtrace.TracerFromContext(ctx)
Expand All @@ -450,39 +461,39 @@ func (h *HeaderPropagation) applyResponseRuleMostRestrictiveCacheControl(res *ht
trace.WithSpanKind(trace.SpanKindInternal),
trace.WithAttributes(commonAttributes...),
)
defer span.End()

// Set no-cache for all mutations, to ensure that requests to mutate data always work as expected (without returning cached data)
if resolve.SingleFlightDisallowed(ctx) {
var noCache = "no-cache"
propagation.header.Set(cacheControlKey, noCache)
return
}

reqCacheHeader := res.Request.Header.Get(cacheControlKey)
reqDir, _ := cachedirective.ParseRequestCacheControl(reqCacheHeader)
resCacheHeader := res.Header.Get(cacheControlKey)
resDir, _ := cachedirective.ParseResponseCacheControl(resCacheHeader)
expiresHeaderVal := res.Header.Get("Expires")
expiresHeader, _ := http.ParseTime(expiresHeaderVal)
expiresHeader, _ := http.ParseTime(res.Header.Get(expiresKey))
dateHeader, _ := http.ParseTime(res.Header.Get("Date"))
lastModifiedHeader, _ := http.ParseTime(res.Header.Get("Last-Modified"))

if propagation.previousCacheControl == nil && reqCacheHeader == "" && resCacheHeader == "" && expiresHeader.IsZero() && rule.Default == "" {
// There is no default/previous value to set, and since no cache control headers have been set, exit early
return
}

reqDir, _ := cachedirective.ParseRequestCacheControl(reqCacheHeader)
resDir, _ := cachedirective.ParseResponseCacheControl(resCacheHeader)
obj := &cachedirective.Object{
RespDirectives: resDir,
RespHeaders: res.Header,
RespStatusCode: res.StatusCode,
RespExpiresHeader: expiresHeader,
RespDateHeader: dateHeader,
RespLastModifiedHeader: lastModifiedHeader,

ReqDirectives: reqDir,
ReqHeaders: res.Request.Header,
ReqMethod: res.Request.Method,

NowUTC: time.Now().UTC(),
ReqDirectives: reqDir,
ReqHeaders: res.Request.Header,
NowUTC: time.Now().UTC(),
}
rv := cachedirective.ObjectResults{}

cachedirective.CachableObject(obj, &rv)
cachedirective.ExpirationObject(obj, &rv)

Expand All @@ -492,70 +503,89 @@ func (h *HeaderPropagation) applyResponseRuleMostRestrictiveCacheControl(res *ht
otel.WgResponseCacheControlExpiration.String(rv.OutExpirationTime.String()),
)

propagation.m.Lock()
defer propagation.m.Unlock()

defaultResponseCache, _ := cachedirective.ParseResponseCacheControl(rule.Default)
defaultCacheControlObj := &cachedirective.Object{
RespDirectives: defaultResponseCache,
// Add each cache control object to the policies list
policies := []*cachedirective.Object{obj}
if rule.Default != "" {
defaultResponseCache, _ := cachedirective.ParseResponseCacheControl(rule.Default)
policies = append(policies, &cachedirective.Object{RespDirectives: defaultResponseCache})
}

if propagation.previousCacheControl == nil {
if rule.Default != "" {
propagation.previousCacheControl = defaultCacheControlObj
propagation.header.Set(cacheControlKey, rule.Default)
} else if reqCacheHeader == "" && resCacheHeader == "" && expiresHeaderVal == "" {
// There is no default/previous value to set, and since no cache control headers have been set, exit early
return
} else {
propagation.previousCacheControl = obj
propagation.header.Set(cacheControlKey, res.Header.Get(cacheControlKey))
return
}
} else if rule.Default != "" && isMoreRestrictive(defaultCacheControlObj, propagation.previousCacheControl) {
// Overwriting previous cache control with the current subgraph default
propagation.previousCacheControl = defaultCacheControlObj
propagation.header.Set(cacheControlKey, rule.Default)
if propagation.previousCacheControl != nil {
policies = append(policies, propagation.previousCacheControl)
}

if !expiresHeader.IsZero() && (propagation.previousCacheControl.RespExpiresHeader.IsZero() || expiresHeader.Before(propagation.previousCacheControl.RespExpiresHeader)) {
propagation.previousCacheControl = obj
propagation.header.Set("Expires", res.Header.Get("Expires"))
// Determine the most restrictive cache policy and cache control header
restrictivePolicy, cacheControlHeader := createMostRestrictivePolicy(policies)

propagation.m.Lock()
defer propagation.m.Unlock()
propagation.previousCacheControl = restrictivePolicy
if cacheControlHeader != "" {
propagation.header.Set(cacheControlKey, cacheControlHeader)
}

// Compare the previous cache control with the current one to find the most restrictive
if !isMoreRestrictive(propagation.previousCacheControl, obj) {
// The current cache control is more restrictive, so update it
propagation.previousCacheControl = obj
propagation.header.Set(cacheControlKey, res.Header.Get(cacheControlKey))
// Update the Expires header if applicable
if !expiresHeader.IsZero() && !restrictivePolicy.RespExpiresHeader.IsZero() {
propagation.header.Set(expiresKey, restrictivePolicy.RespExpiresHeader.Format(http.TimeFormat))
}
}

// isMoreRestrictive compares two cachedirective.Object instances and returns true if the first is more restrictive
func isMoreRestrictive(prev *cachedirective.Object, curr *cachedirective.Object) bool {
// Example comparison logic: check if "no-store" or "no-cache" are present, which are more restrictive
if prev.RespDirectives.NoStore || curr.RespDirectives.NoStore {
return true // No store is the most restrictive
}
if prev.RespDirectives.NoCachePresent && !curr.RespDirectives.NoCachePresent {
return true // No-cache is more restrictive than not having it
func createMostRestrictivePolicy(policies []*cachedirective.Object) (*cachedirective.Object, string) {
result := cachedirective.Object{
RespDirectives: &cachedirective.ResponseCacheDirectives{},
}
if curr.RespDirectives.NoCachePresent && !prev.RespDirectives.NoCachePresent {
return false // Current response has no-cache, which is more restrictive
var minMaxAge cachedirective.DeltaSeconds = -1
isPrivate := false
isPublic := false

for _, policy := range policies {
// Check no-store and no-cache first
if policy.RespDirectives.NoStore {
result.RespDirectives.NoStore = true
return &result, "no-store"
}
if policy.RespDirectives.NoCachePresent {
result.RespDirectives.NoCachePresent = true
}

// Determine the shortest max-age if available
if policy.RespDirectives.MaxAge > 0 && (minMaxAge == -1 || policy.RespDirectives.MaxAge < minMaxAge) {
minMaxAge = policy.RespDirectives.MaxAge
}

// Track if any policy specifies "private"
if policy.RespDirectives.PrivatePresent {
isPrivate = true
} else if policy.RespDirectives.Public {
isPublic = true
}

// Handle expires header comparisons
if policy.RespExpiresHeader.Before(result.RespExpiresHeader) || result.RespExpiresHeader.IsZero() {
result.RespExpiresHeader = policy.RespExpiresHeader
}
}

// Compare max-age: the shorter max-age is more restrictive
if prev.RespDirectives.MaxAge > 0 && curr.RespDirectives.MaxAge > 0 {
return prev.RespDirectives.MaxAge < curr.RespDirectives.MaxAge
// Set the calculated max-age and privacy level on the result
if minMaxAge > 0 {
result.RespDirectives.MaxAge = minMaxAge
}
result.RespDirectives.PrivatePresent = isPrivate

// If neither has max-age, but one has other expiration controls like Expires header, use that
if !prev.RespExpiresHeader.IsZero() && !curr.RespExpiresHeader.IsZero() {
return prev.RespExpiresHeader.Before(curr.RespExpiresHeader)
// Format the final Cache-Control header
headerParts := []string{}
if result.RespDirectives.NoCachePresent {
headerParts = append(headerParts, noCache)
} else if minMaxAge > 0 {
headerParts = append(headerParts, fmt.Sprintf("max-age=%d", minMaxAge))
}
if isPrivate {
headerParts = append(headerParts, "private")
} else if isPublic {
headerParts = append(headerParts, "public")
}
cacheControlHeader := strings.Join(headerParts, ", ")

// Fallback: if they are equal in restrictiveness, keep the previous one
return true
return &result, cacheControlHeader
}

// SubgraphRules returns the list of header rules for the subgraph with the given name
Expand Down

0 comments on commit f0da638

Please sign in to comment.