Skip to content

Commit

Permalink
pw_rpc: Do not hold the Endpoint lock while executing user code
Browse files Browse the repository at this point in the history
Push call object updates to a queue (to maintain order) and then execute
them after releasing the Endpoint's intrinsic lock.

Change-Id: If4836baccccc39457c946372ee57c2a0baf7780c
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/129773
Reviewed-by: Alexei Frolov <frolv@google.com>
Commit-Queue: Auto-Submit <auto-submit@pigweed.google.com.iam.gserviceaccount.com>
Pigweed-Auto-Submit: Wyatt Hepler <hepler@google.com>
  • Loading branch information
255 authored and CQ Bot Account committed Feb 24, 2023
1 parent 2819376 commit 1f85b26
Showing 1 changed file with 80 additions and 37 deletions.
117 changes: 80 additions & 37 deletions pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
Expand All @@ -30,12 +32,19 @@
*
* The RPC endpoint handles all RPC-related events and actions. It synchronizes interactions between
* the endpoint and any threads interacting with RPC call objects.
*
* The Endpoint's intrinsic lock is held when updating the channels or pending calls lists. Call
* objects only make updates to their own state through function calls made from the Endpoint, which
* ensures their states are also guarded by the Endpoint's lock. Updates to call objects are
* enqueued while the lock is held and processed after releasing the lock. This ensures updates
* occur in order without needing to hold the Endpoint's lock while possibly executing user code.
*/
class Endpoint {
private static final Logger logger = Logger.forClass(Endpoint.class);

private final Map<Integer, Channel> channels;
private final Map<PendingRpc, AbstractCall<?, ?>> pending = new HashMap<>();
private final BlockingQueue<Runnable> callUpdates = new LinkedBlockingQueue<>();

public Endpoint(List<Channel> channels) {
this.channels = channels.stream().collect(Collectors.toMap(Channel::id, c -> c));
Expand Down Expand Up @@ -102,24 +111,52 @@ private void registerCall(AbstractCall<?, ?> call) {
pending.put(call.rpc(), call);
}

/** Enqueues call object updates to make after release the Endpoint's lock. */
private void enqueueCallUpdate(Runnable callUpdate) {
while (!callUpdates.add(callUpdate)) {
// Retry until added successfully
}
}

/** Processes all enqueued call updates; the lock must NOT be held when this is called. */
private void processCallUpdates() {
while (true) {
Runnable callUpdate = callUpdates.poll();
if (callUpdate == null) {
break;
}
callUpdate.run();
}
}

/** Cancels an ongoing RPC */
public synchronized boolean cancel(AbstractCall<?, ?> call) throws ChannelOutputException {
if (pending.remove(call.rpc()) == null) {
return false;
public boolean cancel(AbstractCall<?, ?> call) throws ChannelOutputException {
try {
synchronized (this) {
if (pending.remove(call.rpc()) == null) {
return false;
}

enqueueCallUpdate(() -> call.handleError(Status.CANCELLED));
call.sendPacket(Packets.cancel(call.rpc()));
}
} finally {
logger.atFiner().log("Cancelling %s", call);
processCallUpdates();
}
logger.atFiner().log("Cancelling %s", call);
call.handleError(Status.CANCELLED);
call.sendPacket(Packets.cancel(call.rpc()));
return true;
}

/** Cancels an ongoing RPC without sending a cancellation packet. */
public synchronized boolean abandon(AbstractCall<?, ?> call) {
if (pending.remove(call.rpc()) == null) {
return false;
public boolean abandon(AbstractCall<?, ?> call) {
synchronized (this) {
if (pending.remove(call.rpc()) == null) {
return false;
}
enqueueCallUpdate(() -> call.handleError(Status.CANCELLED));
}
logger.atFiner().log("Abandoning %s", call);
call.handleError(Status.CANCELLED);
processCallUpdates();
return true;
}

Expand Down Expand Up @@ -148,14 +185,16 @@ public synchronized void openChannel(Channel channel) {
}
}

public synchronized boolean closeChannel(int id) {
if (channels.remove(id) == null) {
return false;
public boolean closeChannel(int id) {
synchronized (this) {
if (channels.remove(id) == null) {
return false;
}
pending.values().stream().filter(call -> call.getChannelId() == id).forEach(call -> {
enqueueCallUpdate(() -> call.handleError(Status.ABORTED));
});
}
pending.values()
.stream()
.filter(call -> call.getChannelId() == id)
.forEach(call -> call.handleError(Status.ABORTED));
processCallUpdates();
return true;
}

Expand All @@ -164,8 +203,8 @@ private boolean handleNext(PendingRpc rpc, ByteString payload) {
if (call == null) {
return false;
}
call.handleNext(payload);
logger.atFiner().log("%s received server stream with %d B payload", call, payload.size());
enqueueCallUpdate(() -> call.handleNext(payload));
return true;
}

Expand All @@ -174,9 +213,9 @@ private boolean handleUnaryCompleted(PendingRpc rpc, ByteString payload, Status
if (call == null) {
return false;
}
call.handleUnaryCompleted(payload, status);
logger.atFiner().log(
"%s completed with status %s and %d B payload", call, status, payload.size());
enqueueCallUpdate(() -> call.handleUnaryCompleted(payload, status));
return true;
}

Expand All @@ -185,8 +224,8 @@ private boolean handleStreamCompleted(PendingRpc rpc, Status status) {
if (call == null) {
return false;
}
call.handleStreamCompleted(status);
logger.atFiner().log("%s completed with status %s", call, status);
enqueueCallUpdate(() -> call.handleStreamCompleted(status));
return true;
}

Expand All @@ -195,31 +234,35 @@ private boolean handleError(PendingRpc rpc, Status status) {
if (call == null) {
return false;
}
call.handleError(status);
logger.atFiner().log("%s failed with error %s", call, status);
enqueueCallUpdate(() -> call.handleError(status));
return true;
}

public synchronized boolean processClientPacket(@Nullable Method method, RpcPacket packet) {
Channel channel = channels.get(packet.getChannelId());
if (channel == null) {
logger.atWarning().log("Received packet for unrecognized channel %d", packet.getChannelId());
return false;
}
public boolean processClientPacket(@Nullable Method method, RpcPacket packet) {
synchronized (this) {
Channel channel = channels.get(packet.getChannelId());
if (channel == null) {
logger.atWarning().log(
"Received packet for unrecognized channel %d", packet.getChannelId());
return false;
}

if (method == null) {
logger.atFine().log("Ignoring packet for unknown service method");
sendError(channel, packet, Status.NOT_FOUND);
return true; // true since the packet was handled, even though it was invalid.
}
if (method == null) {
logger.atFine().log("Ignoring packet for unknown service method");
sendError(channel, packet, Status.NOT_FOUND);
return true; // true since the packet was handled, even though it was invalid.
}

PendingRpc rpc = PendingRpc.create(channel, method);
if (!updateCall(packet, rpc)) {
logger.atFine().log("Ignoring packet for %s, which isn't pending", rpc);
sendError(channel, packet, Status.FAILED_PRECONDITION);
return true;
PendingRpc rpc = PendingRpc.create(channel, method);
if (!updateCall(packet, rpc)) {
logger.atFine().log("Ignoring packet for %s, which isn't pending", rpc);
sendError(channel, packet, Status.FAILED_PRECONDITION);
return true;
}
}

processCallUpdates();
return true;
}

Expand Down

0 comments on commit 1f85b26

Please sign in to comment.