diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index d786877c51d0cc..5a183272d719df 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -79,8 +80,13 @@ void ExchangeContext::SetResponseTimeout(Timeout timeout) CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgType, PacketBufferHandle && msgBuf, const SendFlags & sendFlags) { - // If we were waiting for a message send, this is it. - mFlags.Clear(Flags::kFlagWillSendMessage); + if (protocolId != Protocols::SecureChannel::Id || msgType != to_underlying(Protocols::SecureChannel::MsgType::StandaloneAck)) + { + // If we were waiting for a message send, this is it. Standalone acks + // are not application-level sends, which is why we don't allow those to + // clear the WillSendMessage flag. + mFlags.Clear(Flags::kFlagWillSendMessage); + } CHIP_ERROR err = CHIP_NO_ERROR; Transport::PeerConnectionState * state = nullptr; diff --git a/src/protocols/Protocols.h b/src/protocols/Protocols.h index deeec5b4e6aae4..1cc93f6a7d1b79 100644 --- a/src/protocols/Protocols.h +++ b/src/protocols/Protocols.h @@ -42,6 +42,8 @@ class Id return mVendorId == aOther.mVendorId && mProtocolId == aOther.mProtocolId; } + constexpr bool operator!=(const Id & aOther) const { return !(*this == aOther); } + // Convert the Protocols::Id to a TLV profile id. // NOTE: We may want to change the TLV reader/writer to take Protocols::Id // directly later on and get rid of this method.