Skip to content

Commit 83e0ed6

Browse files
authored
feat: add getter for sha256Hash for modules (#2300)
1 parent 88084ac commit 83e0ed6

File tree

4 files changed

+238
-3
lines changed

4 files changed

+238
-3
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package sha256_verifier
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/wundergraph/cosmo/router/core"
7+
)
8+
9+
const myModuleID = "sha256VerifierModule"
10+
11+
// ResultContainer holds the SHA256 result, shared across module instances
12+
type ResultContainer struct {
13+
Sha256Result string
14+
}
15+
16+
// Sha256VerifierModule is a simple module that has access to the GraphQL operation and adds custom scopes to the response
17+
type Sha256VerifierModule struct {
18+
ForceSha256 bool
19+
ResultContainer *ResultContainer
20+
}
21+
22+
func (m *Sha256VerifierModule) Middleware(ctx core.RequestContext, next http.Handler) {
23+
m.ResultContainer.Sha256Result = ctx.Operation().Sha256Hash()
24+
next.ServeHTTP(ctx.ResponseWriter(), ctx.Request())
25+
}
26+
27+
func (m *Sha256VerifierModule) RouterOnRequest(ctx core.RequestContext, next http.Handler) {
28+
if m.ForceSha256 {
29+
ctx.SetForceSha256Compute()
30+
}
31+
next.ServeHTTP(ctx.ResponseWriter(), ctx.Request())
32+
}
33+
34+
func (m *Sha256VerifierModule) Module() core.ModuleInfo {
35+
return core.ModuleInfo{
36+
// This is the ID of your module, it must be unique
37+
ID: myModuleID,
38+
// The priority of your module, lower numbers are executed first
39+
Priority: 1,
40+
New: func() core.Module {
41+
return &Sha256VerifierModule{}
42+
},
43+
}
44+
}
45+
46+
// Interface guard
47+
var (
48+
_ core.RouterMiddlewareHandler = (*Sha256VerifierModule)(nil)
49+
_ core.RouterOnRequestHandler = (*Sha256VerifierModule)(nil)
50+
)
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
package module_test
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
sha256_verifier "github.com/wundergraph/cosmo/router-tests/modules/sha256-verifier"
9+
"github.com/wundergraph/cosmo/router-tests/testenv"
10+
"github.com/wundergraph/cosmo/router/core"
11+
"github.com/wundergraph/cosmo/router/pkg/config"
12+
)
13+
14+
func TestSha256VerifierModule(t *testing.T) {
15+
t.Parallel()
16+
17+
t.Run("verify Sha256Hash is not captured when sha256 force is not enabled", func(t *testing.T) {
18+
t.Parallel()
19+
20+
resultContainer := &sha256_verifier.ResultContainer{}
21+
22+
cfg := config.Config{
23+
Graph: config.Graph{},
24+
Modules: map[string]interface{}{
25+
"sha256VerifierModule": sha256_verifier.Sha256VerifierModule{
26+
ForceSha256: false,
27+
ResultContainer: resultContainer,
28+
},
29+
},
30+
}
31+
32+
testenv.Run(t, &testenv.Config{
33+
RouterOptions: []core.Option{
34+
core.WithModulesConfig(cfg.Modules),
35+
core.WithCustomModules(&sha256_verifier.Sha256VerifierModule{}),
36+
},
37+
}, func(t *testing.T, xEnv *testenv.Environment) {
38+
res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
39+
Query: `query MyQuery { employees { id } }`,
40+
OperationName: json.RawMessage(`"MyQuery"`),
41+
})
42+
require.NoError(t, err)
43+
require.Equal(t, 200, res.Response.StatusCode)
44+
45+
require.Empty(t, resultContainer.Sha256Result)
46+
})
47+
})
48+
49+
t.Run("verify sha256Hash is captured from operation when force is enabled", func(t *testing.T) {
50+
t.Parallel()
51+
52+
resultContainer := &sha256_verifier.ResultContainer{}
53+
54+
cfg := config.Config{
55+
Graph: config.Graph{},
56+
Modules: map[string]interface{}{
57+
"sha256VerifierModule": sha256_verifier.Sha256VerifierModule{
58+
ForceSha256: true,
59+
ResultContainer: resultContainer,
60+
},
61+
},
62+
}
63+
64+
testenv.Run(t, &testenv.Config{
65+
RouterOptions: []core.Option{
66+
core.WithModulesConfig(cfg.Modules),
67+
core.WithCustomModules(&sha256_verifier.Sha256VerifierModule{}),
68+
},
69+
}, func(t *testing.T, xEnv *testenv.Environment) {
70+
res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
71+
Query: `query MyQuery { employees { id } }`,
72+
OperationName: json.RawMessage(`"MyQuery"`),
73+
})
74+
require.NoError(t, err)
75+
require.Equal(t, 200, res.Response.StatusCode)
76+
77+
require.NotEmpty(t, resultContainer.Sha256Result)
78+
require.Equal(t, "f037469b9c85bb28ae4c13e1d51c1f7e3333ecbe3c28b877c8659a52378f56c0", resultContainer.Sha256Result)
79+
})
80+
})
81+
82+
t.Run("verify different queries produces different Sha256Hashes", func(t *testing.T) {
83+
t.Parallel()
84+
85+
resultContainer := &sha256_verifier.ResultContainer{}
86+
87+
cfg := config.Config{
88+
Graph: config.Graph{},
89+
Modules: map[string]interface{}{
90+
"sha256VerifierModule": sha256_verifier.Sha256VerifierModule{
91+
ForceSha256: true,
92+
ResultContainer: resultContainer,
93+
},
94+
},
95+
}
96+
97+
testenv.Run(t, &testenv.Config{
98+
RouterOptions: []core.Option{
99+
core.WithModulesConfig(cfg.Modules),
100+
core.WithCustomModules(&sha256_verifier.Sha256VerifierModule{}),
101+
},
102+
}, func(t *testing.T, xEnv *testenv.Environment) {
103+
_, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
104+
Query: `query ConsistentQuery { employees { id } }`,
105+
OperationName: json.RawMessage(`"ConsistentQuery"`),
106+
})
107+
require.NoError(t, err)
108+
firstHash := resultContainer.Sha256Result
109+
require.NotEmpty(t, firstHash)
110+
111+
_, err = xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
112+
Query: `query ConsistentQuery { employees { id tag } }`,
113+
OperationName: json.RawMessage(`"ConsistentQuery"`),
114+
})
115+
require.NoError(t, err)
116+
secondHash := resultContainer.Sha256Result
117+
require.NotEmpty(t, secondHash)
118+
119+
require.NotEqual(t, firstHash, secondHash)
120+
})
121+
})
122+
123+
t.Run("verify the same query produces same Sha256Hash", func(t *testing.T) {
124+
t.Parallel()
125+
126+
resultContainer := &sha256_verifier.ResultContainer{}
127+
128+
cfg := config.Config{
129+
Graph: config.Graph{},
130+
Modules: map[string]interface{}{
131+
"sha256VerifierModule": sha256_verifier.Sha256VerifierModule{
132+
ForceSha256: true,
133+
ResultContainer: resultContainer,
134+
},
135+
},
136+
}
137+
138+
testenv.Run(t, &testenv.Config{
139+
RouterOptions: []core.Option{
140+
core.WithModulesConfig(cfg.Modules),
141+
core.WithCustomModules(&sha256_verifier.Sha256VerifierModule{}),
142+
},
143+
}, func(t *testing.T, xEnv *testenv.Environment) {
144+
_, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
145+
Query: `query ConsistentQuery { employees { id } }`,
146+
OperationName: json.RawMessage(`"ConsistentQuery"`),
147+
})
148+
require.NoError(t, err)
149+
firstHash := resultContainer.Sha256Result
150+
require.NotEmpty(t, firstHash)
151+
152+
_, err = xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
153+
Query: `query ConsistentQuery { employees { id } }`,
154+
OperationName: json.RawMessage(`"ConsistentQuery"`),
155+
})
156+
require.NoError(t, err)
157+
secondHash := resultContainer.Sha256Result
158+
require.NotEmpty(t, secondHash)
159+
160+
require.Equal(t, firstHash, secondHash, "Same query should produce the same SHA256 hash")
161+
})
162+
})
163+
164+
}

router/core/context.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,14 @@ type RequestContext interface {
131131
// SetAuthenticationScopes sets the scopes for the request on Authentication
132132
// If Authentication is not set, it will be initialized with the scopes
133133
SetAuthenticationScopes(scopes []string)
134+
134135
// SetCustomFieldValueRenderer overrides the default field value rendering behavior
135136
// This can be used, e.g. to obfuscate sensitive data in the response
136137
SetCustomFieldValueRenderer(renderer resolve.FieldValueRenderer)
138+
139+
// SetForceSha256Compute forces the computation of the Sha256Hash of the operation
140+
// This is useful if the Sha256Hash is needed in custom modules but not used anywhere else
141+
SetForceSha256Compute()
137142
}
138143

139144
var metricAttrsPool = sync.Pool{
@@ -263,6 +268,8 @@ type requestContext struct {
263268
expressionContext expr.Context
264269
// customFieldValueRenderer is used to override the default field value rendering behavior
265270
customFieldValueRenderer resolve.FieldValueRenderer
271+
// forceSha256Compute indicates whether the Sha256Hash of the operation should definitely be computed
272+
forceSha256Compute bool
266273
}
267274

268275
func (c *requestContext) SetCustomFieldValueRenderer(renderer resolve.FieldValueRenderer) {
@@ -462,6 +469,10 @@ func (c *requestContext) SetAuthenticationScopes(scopes []string) {
462469
auth.SetScopes(scopes)
463470
}
464471

472+
func (c *requestContext) SetForceSha256Compute() {
473+
c.forceSha256Compute = true
474+
}
475+
465476
type OperationContext interface {
466477
// Name is the name of the operation
467478
Name() string
@@ -475,6 +486,12 @@ type OperationContext interface {
475486
Variables() *astjson.Value
476487
// ClientInfo returns information about the client that initiated this operation
477488
ClientInfo() ClientInfo
489+
490+
// Sha256Hash returns the SHA256 hash of the original operation
491+
// It is important to note that this hash is not calculated just because this method has been called
492+
// and is only calculated based on other existing logic (such as if sha256Hash is used in expressions)
493+
Sha256Hash() string
494+
478495
// QueryPlanStats returns some statistics about the query plan for the operation
479496
// if called too early in request chain, it may be inaccurate for modules, using
480497
// in Middleware is recommended
@@ -576,6 +593,10 @@ func (o *operationContext) ClientInfo() ClientInfo {
576593
return *o.clientInfo
577594
}
578595

596+
func (o *operationContext) Sha256Hash() string {
597+
return o.sha256Hash
598+
}
599+
579600
type QueryPlanStats struct {
580601
TotalSubgraphFetches int
581602
SubgraphFetches map[string]int

router/core/graphql_prehandler.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,9 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler {
439439
})
440440
}
441441

442-
func (h *PreHandler) shouldComputeOperationSha256(operationKit *OperationKit) bool {
442+
func (h *PreHandler) shouldComputeOperationSha256(operationKit *OperationKit, reqCtx *requestContext) bool {
443443
// If forced, always compute the hash
444-
if h.computeOperationSha256 {
444+
if h.computeOperationSha256 || reqCtx.forceSha256Compute {
445445
return true
446446
}
447447

@@ -523,7 +523,7 @@ func (h *PreHandler) handleOperation(w http.ResponseWriter, req *http.Request, v
523523
}
524524

525525
// Compute the operation sha256 hash as soon as possible for observability reasons
526-
if h.shouldComputeOperationSha256(operationKit) {
526+
if h.shouldComputeOperationSha256(operationKit, requestContext) {
527527
if err := operationKit.ComputeOperationSha256(); err != nil {
528528
return &httpGraphqlError{
529529
message: fmt.Sprintf("error hashing operation: %s", err),

0 commit comments

Comments
 (0)