Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow to conditionally block mutation via expressions #1480

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 168 additions & 19 deletions router-tests/block_operations_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package integration_test

import (
"bytes"
"encoding/json"
"github.com/stretchr/testify/require"
"github.com/wundergraph/cosmo/router-tests/testenv"
"github.com/wundergraph/cosmo/router/core"
Expand Down Expand Up @@ -32,7 +34,7 @@ func TestBlockOperations(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockMutations = config.BlockMutationConfiguration{
securityConfiguration.BlockMutations = config.BlockOperationConfiguration{
Enabled: true,
}
},
Expand All @@ -49,12 +51,15 @@ func TestBlockOperations(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockMutations = config.BlockMutationConfiguration{
securityConfiguration.BlockMutations = config.BlockOperationConfiguration{
Enabled: true,
Condition: "Request.Header.Get('graphql-client-name') == 'my-client'",
Condition: "request.header.Get('graphql-client-name') == 'my-client'",
}
},
}, func(t *testing.T, xEnv *testenv.Environment) {

// Positive test

res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Header: map[string][]string{
"graphql-client-name": {"my-client-different"},
Expand All @@ -64,6 +69,8 @@ func TestBlockOperations(t *testing.T) {
require.Equal(t, http.StatusOK, res.Response.StatusCode)
require.Equal(t, `{"data":{"updateEmployeeTag":{"id":1,"tag":"test"}}}`, res.Body)

// Negative test

res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Header: map[string][]string{
"graphql-client-name": {"my-client"},
Expand All @@ -75,6 +82,46 @@ func TestBlockOperations(t *testing.T) {
})
})

t.Run("should block operations by query match expression", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockMutations = config.BlockOperationConfiguration{
Enabled: true,
Condition: "request.url.query.foo == 'bar'",
}
},
}, func(t *testing.T, xEnv *testenv.Environment) {

// Negative test

data, err := json.Marshal(testenv.GraphQLRequest{
Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`,
})

require.NoError(t, err)
req, err := http.NewRequestWithContext(xEnv.Context, http.MethodPost, xEnv.GraphQLRequestURL(), bytes.NewReader(data))
require.NoError(t, err)

res, err := xEnv.MakeGraphQLRequestRaw(req)
require.NoError(t, err)

require.Equal(t, http.StatusOK, res.Response.StatusCode)
require.Equal(t, `{"data":{"updateEmployeeTag":{"id":1,"tag":"test"}}}`, res.Body)

// Positive test

req, err = http.NewRequestWithContext(xEnv.Context, http.MethodPost, xEnv.GraphQLRequestURL()+"?foo=bar", bytes.NewReader(data))
require.NoError(t, err)

res, err = xEnv.MakeGraphQLRequestRaw(req)
require.NoError(t, err)

require.Equal(t, http.StatusOK, res.Response.StatusCode)
require.Equal(t, `{"errors":[{"message":"operation type 'mutation' is blocked"}]}`, res.Body)
})
})

t.Run("should block operation by claim expression condition", func(t *testing.T) {
t.Parallel()

Expand All @@ -84,12 +131,15 @@ func TestBlockOperations(t *testing.T) {
core.WithAccessController(core.NewAccessController(authenticators, false)),
},
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockMutations = config.BlockMutationConfiguration{
securityConfiguration.BlockMutations = config.BlockOperationConfiguration{
Enabled: true,
Condition: "'read:miscellaneous' in Request.Auth.Scopes",
Condition: "'read:miscellaneous' in request.auth.scopes && request.auth.isAuthenticated",
}
},
}, func(t *testing.T, xEnv *testenv.Environment) {

// Positive test

token, err := authServer.Token(map[string]any{
"scope": "write:fact read:miscellaneous read:all",
})
Expand All @@ -107,6 +157,8 @@ func TestBlockOperations(t *testing.T) {
require.NoError(t, err)
require.Equal(t, `{"errors":[{"message":"operation type 'mutation' is blocked"}]}`, string(data))

// Negative test

token, err = authServer.Token(map[string]any{
"scope": "write:fact read:all",
})
Expand All @@ -127,57 +179,153 @@ func TestBlockOperations(t *testing.T) {
})
})

t.Run("block non-persisted operations", func(t *testing.T) {
t.Parallel()
t.Run("block subscriptions", func(t *testing.T) {

t.Run("should block all subscriptions", func(t *testing.T) {
t.Parallel()

t.Run("allow", func(t *testing.T) {
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockNonPersistedOperations = config.BlockNonPersistedConfiguration{
securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{
Enabled: true,
}
},
}, func(t *testing.T, xEnv *testenv.Environment) {

conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil)
err := conn.WriteJSON(&testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`),
})
require.NoError(t, err)

var msg testenv.WebSocketMessage
err = conn.ReadJSON(&msg)
require.NoError(t, err)
require.Equal(t, "1", msg.ID)
require.Equal(t, "error", msg.Type)
require.Equal(t, `[{"message":"operation type 'subscription' is blocked"}]`, string(msg.Payload))
})
})

t.Run("should block subscriptions by header match expression", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{
Enabled: true,
Condition: "request.header.Get('graphql-client-name') == 'my-client'",
}
},
}, func(t *testing.T, xEnv *testenv.Environment) {

type currentTimePayload struct {
Data struct {
CurrentTime struct {
UnixTime float64 `json:"unixTime"`
Timestamp string `json:"timestamp"`
} `json:"currentTime"`
} `json:"data"`
}

// Positive test

conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil)
err := conn.WriteJSON(&testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`),
})
require.NoError(t, err)

var msg testenv.WebSocketMessage
var payload currentTimePayload

err = conn.ReadJSON(&msg)
require.NoError(t, err)
require.Equal(t, "1", msg.ID)
require.Equal(t, "next", msg.Type)

err = json.Unmarshal(msg.Payload, &payload)
require.NoError(t, err)

require.NotEmpty(t, payload.Data.CurrentTime.UnixTime)
require.NotEmpty(t, payload.Data.CurrentTime.Timestamp)

_ = conn.Close()

// Negative test

header := make(http.Header)
header.Add("graphql-client-name", "my-client")
res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
OperationName: []byte(`"Employees"`),
Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"}}`),
Header: header,
conn = xEnv.InitGraphQLWebSocketConnection(header, nil, nil)
err = conn.WriteJSON(&testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`),
})
require.NoError(t, err)
require.Equal(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, res.Body)

msg = testenv.WebSocketMessage{}
err = conn.ReadJSON(&msg)
require.NoError(t, err)
require.Equal(t, "1", msg.ID)
require.Equal(t, "error", msg.Type)
require.Equal(t, `[{"message":"operation type 'subscription' is blocked"}]`, string(msg.Payload))
})
})

t.Run("block", func(t *testing.T) {
t.Parallel()
})

t.Run("block non-persisted operations", func(t *testing.T) {
t.Parallel()

t.Run("should allow operations", func(t *testing.T) {
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockNonPersistedOperations = config.BlockNonPersistedConfiguration{
securityConfiguration.BlockNonPersistedOperations = config.BlockOperationConfiguration{
Enabled: true,
}
},
}, func(t *testing.T, xEnv *testenv.Environment) {

// Negative test

res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`,
})
require.Equal(t, http.StatusOK, res.Response.StatusCode)
require.Equal(t, res.Response.Header.Get("Content-Type"), "application/json")
require.Equal(t, `{"errors":[{"message":"non-persisted operation is blocked"}]}`, res.Body)

// Positive test

header := make(http.Header)
header.Add("graphql-client-name", "my-client")
res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
OperationName: []byte(`"Employees"`),
Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"}}`),
Header: header,
})
require.NoError(t, err)
require.Equal(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, res.Body)
})
})

t.Run("should block operation by header match expression", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockNonPersistedOperations = config.BlockNonPersistedConfiguration{
securityConfiguration.BlockNonPersistedOperations = config.BlockOperationConfiguration{
Enabled: true,
Condition: "Request.Header.Get('graphql-client-name') == 'my-client'",
Condition: "request.header.Get('graphql-client-name') == 'my-client'",
}
},
}, func(t *testing.T, xEnv *testenv.Environment) {

// Negative test
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Header: map[string][]string{
"graphql-client-name": {"my-client-different"},
Expand All @@ -187,6 +335,7 @@ func TestBlockOperations(t *testing.T) {
require.Equal(t, http.StatusOK, res.Response.StatusCode)
require.Equal(t, `{"data":{"updateEmployeeTag":{"id":1,"tag":"test"}}}`, res.Body)

// Positive test
res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Header: map[string][]string{
"graphql-client-name": {"my-client"},
Expand Down
4 changes: 2 additions & 2 deletions router-tests/kafka_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ func TestKafkaEvents(t *testing.T) {
KafkaSeeds: seeds,
EnableKafka: true,
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockSubscriptions = config.BlockSubscriptionConfiguration{
securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{
Enabled: true,
}
},
Expand Down Expand Up @@ -660,7 +660,7 @@ func TestKafkaEvents(t *testing.T) {
KafkaSeeds: seeds,
EnableKafka: true,
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockSubscriptions = config.BlockSubscriptionConfiguration{
securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{
Enabled: true,
}
},
Expand Down
4 changes: 2 additions & 2 deletions router-tests/nats_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ func TestNatsEvents(t *testing.T) {
testenv.Run(t, &testenv.Config{
EnableNats: true,
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockSubscriptions = config.BlockSubscriptionConfiguration{
securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{
Enabled: true,
}
},
Expand Down Expand Up @@ -625,7 +625,7 @@ func TestNatsEvents(t *testing.T) {
testenv.Run(t, &testenv.Config{
EnableNats: true,
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.BlockSubscriptions = config.BlockSubscriptionConfiguration{
securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{
Enabled: true,
}
},
Expand Down
10 changes: 5 additions & 5 deletions router-tests/testenv/testenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ func (e *Environment) MakeGraphQLRequestWithContext(ctx context.Context, request
req.Header = request.Header
}
req.Header.Set("Accept-Encoding", "identity")
return e.makeGraphQLRequest(req)
return e.MakeGraphQLRequestRaw(req)
}

func (e *Environment) MakeGraphQLRequestWithHeaders(request GraphQLRequest, headers map[string]string) (*TestResponse, error) {
Expand All @@ -1166,11 +1166,10 @@ func (e *Environment) MakeGraphQLRequestWithHeaders(request GraphQLRequest, head
if request.Header != nil {
req.Header = request.Header
}
req.Header.Set("Accept-Encoding", "identity")
for k, v := range headers {
req.Header.Set(k, v)
}
return e.makeGraphQLRequest(req)
return e.MakeGraphQLRequestRaw(req)
}

func (e *Environment) MakeGraphQLRequestOverGET(request GraphQLRequest) (*TestResponse, error) {
Expand All @@ -1179,7 +1178,7 @@ func (e *Environment) MakeGraphQLRequestOverGET(request GraphQLRequest) (*TestRe
return nil, err
}

return e.makeGraphQLRequest(req)
return e.MakeGraphQLRequestRaw(req)
}

func (e *Environment) newGraphQLRequestOverGET(baseURL string, request GraphQLRequest) (*http.Request, error) {
Expand Down Expand Up @@ -1210,7 +1209,8 @@ func (e *Environment) newGraphQLRequestOverGET(baseURL string, request GraphQLRe
return req, nil
}

func (e *Environment) makeGraphQLRequest(request *http.Request) (*TestResponse, error) {
func (e *Environment) MakeGraphQLRequestRaw(request *http.Request) (*TestResponse, error) {
request.Header.Set("Accept-Encoding", "identity")
resp, err := e.RouterClient.Do(request)
if err != nil {
return nil, err
Expand Down
Loading
Loading