Skip to content

Commit

Permalink
Fix bi-di subscription to support dapr-api-token
Browse files Browse the repository at this point in the history
Signed-off-by: Artur Souza <asouza.pro@gmail.com>
  • Loading branch information
artursouza committed Oct 10, 2024
1 parent cb552ba commit 5958266
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 153 deletions.
17 changes: 11 additions & 6 deletions sdk-actors/src/main/java/io/dapr/actors/client/ActorClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public ActorClient(ResiliencyOptions resiliencyOptions) {
* @param overrideProperties Override properties.
*/
public ActorClient(Properties overrideProperties) {
this(buildManagedChannel(overrideProperties), null);
this(buildManagedChannel(overrideProperties), null, overrideProperties.getValue(Properties.API_TOKEN));
}

/**
Expand All @@ -69,7 +69,7 @@ public ActorClient(Properties overrideProperties) {
* @param resiliencyOptions Client resiliency options.
*/
public ActorClient(Properties overrideProperties, ResiliencyOptions resiliencyOptions) {
this(buildManagedChannel(overrideProperties), resiliencyOptions);
this(buildManagedChannel(overrideProperties), resiliencyOptions, overrideProperties.getValue(Properties.API_TOKEN));
}

/**
Expand All @@ -80,9 +80,10 @@ public ActorClient(Properties overrideProperties, ResiliencyOptions resiliencyOp
*/
private ActorClient(
ManagedChannel grpcManagedChannel,
ResiliencyOptions resiliencyOptions) {
ResiliencyOptions resiliencyOptions,
String daprApiToken) {
this.grpcManagedChannel = grpcManagedChannel;
this.daprClient = buildDaprClient(grpcManagedChannel, resiliencyOptions);
this.daprClient = buildDaprClient(grpcManagedChannel, resiliencyOptions, daprApiToken);
}

/**
Expand Down Expand Up @@ -136,7 +137,11 @@ private static ManagedChannel buildManagedChannel(Properties overrideProperties)
*/
private static DaprClient buildDaprClient(
Channel grpcManagedChannel,
ResiliencyOptions resiliencyOptions) {
return new DaprClientImpl(DaprGrpc.newStub(grpcManagedChannel), resiliencyOptions);
ResiliencyOptions resiliencyOptions,
String daprApiToken) {
return new DaprClientImpl(
DaprGrpc.newStub(grpcManagedChannel),
resiliencyOptions,
daprApiToken);
}
}
66 changes: 12 additions & 54 deletions sdk-actors/src/main/java/io/dapr/actors/client/DaprClientImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@
*/
class DaprClientImpl implements DaprClient {

/**
* Timeout policy for SDK calls to Dapr API.
*/
private final TimeoutPolicy timeoutPolicy;

/**
* Retry policy for SDK calls to Dapr API.
*/
Expand All @@ -57,16 +52,22 @@ class DaprClientImpl implements DaprClient {
*/
private final DaprGrpc.DaprStub client;

/**
* gRPC client interceptors.
*/
private final DaprClientGrpcInterceptors grpcInterceptors;

/**
* Internal constructor.
*
* @param grpcClient Dapr's GRPC client.
* @param resiliencyOptions Client resiliency options (optional)
* @param resiliencyOptions Client resiliency options (optional).
* @param daprApiToken Dapr API token (optional).
*/
DaprClientImpl(DaprGrpc.DaprStub grpcClient, ResiliencyOptions resiliencyOptions) {
this.client = intercept(grpcClient);
this.timeoutPolicy = new TimeoutPolicy(
resiliencyOptions == null ? null : resiliencyOptions.getTimeout());
DaprClientImpl(DaprGrpc.DaprStub grpcClient, ResiliencyOptions resiliencyOptions, String daprApiToken) {
this.client = grpcClient;
this.grpcInterceptors = new DaprClientGrpcInterceptors(daprApiToken,
new TimeoutPolicy(resiliencyOptions == null ? null : resiliencyOptions.getTimeout()));
this.retryPolicy = new RetryPolicy(
resiliencyOptions == null ? null : resiliencyOptions.getMaxRetries());
}
Expand All @@ -85,54 +86,11 @@ public Mono<byte[]> invoke(String actorType, String actorId, String methodName,
.build();
return Mono.deferContextual(
context -> this.<DaprProtos.InvokeActorResponse>createMono(
it -> intercept(context, this.timeoutPolicy, client).invokeActor(req, it)
it -> this.grpcInterceptors.intercept(client, context).invokeActor(req, it)
)
).map(r -> r.getData().toByteArray());
}

/**
* Populates GRPC client with interceptors.
*
* @param client GRPC client for Dapr.
* @return Client after adding interceptors.
*/
private DaprGrpc.DaprStub intercept(DaprGrpc.DaprStub client) {
ClientInterceptor interceptor = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> methodDescriptor,
CallOptions options,
Channel channel) {
ClientCall<ReqT, RespT> clientCall = channel.newCall(methodDescriptor, timeoutPolicy.apply(options));
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(clientCall) {
@Override
public void start(final Listener<RespT> responseListener, final Metadata metadata) {
String daprApiToken = Properties.API_TOKEN.get();
if (daprApiToken != null) {
metadata.put(Metadata.Key.of("dapr-api-token", Metadata.ASCII_STRING_MARSHALLER), daprApiToken);
}

super.start(responseListener, metadata);
}
};
}
};
return client.withInterceptors(interceptor);
}

/**
* Populates GRPC client with interceptors for telemetry.
*
* @param context Reactor's context.
* @param timeoutPolicy Timeout policy for gRPC call.
* @param client GRPC client for Dapr.
* @return Client after adding interceptors.
*/
private static DaprGrpc.DaprStub intercept(
ContextView context, TimeoutPolicy timeoutPolicy, DaprGrpc.DaprStub client) {
return DaprClientGrpcInterceptors.intercept(client, timeoutPolicy, context);
}

private <T> Mono<T> createMono(Consumer<StreamObserver<T>> consumer) {
return retryPolicy.apply(
Mono.create(sink -> DaprException.wrap(() -> consumer.accept(createStreamObserver(sink))).run()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void setup() throws IOException {
InProcessChannelBuilder.forName(serverName).directExecutor().build());

// Create a HelloWorldClient using the in-process channel;
client = new DaprClientImpl(DaprGrpc.newStub(channel), null);
client = new DaprClientImpl(DaprGrpc.newStub(channel), null, null);
}

@Test
Expand Down
33 changes: 25 additions & 8 deletions sdk-tests/src/test/java/io/dapr/it/DaprRun.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
import org.apache.commons.lang3.tuple.ImmutablePair;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
Expand Down Expand Up @@ -68,19 +71,34 @@ public class DaprRun implements Stoppable {

private final boolean hasAppHealthCheck;

private final Map<Property<?>, String> propertyOverrides;

private DaprRun(String testName,
DaprPorts ports,
String successMessage,
Class serviceClass,
int maxWaitMilliseconds,
AppRun.AppProtocol appProtocol) {
this(testName, ports, successMessage, serviceClass, maxWaitMilliseconds, appProtocol, UUID.randomUUID().toString());
}

private DaprRun(String testName,
DaprPorts ports,
String successMessage,
Class serviceClass,
int maxWaitMilliseconds,
AppRun.AppProtocol appProtocol,
String daprApiToken) {
// The app name needs to be deterministic since we depend on it to kill previous runs.
this.appName = serviceClass == null ?
testName.toLowerCase() :
String.format("%s-%s", testName, serviceClass.getSimpleName()).toLowerCase();
this.appProtocol = appProtocol;
this.startCommand =
new Command(successMessage, buildDaprCommand(this.appName, serviceClass, ports, appProtocol));
new Command(
successMessage,
buildDaprCommand(this.appName, serviceClass, ports, appProtocol),
Map.of("DAPR_API_TOKEN", daprApiToken));
this.listCommand = new Command(
this.appName,
"dapr list");
Expand All @@ -91,6 +109,9 @@ private DaprRun(String testName,
this.maxWaitMilliseconds = maxWaitMilliseconds;
this.started = new AtomicBoolean(false);
this.hasAppHealthCheck = isAppHealthCheckEnabled(serviceClass);
this.propertyOverrides = Collections.unmodifiableMap(new HashMap<>(ports.getPropertyOverrides()) {{
put(Properties.API_TOKEN, daprApiToken);
}});
}

public void start() throws InterruptedException, IOException {
Expand Down Expand Up @@ -149,7 +170,7 @@ public void stop() throws InterruptedException, IOException {
}

public Map<Property<?>, String> getPropertyOverrides() {
return this.ports.getPropertyOverrides();
return this.propertyOverrides;
}

public DaprClientBuilder newDaprClientBuilder() {
Expand Down Expand Up @@ -239,17 +260,13 @@ public String getAppName() {

public DaprClient newDaprClient() {
return new DaprClientBuilder()
.withPropertyOverride(Properties.GRPC_PORT, ports.getGrpcPort().toString())
.withPropertyOverride(Properties.HTTP_PORT, ports.getHttpPort().toString())
.withPropertyOverride(Properties.SIDECAR_IP, "127.0.0.1")
.withPropertyOverrides(this.getPropertyOverrides())
.build();
}

public DaprPreviewClient newDaprPreviewClient() {
return new DaprClientBuilder()
.withPropertyOverride(Properties.GRPC_PORT, ports.getGrpcPort().toString())
.withPropertyOverride(Properties.HTTP_PORT, ports.getHttpPort().toString())
.withPropertyOverride(Properties.SIDECAR_IP, "127.0.0.1")
.withPropertyOverrides(this.getPropertyOverrides())
.buildPreviewClient();
}

Expand Down
3 changes: 2 additions & 1 deletion sdk/src/main/java/io/dapr/client/DaprClientBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ private DaprClientImpl buildDaprClient() {
daprHttp,
this.objectSerializer,
this.stateSerializer,
this.resiliencyOptions);
this.resiliencyOptions,
properties.getValue(Properties.API_TOKEN));
}
}
84 changes: 67 additions & 17 deletions sdk/src/main/java/io/dapr/client/DaprClientImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,6 @@ public class DaprClientImpl extends AbstractDaprClient {
*/
private final GrpcChannelFacade channel;

/**
* The timeout policy.
*/
private final TimeoutPolicy timeoutPolicy;

/**
* The retry policy.
*/
Expand All @@ -141,9 +136,10 @@ public class DaprClientImpl extends AbstractDaprClient {
*/
private final DaprHttp httpClient;

private final DaprClientGrpcInterceptors grpcInterceptors;

/**
* Default access level constructor, in order to create an instance of this
* class use io.dapr.client.DaprClientBuilder
* Default access level constructor, in order to create an instance of this class use io.dapr.client.DaprClientBuilder
*
* @param channel Facade for the managed GRPC channel
* @param asyncStub async gRPC stub
Expand All @@ -157,7 +153,27 @@ public class DaprClientImpl extends AbstractDaprClient {
DaprHttp httpClient,
DaprObjectSerializer objectSerializer,
DaprObjectSerializer stateSerializer) {
this(channel, asyncStub, httpClient, objectSerializer, stateSerializer, null);
this(channel, asyncStub, httpClient, objectSerializer, stateSerializer, null, null);
}

/**
* Default access level constructor, in order to create an instance of this class use io.dapr.client.DaprClientBuilder
*
* @param channel Facade for the managed GRPC channel
* @param asyncStub async gRPC stub
* @param objectSerializer Serializer for transient request/response objects.
* @param stateSerializer Serializer for state objects.
* @param daprApiToken Dapr API Token.
* @see DaprClientBuilder
*/
DaprClientImpl(
GrpcChannelFacade channel,
DaprGrpc.DaprStub asyncStub,
DaprHttp httpClient,
DaprObjectSerializer objectSerializer,
DaprObjectSerializer stateSerializer,
String daprApiToken) {
this(channel, asyncStub, httpClient, objectSerializer, stateSerializer, null, daprApiToken);
}

/**
Expand All @@ -169,6 +185,7 @@ public class DaprClientImpl extends AbstractDaprClient {
* @param objectSerializer Serializer for transient request/response objects.
* @param stateSerializer Serializer for state objects.
* @param resiliencyOptions Client-level override for resiliency options.
* @param daprApiToken Dapr API Token.
* @see DaprClientBuilder
*/
DaprClientImpl(
Expand All @@ -177,15 +194,47 @@ public class DaprClientImpl extends AbstractDaprClient {
DaprHttp httpClient,
DaprObjectSerializer objectSerializer,
DaprObjectSerializer stateSerializer,
ResiliencyOptions resiliencyOptions) {
ResiliencyOptions resiliencyOptions,
String daprApiToken) {
this(
channel,
asyncStub,
httpClient,
objectSerializer,
stateSerializer,
new TimeoutPolicy(resiliencyOptions == null ? null : resiliencyOptions.getTimeout()),
new RetryPolicy(resiliencyOptions == null ? null : resiliencyOptions.getMaxRetries()),
daprApiToken);
}

/**
* Instantiates a new DaprClient.
*
* @param channel Facade for the managed GRPC channel
* @param asyncStub async gRPC stub
* @param httpClient client for http service invocation
* @param objectSerializer Serializer for transient request/response objects.
* @param stateSerializer Serializer for state objects.
* @param timeoutPolicy Client-level timeout policy.
* @param retryPolicy Client-level retry policy.
* @param daprApiToken Dapr API Token.
* @see DaprClientBuilder
*/
private DaprClientImpl(
GrpcChannelFacade channel,
DaprGrpc.DaprStub asyncStub,
DaprHttp httpClient,
DaprObjectSerializer objectSerializer,
DaprObjectSerializer stateSerializer,
TimeoutPolicy timeoutPolicy,
RetryPolicy retryPolicy,
String daprApiToken) {
super(objectSerializer, stateSerializer);
this.channel = channel;
this.asyncStub = asyncStub;
this.httpClient = httpClient;
this.timeoutPolicy = new TimeoutPolicy(
resiliencyOptions == null ? null : resiliencyOptions.getTimeout());
this.retryPolicy = new RetryPolicy(
resiliencyOptions == null ? null : resiliencyOptions.getMaxRetries());
this.retryPolicy = retryPolicy;
this.grpcInterceptors = new DaprClientGrpcInterceptors(daprApiToken, timeoutPolicy);
}

private CommonProtos.StateOptions.StateConsistency getGrpcStateConsistency(StateOptions options) {
Expand Down Expand Up @@ -215,7 +264,7 @@ private CommonProtos.StateOptions.StateConcurrency getGrpcStateConcurrency(State
*/
public <T extends AbstractStub<T>> T newGrpcStub(String appId, Function<Channel, T> stubBuilder) {
// Adds Dapr interceptors to populate gRPC metadata automatically.
return DaprClientGrpcInterceptors.intercept(appId, stubBuilder.apply(this.channel.getGrpcChannel()), timeoutPolicy);
return this.grpcInterceptors.intercept(appId, stubBuilder.apply(this.channel.getGrpcChannel()));
}

/**
Expand Down Expand Up @@ -425,7 +474,8 @@ private <T> Subscription<T> buildSubscription(
SubscriptionListener<T> listener,
TypeRef<T> type,
DaprProtos.SubscribeTopicEventsRequestAlpha1 request) {
Subscription<T> subscription = new Subscription<>(this.asyncStub, request, listener, response -> {
var interceptedStub = this.grpcInterceptors.intercept(this.asyncStub);
Subscription<T> subscription = new Subscription<>(interceptedStub, request, listener, response -> {
if (response.getEventMessage() == null) {
return null;
}
Expand Down Expand Up @@ -1268,7 +1318,7 @@ private ConfigurationItem buildConfigurationItem(
* @return Client after adding interceptors.
*/
private DaprGrpc.DaprStub intercept(ContextView context, DaprGrpc.DaprStub client) {
return DaprClientGrpcInterceptors.intercept(client, this.timeoutPolicy, context);
return this.grpcInterceptors.intercept(client, context);
}

/**
Expand All @@ -1281,7 +1331,7 @@ private DaprGrpc.DaprStub intercept(ContextView context, DaprGrpc.DaprStub clien
*/
private DaprGrpc.DaprStub intercept(
ContextView context, DaprGrpc.DaprStub client, Consumer<Metadata> metadataConsumer) {
return DaprClientGrpcInterceptors.intercept(client, this.timeoutPolicy, context, metadataConsumer);
return this.grpcInterceptors.intercept(client, context, metadataConsumer);
}

private <T> Mono<T> createMono(Consumer<StreamObserver<T>> consumer) {
Expand Down
Loading

0 comments on commit 5958266

Please sign in to comment.