Skip to content
This repository was archived by the owner on Dec 19, 2023. It is now read-only.

Feature/fix test subscription reset #472

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.springframework.web.util.UriBuilderFactory;

import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
Expand All @@ -33,8 +34,8 @@
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.fail;
Expand All @@ -46,29 +47,42 @@
@Slf4j
public class GraphQLTestSubscription {

private static final WebSocketContainer WEB_SOCKET_CONTAINER = ContainerProvider.getWebSocketContainer();
private static final int SLEEP_INTERVAL_MS = 100;
private static final int ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT = 6000000;
private static final int ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT = 60000;
private static final AtomicInteger ID_COUNTER = new AtomicInteger(1);
private static final UriBuilderFactory URI_BUILDER_FACTORY = new DefaultUriBuilderFactory();
private static final Object STATE_LOCK = new Object();

@Getter
private Session session;

@Getter
private boolean initialized = false;
@Getter
private boolean acknowledged = false;
@Getter
private boolean started = false;
@Getter
private boolean stopped = false;
private SubscriptionState state = SubscriptionState.builder()
.id(ID_COUNTER.incrementAndGet())
.build();

private final Environment environment;
private final ObjectMapper objectMapper;
private final String subscriptionPath;

private final Queue<GraphQLResponse> responses = new ConcurrentLinkedQueue<>();
private int id = ID_COUNTER.getAndIncrement();
public boolean isInitialized() {
return state.isInitialized();
}

public boolean isAcknowledged() {
return state.isAcknowledged();
}

public boolean isStarted() {
return state.isStarted();
}

public boolean isStopped() {
return state.isStopped();
}

public boolean isCompleted() {
return state.isCompleted();
}

/**
* Sends the "connection_init" message to the GraphQL server without a payload.
Expand All @@ -85,7 +99,7 @@ public GraphQLTestSubscription init() {
* @return self reference
*/
public GraphQLTestSubscription init(@Nullable final Object payload) {
if (initialized) {
if (isInitialized()) {
fail("Subscription already initialized.");
}
try {
Expand All @@ -97,8 +111,9 @@ public GraphQLTestSubscription init(@Nullable final Object payload) {
message.put("type", "connection_init");
message.set("payload", getFinalPayload(payload));
sendMessage(message);
initialized = true;
state.setInitialized(true);
awaitAcknowledgement();
log.debug("Subscription successfully initialized.");
return this;
}

Expand All @@ -120,20 +135,21 @@ public GraphQLTestSubscription start(@NonNull final String graphQLResource) {
* @return self reference
*/
public GraphQLTestSubscription start(@NonNull final String graphGLResource, @Nullable final Object variables) {
if (!initialized) {
if (!isInitialized()) {
init();
}
if (started) {
if (isStarted()) {
fail("Start message already sent. To start a new subscription, please call reset first.");
}
started = true;
state.setStarted(true);
ObjectNode payload = objectMapper.createObjectNode();
payload.put("query", loadQuery(graphGLResource));
payload.set("variables", getFinalPayload(variables));
ObjectNode message = objectMapper.createObjectNode();
message.put("type", "start");
message.put("id", id);
message.put("id", state.getId());
message.set("payload", payload);
log.debug("Sending start message.");
sendMessage(message);
return this;
}
Expand All @@ -143,24 +159,25 @@ public GraphQLTestSubscription start(@NonNull final String graphGLResource, @Nul
* @return self reference
*/
public GraphQLTestSubscription stop() {
if (!initialized) {
if (!isInitialized()) {
fail("Subscription not yet initialized.");
}
if (stopped) {
if (isStopped()) {
fail("Subscription already stopped.");
}
final ObjectNode message = objectMapper.createObjectNode();
message.put("type", "stop");
message.put("id", id);
message.put("id", state.getId());
log.debug("Sending stop message.");
sendMessage(message);
stopped = true;
try {
log.debug("Closing web socket session.");
session.close();
session = null;
awaitStop();
log.debug("Web socket session closed.");
} catch (IOException e) {
fail("Could not close web socket session", e);
}
log.debug("Subscription stopped.");
return this;
}

Expand All @@ -169,20 +186,12 @@ public GraphQLTestSubscription stop() {
* ensure that the bean is reusable between tests.
*/
public void reset() {
if (initialized && !stopped) {
if (isInitialized() && !isStopped()) {
stop();
}
if (stopped) {
id = ID_COUNTER.getAndIncrement();
}
initialized = false;
started = false;
stopped = false;
acknowledged = false;
state = SubscriptionState.builder().id(ID_COUNTER.incrementAndGet()).build();
session = null;
synchronized (responses) {
responses.clear();
}
log.debug("Test subscription client reset.");
}

/**
Expand Down Expand Up @@ -264,15 +273,15 @@ public List<GraphQLResponse> awaitAndGetNextResponses(
final int numExpectedResponses,
final boolean stopAfter
) {
if (!started) {
if (!isStarted()) {
fail("Start message not sent. Please send start message first.");
}
if (stopped) {
if (isStopped()) {
fail("Subscription already stopped. Forgot to call reset after test case?");
}
int elapsedTime = 0;
while (
((responses.size() < numExpectedResponses) || numExpectedResponses <= 0)
((state.getResponses().size() < numExpectedResponses) || numExpectedResponses <= 0)
&& elapsedTime < timeout
) {
try {
Expand All @@ -282,10 +291,11 @@ public List<GraphQLResponse> awaitAndGetNextResponses(
fail("Test execution error - Thread.sleep failed.", e);
}
}
synchronized (responses) {
if (stopAfter) {
stop();
}
if (stopAfter) {
stop();
}
synchronized (STATE_LOCK) {
final Queue<GraphQLResponse> responses = state.getResponses();
int responsesToPoll = responses.size();
if (numExpectedResponses == 0) {
assertThat(responses)
Expand Down Expand Up @@ -336,16 +346,15 @@ public GraphQLTestSubscription waitAndExpectNoResponse(final int timeToWait) {
* @return the remaining responses.
*/
public List<GraphQLResponse> getRemainingResponses() {
if (!stopped) {
if (!isStopped()) {
fail("getRemainingResponses should only be called after the subscription was stopped.");
}
final ArrayList<GraphQLResponse> graphQLResponses = new ArrayList<>(responses);
responses.clear();
final ArrayList<GraphQLResponse> graphQLResponses = new ArrayList<>(state.getResponses());
state.getResponses().clear();
return graphQLResponses;
}

private void initClient() throws Exception {
final WebSocketContainer webSocketContainer = ContainerProvider.getWebSocketContainer();
final String port = environment.getProperty("local.server.port");
final URI uri = URI_BUILDER_FACTORY.builder().scheme("ws").host("localhost").port(port).path(subscriptionPath)
.build();
Expand All @@ -355,8 +364,8 @@ private void initClient() throws Exception {
.build();
clientEndpointConfig.getUserProperties().put("org.apache.tomcat.websocket.IO_TIMEOUT_MS",
String.valueOf(ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT));
session = webSocketContainer.connectToServer(TestWebSocketClient.class, clientEndpointConfig, uri);
session.addMessageHandler(new TestMessageHandler());
session = WEB_SOCKET_CONTAINER.connectToServer(new TestWebSocketClient(state), clientEndpointConfig, uri);
session.addMessageHandler(new TestMessageHandler(objectMapper, state));
}

private JsonNode getFinalPayload(final Object variables) {
Expand Down Expand Up @@ -384,8 +393,16 @@ private void sendMessage(final Object message) {
}

private void awaitAcknowledgement() {
await(GraphQLTestSubscription::isAcknowledged, "Connection was not acknowledged by the GraphQL server.");
}

private void awaitStop() {
await(GraphQLTestSubscription::isStopped, "Connection was not stopped in time.");
}

private void await(final Predicate<GraphQLTestSubscription> condition, final String timeoutDescription) {
int elapsedTime = 0;
while(!acknowledged && elapsedTime < ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT) {
while(!condition.test(this) && elapsedTime < ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT) {
try {
Thread.sleep(SLEEP_INTERVAL_MS);
elapsedTime += SLEEP_INTERVAL_MS;
Expand All @@ -394,31 +411,45 @@ private void awaitAcknowledgement() {
}
}

if (!acknowledged) {
fail("Timeout: Connection was not acknowledged by the GraphQL server.");
if (!condition.test(this)) {
fail(String.format("Timeout: " + timeoutDescription));
}
}

class TestMessageHandler implements MessageHandler.Whole<String> {
@RequiredArgsConstructor
static class TestMessageHandler implements MessageHandler.Whole<String> {

private final ObjectMapper objectMapper;
private final SubscriptionState state;

@Override
public void onMessage(final String message) {
try {
log.debug("Received message from web socket: {}", message);
final JsonNode jsonNode = objectMapper.readTree(message);
final JsonNode typeNode = jsonNode.get("type");
assertThat(typeNode.isNull()).as("GraphQL messages should have a type field.").isFalse();
assertThat(typeNode).as("GraphQL messages should have a type field.").isNotNull();
assertThat(typeNode.isNull()).as("GraphQL messages type should not be null.").isFalse();
final String type = typeNode.asText();
if (type.equals("connection_ack")) {
acknowledged = true;
if (type.equals("complete")) {
state.setCompleted(true);
log.debug("Subscription completed.");
} else if (type.equals("connection_ack")) {
state.setAcknowledged(true);
log.debug("WebSocket connection acknowledged by the GraphQL Server.");
} else if (type.equals("data") || type.equals("error")) {
final JsonNode payload = jsonNode.get("payload");
assertThat(payload).as("Data/error messages must have a payload.").isNotNull();
final String payloadString = objectMapper.writeValueAsString(payload);
final GraphQLResponse graphQLResponse = new GraphQLResponse(ResponseEntity.ok(payloadString),
objectMapper);
synchronized (responses) {
responses.add(graphQLResponse);
if (state.isStopped() || state.isCompleted()) {
log.debug("Response discarded because subscription was stopped or completed in the meanwhile.");
} else {
synchronized (STATE_LOCK) {
state.getResponses().add(graphQLResponse);
}
log.debug("New response recorded.");
}
}
} catch (JsonProcessingException e) {
Expand All @@ -427,11 +458,21 @@ public void onMessage(final String message) {
}
}

public static class TestWebSocketClient extends Endpoint {
@RequiredArgsConstructor
private static class TestWebSocketClient extends Endpoint {

private final SubscriptionState state;

@Override
public void onOpen(final Session session, final EndpointConfig config) {
log.debug("Connection established.");
}

@Override
public void onClose(Session session, CloseReason closeReason) {
super.onClose(session, closeReason);
state.setStopped(true);
}
}

static class TestWebSocketClientConfigurator extends ClientEndpointConfig.Configurator {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.graphql.spring.boot.test;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
class SubscriptionState {

private boolean initialized;
private boolean acknowledged;
private boolean started;
private boolean stopped;
private boolean completed;
@Builder.Default
private Queue<GraphQLResponse> responses = new ConcurrentLinkedQueue<>();
private int id;
}
Loading