Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ nbactions.xml
# Private Claude config
.claude/
.serena/
.bob/
claudedocs
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,32 @@
import java.util.function.Consumer;
import java.util.logging.Logger;

import io.a2a.client.transport.spi.sse.AbstractSSEEventListener;
import io.a2a.grpc.StreamResponse;
import io.a2a.grpc.utils.JSONRPCUtils;
import io.a2a.grpc.utils.ProtoUtils;
import io.a2a.jsonrpc.common.json.JsonProcessingException;
import io.a2a.spec.A2AError;
import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.Task;
import io.a2a.spec.TaskState;
import io.a2a.spec.TaskStatusUpdateEvent;
import org.jspecify.annotations.Nullable;

public class SSEEventListener {
/**
* JSON-RPC transport implementation of SSE event listener.
* Handles parsing of JSON-RPC formatted messages from SSE streams.
*/
public class SSEEventListener extends AbstractSSEEventListener {

private static final Logger log = Logger.getLogger(SSEEventListener.class.getName());
private final Consumer<StreamingEventKind> eventHandler;
private final @Nullable
Consumer<Throwable> errorHandler;
private volatile boolean completed = false;

public SSEEventListener(Consumer<StreamingEventKind> eventHandler,
@Nullable Consumer<Throwable> errorHandler) {
this.eventHandler = eventHandler;
this.errorHandler = errorHandler;
super(eventHandler, errorHandler);
}

@Override
public void onMessage(String message, @Nullable Future<Void> completableFuture) {
handleMessage(message, completableFuture);
}

public void onError(Throwable throwable, @Nullable Future<Void> future) {
if (errorHandler != null) {
errorHandler.accept(throwable);
}
if (future != null) {
future.cancel(true); // close SSE channel
}
parseAndHandleMessage(message, completableFuture);
}

public void onComplete() {
Expand All @@ -52,40 +42,30 @@ public void onComplete() {

// Signal normal stream completion (null error means successful completion)
log.fine("SSEEventListener.onComplete() called - signaling successful stream completion");
if (errorHandler != null) {
if (getErrorHandler() != null) {
log.fine("Calling errorHandler.accept(null) to signal successful completion");
errorHandler.accept(null);
getErrorHandler().accept(null);
} else {
log.warning("errorHandler is null, cannot signal completion");
}
}

private void handleMessage(String message, @Nullable Future<Void> future) {
/**
* Parses a JSON-RPC message and delegates to the base class for event handling.
*
* @param message The raw JSON-RPC message string
* @param future Optional future for controlling the SSE connection
*/
private void parseAndHandleMessage(String message, @Nullable Future<Void> future) {
try {
StreamResponse response = JSONRPCUtils.parseResponseEvent(message);

StreamingEventKind event = ProtoUtils.FromProto.streamingEventKind(response);
eventHandler.accept(event);

// Client-side auto-close on final events to prevent connection leaks
// Handles both TaskStatusUpdateEvent and Task objects with final states
// This covers late subscriptions to completed tasks and ensures no connection leaks
boolean shouldClose = false;
if (event instanceof TaskStatusUpdateEvent tue && tue.isFinal()) {
shouldClose = true;
} else if (event instanceof Task task) {
TaskState state = task.status().state();
if (state.isFinal()) {
shouldClose = true;
}
}

if (shouldClose && future != null) {
future.cancel(true); // close SSE channel
}

// Delegate to base class for common event handling and auto-close logic
handleEvent(event, future);
} catch (A2AError error) {
if (errorHandler != null) {
errorHandler.accept(error);
if (getErrorHandler() != null) {
getErrorHandler().accept(error);
}
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,22 @@ public void testFinalTaskStatusUpdateEventCancels() {
error -> {}
);

// Parse the message event JSON
String eventData = JsonStreamingMessages.STREAMING_STATUS_UPDATE_EVENT_FINAL.substring(
JsonStreamingMessages.STREAMING_STATUS_UPDATE_EVENT_FINAL.indexOf("{"));

// Call onMessage with a cancellable future
CancelCapturingFuture future = new CancelCapturingFuture();
listener.onMessage(eventData, future);

// Verify the event was received and processed
assertNotNull(receivedEvent.get());
assertTrue(receivedEvent.get() instanceof TaskStatusUpdateEvent);
TaskStatusUpdateEvent received = (TaskStatusUpdateEvent) receivedEvent.get();
assertTrue(received.isFinal());

// Verify the future was cancelled (auto-close on final event)
assertTrue(future.cancelHandlerCalled);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import io.a2a.client.http.A2AHttpClient;
import io.a2a.client.http.A2AHttpClientFactory;
import io.a2a.client.http.A2AHttpResponse;
import io.a2a.client.transport.rest.sse.RestSSEEventListener;
import io.a2a.client.transport.rest.sse.SSEEventListener;
import io.a2a.client.transport.spi.ClientTransport;
import io.a2a.client.transport.spi.interceptors.ClientCallContext;
import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor;
Expand Down Expand Up @@ -110,7 +110,7 @@ public void sendMessageStreaming(MessageSendParams messageSendParams, Consumer<S
io.a2a.grpc.SendMessageRequest.Builder builder = io.a2a.grpc.SendMessageRequest.newBuilder(ProtoUtils.ToProto.sendMessageRequest(messageSendParams));
PayloadAndHeaders payloadAndHeaders = applyInterceptors(SEND_STREAMING_MESSAGE_METHOD, builder, agentCard, context);
AtomicReference<CompletableFuture<Void>> ref = new AtomicReference<>();
RestSSEEventListener sseEventListener = new RestSSEEventListener(eventConsumer, errorConsumer);
SSEEventListener sseEventListener = new SSEEventListener(eventConsumer, errorConsumer);
try {
A2AHttpClient.PostBuilder postBuilder = createPostBuilder(Utils.buildBaseUrl(agentInterface, messageSendParams.tenant()) + "/message:stream", payloadAndHeaders);
ref.set(postBuilder.postAsyncSSE(
Expand Down Expand Up @@ -395,7 +395,7 @@ public void resubscribe(TaskIdParams request, Consumer<StreamingEventKind> event
PayloadAndHeaders payloadAndHeaders = applyInterceptors(SUBSCRIBE_TO_TASK_METHOD, builder,
agentCard, context);
AtomicReference<CompletableFuture<Void>> ref = new AtomicReference<>();
RestSSEEventListener sseEventListener = new RestSSEEventListener(eventConsumer, errorConsumer);
SSEEventListener sseEventListener = new SSEEventListener(eventConsumer, errorConsumer);
try {
String url = Utils.buildBaseUrl(agentInterface, request.tenant()) + String.format("/tasks/%1s:subscribe", request.id());
A2AHttpClient.PostBuilder postBuilder = createPostBuilder(url, payloadAndHeaders);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package io.a2a.client.transport.rest.sse;

import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.logging.Logger;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import io.a2a.client.transport.spi.sse.AbstractSSEEventListener;
import io.a2a.client.transport.rest.RestErrorMapper;
import io.a2a.grpc.StreamResponse;
import io.a2a.grpc.utils.ProtoUtils;
import io.a2a.spec.StreamingEventKind;
import org.jspecify.annotations.Nullable;

/**
* REST transport implementation of SSE event listener.
* Handles parsing of JSON-formatted protobuf messages from REST SSE streams.
*/
public class SSEEventListener extends AbstractSSEEventListener {

private static final Logger log = Logger.getLogger(SSEEventListener.class.getName());

public SSEEventListener(Consumer<StreamingEventKind> eventHandler,
@Nullable Consumer<Throwable> errorHandler) {
super(eventHandler, errorHandler);
}

@Override
public void onMessage(String message, @Nullable Future<Void> completableFuture) {
try {
log.fine("Streaming message received: " + message);
io.a2a.grpc.StreamResponse.Builder builder = io.a2a.grpc.StreamResponse.newBuilder();
JsonFormat.parser().merge(message, builder);
parseAndHandleMessage(builder.build(), completableFuture);
} catch (InvalidProtocolBufferException e) {
if (getErrorHandler() != null) {
getErrorHandler().accept(RestErrorMapper.mapRestError(message, 500));
}
}
}

/**
* Parses a StreamResponse protobuf message and delegates to the base class for event handling.
*
* @param response The parsed StreamResponse
* @param future Optional future for controlling the SSE connection
*/
private void parseAndHandleMessage(StreamResponse response, @Nullable Future<Void> future) {
StreamingEventKind event;
switch (response.getPayloadCase()) {
case MESSAGE ->
event = ProtoUtils.FromProto.message(response.getMessage());
case TASK ->
event = ProtoUtils.FromProto.task(response.getTask());
case STATUS_UPDATE ->
event = ProtoUtils.FromProto.taskStatusUpdateEvent(response.getStatusUpdate());
case ARTIFACT_UPDATE ->
event = ProtoUtils.FromProto.taskArtifactUpdateEvent(response.getArtifactUpdate());
default -> {
log.warning("Invalid stream response " + response.getPayloadCase());
if (getErrorHandler() != null) {
getErrorHandler().accept(new IllegalStateException("Invalid stream response from server: " + response.getPayloadCase()));
}
return;
}
}

// Delegate to base class for common event handling and auto-close logic
handleEvent(event, future);
}

}
Loading