Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
import static io.a2a.util.Assert.checkNotNullParam;

import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import io.a2a.client.transport.spi.interceptors.ClientCallContext;
import io.a2a.client.transport.spi.ClientTransport;
import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor;
import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders;
import io.a2a.client.transport.spi.interceptors.auth.AuthInterceptor;
import io.a2a.common.A2AHeaders;
import io.a2a.grpc.A2AServiceGrpc;
import io.a2a.grpc.CancelTaskRequest;
Expand All @@ -32,11 +36,14 @@
import io.a2a.spec.GetTaskPushNotificationConfigParams;
import io.a2a.spec.ListTaskPushNotificationConfigParams;
import io.a2a.spec.MessageSendParams;
import io.a2a.spec.SendStreamingMessageRequest;
import io.a2a.spec.SetTaskPushNotificationConfigRequest;
import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.Task;
import io.a2a.spec.TaskIdParams;
import io.a2a.spec.TaskPushNotificationConfig;
import io.a2a.spec.TaskQueryParams;
import io.a2a.spec.TaskResubscriptionRequest;
import io.grpc.Channel;
import io.grpc.Metadata;
import io.grpc.StatusRuntimeException;
Expand All @@ -45,25 +52,39 @@

public class GrpcTransport implements ClientTransport {

private static final Metadata.Key<String> AUTHORIZATION_METADATA_KEY = Metadata.Key.of(
AuthInterceptor.AUTHORIZATION,
Metadata.ASCII_STRING_MARSHALLER);
private static final Metadata.Key<String> EXTENSIONS_KEY = Metadata.Key.of(
A2AHeaders.X_A2A_EXTENSIONS,
Metadata.ASCII_STRING_MARSHALLER);
private final A2AServiceBlockingV2Stub blockingStub;
private final A2AServiceStub asyncStub;
private final List<ClientCallInterceptor> interceptors;
private AgentCard agentCard;

public GrpcTransport(Channel channel, AgentCard agentCard) {
this(channel, agentCard, null);
}

public GrpcTransport(Channel channel, AgentCard agentCard, List<ClientCallInterceptor> interceptors) {
checkNotNullParam("channel", channel);
this.asyncStub = A2AServiceGrpc.newStub(channel);
this.blockingStub = A2AServiceGrpc.newBlockingV2Stub(channel);
this.agentCard = agentCard;
this.interceptors = interceptors;
}

@Override
public EventKind sendMessage(MessageSendParams request, ClientCallContext context) throws A2AClientException {
checkNotNullParam("request", request);

SendMessageRequest sendMessageRequest = createGrpcSendMessageRequest(request, context);
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.SendMessageRequest.METHOD, sendMessageRequest,
agentCard, context);

try {
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
SendMessageResponse response = stubWithMetadata.sendMessage(sendMessageRequest);
if (response.hasMsg()) {
return FromProto.message(response.getMsg());
Expand All @@ -83,10 +104,12 @@ public void sendMessageStreaming(MessageSendParams request, Consumer<StreamingEv
checkNotNullParam("request", request);
checkNotNullParam("eventConsumer", eventConsumer);
SendMessageRequest grpcRequest = createGrpcSendMessageRequest(request, context);
PayloadAndHeaders payloadAndHeaders = applyInterceptors(SendStreamingMessageRequest.METHOD,
grpcRequest, agentCard, context);
StreamObserver<StreamResponse> streamObserver = new EventStreamObserver(eventConsumer, errorConsumer);

try {
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context);
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context, payloadAndHeaders);
stubWithMetadata.sendStreamingMessage(grpcRequest, streamObserver);
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to send streaming message request: ");
Expand All @@ -103,9 +126,11 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A
requestBuilder.setHistoryLength(request.historyLength());
}
GetTaskRequest getTaskRequest = requestBuilder.build();
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, getTaskRequest,
agentCard, context);

try {
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
return FromProto.task(stubWithMetadata.getTask(getTaskRequest));
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task: ");
Expand All @@ -119,9 +144,11 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A
CancelTaskRequest cancelTaskRequest = CancelTaskRequest.newBuilder()
.setName("tasks/" + request.id())
.build();
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.CancelTaskRequest.METHOD, cancelTaskRequest,
agentCard, context);

try {
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
return FromProto.task(stubWithMetadata.cancelTask(cancelTaskRequest));
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to cancel task: ");
Expand All @@ -139,9 +166,11 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN
.setConfig(ToProto.taskPushNotificationConfig(request))
.setConfigId(configId != null ? configId : request.taskId())
.build();
PayloadAndHeaders payloadAndHeaders = applyInterceptors(SetTaskPushNotificationConfigRequest.METHOD,
grpcRequest, agentCard, context);

try {
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
return FromProto.taskPushNotificationConfig(stubWithMetadata.createTaskPushNotificationConfig(grpcRequest));
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to create task push notification config: ");
Expand All @@ -157,9 +186,11 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(
GetTaskPushNotificationConfigRequest grpcRequest = GetTaskPushNotificationConfigRequest.newBuilder()
.setName(getTaskPushNotificationConfigName(request))
.build();
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskPushNotificationConfigRequest.METHOD,
grpcRequest, agentCard, context);

try {
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
return FromProto.taskPushNotificationConfig(stubWithMetadata.getTaskPushNotificationConfig(grpcRequest));
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task push notification config: ");
Expand All @@ -175,9 +206,11 @@ public List<TaskPushNotificationConfig> listTaskPushNotificationConfigurations(
ListTaskPushNotificationConfigRequest grpcRequest = ListTaskPushNotificationConfigRequest.newBuilder()
.setParent("tasks/" + request.id())
.build();
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.ListTaskPushNotificationConfigRequest.METHOD,
grpcRequest, agentCard, context);

try {
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
return stubWithMetadata.listTaskPushNotificationConfig(grpcRequest).getConfigsList().stream()
.map(FromProto::taskPushNotificationConfig)
.collect(Collectors.toList());
Expand All @@ -194,9 +227,11 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC
DeleteTaskPushNotificationConfigRequest grpcRequest = DeleteTaskPushNotificationConfigRequest.newBuilder()
.setName(getTaskPushNotificationConfigName(request.id(), request.pushNotificationConfigId()))
.build();
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.DeleteTaskPushNotificationConfigRequest.METHOD,
grpcRequest, agentCard, context);

try {
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
stubWithMetadata.deleteTaskPushNotificationConfig(grpcRequest);
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to delete task push notification config: ");
Expand All @@ -212,11 +247,13 @@ public void resubscribe(TaskIdParams request, Consumer<StreamingEventKind> event
TaskSubscriptionRequest grpcRequest = TaskSubscriptionRequest.newBuilder()
.setName("tasks/" + request.id())
.build();
PayloadAndHeaders payloadAndHeaders = applyInterceptors(TaskResubscriptionRequest.METHOD,
grpcRequest, agentCard, context);

StreamObserver<StreamResponse> streamObserver = new EventStreamObserver(eventConsumer, errorConsumer);

try {
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context);
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context, payloadAndHeaders);
stubWithMetadata.taskSubscription(grpcRequest, streamObserver);
} catch (StatusRuntimeException e) {
throw GrpcErrorMapper.mapGrpcError(e, "Failed to resubscribe task push notification config: ");
Expand Down Expand Up @@ -249,43 +286,64 @@ private SendMessageRequest createGrpcSendMessageRequest(MessageSendParams messag
* Creates gRPC metadata from ClientCallContext headers.
* Extracts headers like X-A2A-Extensions and sets them as gRPC metadata.
*/
private Metadata createGrpcMetadata(ClientCallContext context) {
private Metadata createGrpcMetadata(ClientCallContext context, PayloadAndHeaders payloadAndHeaders) {
Metadata metadata = new Metadata();

if (context != null && context.getHeaders() != null) {
// Set X-A2A-Extensions header if present
String extensionsHeader = context.getHeaders().get(A2AHeaders.X_A2A_EXTENSIONS);
if (extensionsHeader != null) {
Metadata.Key<String> extensionsKey = Metadata.Key.of(A2AHeaders.X_A2A_EXTENSIONS, Metadata.ASCII_STRING_MARSHALLER);
metadata.put(extensionsKey, extensionsHeader);
metadata.put(EXTENSIONS_KEY, extensionsHeader);
}

// Add other headers as needed in the future
// For now, we only handle X-A2A-Extensions
}
if (payloadAndHeaders != null && payloadAndHeaders.getHeaders() != null) {
// Handle all headers from interceptors (including auth headers)
for (Map.Entry<String, String> headerEntry : payloadAndHeaders.getHeaders().entrySet()) {
String headerName = headerEntry.getKey();
String headerValue = headerEntry.getValue();

if (headerValue != null) {
// Use static key for common Authorization header, create dynamic keys for others
if (AuthInterceptor.AUTHORIZATION.equals(headerName)) {
metadata.put(AUTHORIZATION_METADATA_KEY, headerValue);
} else {
// Create a metadata key dynamically for API keys and other custom headers
Metadata.Key<String> metadataKey = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER);
metadata.put(metadataKey, headerValue);
}
}
}
}

return metadata;
}

/**
* Creates a blocking stub with metadata attached from the ClientCallContext.
*
* @param context the client call context
*
* @param context the client call context
* @param payloadAndHeaders the payloadAndHeaders after applying any interceptors
* @return blocking stub with metadata interceptor
*/
private A2AServiceBlockingV2Stub createBlockingStubWithMetadata(ClientCallContext context) {
Metadata metadata = createGrpcMetadata(context);
private A2AServiceBlockingV2Stub createBlockingStubWithMetadata(ClientCallContext context,
PayloadAndHeaders payloadAndHeaders) {
Metadata metadata = createGrpcMetadata(context, payloadAndHeaders);
return blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
}

/**
* Creates an async stub with metadata attached from the ClientCallContext.
*
* @param context the client call context
*
* @param context the client call context
* @param payloadAndHeaders the payloadAndHeaders after applying any interceptors
* @return async stub with metadata interceptor
*/
private A2AServiceStub createAsyncStubWithMetadata(ClientCallContext context) {
Metadata metadata = createGrpcMetadata(context);
private A2AServiceStub createAsyncStubWithMetadata(ClientCallContext context,
PayloadAndHeaders payloadAndHeaders) {
Metadata metadata = createGrpcMetadata(context, payloadAndHeaders);
return asyncStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
}

Expand All @@ -307,4 +365,17 @@ private String getTaskPushNotificationConfigName(String taskId, String pushNotif
return name.toString();
}

private PayloadAndHeaders applyInterceptors(String methodName, Object payload,
AgentCard agentCard, ClientCallContext clientCallContext) {
PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload,
clientCallContext != null ? clientCallContext.getHeaders() : null);
if (interceptors != null && ! interceptors.isEmpty()) {
for (ClientCallInterceptor interceptor : interceptors) {
payloadAndHeaders = interceptor.intercept(methodName, payloadAndHeaders.getPayload(),
payloadAndHeaders.getHeaders(), agentCard, clientCallContext);
}
}
return payloadAndHeaders;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public GrpcTransportConfigBuilder channelFactory(Function<String, Channel> chann

@Override
public GrpcTransportConfig build() {
return new GrpcTransportConfig(channelFactory);
GrpcTransportConfig config = new GrpcTransportConfig(channelFactory);
config.setInterceptors(interceptors);
return config;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public GrpcTransport create(GrpcTransportConfig grpcTransportConfig, AgentCard a

Channel channel = grpcTransportConfig.getChannelFactory().apply(agentUrl);
if (channel != null) {
return new GrpcTransport(channel, agentCard);
return new GrpcTransport(channel, agentCard, grpcTransportConfig.getInterceptors());
}

throw new A2AClientException("Missing required GrpcTransportConfig");
Expand Down
10 changes: 10 additions & 0 deletions client/transport/spi/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
<groupId>io.github.a2asdk</groupId>
<artifactId>a2a-java-sdk-spec</artifactId>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Loading