Skip to content

Support creating the new context map #1292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ package com.expediagroup.graphql.generator.execution
* Marker interface to indicate that the implementing class should be considered
* as the GraphQL context. This means the implementing class will not appear in the schema.
*/
@Deprecated("The generic context object is deprecated in favor of the context map")
interface GraphQLContext
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,12 @@ interface GraphQLContextFactory<out Context : GraphQLContext, Request> {
* Generate GraphQL context based on the incoming request and the corresponding response.
* If no context should be generated and used in the request, return null.
*/
@Deprecated("The generic context object is deprecated in favor of the context map")
suspend fun generateContext(request: Request): Context?

/**
* GraphQL Java 17 has a new context map instead of a generic object. Implementing this method
* will set the context map in the execution input.
*/
suspend fun generateContextMap(request: Request): Map<*, Any>? = null
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ open class GraphQLRequestHandler(
* This should only be used for queries and mutations.
* Subscriptions require more specific server logic and will need to be handled separately.
*/
open suspend fun executeRequest(request: GraphQLRequest, context: GraphQLContext? = null): GraphQLResponse<*> {
open suspend fun executeRequest(request: GraphQLRequest, context: GraphQLContext? = null, graphQLContext: Map<*, Any>? = null): GraphQLResponse<*> {
// We should generate a new registry for every request
val dataLoaderRegistry = dataLoaderRegistryFactory?.generate()
val executionInput = request.toExecutionInput(context, dataLoaderRegistry)
val executionInput = request.toExecutionInput(context, dataLoaderRegistry, graphQLContext)

return try {
graphQL.executeAsync(executionInput).await().toGraphQLResponse()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ open class GraphQLServer<Request>(

return if (graphQLRequest != null) {
val context = contextFactory.generateContext(request)
val graphQLContext = contextFactory.generateContextMap(request)

when (graphQLRequest) {
is GraphQLRequest -> requestHandler.executeRequest(graphQLRequest, context)
is GraphQLRequest -> requestHandler.executeRequest(graphQLRequest, context, graphQLContext)
is GraphQLBatchRequest -> GraphQLBatchResponse(
graphQLRequest.requests.map {
requestHandler.executeRequest(it, context)
requestHandler.executeRequest(it, context, graphQLContext)
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ import org.dataloader.DataLoaderRegistry
/**
* Convert the common [GraphQLRequest] to the execution input used by graphql-java
*/
fun GraphQLRequest.toExecutionInput(graphQLContext: Any? = null, dataLoaderRegistry: DataLoaderRegistry? = null): ExecutionInput =
fun GraphQLRequest.toExecutionInput(graphQLContext: Any? = null, dataLoaderRegistry: DataLoaderRegistry? = null, graphQLContextMap: Map<*, Any>? = null): ExecutionInput =
ExecutionInput.newExecutionInput()
.query(this.query)
.operationName(this.operationName)
.variables(this.variables ?: emptyMap())
.context(graphQLContext)
.also { builder ->
graphQLContext?.let { builder.context(it) }
graphQLContextMap?.let { builder.graphQLContext(it) }
}
.dataLoaderRegistry(dataLoaderRegistry ?: DataLoaderRegistry())
.build()
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.expediagroup.graphql.server.types.GraphQLRequest
import graphql.ExecutionInput
import graphql.GraphQL
import graphql.execution.AbortExecutionException
import graphql.schema.DataFetchingEnvironment
import graphql.schema.GraphQLSchema
import io.mockk.every
import io.mockk.mockk
Expand Down Expand Up @@ -122,6 +123,22 @@ class GraphQLRequestHandlerTest {
assertNull(response.extensions)
}

@Test
@ExperimentalCoroutinesApi
fun `execute graphQL query with graphql context`() = runBlockingTest {
val context = mapOf("foo" to "JUNIT context value")
val request = GraphQLRequest(query = "query { graphQLContextualValue }")

val response = graphQLRequestHandler.executeRequest(request, context = null, graphQLContext = context)
assertNotNull(response.data as? Map<*, *>) { data ->
assertNotNull(data["graphQLContextualValue"] as? String) { msg ->
assertEquals("JUNIT context value", msg)
}
}
assertNull(response.errors)
assertNull(response.extensions)
}

@Test
@ExperimentalCoroutinesApi
fun `execute graphQL query throwing uncaught exception`() = runBlockingTest {
Expand Down Expand Up @@ -164,6 +181,8 @@ class GraphQLRequestHandlerTest {
fun alwaysThrows(): String = throw Exception("JUNIT Failure")

fun contextualValue(context: MyContext): String = context.value ?: "default"

fun graphQLContextualValue(dataFetchingEnvironment: DataFetchingEnvironment): String = dataFetchingEnvironment.graphQlContext.get("foo") ?: "default"
}

data class MyContext(val value: String? = null) : GraphQLContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ class GraphQLServerTest {
}
val mockContextFactory = mockk<GraphQLContextFactory<MockContext, MockHttpRequest>> {
coEvery { generateContext(any()) } returns MockContext()
coEvery { generateContextMap(any()) } returns mapOf("foo" to 1)
}
val mockHandler = mockk<GraphQLRequestHandler> {
coEvery { executeRequest(any(), any()) } returns mockk()
coEvery { executeRequest(any(), any(), any()) } returns mockk()
}

val server = GraphQLServer(mockParser, mockContextFactory, mockHandler)
Expand All @@ -52,7 +53,7 @@ class GraphQLServerTest {
coVerify(exactly = 1) {
mockParser.parseRequest(any())
mockContextFactory.generateContext(any())
mockHandler.executeRequest(any(), any())
mockHandler.executeRequest(any(), any(), any())
}
}

Expand All @@ -63,9 +64,34 @@ class GraphQLServerTest {
}
val mockContextFactory = mockk<GraphQLContextFactory<MockContext, MockHttpRequest>> {
coEvery { generateContext(any()) } returns null
coEvery { generateContextMap(any()) } returns mapOf(1 to "foo")
}
val mockHandler = mockk<GraphQLRequestHandler> {
coEvery { executeRequest(any(), null, any()) } returns mockk()
}

val server = GraphQLServer(mockParser, mockContextFactory, mockHandler)

runBlockingTest { server.execute(mockk()) }

coVerify(exactly = 1) {
mockParser.parseRequest(any())
mockContextFactory.generateContext(any())
mockHandler.executeRequest(any(), null, any())
}
}

@Test
fun `null graphQL context is used and passed to the request handler`() {
val mockParser = mockk<GraphQLRequestParser<MockHttpRequest>> {
coEvery { parseRequest(any()) } returns mockk<GraphQLRequest>()
}
val mockContextFactory = mockk<GraphQLContextFactory<MockContext, MockHttpRequest>> {
coEvery { generateContext(any()) } returns null
coEvery { generateContextMap(any()) } returns null
}
val mockHandler = mockk<GraphQLRequestHandler> {
coEvery { executeRequest(any(), null) } returns mockk()
coEvery { executeRequest(any(), null, any()) } returns mockk()
}

val server = GraphQLServer(mockParser, mockContextFactory, mockHandler)
Expand All @@ -75,7 +101,7 @@ class GraphQLServerTest {
coVerify(exactly = 1) {
mockParser.parseRequest(any())
mockContextFactory.generateContext(any())
mockHandler.executeRequest(any(), null)
mockHandler.executeRequest(any(), any(), null)
}
}

Expand All @@ -86,6 +112,7 @@ class GraphQLServerTest {
}
val mockContextFactory = mockk<GraphQLContextFactory<MockContext, MockHttpRequest>> {
coEvery { generateContext(any()) } returns MockContext()
coEvery { generateContextMap(any()) } returns null
}
val mockHandler = mockk<GraphQLRequestHandler> {
coEvery { executeRequest(any(), any()) } returns mockk()
Expand All @@ -100,7 +127,7 @@ class GraphQLServerTest {
}
coVerify(exactly = 0) {
mockContextFactory.generateContext(any())
mockHandler.executeRequest(any(), any())
mockHandler.executeRequest(any(), any(), any())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.dataloader.DataLoaderRegistry
import org.junit.jupiter.api.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertNull

class RequestExtensionsKtTest {

Expand All @@ -31,7 +30,7 @@ class RequestExtensionsKtTest {
val request = GraphQLRequest(query = "query { whatever }")
val executionInput = request.toExecutionInput()
assertEquals(request.query, executionInput.query)
assertNull(executionInput.context)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is not set, then graphql-java defaults to empty object. This is a much better validation regardless of the new map

assertNotNull(executionInput.variables)
assertNotNull(executionInput.dataLoaderRegistry)
}

Expand Down Expand Up @@ -67,4 +66,13 @@ class RequestExtensionsKtTest {
assertEquals(request.query, executionInput.query)
assertEquals(dataLoaderRegistry, executionInput.dataLoaderRegistry)
}

@Test
fun `verify can convert request with context map to execution input`() {
val request = GraphQLRequest(query = "query { whatever }")
val context = mapOf("foo" to 1)

val executionInput = request.toExecutionInput(graphQLContextMap = context)
assertEquals(1, executionInput.graphQLContext.get("foo"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,45 @@ interface ApolloSubscriptionHooks {
* You can reject the connection by throwing an exception.
* If you need to forward state to execution, update and return the [GraphQLContext].
*/
@Deprecated("The generic context object is deprecated in favor of the context map")
fun onConnect(
connectionParams: Map<String, String>,
session: WebSocketSession,
graphQLContext: GraphQLContext?
): GraphQLContext? = graphQLContext

/**
* Allows validation of connectionParams prior to starting the connection.
* You can reject the connection by throwing an exception.
* If you need to forward state to execution, update and return the context map.
*/
fun onConnectWithContext(
connectionParams: Map<String, String>,
session: WebSocketSession,
graphQLContext: Map<*, Any>?
): Map<*, Any>? = graphQLContext

/**
* Called when the client executes a GraphQL operation.
* The context can not be updated here, it is read only.
*/
@Deprecated("The generic context object is deprecated in favor of the context map")
fun onOperation(
operationMessage: SubscriptionOperationMessage,
session: WebSocketSession,
graphQLContext: GraphQLContext?
): Unit = Unit

/**
* Called when the client executes a GraphQL operation.
* The context can not be updated here, it is read only.
*/
fun onOperationWithContext(
operationMessage: SubscriptionOperationMessage,
session: WebSocketSession,
graphQLContext: Map<*, Any>?
): Unit = Unit

/**
* Called when client's unsubscribes
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ class ApolloSubscriptionProtocolHandler(
session: WebSocketSession
): Flux<SubscriptionOperationMessage> {
val context = sessionState.getContext(session)
val graphQLContext = sessionState.getGraphQLContext(session)

subscriptionHooks.onOperation(operationMessage, session, context)
subscriptionHooks.onOperationWithContext(operationMessage, session, graphQLContext)

if (operationMessage.id == null) {
logger.error("GraphQL subscription operation id is required")
Expand All @@ -130,7 +132,7 @@ class ApolloSubscriptionProtocolHandler(

try {
val request = objectMapper.convertValue<GraphQLRequest>(payload)
return subscriptionHandler.executeSubscription(request, context)
return subscriptionHandler.executeSubscription(request, context, graphQLContext)
.asFlux()
.map {
if (it.errors?.isNotEmpty() == true) {
Expand Down Expand Up @@ -164,8 +166,11 @@ class ApolloSubscriptionProtocolHandler(
runBlocking {
val connectionParams = castToMapOfStringString(operationMessage.payload)
val context = contextFactory.generateContext(session)
val onConnect = subscriptionHooks.onConnect(connectionParams, session, context)
sessionState.saveContext(session, onConnect)
val graphQLContext = contextFactory.generateContextMap(session)
val onConnectContext = subscriptionHooks.onConnect(connectionParams, session, context)
val onConnectGraphQLContext = subscriptionHooks.onConnectWithContext(connectionParams, session, graphQLContext)
sessionState.saveContext(session, onConnectContext)
sessionState.saveContextMap(session, onConnectGraphQLContext)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ internal class ApolloSubscriptionSessionState {
// The context is saved by web socket session id
private val cachedContext = ConcurrentHashMap<String, GraphQLContext>()

// The graphQL context is saved by web socket session id
private val cachedGraphQLContext = ConcurrentHashMap<String, Map<*, Any>>()

/**
* Save the context created from the factory and possibly updated in the onConnect hook.
* This allows us to include some intial state to be used when handling all the messages.
Expand All @@ -45,11 +48,27 @@ internal class ApolloSubscriptionSessionState {
}
}

/**
* Save the context created from the factory and possibly updated in the onConnect hook.
* This allows us to include some intial state to be used when handling all the messages.
* This will be removed in [terminateSession].
*/
fun saveContextMap(session: WebSocketSession, graphQLContext: Map<*, Any>?) {
if (graphQLContext != null) {
cachedGraphQLContext[session.id] = graphQLContext
}
}

/**
* Return the context for this session.
*/
fun getContext(session: WebSocketSession): GraphQLContext? = cachedContext[session.id]

/**
* Return the graphQL context for this session.
*/
fun getGraphQLContext(session: WebSocketSession): Map<*, Any>? = cachedGraphQLContext[session.id]

/**
* Save the session that is sending keep alive messages.
* This will override values without cancelling the subscription so it is the responsibility of the consumer to cancel.
Expand Down Expand Up @@ -122,6 +141,7 @@ internal class ApolloSubscriptionSessionState {
activeOperations[session.id]?.forEach { (_, subscription) -> subscription.cancel() }
activeOperations.remove(session.id)
cachedContext.remove(session.id)
cachedGraphQLContext.remove(session.id)
activeKeepAliveSessions[session.id]?.cancel()
activeKeepAliveSessions.remove(session.id)
session.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ open class SpringGraphQLSubscriptionHandler(
private val dataLoaderRegistryFactory: DataLoaderRegistryFactory? = null
) {

fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContext: GraphQLContext?): Flow<GraphQLResponse<*>> {
fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContext: GraphQLContext?, graphQLContextMap: Map<*, Any>? = null): Flow<GraphQLResponse<*>> {
val dataLoaderRegistry = dataLoaderRegistryFactory?.generate()
val input = graphQLRequest.toExecutionInput(graphQLContext, dataLoaderRegistry)
val input = graphQLRequest.toExecutionInput(graphQLContext, dataLoaderRegistry, graphQLContextMap)

return graphQL.execute(input)
.getData<Flow<ExecutionResult>>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class SubscriptionConfigurationTest {

@Bean
fun subscriptionHandler(): SpringGraphQLSubscriptionHandler = mockk {
every { executeSubscription(any(), any()) } returns flowOf()
every { executeSubscription(any(), any(), any()) } returns flowOf()
}

@Bean
Expand Down
Loading