diff --git a/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java b/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java index 810f1fc90e..974a13204e 100644 --- a/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java +++ b/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java @@ -21,6 +21,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; import java.util.function.BiFunction; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -30,12 +32,19 @@ * * The RPC endpoint handles all RPC-related events and actions. It synchronizes interactions between * the endpoint and any threads interacting with RPC call objects. + * + * The Endpoint's intrinsic lock is held when updating the channels or pending calls lists. Call + * objects only make updates to their own state through function calls made from the Endpoint, which + * ensures their states are also guarded by the Endpoint's lock. Updates to call objects are + * enqueued while the lock is held and processed after releasing the lock. This ensures updates + * occur in order without needing to hold the Endpoint's lock while possibly executing user code. */ class Endpoint { private static final Logger logger = Logger.forClass(Endpoint.class); private final Map channels; private final Map> pending = new HashMap<>(); + private final BlockingQueue callUpdates = new LinkedBlockingQueue<>(); public Endpoint(List channels) { this.channels = channels.stream().collect(Collectors.toMap(Channel::id, c -> c)); @@ -102,24 +111,52 @@ private void registerCall(AbstractCall call) { pending.put(call.rpc(), call); } + /** Enqueues call object updates to make after release the Endpoint's lock. */ + private void enqueueCallUpdate(Runnable callUpdate) { + while (!callUpdates.add(callUpdate)) { + // Retry until added successfully + } + } + + /** Processes all enqueued call updates; the lock must NOT be held when this is called. */ + private void processCallUpdates() { + while (true) { + Runnable callUpdate = callUpdates.poll(); + if (callUpdate == null) { + break; + } + callUpdate.run(); + } + } + /** Cancels an ongoing RPC */ - public synchronized boolean cancel(AbstractCall call) throws ChannelOutputException { - if (pending.remove(call.rpc()) == null) { - return false; + public boolean cancel(AbstractCall call) throws ChannelOutputException { + try { + synchronized (this) { + if (pending.remove(call.rpc()) == null) { + return false; + } + + enqueueCallUpdate(() -> call.handleError(Status.CANCELLED)); + call.sendPacket(Packets.cancel(call.rpc())); + } + } finally { + logger.atFiner().log("Cancelling %s", call); + processCallUpdates(); } - logger.atFiner().log("Cancelling %s", call); - call.handleError(Status.CANCELLED); - call.sendPacket(Packets.cancel(call.rpc())); return true; } /** Cancels an ongoing RPC without sending a cancellation packet. */ - public synchronized boolean abandon(AbstractCall call) { - if (pending.remove(call.rpc()) == null) { - return false; + public boolean abandon(AbstractCall call) { + synchronized (this) { + if (pending.remove(call.rpc()) == null) { + return false; + } + enqueueCallUpdate(() -> call.handleError(Status.CANCELLED)); } logger.atFiner().log("Abandoning %s", call); - call.handleError(Status.CANCELLED); + processCallUpdates(); return true; } @@ -148,14 +185,16 @@ public synchronized void openChannel(Channel channel) { } } - public synchronized boolean closeChannel(int id) { - if (channels.remove(id) == null) { - return false; + public boolean closeChannel(int id) { + synchronized (this) { + if (channels.remove(id) == null) { + return false; + } + pending.values().stream().filter(call -> call.getChannelId() == id).forEach(call -> { + enqueueCallUpdate(() -> call.handleError(Status.ABORTED)); + }); } - pending.values() - .stream() - .filter(call -> call.getChannelId() == id) - .forEach(call -> call.handleError(Status.ABORTED)); + processCallUpdates(); return true; } @@ -164,8 +203,8 @@ private boolean handleNext(PendingRpc rpc, ByteString payload) { if (call == null) { return false; } - call.handleNext(payload); logger.atFiner().log("%s received server stream with %d B payload", call, payload.size()); + enqueueCallUpdate(() -> call.handleNext(payload)); return true; } @@ -174,9 +213,9 @@ private boolean handleUnaryCompleted(PendingRpc rpc, ByteString payload, Status if (call == null) { return false; } - call.handleUnaryCompleted(payload, status); logger.atFiner().log( "%s completed with status %s and %d B payload", call, status, payload.size()); + enqueueCallUpdate(() -> call.handleUnaryCompleted(payload, status)); return true; } @@ -185,8 +224,8 @@ private boolean handleStreamCompleted(PendingRpc rpc, Status status) { if (call == null) { return false; } - call.handleStreamCompleted(status); logger.atFiner().log("%s completed with status %s", call, status); + enqueueCallUpdate(() -> call.handleStreamCompleted(status)); return true; } @@ -195,31 +234,35 @@ private boolean handleError(PendingRpc rpc, Status status) { if (call == null) { return false; } - call.handleError(status); logger.atFiner().log("%s failed with error %s", call, status); + enqueueCallUpdate(() -> call.handleError(status)); return true; } - public synchronized boolean processClientPacket(@Nullable Method method, RpcPacket packet) { - Channel channel = channels.get(packet.getChannelId()); - if (channel == null) { - logger.atWarning().log("Received packet for unrecognized channel %d", packet.getChannelId()); - return false; - } + public boolean processClientPacket(@Nullable Method method, RpcPacket packet) { + synchronized (this) { + Channel channel = channels.get(packet.getChannelId()); + if (channel == null) { + logger.atWarning().log( + "Received packet for unrecognized channel %d", packet.getChannelId()); + return false; + } - if (method == null) { - logger.atFine().log("Ignoring packet for unknown service method"); - sendError(channel, packet, Status.NOT_FOUND); - return true; // true since the packet was handled, even though it was invalid. - } + if (method == null) { + logger.atFine().log("Ignoring packet for unknown service method"); + sendError(channel, packet, Status.NOT_FOUND); + return true; // true since the packet was handled, even though it was invalid. + } - PendingRpc rpc = PendingRpc.create(channel, method); - if (!updateCall(packet, rpc)) { - logger.atFine().log("Ignoring packet for %s, which isn't pending", rpc); - sendError(channel, packet, Status.FAILED_PRECONDITION); - return true; + PendingRpc rpc = PendingRpc.create(channel, method); + if (!updateCall(packet, rpc)) { + logger.atFine().log("Ignoring packet for %s, which isn't pending", rpc); + sendError(channel, packet, Status.FAILED_PRECONDITION); + return true; + } } + processCallUpdates(); return true; }