diff --git a/quic/api/QuicTransportBaseLite.cpp b/quic/api/QuicTransportBaseLite.cpp index 6cfae6eb6..6f793d0ec 100644 --- a/quic/api/QuicTransportBaseLite.cpp +++ b/quic/api/QuicTransportBaseLite.cpp @@ -417,69 +417,21 @@ void QuicTransportBaseLite::cancelDeliveryCallbacksForStream( void QuicTransportBaseLite::cancelByteEventCallbacksForStream( const StreamId id, - const Optional& offset) { - invokeForEachByteEventType(([this, id, &offset](const ByteEvent::Type type) { - cancelByteEventCallbacksForStream(type, id, offset); - })); + const Optional& offsetUpperBound) { + invokeForEachByteEventType( + ([this, id, &offsetUpperBound](const ByteEvent::Type type) { + cancelByteEventCallbacksForStream(type, id, offsetUpperBound); + })); } void QuicTransportBaseLite::cancelByteEventCallbacksForStream( const ByteEvent::Type type, const StreamId id, - const Optional& offset) { - if (isReceivingStream(conn_->nodeType, id)) { - return; - } - - auto& byteEventMap = getByteEventMap(type); - auto byteEventMapIt = byteEventMap.find(id); - if (byteEventMapIt == byteEventMap.end()) { - switch (type) { - case ByteEvent::Type::ACK: - conn_->streamManager->removeDeliverable(id); - break; - case ByteEvent::Type::TX: - conn_->streamManager->removeTx(id); - break; - } - return; - } - auto& streamByteEvents = byteEventMapIt->second; - - // Callbacks are kept sorted by offset, so we can just walk the queue and - // invoke those with offset below provided offset. - while (!streamByteEvents.empty()) { - // decomposition not supported for xplat - const auto cbOffset = streamByteEvents.front().offset; - const auto callback = streamByteEvents.front().callback; - if (!offset.has_value() || cbOffset < *offset) { - streamByteEvents.pop_front(); - ByteEventCancellation cancellation{id, cbOffset, type}; - callback->onByteEventCanceled(cancellation); - if (closeState_ != CloseState::OPEN) { - // socket got closed - we can't use streamByteEvents anymore, - // closeImpl should take care of cleaning up any remaining callbacks - return; - } - } else { - // Only larger or equal offsets left, exit the loop. - break; - } - } - - // Clean up state for this stream if no callbacks left to invoke. - if (streamByteEvents.empty()) { - switch (type) { - case ByteEvent::Type::ACK: - conn_->streamManager->removeDeliverable(id); - break; - case ByteEvent::Type::TX: - conn_->streamManager->removeTx(id); - break; - } - // The callback could have changed the map so erase by id. - byteEventMap.erase(id); - } + const Optional& offsetUpperBound) { + cancelByteEventCallbacksForStreamInternal( + type, id, [&offsetUpperBound](uint64_t cbOffset) { + return !offsetUpperBound || cbOffset < *offsetUpperBound; + }); } folly::Expected @@ -1682,6 +1634,65 @@ QuicTransportBaseLite::resetStreamInternal( return folly::unit; } +void QuicTransportBaseLite::cancelByteEventCallbacksForStreamInternal( + const ByteEvent::Type type, + const StreamId id, + const std::function& offsetFilter) { + if (isReceivingStream(conn_->nodeType, id)) { + return; + } + + auto& byteEventMap = getByteEventMap(type); + auto byteEventMapIt = byteEventMap.find(id); + if (byteEventMapIt == byteEventMap.end()) { + switch (type) { + case ByteEvent::Type::ACK: + conn_->streamManager->removeDeliverable(id); + break; + case ByteEvent::Type::TX: + conn_->streamManager->removeTx(id); + break; + } + return; + } + auto& streamByteEvents = byteEventMapIt->second; + + // Callbacks are kept sorted by offset, so we can just walk the queue and + // invoke those with offset below provided offset. + while (!streamByteEvents.empty()) { + // decomposition not supported for xplat + const auto cbOffset = streamByteEvents.front().offset; + const auto callback = streamByteEvents.front().callback; + if (offsetFilter(cbOffset)) { + streamByteEvents.pop_front(); + ByteEventCancellation cancellation{id, cbOffset, type}; + callback->onByteEventCanceled(cancellation); + if (closeState_ != CloseState::OPEN) { + // socket got closed - we can't use streamByteEvents anymore, + // closeImpl should take care of cleaning up any remaining callbacks + return; + } + } else { + // Only larger or equal offsets left, exit the loop. + break; + } + } + + // Clean up state for this stream if no callbacks left to invoke. + if (streamByteEvents.empty()) { + switch (type) { + case ByteEvent::Type::ACK: + conn_->streamManager->removeDeliverable(id); + break; + case ByteEvent::Type::TX: + conn_->streamManager->removeTx(id); + break; + } + // The callback could have changed the map so erase by id. + byteEventMap.erase(id); + } +} + void QuicTransportBaseLite::onSocketWritable() noexcept { // Remove the writable callback. socket_->pauseWrite(); diff --git a/quic/api/QuicTransportBaseLite.h b/quic/api/QuicTransportBaseLite.h index 77aac4a4d..1cc642357 100644 --- a/quic/api/QuicTransportBaseLite.h +++ b/quic/api/QuicTransportBaseLite.h @@ -106,22 +106,22 @@ class QuicTransportBaseLite : virtual public QuicSocketLite, * Cancel byte event callbacks for given stream. * * If an offset is provided, cancels only callbacks with an offset less than - * or equal to the provided offset, otherwise cancels all callbacks. + * the provided offset, otherwise cancels all callbacks. */ void cancelByteEventCallbacksForStream( const StreamId id, - const Optional& offset = none) override; + const Optional& offsetUpperBound = none) override; /** * Cancel byte event callbacks for given type and stream. * * If an offset is provided, cancels only callbacks with an offset less than - * or equal to the provided offset, otherwise cancels all callbacks. + * the provided offset, otherwise cancels all callbacks. */ void cancelByteEventCallbacksForStream( const ByteEvent::Type type, const StreamId id, - const Optional& offset = none) override; + const Optional& offsetUpperBound = none) override; /** * Register a byte event to be triggered when specified event type occurs for @@ -618,6 +618,12 @@ class QuicTransportBaseLite : virtual public QuicSocketLite, StreamId id, ApplicationErrorCode errorCode); + // Only remove byte event callbacks if offsetFilter returns true. + void cancelByteEventCallbacksForStreamInternal( + const ByteEvent::Type type, + const StreamId id, + const std::function& offsetFilter); + void onSocketWritable() noexcept override; void handleNewStreamCallbacks(std::vector& newPeerStreams);