From 12016328929b29d64d42bb2aada9cbb9c4721350 Mon Sep 17 00:00:00 2001 From: Jerry Johns Date: Mon, 25 Jul 2022 11:42:12 -0700 Subject: [PATCH] Move CommandSender/Handler and WriteClient/Handler over to `ExchangeHolder` (#21081) * Move CommandSender/Handler and WriteClient/Handler over to using the new but safer, ExchangeHolder way of EC management. * Update src/app/CommandHandler.h --- src/app/CommandHandler.cpp | 77 +++++------------------- src/app/CommandHandler.h | 46 +++++++++----- src/app/CommandSender.cpp | 66 ++++++-------------- src/app/CommandSender.h | 21 +------ src/app/WriteClient.cpp | 67 ++++++--------------- src/app/WriteClient.h | 21 +++---- src/app/WriteHandler.cpp | 57 ++++++------------ src/app/WriteHandler.h | 8 ++- src/app/tests/TestCommandInteraction.cpp | 18 +++--- src/app/tests/TestWriteInteraction.cpp | 9 --- src/messaging/ExchangeHolder.h | 2 + 11 files changed, 122 insertions(+), 270 deletions(-) diff --git a/src/app/CommandHandler.cpp b/src/app/CommandHandler.cpp index 2fc2dc55c4d850..d17c4f36cd33e1 100644 --- a/src/app/CommandHandler.cpp +++ b/src/app/CommandHandler.cpp @@ -40,7 +40,7 @@ namespace chip { namespace app { -CommandHandler::CommandHandler(Callback * apCallback) : mpCallback(apCallback), mSuppressResponse(false) {} +CommandHandler::CommandHandler(Callback * apCallback) : mExchangeCtx(*this), mpCallback(apCallback), mSuppressResponse(false) {} CHIP_ERROR CommandHandler::AllocateBuffer() { @@ -73,7 +73,7 @@ CHIP_ERROR CommandHandler::OnInvokeCommandRequest(Messaging::ExchangeContext * e // NOTE: we already know this is an InvokeCommand Request message because we explicitly registered with the // Exchange Manager for unsolicited InvokeCommand Requests. - mpExchangeCtx = ec; + mExchangeCtx.Grab(ec); // Use the RAII feature, if this is the only Handle when this function returns, DecrementHoldOff will trigger sending response. // TODO: This is broken! If something under here returns error, we will try @@ -81,7 +81,7 @@ CHIP_ERROR CommandHandler::OnInvokeCommandRequest(Messaging::ExchangeContext * e // response too. Figure out at what point it's our responsibility to // handler errors vs our caller's. Handle workHandle(this); - mpExchangeCtx->WillSendMessage(); + mExchangeCtx->WillSendMessage(); ReturnErrorOnFailure(ProcessInvokeRequest(std::move(payload), isTimedInvoke)); return CHIP_NO_ERROR; @@ -103,26 +103,19 @@ CHIP_ERROR CommandHandler::ProcessInvokeRequest(System::PacketBufferHandle && pa ReturnErrorOnFailure(invokeRequestMessage.GetTimedRequest(&mTimedRequest)); ReturnErrorOnFailure(invokeRequestMessage.GetInvokeRequests(&invokeRequests)); - VerifyOrReturnError(mpExchangeCtx != nullptr, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(mExchangeCtx, CHIP_ERROR_INCORRECT_STATE); if (mTimedRequest != isTimedInvoke) { // The message thinks it should be part of a timed interaction but it's // not, or vice versa. Spec says to Respond with UNSUPPORTED_ACCESS. - err = StatusResponse::Send(Protocols::InteractionModel::Status::UnsupportedAccess, mpExchangeCtx, + err = StatusResponse::Send(Protocols::InteractionModel::Status::UnsupportedAccess, mExchangeCtx.Get(), /* aExpectResponse = */ false); - - if (err != CHIP_NO_ERROR) + if (err == CHIP_NO_ERROR) { - // We have to manually close the exchange, because we called - // WillSendMessage already. - mpExchangeCtx->Close(); + mSentStatusResponse = true; } - // Null out the (now-closed) exchange, so that when we try to - // SendCommandResponse() later (when our holdoff count drops to 0) it - // just fails and we don't double-respond. - mpExchangeCtx = nullptr; return err; } @@ -142,7 +135,7 @@ CHIP_ERROR CommandHandler::ProcessInvokeRequest(System::PacketBufferHandle && pa CommandDataIB::Parser commandData; ReturnErrorOnFailure(commandData.Init(invokeRequestsReader)); - if (mpExchangeCtx->IsGroupExchangeContext()) + if (mExchangeCtx->IsGroupExchangeContext()) { ReturnErrorOnFailure(ProcessGroupCommandDataIB(commandData)); } @@ -172,18 +165,6 @@ void CommandHandler::Close() VerifyOrDieWithMsg(mPendingWork == 0, DataManagement, "CommandHandler::Close() called with %u unfinished async work items", static_cast(mPendingWork)); - // OnDone below can destroy us before we unwind all the way back into the - // exchange code and it tries to close itself. Make sure that it doesn't - // try to notify us that it's closing, since we will be dead. - // - // For more details, see #10344. - if (mpExchangeCtx != nullptr) - { - mpExchangeCtx->SetDelegate(nullptr); - } - - mpExchangeCtx = nullptr; - if (mpCallback) { mpCallback->OnDone(*this); @@ -205,21 +186,12 @@ void CommandHandler::DecrementHoldOff() return; } - if (mpExchangeCtx->IsGroupExchangeContext()) - { - mpExchangeCtx->Close(); - } - else + if (!mExchangeCtx->IsGroupExchangeContext() && !mSentStatusResponse) { CHIP_ERROR err = SendCommandResponse(); if (err != CHIP_NO_ERROR) { ChipLogError(DataManagement, "Failed to send command response: %" CHIP_ERROR_FORMAT, err.Format()); - // We marked the exchange as "WillSendMessage", need to shutdown the exchange manually to avoid leaking exchanges. - if (mpExchangeCtx != nullptr) - { - mpExchangeCtx->Close(); - } } } @@ -232,11 +204,11 @@ CHIP_ERROR CommandHandler::SendCommandResponse() VerifyOrReturnError(mPendingWork == 0, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(mState == State::AddedCommand, CHIP_ERROR_INCORRECT_STATE); - VerifyOrReturnError(mpExchangeCtx != nullptr, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(mExchangeCtx, CHIP_ERROR_INCORRECT_STATE); ReturnErrorOnFailure(Finalize(commandPacket)); ReturnErrorOnFailure( - mpExchangeCtx->SendMessage(Protocols::InteractionModel::MsgType::InvokeCommandResponse, std::move(commandPacket))); + mExchangeCtx->SendMessage(Protocols::InteractionModel::MsgType::InvokeCommandResponse, std::move(commandPacket))); // The ExchangeContext is automatically freed here, and it makes mpExchangeCtx be temporarily dangling, but in // all cases, we are going to call Close immediately after this function, which nulls out mpExchangeCtx. @@ -290,7 +262,7 @@ CHIP_ERROR CommandHandler::ProcessCommandDataIB(CommandDataIB::Parser & aCommand } } - VerifyOrExit(mpExchangeCtx != nullptr && mpExchangeCtx->HasSessionHandle(), err = CHIP_ERROR_INCORRECT_STATE); + VerifyOrExit(mExchangeCtx && mExchangeCtx->HasSessionHandle(), err = CHIP_ERROR_INCORRECT_STATE); { Access::SubjectDescriptor subjectDescriptor = GetSubjectDescriptor(); @@ -379,7 +351,7 @@ CHIP_ERROR CommandHandler::ProcessGroupCommandDataIB(CommandDataIB::Parser & aCo err = commandPath.GetCommandId(&commandId); SuccessOrExit(err); - groupId = mpExchangeCtx->GetSessionHandle()->AsIncomingGroupSession()->GetGroupId(); + groupId = mExchangeCtx->GetSessionHandle()->AsIncomingGroupSession()->GetGroupId(); fabric = GetAccessingFabricIndex(); ChipLogDetail(DataManagement, "Received group command for Group=%u Cluster=" ChipLogFormatMEI " Command=" ChipLogFormatMEI, @@ -596,7 +568,7 @@ TLV::TLVWriter * CommandHandler::GetCommandDataIBTLVWriter() FabricIndex CommandHandler::GetAccessingFabricIndex() const { - return mpExchangeCtx->GetSessionHandle()->GetFabricIndex(); + return mExchangeCtx->GetSessionHandle()->GetFabricIndex(); } CommandHandler * CommandHandler::Handle::Get() @@ -666,27 +638,6 @@ void CommandHandler::MoveToState(const State aTargetState) ChipLogDetail(DataManagement, "ICR moving to [%10.10s]", GetStateStr()); } -void CommandHandler::Abort() -{ - // - // If the exchange context hasn't already been gracefully closed - // (signaled by setting it to null), then we need to forcibly - // tear it down. - // - if (mpExchangeCtx != nullptr) - { - // We might be a delegate for this exchange, and we don't want the - // OnExchangeClosing notification in that case. Null out the delegate - // to avoid that. - // - // TODO: This makes all sorts of assumptions about what the delegate is - // (notice the "might" above!) that might not hold in practice. We - // really need a better solution here.... - mpExchangeCtx->SetDelegate(nullptr); - mpExchangeCtx->Abort(); - mpExchangeCtx = nullptr; - } -} } // namespace app } // namespace chip diff --git a/src/app/CommandHandler.h b/src/app/CommandHandler.h index 200feee67a1347..da198cad0af6ef 100644 --- a/src/app/CommandHandler.h +++ b/src/app/CommandHandler.h @@ -33,7 +33,7 @@ #include #include #include -#include +#include #include #include #include @@ -46,16 +46,9 @@ namespace chip { namespace app { -class CommandHandler +class CommandHandler : public Messaging::ExchangeDelegate { public: - /* - * Destructor - as part of destruction, it will abort the exchange context - * if a valid one still exists. - * - * See Abort() for details on when that might occur. - */ - virtual ~CommandHandler() { Abort(); } class Callback { public: @@ -221,11 +214,15 @@ class CommandHandler /** * Gets the inner exchange context object, without ownership. * + * WARNING: This is dangerous, since it is directly interacting with the + * exchange being managed automatically by mExchangeCtx and + * if not done carefully, may end up with use-after-free errors. + * * @return The inner exchange context, might be nullptr if no * exchange context has been assigned or the context * has been released. */ - Messaging::ExchangeContext * GetExchangeContext() const { return mpExchangeCtx; } + Messaging::ExchangeContext * GetExchangeContext() const { return mExchangeCtx.Get(); } /** * @brief Flush acks right away for a slow command @@ -240,18 +237,35 @@ class CommandHandler */ void FlushAcksRightAwayOnSlowCommand() { - VerifyOrReturn(mpExchangeCtx != nullptr); - auto * msgContext = mpExchangeCtx->GetReliableMessageContext(); + VerifyOrReturn(mExchangeCtx); + auto * msgContext = mExchangeCtx->GetReliableMessageContext(); VerifyOrReturn(msgContext != nullptr); msgContext->FlushAcks(); } - Access::SubjectDescriptor GetSubjectDescriptor() const { return mpExchangeCtx->GetSessionHandle()->GetSubjectDescriptor(); } + Access::SubjectDescriptor GetSubjectDescriptor() const { return mExchangeCtx->GetSessionHandle()->GetSubjectDescriptor(); } private: friend class TestCommandInteraction; friend class CommandHandler::Handle; + CHIP_ERROR OnMessageReceived(Messaging::ExchangeContext * ec, const PayloadHeader & payloadHeader, + System::PacketBufferHandle && payload) override + { + // + // We shouldn't be receiving any further messages on this exchange. + // + return CHIP_ERROR_INCORRECT_STATE; + } + + void OnResponseTimeout(Messaging::ExchangeContext * ec) override + { + // + // We're not expecting responses to any messages we send out on this EC. + // + VerifyOrDie(false); + } + enum class State { Idle, ///< Default state that the object starts out in, where no work has commenced @@ -343,14 +357,16 @@ class CommandHandler return FinishCommand(/* aEndDataStruct = */ false); } - Messaging::ExchangeContext * mpExchangeCtx = nullptr; - Callback * mpCallback = nullptr; + Messaging::ExchangeHolder mExchangeCtx; + Callback * mpCallback = nullptr; InvokeResponseMessage::Builder mInvokeResponseBuilder; TLV::TLVType mDataElementContainerType = TLV::kTLVType_NotSpecified; size_t mPendingWork = 0; bool mSuppressResponse = false; bool mTimedRequest = false; + bool mSentStatusResponse = false; + State mState = State::Idle; chip::System::PacketBufferTLVWriter mCommandMessageWriter; TLV::TLVWriter mBackupWriter; diff --git a/src/app/CommandSender.cpp b/src/app/CommandSender.cpp index 63053751240eb7..803204dadaebcb 100644 --- a/src/app/CommandSender.cpp +++ b/src/app/CommandSender.cpp @@ -33,7 +33,8 @@ namespace chip { namespace app { CommandSender::CommandSender(Callback * apCallback, Messaging::ExchangeManager * apExchangeMgr, bool aIsTimedRequest) : - mpCallback(apCallback), mpExchangeMgr(apExchangeMgr), mSuppressResponse(false), mTimedRequest(aIsTimedRequest) + mExchangeCtx(*this), mpCallback(apCallback), mpExchangeMgr(apExchangeMgr), mSuppressResponse(false), + mTimedRequest(aIsTimedRequest) {} CHIP_ERROR CommandSender::AllocateBuffer() @@ -67,15 +68,17 @@ CHIP_ERROR CommandSender::SendCommandRequest(const SessionHandle & session, Opti ReturnErrorOnFailure(Finalize(mPendingInvokeData)); // Create a new exchange context. - mpExchangeCtx = mpExchangeMgr->NewContext(session, this); - VerifyOrReturnError(mpExchangeCtx != nullptr, CHIP_ERROR_NO_MEMORY); - VerifyOrReturnError(!mpExchangeCtx->IsGroupExchangeContext(), CHIP_ERROR_INVALID_MESSAGE_TYPE); + auto exchange = mpExchangeMgr->NewContext(session, this); + VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_NO_MEMORY); - mpExchangeCtx->SetResponseTimeout(timeout.ValueOr(session->ComputeRoundTripTimeout(app::kExpectedIMProcessingTime))); + mExchangeCtx.Grab(exchange); + VerifyOrReturnError(!mExchangeCtx->IsGroupExchangeContext(), CHIP_ERROR_INVALID_MESSAGE_TYPE); + + mExchangeCtx->SetResponseTimeout(timeout.ValueOr(session->ComputeRoundTripTimeout(app::kExpectedIMProcessingTime))); if (mTimedInvokeTimeoutMs.HasValue()) { - ReturnErrorOnFailure(TimedRequest::Send(mpExchangeCtx, mTimedInvokeTimeoutMs.Value())); + ReturnErrorOnFailure(TimedRequest::Send(mExchangeCtx.Get(), mTimedInvokeTimeoutMs.Value())); MoveToState(State::AwaitingTimedStatus); return CHIP_NO_ERROR; } @@ -90,14 +93,13 @@ CHIP_ERROR CommandSender::SendGroupCommandRequest(const SessionHandle & session) ReturnErrorOnFailure(Finalize(mPendingInvokeData)); // Create a new exchange context. - mpExchangeCtx = mpExchangeMgr->NewContext(session, this); - VerifyOrReturnError(mpExchangeCtx != nullptr, CHIP_ERROR_NO_MEMORY); - VerifyOrReturnError(mpExchangeCtx->IsGroupExchangeContext(), CHIP_ERROR_INVALID_MESSAGE_TYPE); + auto exchange = mpExchangeMgr->NewContext(session, this); + VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_NO_MEMORY); - ReturnErrorOnFailure(SendInvokeRequest()); + mExchangeCtx.Grab(exchange); + VerifyOrReturnError(mExchangeCtx->IsGroupExchangeContext(), CHIP_ERROR_INVALID_MESSAGE_TYPE); - // Exchange is gone now, since it closed itself on successful send. - mpExchangeCtx = nullptr; + ReturnErrorOnFailure(SendInvokeRequest()); Close(); return CHIP_NO_ERROR; @@ -108,8 +110,8 @@ CHIP_ERROR CommandSender::SendInvokeRequest() using namespace Protocols::InteractionModel; using namespace Messaging; - ReturnErrorOnFailure(mpExchangeCtx->SendMessage(MsgType::InvokeCommandRequest, std::move(mPendingInvokeData), - SendMessageFlags::kExpectResponse)); + ReturnErrorOnFailure( + mExchangeCtx->SendMessage(MsgType::InvokeCommandRequest, std::move(mPendingInvokeData), SendMessageFlags::kExpectResponse)); MoveToState(State::CommandSent); return CHIP_NO_ERROR; @@ -124,7 +126,7 @@ CHIP_ERROR CommandSender::OnMessageReceived(Messaging::ExchangeContext * apExcha } CHIP_ERROR err = CHIP_NO_ERROR; - VerifyOrExit(apExchangeContext == mpExchangeCtx, err = CHIP_ERROR_INCORRECT_STATE); + VerifyOrExit(apExchangeContext == mExchangeCtx.Get(), err = CHIP_ERROR_INCORRECT_STATE); if (mState == State::AwaitingTimedStatus) { @@ -223,18 +225,6 @@ void CommandSender::Close() mTimedRequest = false; MoveToState(State::AwaitingDestruction); - // OnDone below can destroy us before we unwind all the way back into the - // exchange code and it tries to close itself. Make sure that it doesn't - // try to notify us that it's closing, since we will be dead. - // - // For more details, see #10344. - if (mpExchangeCtx != nullptr) - { - mpExchangeCtx->SetDelegate(nullptr); - } - - mpExchangeCtx = nullptr; - if (mpCallback) { mpCallback->OnDone(this); @@ -443,27 +433,5 @@ void CommandSender::MoveToState(const State aTargetState) ChipLogDetail(DataManagement, "ICR moving to [%10.10s]", GetStateStr()); } -void CommandSender::Abort() -{ - // - // If the exchange context hasn't already been gracefully closed - // (signaled by setting it to null), then we need to forcibly - // tear it down. - // - if (mpExchangeCtx != nullptr) - { - // We might be a delegate for this exchange, and we don't want the - // OnExchangeClosing notification in that case. Null out the delegate - // to avoid that. - // - // TODO: This makes all sorts of assumptions about what the delegate is - // (notice the "might" above!) that might not hold in practice. We - // really need a better solution here.... - mpExchangeCtx->SetDelegate(nullptr); - mpExchangeCtx->Abort(); - mpExchangeCtx = nullptr; - } -} - } // namespace app } // namespace chip diff --git a/src/app/CommandSender.h b/src/app/CommandSender.h index f0f0f945c7f8a4..e0ad257d5722fb 100644 --- a/src/app/CommandSender.h +++ b/src/app/CommandSender.h @@ -38,7 +38,7 @@ #include #include #include -#include +#include #include #include #include @@ -53,23 +53,6 @@ namespace app { class CommandSender final : public Messaging::ExchangeDelegate { public: - /* - * Destructor - as part of destruction, it will abort the exchange context - * if a valid one still exists. - * - * See Abort() for details on when that might occur. - */ - ~CommandSender() override { Abort(); } - - /** - * Gets the inner exchange context object, without ownership. - * - * @return The inner exchange context, might be nullptr if no - * exchange context has been assigned or the context - * has been released. - */ - Messaging::ExchangeContext * GetExchangeContext() const { return mpExchangeCtx; } - class Callback { public: @@ -290,7 +273,7 @@ class CommandSender final : public Messaging::ExchangeDelegate CHIP_ERROR Finalize(System::PacketBufferHandle & commandPacket); - Messaging::ExchangeContext * mpExchangeCtx = nullptr; + Messaging::ExchangeHolder mExchangeCtx; Callback * mpCallback = nullptr; Messaging::ExchangeManager * mpExchangeMgr = nullptr; InvokeRequestMessage::Builder mInvokeRequestBuilder; diff --git a/src/app/WriteClient.cpp b/src/app/WriteClient.cpp index c17c63af372500..3e70a003d24446 100644 --- a/src/app/WriteClient.cpp +++ b/src/app/WriteClient.cpp @@ -35,46 +35,12 @@ void WriteClient::Close() { MoveToState(State::AwaitingDestruction); - // OnDone below can destroy us before we unwind all the way back into the - // exchange code and it tries to close itself. Make sure that it doesn't - // try to notify us that it's closing, since we will be dead. - // - // For more details, see #10344. - if (mpExchangeCtx != nullptr) - { - mpExchangeCtx->SetDelegate(nullptr); - } - - mpExchangeCtx = nullptr; - if (mpCallback) { mpCallback->OnDone(this); } } -void WriteClient::Abort() -{ - // - // If the exchange context hasn't already been gracefully closed - // (signaled by setting it to null), then we need to forcibly - // tear it down. - // - if (mpExchangeCtx != nullptr) - { - // We might be a delegate for this exchange, and we don't want the - // OnExchangeClosing notification in that case. Null out the delegate - // to avoid that. - // - // TODO: This makes all sorts of assumptions about what the delegate is - // (notice the "might" above!) that might not hold in practice. We - // really need a better solution here.... - mpExchangeCtx->SetDelegate(nullptr); - mpExchangeCtx->Abort(); - mpExchangeCtx = nullptr; - } -} - CHIP_ERROR WriteClient::ProcessWriteResponseMessage(System::PacketBufferHandle && payload) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -365,23 +331,28 @@ CHIP_ERROR WriteClient::SendWriteRequest(const SessionHandle & session, System:: err = FinalizeMessage(false /* hasMoreChunks */); SuccessOrExit(err); - // Create a new exchange context. - mpExchangeCtx = mpExchangeMgr->NewContext(session, this); - VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); - VerifyOrReturnError(!(mpExchangeCtx->IsGroupExchangeContext() && mHasDataVersion), CHIP_ERROR_INVALID_MESSAGE_TYPE); + { + // Create a new exchange context. + auto exchange = mpExchangeMgr->NewContext(session, this); + VerifyOrExit(exchange != nullptr, err = CHIP_ERROR_NO_MEMORY); + + mExchangeCtx.Grab(exchange); + } + + VerifyOrReturnError(!(mExchangeCtx->IsGroupExchangeContext() && mHasDataVersion), CHIP_ERROR_INVALID_MESSAGE_TYPE); if (timeout == System::Clock::kZero) { - mpExchangeCtx->UseSuggestedResponseTimeout(app::kExpectedIMProcessingTime); + mExchangeCtx->UseSuggestedResponseTimeout(app::kExpectedIMProcessingTime); } else { - mpExchangeCtx->SetResponseTimeout(timeout); + mExchangeCtx->SetResponseTimeout(timeout); } if (mTimedWriteTimeoutMs.HasValue()) { - err = TimedRequest::Send(mpExchangeCtx, mTimedWriteTimeoutMs.Value()); + err = TimedRequest::Send(mExchangeCtx.Get(), mTimedWriteTimeoutMs.Value()); SuccessOrExit(err); MoveToState(State::AwaitingTimedStatus); } @@ -425,7 +396,7 @@ CHIP_ERROR WriteClient::SendWriteRequest() System::PacketBufferHandle data = mChunks.PopHead(); - bool isGroupWrite = mpExchangeCtx->IsGroupExchangeContext(); + bool isGroupWrite = mExchangeCtx->IsGroupExchangeContext(); if (!mChunks.IsNull() && isGroupWrite) { // Reject this request if we have more than one chunk (mChunks is not null after PopHead()), and this is a group @@ -434,13 +405,8 @@ CHIP_ERROR WriteClient::SendWriteRequest() } // kExpectResponse is ignored by ExchangeContext in case of groupcast - ReturnErrorOnFailure(mpExchangeCtx->SendMessage(MsgType::WriteRequest, std::move(data), SendMessageFlags::kExpectResponse)); - if (isGroupWrite) - { - // Exchange is closed now, since there are no group responses. Drop our - // ref to it. - mpExchangeCtx = nullptr; - } + ReturnErrorOnFailure(mExchangeCtx->SendMessage(MsgType::WriteRequest, std::move(data), SendMessageFlags::kExpectResponse)); + MoveToState(State::AwaitingResponse); return CHIP_NO_ERROR; } @@ -456,11 +422,12 @@ CHIP_ERROR WriteClient::OnMessageReceived(Messaging::ExchangeContext * apExchang } CHIP_ERROR err = CHIP_NO_ERROR; + // Assert that the exchange context matches the client's current context. // This should never fail because even if SendWriteRequest is called // back-to-back, the second call will call Close() on the first exchange, // which clears the OnMessageReceived callback. - VerifyOrExit(apExchangeContext == mpExchangeCtx, err = CHIP_ERROR_INCORRECT_STATE); + VerifyOrExit(apExchangeContext == mExchangeCtx.Get(), err = CHIP_ERROR_INCORRECT_STATE); if (mState == State::AwaitingTimedStatus) { diff --git a/src/app/WriteClient.h b/src/app/WriteClient.h index 2f7ca2366de728..395809dc6cb750 100644 --- a/src/app/WriteClient.h +++ b/src/app/WriteClient.h @@ -33,7 +33,7 @@ #include #include #include -#include +#include #include #include #include @@ -125,14 +125,15 @@ class WriteClient : public Messaging::ExchangeDelegate WriteClient(Messaging::ExchangeManager * apExchangeMgr, Callback * apCallback, const Optional & aTimedWriteTimeoutMs, bool aSuppressResponse = false) : mpExchangeMgr(apExchangeMgr), - mpCallback(apCallback), mTimedWriteTimeoutMs(aTimedWriteTimeoutMs), mSuppressResponse(aSuppressResponse) + mExchangeCtx(*this), mpCallback(apCallback), mTimedWriteTimeoutMs(aTimedWriteTimeoutMs), + mSuppressResponse(aSuppressResponse) {} #if CONFIG_BUILD_FOR_HOST_UNIT_TEST WriteClient(Messaging::ExchangeManager * apExchangeMgr, Callback * apCallback, const Optional & aTimedWriteTimeoutMs, uint16_t aReservedSize) : mpExchangeMgr(apExchangeMgr), - mpCallback(apCallback), mTimedWriteTimeoutMs(aTimedWriteTimeoutMs), mReservedSize(aReservedSize) + mExchangeCtx(*this), mpCallback(apCallback), mTimedWriteTimeoutMs(aTimedWriteTimeoutMs), mReservedSize(aReservedSize) {} #endif @@ -226,14 +227,6 @@ class WriteClient : public Messaging::ExchangeDelegate */ void Shutdown(); - /* - * Destructor - as part of destruction, it will abort the exchange context - * if a valid one still exists. - * - * See Abort() for details on when that might occur. - */ - ~WriteClient() override { Abort(); } - private: friend class TestWriteInteraction; friend class InteractionModelEngine; @@ -378,9 +371,9 @@ class WriteClient : public Messaging::ExchangeDelegate CHIP_ERROR FinalizeMessage(bool aHasMoreChunks); Messaging::ExchangeManager * mpExchangeMgr = nullptr; - Messaging::ExchangeContext * mpExchangeCtx = nullptr; - Callback * mpCallback = nullptr; - State mState = State::Initialized; + Messaging::ExchangeHolder mExchangeCtx; + Callback * mpCallback = nullptr; + State mState = State::Initialized; System::PacketBufferTLVWriter mMessageWriter; WriteRequestMessage::Builder mWriteRequestBuilder; // TODO Maybe we should change PacketBufferTLVWriter so we can finalize it diff --git a/src/app/WriteHandler.cpp b/src/app/WriteHandler.cpp index ba52461882abe9..3ede722aa33bf0 100644 --- a/src/app/WriteHandler.cpp +++ b/src/app/WriteHandler.cpp @@ -36,7 +36,7 @@ constexpr uint8_t kListAttributeType = 0x48; CHIP_ERROR WriteHandler::Init() { - VerifyOrReturnError(mpExchangeCtx == nullptr, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(!mExchangeCtx, CHIP_ERROR_INCORRECT_STATE); MoveToState(State::Initialized); @@ -51,31 +51,11 @@ void WriteHandler::Close() mSuppressResponse = false; VerifyOrReturn(mState != State::Uninitialized); - if (mpExchangeCtx != nullptr) - { - mpExchangeCtx->SetDelegate(nullptr); - mpExchangeCtx = nullptr; - } - ClearState(); } void WriteHandler::Abort() { - if (mpExchangeCtx != nullptr) - { - // We might be a delegate for this exchange, and we don't want the - // OnExchangeClosing notification in that case. Null out the delegate - // to avoid that. - // - // TODO: This makes all sorts of assumptions about what the delegate is - // (notice the "might" above!) that might not hold in practice. We - // really need a better solution here.... - mpExchangeCtx->SetDelegate(nullptr); - mpExchangeCtx->Abort(); - mpExchangeCtx = nullptr; - } - ClearState(); } @@ -110,13 +90,11 @@ Status WriteHandler::HandleWriteRequestMessage(Messaging::ExchangeContext * apEx Status WriteHandler::OnWriteRequest(Messaging::ExchangeContext * apExchangeContext, System::PacketBufferHandle && aPayload, bool aIsTimedWrite) { - mpExchangeCtx = apExchangeContext; - // // Let's take over further message processing on this exchange from the IM. // This is only relevant during chunked requests. // - mpExchangeCtx->SetDelegate(this); + mExchangeCtx.Grab(apExchangeContext); Status status = HandleWriteRequestMessage(apExchangeContext, std::move(aPayload), aIsTimedWrite); @@ -134,7 +112,7 @@ CHIP_ERROR WriteHandler::OnMessageReceived(Messaging::ExchangeContext * apExchan { CHIP_ERROR err = CHIP_NO_ERROR; - VerifyOrDieWithMsg(apExchangeContext == mpExchangeCtx, DataManagement, + VerifyOrDieWithMsg(apExchangeContext == mExchangeCtx.Get(), DataManagement, "Incoming exchange context should be same as the initial request."); VerifyOrDieWithMsg(!apExchangeContext->IsGroupExchangeContext(), DataManagement, "OnMessageReceived should not be called on GroupExchangeContext"); @@ -191,11 +169,11 @@ CHIP_ERROR WriteHandler::SendWriteResponse(System::PacketBufferTLVWriter && aMes err = FinalizeMessage(std::move(aMessageWriter), packet); SuccessOrExit(err); - VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_INCORRECT_STATE); - mpExchangeCtx->UseSuggestedResponseTimeout(app::kExpectedIMProcessingTime); - err = mpExchangeCtx->SendMessage(Protocols::InteractionModel::MsgType::WriteResponse, std::move(packet), - mHasMoreChunks ? Messaging::SendMessageFlags::kExpectResponse - : Messaging::SendMessageFlags::kNone); + VerifyOrExit(mExchangeCtx, err = CHIP_ERROR_INCORRECT_STATE); + mExchangeCtx->UseSuggestedResponseTimeout(app::kExpectedIMProcessingTime); + err = mExchangeCtx->SendMessage(Protocols::InteractionModel::MsgType::WriteResponse, std::move(packet), + mHasMoreChunks ? Messaging::SendMessageFlags::kExpectResponse + : Messaging::SendMessageFlags::kNone); SuccessOrExit(err); MoveToState(State::Sending); @@ -237,7 +215,7 @@ CHIP_ERROR WriteHandler::DeliverFinalListWriteEndForGroupWrite(bool writeWasSucc Credentials::GroupDataProvider * groupDataProvider = Credentials::GetGroupDataProvider(); Credentials::GroupDataProvider::EndpointIterator * iterator; - GroupId groupId = mpExchangeCtx->GetSessionHandle()->AsIncomingGroupSession()->GetGroupId(); + GroupId groupId = mExchangeCtx->GetSessionHandle()->AsIncomingGroupSession()->GetGroupId(); FabricIndex fabricIndex = GetAccessingFabricIndex(); auto processingConcreteAttributePath = mProcessingAttributePath.Value(); @@ -289,8 +267,8 @@ CHIP_ERROR WriteHandler::ProcessAttributeDataIBs(TLV::TLVReader & aAttributeData { CHIP_ERROR err = CHIP_NO_ERROR; - ReturnErrorCodeIf(mpExchangeCtx == nullptr, CHIP_ERROR_INTERNAL); - const Access::SubjectDescriptor subjectDescriptor = mpExchangeCtx->GetSessionHandle()->GetSubjectDescriptor(); + ReturnErrorCodeIf(!mExchangeCtx, CHIP_ERROR_INTERNAL); + const Access::SubjectDescriptor subjectDescriptor = mExchangeCtx->GetSessionHandle()->GetSubjectDescriptor(); while (CHIP_NO_ERROR == (err = aAttributeDataIBsReader.Next())) { @@ -396,11 +374,11 @@ CHIP_ERROR WriteHandler::ProcessGroupAttributeDataIBs(TLV::TLVReader & aAttribut { CHIP_ERROR err = CHIP_NO_ERROR; - ReturnErrorCodeIf(mpExchangeCtx == nullptr, CHIP_ERROR_INTERNAL); + ReturnErrorCodeIf(!mExchangeCtx, CHIP_ERROR_INTERNAL); const Access::SubjectDescriptor subjectDescriptor = - mpExchangeCtx->GetSessionHandle()->AsIncomingGroupSession()->GetSubjectDescriptor(); + mExchangeCtx->GetSessionHandle()->AsIncomingGroupSession()->GetSubjectDescriptor(); - GroupId groupId = mpExchangeCtx->GetSessionHandle()->AsIncomingGroupSession()->GetGroupId(); + GroupId groupId = mExchangeCtx->GetSessionHandle()->AsIncomingGroupSession()->GetGroupId(); FabricIndex fabric = GetAccessingFabricIndex(); while (CHIP_NO_ERROR == (err = aAttributeDataIBsReader.Next())) @@ -590,7 +568,7 @@ Status WriteHandler::ProcessWriteRequest(System::PacketBufferHandle && aPayload, } SuccessOrExit(err); - if (mHasMoreChunks && (mpExchangeCtx->IsGroupExchangeContext() || mIsTimedRequest)) + if (mHasMoreChunks && (mExchangeCtx->IsGroupExchangeContext() || mIsTimedRequest)) { // Sanity check: group exchange context should only have one chunk. // Also, timed requests should not have more than one chunk. @@ -610,7 +588,7 @@ Status WriteHandler::ProcessWriteRequest(System::PacketBufferHandle && aPayload, AttributeDataIBsParser.GetReader(&AttributeDataIBsReader); - if (mpExchangeCtx->IsGroupExchangeContext()) + if (mExchangeCtx->IsGroupExchangeContext()) { err = ProcessGroupAttributeDataIBs(AttributeDataIBsReader); } @@ -680,7 +658,7 @@ CHIP_ERROR WriteHandler::AddStatus(const ConcreteDataAttributePath & aPath, cons FabricIndex WriteHandler::GetAccessingFabricIndex() const { - return mpExchangeCtx->GetSessionHandle()->GetFabricIndex(); + return mExchangeCtx->GetSessionHandle()->GetFabricIndex(); } const char * WriteHandler::GetStateStr() const @@ -712,6 +690,7 @@ void WriteHandler::MoveToState(const State aTargetState) void WriteHandler::ClearState() { DeliverFinalListWriteEnd(false /* wasSuccessful */); + mExchangeCtx.Release(); MoveToState(State::Uninitialized); } diff --git a/src/app/WriteHandler.h b/src/app/WriteHandler.h index 1edf40b15273ff..43163a2b8abab9 100644 --- a/src/app/WriteHandler.h +++ b/src/app/WriteHandler.h @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include #include @@ -41,6 +41,8 @@ namespace app { class WriteHandler : public Messaging::ExchangeDelegate { public: + WriteHandler() : mExchangeCtx(*this) {} + /** * Initialize the WriteHandler. Within the lifetime * of this instance, this method is invoked once after object @@ -96,7 +98,7 @@ class WriteHandler : public Messaging::ExchangeDelegate bool MatchesExchangeContext(Messaging::ExchangeContext * apExchangeContext) const { - return !IsFree() && mpExchangeCtx == apExchangeContext; + return !IsFree() && mExchangeCtx.Get() == apExchangeContext; } void CacheACLCheckResult(const AttributeAccessToken & aToken) { mACLCheckCache.SetValue(aToken); } @@ -158,7 +160,7 @@ class WriteHandler : public Messaging::ExchangeDelegate System::PacketBufferHandle && aPayload) override; void OnResponseTimeout(Messaging::ExchangeContext * apExchangeContext) override; - Messaging::ExchangeContext * mpExchangeCtx = nullptr; + Messaging::ExchangeHolder mExchangeCtx; WriteResponseMessage::Builder mWriteResponseBuilder; State mState = State::Uninitialized; bool mIsTimedRequest = false; diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp index 6b7a089104229c..06a007c83c57b9 100644 --- a/src/app/tests/TestCommandInteraction.cpp +++ b/src/app/tests/TestCommandInteraction.cpp @@ -452,8 +452,8 @@ void TestCommandInteraction::TestCommandHandlerWithWrongState(nlTestSuite * apSu NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); TestExchangeDelegate delegate; - commandHandler.mpExchangeCtx = ctx.NewExchangeToAlice(&delegate); - err = commandHandler.SendCommandResponse(); + commandHandler.mExchangeCtx.Grab(ctx.NewExchangeToAlice(&delegate)); + err = commandHandler.SendCommandResponse(); NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_INCORRECT_STATE); } @@ -488,7 +488,7 @@ void TestCommandInteraction::TestCommandHandlerWithSendEmptyCommand(nlTestSuite System::PacketBufferHandle commandDatabuf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize); TestExchangeDelegate delegate; - commandHandler.mpExchangeCtx = ctx.NewExchangeToAlice(&delegate); + commandHandler.mExchangeCtx.Grab(ctx.NewExchangeToAlice(&delegate)); err = commandHandler.PrepareCommand(path); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); @@ -521,7 +521,7 @@ void TestCommandInteraction::ValidateCommandHandlerWithSendCommand(nlTestSuite * System::PacketBufferHandle commandPacket; TestExchangeDelegate delegate; - commandHandler.mpExchangeCtx = ctx.NewExchangeToAlice(&delegate); + commandHandler.mExchangeCtx.Grab(ctx.NewExchangeToAlice(&delegate)); AddInvokeResponseData(apSuite, apContext, &commandHandler, aNeedStatusCode); err = commandHandler.Finalize(commandPacket); @@ -581,7 +581,7 @@ void TestCommandInteraction::TestCommandHandlerCommandDataEncoding(nlTestSuite * System::PacketBufferHandle commandPacket; TestExchangeDelegate delegate; - commandHandler.mpExchangeCtx = ctx.NewExchangeToAlice(&delegate); + commandHandler.mExchangeCtx.Grab(ctx.NewExchangeToAlice(&delegate)); auto path = MakeTestCommandPath(); @@ -608,7 +608,7 @@ void TestCommandInteraction::TestCommandHandlerCommandEncodeFailure(nlTestSuite System::PacketBufferHandle commandPacket; TestExchangeDelegate delegate; - commandHandler.mpExchangeCtx = ctx.NewExchangeToAlice(&delegate); + commandHandler.mExchangeCtx.Grab(ctx.NewExchangeToAlice(&delegate)); auto path = MakeTestCommandPath(); @@ -635,7 +635,7 @@ void TestCommandInteraction::TestCommandHandlerCommandEncodeExternalFailure(nlTe System::PacketBufferHandle commandPacket; TestExchangeDelegate delegate; - commandHandler.mpExchangeCtx = ctx.NewExchangeToAlice(&delegate); + commandHandler.mExchangeCtx.Grab(ctx.NewExchangeToAlice(&delegate)); auto path = MakeTestCommandPath(); @@ -672,7 +672,7 @@ void TestCommandInteraction::TestCommandHandlerWithProcessReceivedMsg(nlTestSuit System::PacketBufferHandle commandDatabuf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize); TestExchangeDelegate delegate; - commandHandler.mpExchangeCtx = ctx.NewExchangeToAlice(&delegate); + commandHandler.mExchangeCtx.Grab(ctx.NewExchangeToAlice(&delegate)); GenerateInvokeRequest(apSuite, apContext, commandDatabuf, /* aIsTimedRequest = */ false, kTestCommandIdWithData); err = commandHandler.ProcessInvokeRequest(std::move(commandDatabuf), false); @@ -711,7 +711,7 @@ void TestCommandInteraction::TestCommandHandlerWithProcessReceivedEmptyDataMsg(n System::PacketBufferHandle commandDatabuf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize); TestExchangeDelegate delegate; - commandHandler.mpExchangeCtx = ctx.NewExchangeToAlice(&delegate); + commandHandler.mExchangeCtx.Grab(ctx.NewExchangeToAlice(&delegate)); chip::isCommandDispatched = false; GenerateInvokeRequest(apSuite, apContext, commandDatabuf, messageIsTimed, kTestCommandIdNoData); diff --git a/src/app/tests/TestWriteInteraction.cpp b/src/app/tests/TestWriteInteraction.cpp index 9836c8933cd348..f1af80cef330b7 100644 --- a/src/app/tests/TestWriteInteraction.cpp +++ b/src/app/tests/TestWriteInteraction.cpp @@ -306,15 +306,6 @@ void TestWriteInteraction::TestWriteHandler(nlTestSuite * apSuite, void * apCont } else { - // - // In a normal execution flow, the exchange manager would have closed out the exchange after the - // message dispatch call path had unwound. In this test however, we've manually allocated the exchange - // ourselves (as opposed to the exchange manager), so we need to take ownership of closing out the exchange. - // - // Note that this doesn't happen in the success case above, since that results in a call to send a message through - // the exchange context, which results in the exchange manager correctly closing it. - // - exchange->Close(); NL_TEST_ASSERT(apSuite, status == Status::UnsupportedAccess); } diff --git a/src/messaging/ExchangeHolder.h b/src/messaging/ExchangeHolder.h index 7bb27458489c58..fb7e2cd39c55bd 100644 --- a/src/messaging/ExchangeHolder.h +++ b/src/messaging/ExchangeHolder.h @@ -64,6 +64,8 @@ class ExchangeHolder : public ExchangeDelegate */ void Grab(ExchangeContext * exchange) { + VerifyOrDie(exchange != nullptr); + Release(); mpExchangeCtx = exchange;