diff --git a/examples/chip-tool/commands/common/CommandInvoker.h b/examples/chip-tool/commands/common/CommandInvoker.h index 6583cc9124280e..e51f141a50ea5b 100644 --- a/examples/chip-tool/commands/common/CommandInvoker.h +++ b/examples/chip-tool/commands/common/CommandInvoker.h @@ -100,33 +100,22 @@ class CommandInvoker final : public ResponseReceiver(this, aDevice->GetExchangeManager()); + auto commandSender = Platform::MakeUnique(this, exchangeManager); VerifyOrReturnError(commandSender != nullptr, CHIP_ERROR_NO_MEMORY); ReturnErrorOnFailure(commandSender->AddRequestData(commandPath, aRequestData)); - if (aDevice->GetSecureSession().HasValue()) + Optional session = exchangeManager->GetSessionManager()->CreateGroupSession(groupId); + if (!session.HasValue()) { - SessionHandle session = aDevice->GetSecureSession().Value(); - session.SetGroupId(groupId); - - if (!session.IsGroupSession()) - { - return CHIP_ERROR_INCORRECT_STATE; - } - - ReturnErrorOnFailure(commandSender->SendCommandRequest(session)); - } - else - { - // something fishy is going on - return CHIP_ERROR_INCORRECT_STATE; + return CHIP_ERROR_NO_MEMORY; } + ReturnErrorOnFailure(commandSender->SendCommandRequest(session)); commandSender.release(); return CHIP_NO_ERROR; diff --git a/examples/ota-provider-app/ota-provider-common/OTAProviderExample.cpp b/examples/ota-provider-app/ota-provider-common/OTAProviderExample.cpp index 968bcea2b61a6f..43a9d06efc75ed 100644 --- a/examples/ota-provider-app/ota-provider-common/OTAProviderExample.cpp +++ b/examples/ota-provider-app/ota-provider-common/OTAProviderExample.cpp @@ -114,7 +114,7 @@ EmberAfStatus OTAProviderExample::HandleQueryImage(chip::app::CommandHandler * c { // TODO: This uses the current node as the provider to supply the OTA image. This can be configurable such that the provider // supplying the response is not the provider supplying the OTA image. - FabricIndex fabricIndex = commandObj->GetExchangeContext()->GetSessionHandle().GetFabricIndex(); + FabricIndex fabricIndex = commandObj->GetExchangeContext()->GetSessionHandle()->AsSecureSession()->GetFabricIndex(); FabricInfo * fabricInfo = Server::GetInstance().GetFabricTable().FindFabricWithIndex(fabricIndex); NodeId nodeId = fabricInfo->GetPeerId().GetNodeId(); diff --git a/src/app/CASESessionManager.cpp b/src/app/CASESessionManager.cpp index 5d3ed0e5e3f07e..08d5d42ed3376a 100644 --- a/src/app/CASESessionManager.cpp +++ b/src/app/CASESessionManager.cpp @@ -109,14 +109,6 @@ CHIP_ERROR CASESessionManager::GetPeerAddress(PeerId peerId, Transport::PeerAddr return CHIP_NO_ERROR; } -void CASESessionManager::OnSessionReleased(const SessionHandle & sessionHandle) -{ - OperationalDeviceProxy * session = FindSession(sessionHandle); - VerifyOrReturn(session != nullptr, ChipLogDetail(Controller, "OnSessionReleased was called for unknown device, ignoring it.")); - - session->OnSessionReleased(sessionHandle); -} - OperationalDeviceProxy * CASESessionManager::FindSession(const SessionHandle & session) { return mConfig.devicePool->FindDevice(session); diff --git a/src/app/CASESessionManager.h b/src/app/CASESessionManager.h index 382fcd493784e8..f22bf998736dac 100644 --- a/src/app/CASESessionManager.h +++ b/src/app/CASESessionManager.h @@ -48,7 +48,7 @@ struct CASESessionManagerConfig * 4. During session establishment, trigger node ID resolution (if needed), and update the DNS-SD cache (if resolution is * successful) */ -class CASESessionManager : public SessionReleaseDelegate, public Dnssd::ResolverDelegate +class CASESessionManager : public Dnssd::ResolverDelegate { public: CASESessionManager() = delete; @@ -105,9 +105,6 @@ class CASESessionManager : public SessionReleaseDelegate, public Dnssd::Resolver */ CHIP_ERROR GetPeerAddress(PeerId peerId, Transport::PeerAddress & addr); - //////////// SessionReleaseDelegate Implementation /////////////// - void OnSessionReleased(const SessionHandle & session) override; - //////////// ResolverDelegate Implementation /////////////// void OnNodeIdResolved(const Dnssd::ResolvedNodeData & nodeData) override; void OnNodeIdResolutionFailed(const PeerId & peerId, CHIP_ERROR error) override; diff --git a/src/app/CommandHandler.cpp b/src/app/CommandHandler.cpp index ce31bfef101447..7e41b30b7d2c98 100644 --- a/src/app/CommandHandler.cpp +++ b/src/app/CommandHandler.cpp @@ -436,7 +436,7 @@ TLV::TLVWriter * CommandHandler::GetCommandDataIBTLVWriter() FabricIndex CommandHandler::GetAccessingFabricIndex() const { - return mpExchangeCtx->GetSessionHandle().GetFabricIndex(); + return mpExchangeCtx->GetSessionHandle()->AsSecureSession()->GetFabricIndex(); } CommandHandler * CommandHandler::Handle::Get() diff --git a/src/app/InteractionModelEngine.cpp b/src/app/InteractionModelEngine.cpp index eae1ec87fa7ce1..09d4bc01241475 100644 --- a/src/app/InteractionModelEngine.cpp +++ b/src/app/InteractionModelEngine.cpp @@ -269,8 +269,8 @@ CHIP_ERROR InteractionModelEngine::OnReadInitialRequest(Messaging::ExchangeConte for (auto & readHandler : mReadHandlers) { if (!readHandler.IsFree() && readHandler.IsSubscriptionType() && - readHandler.GetInitiatorNodeId() == apExchangeContext->GetSessionHandle().GetPeerNodeId() && - readHandler.GetAccessingFabricIndex() == apExchangeContext->GetSessionHandle().GetFabricIndex()) + readHandler.GetInitiatorNodeId() == apExchangeContext->GetSessionHandle()->AsSecureSession()->GetPeerNodeId() && + readHandler.GetAccessingFabricIndex() == apExchangeContext->GetSessionHandle()->AsSecureSession()->GetFabricIndex()) { bool keepSubscriptions = true; System::PacketBufferTLVReader reader; diff --git a/src/app/InteractionModelEngine.h b/src/app/InteractionModelEngine.h index a306706e70ba33..04becfc3bbb200 100644 --- a/src/app/InteractionModelEngine.h +++ b/src/app/InteractionModelEngine.h @@ -308,7 +308,7 @@ bool ServerClusterCommandExists(const ConcreteCommandPath & aCommandPath); * * @param[in] aSubjectDescriptor The subject descriptor for the read. * @param[in] aPath The concrete path of the data being read. - * @param[in] aAttributeReport The TLV Builder for Cluter attribute builder. + * @param[in] aAttributeReports The TLV Builder for Cluter attribute builder. * * @retval CHIP_NO_ERROR on success */ diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp index 9cf2234098ca51..0164e5c08247d6 100644 --- a/src/app/OperationalDeviceProxy.cpp +++ b/src/app/OperationalDeviceProxy.cpp @@ -127,11 +127,7 @@ CHIP_ERROR OperationalDeviceProxy::UpdateDeviceData(const Transport::PeerAddress return CHIP_NO_ERROR; } - Transport::SecureSession * secureSession = mInitParams.sessionManager->GetSecureSession(mSecureSession.Get()); - if (secureSession != nullptr) - { - secureSession->SetPeerAddress(addr); - } + mSecureSession.Get()->AsSecureSession()->SetPeerAddress(addr); } return err; @@ -262,7 +258,7 @@ CHIP_ERROR OperationalDeviceProxy::Disconnect() return CHIP_NO_ERROR; } -void OperationalDeviceProxy::SetConnectedSession(SessionHandle handle) +void OperationalDeviceProxy::SetConnectedSession(const SessionHandle & handle) { mSecureSession.Grab(handle); mState = State::SecureConnected; @@ -296,12 +292,9 @@ void OperationalDeviceProxy::DeferCloseCASESession() mSystemLayer->ScheduleWork(CloseCASESessionTask, this); } -void OperationalDeviceProxy::OnSessionReleased(const SessionHandle & session) +void OperationalDeviceProxy::OnSessionReleased() { - VerifyOrReturn(mSecureSession.Contains(session), - ChipLogDetail(Controller, "Connection expired, but it doesn't match the current session")); mState = State::Initialized; - mSecureSession.Release(); } CHIP_ERROR OperationalDeviceProxy::ShutdownSubscriptions() diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index c56a92453c94f0..91dd127706ebce 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -80,7 +80,7 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, SessionReleaseDele { public: virtual ~OperationalDeviceProxy(); - OperationalDeviceProxy(DeviceProxyInitParams & params, PeerId peerId) + OperationalDeviceProxy(DeviceProxyInitParams & params, PeerId peerId) : mSecureSession(*this) { VerifyOrReturn(params.Validate() == CHIP_NO_ERROR); @@ -124,7 +124,7 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, SessionReleaseDele * Called when a connection is closing. * The object releases all resources associated with the connection. */ - void OnSessionReleased(const SessionHandle & session) override; + void OnSessionReleased() override; void OnNodeIdResolved(const Dnssd::ResolvedNodeData & nodeResolutionData) { @@ -150,7 +150,7 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, SessionReleaseDele * * Note: Avoid using this function generally as it is Deprecated */ - void SetConnectedSession(SessionHandle handle); + void SetConnectedSession(const SessionHandle & handle); NodeId GetDeviceId() const override { return mPeerId.GetNodeId(); } @@ -219,7 +219,7 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, SessionReleaseDele State mState = State::Uninitialized; - SessionHolder mSecureSession; + SessionHolderWithDelegate mSecureSession; uint8_t mSequenceNumber = 0; diff --git a/src/app/ReadClient.cpp b/src/app/ReadClient.cpp index c1b993ef34b83d..e46654b6a77359 100644 --- a/src/app/ReadClient.cpp +++ b/src/app/ReadClient.cpp @@ -182,7 +182,7 @@ CHIP_ERROR ReadClient::SendReadRequest(ReadPrepareParams & aReadPrepareParams) ReturnErrorOnFailure(writer.Finalize(&msgBuf)); } - mpExchangeCtx = mpExchangeMgr->NewContext(aReadPrepareParams.mSessionHandle, this); + mpExchangeCtx = mpExchangeMgr->NewContext(aReadPrepareParams.mSessionHolder.Get(), this); VerifyOrReturnError(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); mpExchangeCtx->SetResponseTimeout(aReadPrepareParams.mTimeout); @@ -190,8 +190,8 @@ CHIP_ERROR ReadClient::SendReadRequest(ReadPrepareParams & aReadPrepareParams) ReturnErrorOnFailure(mpExchangeCtx->SendMessage(Protocols::InteractionModel::MsgType::ReadRequest, std::move(msgBuf), Messaging::SendFlags(Messaging::SendMessageFlags::kExpectResponse))); - mPeerNodeId = aReadPrepareParams.mSessionHandle.GetPeerNodeId(); - mFabricIndex = aReadPrepareParams.mSessionHandle.GetFabricIndex(); + mPeerNodeId = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetPeerNodeId(); + mFabricIndex = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetFabricIndex(); MoveToState(ClientState::AwaitingInitialReport); @@ -576,8 +576,10 @@ CHIP_ERROR ReadClient::RefreshLivenessCheckTimer() CHIP_ERROR err = CHIP_NO_ERROR; CancelLivenessCheckTimer(); VerifyOrReturnError(mpExchangeCtx != nullptr, err = CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(mpExchangeCtx->HasSessionHandle(), err = CHIP_ERROR_INCORRECT_STATE); - System::Clock::Timeout timeout = System::Clock::Seconds16(mMaxIntervalCeilingSeconds) + mpExchangeCtx->GetAckTimeout(); + System::Clock::Timeout timeout = + System::Clock::Seconds16(mMaxIntervalCeilingSeconds) + mpExchangeCtx->GetSessionHandle()->GetAckTimeout(); // EFR32/MBED/INFINION/K32W's chrono count return long unsinged, but other platform returns unsigned ChipLogProgress(DataManagement, "Refresh LivenessCheckTime with %lu milliseconds", static_cast(timeout.count())); err = InteractionModelEngine::GetInstance()->GetExchangeManager()->GetSessionManager()->SystemLayer()->StartTimer( @@ -699,21 +701,15 @@ CHIP_ERROR ReadClient::SendSubscribeRequest(ReadPrepareParams & aReadPreparePara ReturnErrorOnFailure(writer.Finalize(&msgBuf)); - mpExchangeCtx = mpExchangeMgr->NewContext(aReadPrepareParams.mSessionHandle, this); + mpExchangeCtx = mpExchangeMgr->NewContext(aReadPrepareParams.mSessionHolder.Get(), this); VerifyOrReturnError(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); - mpExchangeCtx->SetResponseTimeout(kImMessageTimeout); - if (mpExchangeCtx->IsBLETransport()) - { - ChipLogError(DataManagement, "IM Subscribe cannot work with BLE"); - return CHIP_ERROR_INCORRECT_STATE; - } ReturnErrorOnFailure(mpExchangeCtx->SendMessage(Protocols::InteractionModel::MsgType::SubscribeRequest, std::move(msgBuf), Messaging::SendFlags(Messaging::SendMessageFlags::kExpectResponse))); - mPeerNodeId = aReadPrepareParams.mSessionHandle.GetPeerNodeId(); - mFabricIndex = aReadPrepareParams.mSessionHandle.GetFabricIndex(); + mPeerNodeId = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetPeerNodeId(); + mFabricIndex = aReadPrepareParams.mSessionHolder->AsSecureSession()->GetFabricIndex(); MoveToState(ClientState::AwaitingInitialReport); diff --git a/src/app/ReadHandler.cpp b/src/app/ReadHandler.cpp index 8e9c93640b12a1..2d2f429710f04f 100644 --- a/src/app/ReadHandler.cpp +++ b/src/app/ReadHandler.cpp @@ -58,8 +58,8 @@ CHIP_ERROR ReadHandler::Init(Messaging::ExchangeManager * apExchangeMgr, Interac mActiveSubscription = false; mIsChunkedReport = false; mInteractionType = aInteractionType; - mInitiatorNodeId = apExchangeContext->GetSessionHandle().GetPeerNodeId(); - mSubjectDescriptor = apExchangeContext->GetSessionHandle().GetSubjectDescriptor(); + mInitiatorNodeId = apExchangeContext->GetSessionHandle()->AsSecureSession()->GetPeerNodeId(); + mSubjectDescriptor = apExchangeContext->GetSessionHandle()->GetSubjectDescriptor(); mHoldSync = false; mLastWrittenEventsBytes = 0; if (apExchangeContext != nullptr) @@ -201,12 +201,13 @@ CHIP_ERROR ReadHandler::SendReportData(System::PacketBufferHandle && aPayload, b VerifyOrReturnLogError(IsReportable(), CHIP_ERROR_INCORRECT_STATE); if (IsPriming() || IsChunkedReport()) { - mSessionHandle.SetValue(mpExchangeCtx->GetSessionHandle()); + mSessionHandle.Grab(mpExchangeCtx->GetSessionHandle()); } else { VerifyOrReturnLogError(mpExchangeCtx == nullptr, CHIP_ERROR_INCORRECT_STATE); - mpExchangeCtx = mpExchangeMgr->NewContext(mSessionHandle.Value(), this); + VerifyOrReturnLogError(mSessionHandle, CHIP_ERROR_INCORRECT_STATE); + mpExchangeCtx = mpExchangeMgr->NewContext(mSessionHandle.Get(), this); mpExchangeCtx->SetResponseTimeout(kImMessageTimeout); } VerifyOrReturnLogError(mpExchangeCtx != nullptr, CHIP_ERROR_INCORRECT_STATE); diff --git a/src/app/ReadHandler.h b/src/app/ReadHandler.h index 75ce763ff1362a..df8e1c443ea366 100644 --- a/src/app/ReadHandler.h +++ b/src/app/ReadHandler.h @@ -216,7 +216,7 @@ class ReadHandler : public Messaging::ExchangeDelegate uint64_t mSubscriptionId = 0; uint16_t mMinIntervalFloorSeconds = 0; uint16_t mMaxIntervalCeilingSeconds = 0; - Optional mSessionHandle; + SessionHolder mSessionHandle; bool mHoldReport = false; bool mDirty = false; bool mActiveSubscription = false; diff --git a/src/app/ReadPrepareParams.h b/src/app/ReadPrepareParams.h index 547da0607c62e3..7173e27ed23fda 100644 --- a/src/app/ReadPrepareParams.h +++ b/src/app/ReadPrepareParams.h @@ -29,7 +29,7 @@ namespace chip { namespace app { struct ReadPrepareParams { - SessionHandle mSessionHandle; + SessionHolder mSessionHolder; EventPathParams * mpEventPathParamsList = nullptr; size_t mEventPathParamsListSize = 0; AttributePathParams * mpAttributePathParamsList = nullptr; @@ -40,8 +40,8 @@ struct ReadPrepareParams uint16_t mMaxIntervalCeilingSeconds = 0; bool mKeepSubscriptions = true; - ReadPrepareParams(const SessionHandle & sessionHandle) : mSessionHandle(sessionHandle) {} - ReadPrepareParams(ReadPrepareParams && other) : mSessionHandle(other.mSessionHandle) + ReadPrepareParams(const SessionHandle & sessionHandle) { mSessionHolder.Grab(sessionHandle); } + ReadPrepareParams(ReadPrepareParams && other) : mSessionHolder(other.mSessionHolder) { mKeepSubscriptions = other.mKeepSubscriptions; mpEventPathParamsList = other.mpEventPathParamsList; @@ -64,7 +64,7 @@ struct ReadPrepareParams return *this; mKeepSubscriptions = other.mKeepSubscriptions; - mSessionHandle = other.mSessionHandle; + mSessionHolder = other.mSessionHolder; mpEventPathParamsList = other.mpEventPathParamsList; mEventPathParamsListSize = other.mEventPathParamsListSize; mpAttributePathParamsList = other.mpAttributePathParamsList; diff --git a/src/app/WriteClient.cpp b/src/app/WriteClient.cpp index 13bcbff5f80495..fb13647938d53a 100644 --- a/src/app/WriteClient.cpp +++ b/src/app/WriteClient.cpp @@ -245,7 +245,7 @@ CHIP_ERROR WriteClient::SendWriteRequest(const SessionHandle & session, System:: exit: if (err != CHIP_NO_ERROR) { - ChipLogError(DataManagement, "Write client failed to SendWriteRequest"); + ChipLogError(DataManagement, "Write client failed to SendWriteRequest: %s", ErrorStr(err)); ClearExistingExchangeContext(); } else @@ -254,7 +254,7 @@ CHIP_ERROR WriteClient::SendWriteRequest(const SessionHandle & session, System:: // handle this object dying (e.g. due to IM enging shutdown) while the // async bits are pending we'd need to malloc some state bit that we can // twiddle if we die. For now just do the OnDone callback sync. - if (session.IsGroupSession()) + if (session->IsGroupSession()) { // Always shutdown on Group communication ChipLogDetail(DataManagement, "Closing on group Communication "); diff --git a/src/app/WriteClient.h b/src/app/WriteClient.h index 35a5996e70959e..5be883e501bd04 100644 --- a/src/app/WriteClient.h +++ b/src/app/WriteClient.h @@ -117,7 +117,7 @@ class WriteClient : public Messaging::ExchangeDelegate NodeId GetSourceNodeId() const { - return mpExchangeCtx != nullptr ? mpExchangeCtx->GetSessionHandle().GetPeerNodeId() : kUndefinedNodeId; + return mpExchangeCtx != nullptr ? mpExchangeCtx->GetSessionHandle()->AsSecureSession()->GetPeerNodeId() : kUndefinedNodeId; } private: diff --git a/src/app/WriteHandler.cpp b/src/app/WriteHandler.cpp index ccb1779f4edd06..8ec46e57bb82cf 100644 --- a/src/app/WriteHandler.cpp +++ b/src/app/WriteHandler.cpp @@ -113,7 +113,7 @@ CHIP_ERROR WriteHandler::ProcessAttributeDataIBs(TLV::TLVReader & aAttributeData CHIP_ERROR err = CHIP_NO_ERROR; ReturnErrorCodeIf(mpExchangeCtx == nullptr, CHIP_ERROR_INTERNAL); - const Access::SubjectDescriptor subjectDescriptor = mpExchangeCtx->GetSessionHandle().GetSubjectDescriptor(); + const Access::SubjectDescriptor subjectDescriptor = mpExchangeCtx->GetSessionHandle()->GetSubjectDescriptor(); while (CHIP_NO_ERROR == (err = aAttributeDataIBsReader.Next())) { @@ -279,7 +279,7 @@ CHIP_ERROR WriteHandler::AddStatus(const AttributePathParams & aAttributePathPar FabricIndex WriteHandler::GetAccessingFabricIndex() const { - return mpExchangeCtx->GetSessionHandle().GetFabricIndex(); + return mpExchangeCtx->GetSessionHandle()->AsSecureSession()->GetFabricIndex(); } const char * WriteHandler::GetStateStr() const diff --git a/src/app/clusters/general-commissioning-server/general-commissioning-server.cpp b/src/app/clusters/general-commissioning-server/general-commissioning-server.cpp index 9a5fe3e5a3d60b..6e95ab8ff4a097 100644 --- a/src/app/clusters/general-commissioning-server/general-commissioning-server.cpp +++ b/src/app/clusters/general-commissioning-server/general-commissioning-server.cpp @@ -134,8 +134,9 @@ bool emberAfGeneralCommissioningClusterCommissioningCompleteCallback( * This allows device to send messages back to commissioner. * Once bindings are implemented, this may no longer be needed. */ - server->SetFabricIndex(commandObj->GetExchangeContext()->GetSessionHandle().GetFabricIndex()); - server->SetPeerNodeId(commandObj->GetExchangeContext()->GetSessionHandle().GetPeerNodeId()); + SessionHandle handle = commandObj->GetExchangeContext()->GetSessionHandle(); + server->SetFabricIndex(handle->AsSecureSession()->GetFabricIndex()); + server->SetPeerNodeId(handle->AsSecureSession()->GetPeerNodeId()); CheckSuccess(server->CommissioningComplete(), Failure); diff --git a/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp b/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp index 5725f72ddb80b4..b95982e6d2a729 100644 --- a/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp +++ b/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp @@ -173,12 +173,8 @@ CHIP_ERROR ComputeAttestationSignature(app::CommandHandler * commandObj, // TODO: Create an alternative way to retrieve the Attestation Challenge without this huge amount of calls. // Retrieve attestation challenge - ByteSpan attestationChallenge = commandObj->GetExchangeContext() - ->GetExchangeMgr() - ->GetSessionManager() - ->GetSecureSession(commandObj->GetExchangeContext()->GetSessionHandle()) - ->GetCryptoContext() - .GetAttestationChallenge(); + ByteSpan attestationChallenge = + commandObj->GetExchangeContext()->GetSessionHandle()->AsSecureSession()->GetCryptoContext().GetAttestationChallenge(); Hash_SHA256_stream hashStream; ReturnErrorOnFailure(hashStream.Begin()); @@ -291,7 +287,7 @@ class FabricCleanupExchangeDelegate : public chip::Messaging::ExchangeDelegate void OnResponseTimeout(chip::Messaging::ExchangeContext * ec) override {} void OnExchangeClosing(chip::Messaging::ExchangeContext * ec) override { - FabricIndex currentFabricIndex = ec->GetSessionHandle().GetFabricIndex(); + FabricIndex currentFabricIndex = ec->GetSessionHandle()->AsSecureSession()->GetFabricIndex(); ec->GetExchangeMgr()->GetSessionManager()->ExpireAllPairingsForFabric(currentFabricIndex); } }; diff --git a/src/app/clusters/ota-requestor/OTARequestor.cpp b/src/app/clusters/ota-requestor/OTARequestor.cpp index 3feb7966c00d13..d114e4f60665e5 100644 --- a/src/app/clusters/ota-requestor/OTARequestor.cpp +++ b/src/app/clusters/ota-requestor/OTARequestor.cpp @@ -200,7 +200,7 @@ EmberAfStatus OTARequestor::HandleAnnounceOTAProvider(app::CommandHandler * comm } mProviderNodeId = providerNodeId; - mProviderFabricIndex = commandObj->GetExchangeContext()->GetSessionHandle().GetFabricIndex(); + mProviderFabricIndex = commandObj->GetExchangeContext()->GetSessionHandle()->AsSecureSession()->GetFabricIndex(); mProviderEndpointId = providerEndpoint; ChipLogProgress(SoftwareUpdate, "OTA Requestor received AnnounceOTAProvider"); diff --git a/src/app/clusters/scenes/scenes.cpp b/src/app/clusters/scenes/scenes.cpp index 126a17f48b92f8..c2de7e8e1e6862 100644 --- a/src/app/clusters/scenes/scenes.cpp +++ b/src/app/clusters/scenes/scenes.cpp @@ -69,7 +69,7 @@ static FabricIndex GetFabricIndex(app::CommandHandler * commandObj) { VerifyOrReturnError(nullptr != commandObj, 0); VerifyOrReturnError(nullptr != commandObj->GetExchangeContext(), 0); - return commandObj->GetExchangeContext()->GetSessionHandle().GetFabricIndex(); + return commandObj->GetExchangeContext()->GetSessionHandle()->AsSecureSession()->GetFabricIndex(); } static bool readServerAttribute(EndpointId endpoint, ClusterId clusterId, AttributeId attributeId, const char * name, diff --git a/src/app/reporting/Engine.cpp b/src/app/reporting/Engine.cpp index 79471aef7556f6..e432af81147aff 100644 --- a/src/app/reporting/Engine.cpp +++ b/src/app/reporting/Engine.cpp @@ -314,7 +314,7 @@ CHIP_ERROR Engine::BuildAndSendSingleReportData(ReadHandler * apReadHandler) ReportDataMessage::Builder reportDataBuilder; chip::System::PacketBufferHandle bufHandle = System::PacketBufferHandle::New(chip::app::kMaxSecureSduLengthBytes); uint16_t reservedSize = 0; - bool hasMoreChunks; + bool hasMoreChunks = false; // Reserved size for the MoreChunks boolean flag, which takes up 1 byte for the control tag and 1 byte for the context tag. const uint32_t kReservedSizeForMoreChunksFlag = 1 + 1; diff --git a/src/app/tests/TestReadInteraction.cpp b/src/app/tests/TestReadInteraction.cpp index 7ac4043d85fe02..fb019d3a6cb6cd 100644 --- a/src/app/tests/TestReadInteraction.cpp +++ b/src/app/tests/TestReadInteraction.cpp @@ -1319,7 +1319,7 @@ void TestReadInteraction::TestSubscribeInvalidAttributePathRoundtrip(nlTestSuite readPrepareParams.mAttributePathParamsListSize = 1; - readPrepareParams.mSessionHandle = ctx.GetSessionBobToAlice(); + readPrepareParams.mSessionHolder.Grab(ctx.GetSessionBobToAlice()); readPrepareParams.mMinIntervalFloorSeconds = 2; readPrepareParams.mMaxIntervalCeilingSeconds = 5; printf("\nSend subscribe request message to Node: %" PRIu64 "\n", chip::kTestDeviceNodeId); @@ -1406,7 +1406,7 @@ void TestReadInteraction::TestSubscribeInvalidIterval(nlTestSuite * apSuite, voi readPrepareParams.mAttributePathParamsListSize = 1; - readPrepareParams.mSessionHandle = ctx.GetSessionBobToAlice(); + readPrepareParams.mSessionHolder.Grab(ctx.GetSessionBobToAlice()); readPrepareParams.mMinIntervalFloorSeconds = 6; readPrepareParams.mMaxIntervalCeilingSeconds = 5; diff --git a/src/app/tests/TestWriteInteraction.cpp b/src/app/tests/TestWriteInteraction.cpp index 3cc12e88e78c93..ee167fddb16ed9 100644 --- a/src/app/tests/TestWriteInteraction.cpp +++ b/src/app/tests/TestWriteInteraction.cpp @@ -257,7 +257,7 @@ void TestWriteInteraction::TestWriteClientGroup(nlTestSuite * apSuite, void * ap AddAttributeDataIB(apSuite, apContext, writeClientHandle); SessionHandle groupSession = ctx.GetSessionBobToFriends(); - NL_TEST_ASSERT(apSuite, groupSession.IsGroupSession()); + NL_TEST_ASSERT(apSuite, groupSession->IsGroupSession()); err = writeClientHandle.SendWriteRequest(groupSession); diff --git a/src/app/util/af-types.h b/src/app/util/af-types.h index cd838654dbfa55..e2049248875d8b 100644 --- a/src/app/util/af-types.h +++ b/src/app/util/af-types.h @@ -324,7 +324,7 @@ typedef struct */ struct EmberAfClusterCommand { - chip::NodeId SourceNodeId() const { return source->GetSessionHandle().GetPeerNodeId(); } + chip::NodeId SourceNodeId() const { return source->GetSessionHandle()->AsSecureSession()->GetPeerNodeId(); } /** * APS frame for the incoming message diff --git a/src/app/util/util.cpp b/src/app/util/util.cpp index 22f6d3794d5c30..67621dcfba7552 100644 --- a/src/app/util/util.cpp +++ b/src/app/util/util.cpp @@ -448,7 +448,7 @@ static bool dispatchZclMessage(EmberAfClusterCommand * cmd) } #ifdef EMBER_AF_PLUGIN_GROUPS_SERVER else if ((cmd->type == EMBER_INCOMING_MULTICAST || cmd->type == EMBER_INCOMING_MULTICAST_LOOPBACK) && - !emberAfGroupsClusterEndpointInGroupCallback(cmd->source->GetSessionHandle().GetFabricIndex(), + !emberAfGroupsClusterEndpointInGroupCallback(cmd->source->GetSessionHandle()->AsSecureSession()->GetFabricIndex(), cmd->apsFrame->destinationEndpoint, cmd->apsFrame->groupId)) { emberAfDebugPrint("Drop cluster " ChipLogFormatMEI " command " ChipLogFormatMEI, ChipLogValueMEI(cmd->apsFrame->clusterId), diff --git a/src/controller/CHIPCluster.cpp b/src/controller/CHIPCluster.cpp index 4d2ca1ddf44446..34a3f681261b39 100644 --- a/src/controller/CHIPCluster.cpp +++ b/src/controller/CHIPCluster.cpp @@ -52,15 +52,14 @@ CHIP_ERROR ClusterBase::AssociateWithGroup(DeviceProxy * device, GroupId groupId if (mDevice->GetSecureSession().HasValue()) { // Local copy to preserve original SessionHandle for future Unicast communication. - SessionHandle session = mDevice->GetSecureSession().Value(); - session.SetGroupId(groupId); - mSessionHandle.SetValue(session); - + Optional session = mDevice->GetExchangeManager()->GetSessionManager()->CreateGroupSession(groupId); // Sanity check - if (!mSessionHandle.Value().IsGroupSession()) + if (!session.HasValue() || !session.Value()->IsGroupSession()) { err = CHIP_ERROR_INCORRECT_STATE; } + + mGroupSession.Grab(session.Value()); } else { diff --git a/src/controller/CHIPCluster.h b/src/controller/CHIPCluster.h index 4d9452def8f0d0..ca02eedb9f4a8e 100644 --- a/src/controller/CHIPCluster.h +++ b/src/controller/CHIPCluster.h @@ -139,9 +139,17 @@ class DLL_EXPORT ClusterBase } }; - return chip::Controller::WriteAttribute( - (mSessionHandle.HasValue() ? mSessionHandle.Value() : mDevice->GetSecureSession().Value()), mEndpoint, clusterId, - attributeId, requestData, onSuccessCb, onFailureCb, aTimedWriteTimeoutMs, onDoneCb); + if (mGroupSession) + { + return chip::Controller::WriteAttribute(mGroupSession.Get(), mEndpoint, clusterId, attributeId, requestData, + onSuccessCb, onFailureCb, aTimedWriteTimeoutMs, onDoneCb); + } + else + { + return chip::Controller::WriteAttribute(mDevice->GetSecureSession().Value(), mEndpoint, clusterId, + attributeId, requestData, onSuccessCb, onFailureCb, + aTimedWriteTimeoutMs, onDoneCb); + } } template @@ -262,7 +270,7 @@ class DLL_EXPORT ClusterBase const ClusterId mClusterId; DeviceProxy * mDevice; EndpointId mEndpoint; - chip::Optional mSessionHandle; + SessionHolder mGroupSession; }; } // namespace Controller diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index b5e77c032f2adc..d2fae9c15bd373 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -268,17 +268,11 @@ void DeviceController::ReleaseOperationalDevice(NodeId remoteDeviceId) mCASESessionManager->ReleaseSession(mFabricInfo->GetPeerIdForNode(remoteDeviceId)); } -void DeviceController::OnSessionReleased(const SessionHandle & session) -{ - VerifyOrReturn(mState == State::Initialized, ChipLogError(Controller, "OnConnectionExpired was called in incorrect state")); - mCASESessionManager->OnSessionReleased(session); -} - void DeviceController::OnFirstMessageDeliveryFailed(const SessionHandle & session) { VerifyOrReturn(mState == State::Initialized, ChipLogError(Controller, "OnFirstMessageDeliveryFailed was called in incorrect state")); - UpdateDevice(session.GetPeerNodeId()); + UpdateDevice(session->AsSecureSession()->GetPeerNodeId()); } CHIP_ERROR DeviceController::InitializePairedDeviceList() @@ -622,7 +616,6 @@ CHIP_ERROR DeviceCommissioner::Init(CommissionerInitParams params) { ReturnErrorOnFailure(DeviceController::Init(params)); - params.systemState->SessionMgr()->RegisterReleaseDelegate(*this); params.systemState->SessionMgr()->RegisterRecoveryDelegate(*this); uint16_t nextKeyID = 0; @@ -686,16 +679,6 @@ CHIP_ERROR DeviceCommissioner::Shutdown() return CHIP_NO_ERROR; } -void DeviceCommissioner::OnSessionReleased(const SessionHandle & session) -{ - VerifyOrReturn(mState == State::Initialized, ChipLogError(Controller, "OnConnectionExpired was called in incorrect state")); - - CommissioneeDeviceProxy * device = FindCommissioneeDevice(session); - VerifyOrReturn(device != nullptr, ChipLogDetail(Controller, "OnConnectionExpired was called for unknown device, ignoring it.")); - - device->OnSessionReleased(session); -} - CommissioneeDeviceProxy * DeviceCommissioner::FindCommissioneeDevice(const SessionHandle & session) { CommissioneeDeviceProxy * foundDevice = nullptr; @@ -1198,10 +1181,8 @@ CHIP_ERROR DeviceCommissioner::ValidateAttestationInfo(const ByteSpan & attestat DeviceAttestationVerifier * dac_verifier = GetDeviceAttestationVerifier(); // Retrieve attestation challenge - ByteSpan attestationChallenge = mSystemState->SessionMgr() - ->GetSecureSession(mDeviceBeingCommissioned->GetSecureSession().Value()) - ->GetCryptoContext() - .GetAttestationChallenge(); + ByteSpan attestationChallenge = + mDeviceBeingCommissioned->GetSecureSession().Value()->AsSecureSession()->GetCryptoContext().GetAttestationChallenge(); dac_verifier->VerifyAttestationInformation(attestationElements, attestationChallenge, signature, mDeviceBeingCommissioned->GetPAI(), mDeviceBeingCommissioned->GetDAC(), @@ -1361,10 +1342,8 @@ CHIP_ERROR DeviceCommissioner::ProcessOpCSR(const ByteSpan & NOCSRElements, cons ReturnErrorOnFailure(ExtractPubkeyFromX509Cert(device->GetDAC(), dacPubkey)); // Retrieve attestation challenge - ByteSpan attestationChallenge = mSystemState->SessionMgr() - ->GetSecureSession(device->GetSecureSession().Value()) - ->GetCryptoContext() - .GetAttestationChallenge(); + ByteSpan attestationChallenge = + device->GetSecureSession().Value()->AsSecureSession()->GetCryptoContext().GetAttestationChallenge(); // The operational CA should also verify this on its end during NOC generation, if end-to-end attestation is desired. ReturnErrorOnFailure(dacVerifier->VerifyNodeOperationalCSRInformation(NOCSRElements, attestationChallenge, AttestationSignature, diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h index f9437e75b806a9..22a61010ae7167 100644 --- a/src/controller/CHIPDeviceController.h +++ b/src/controller/CHIPDeviceController.h @@ -173,8 +173,7 @@ typedef void (*OnOpenCommissioningWindow)(void * context, NodeId deviceId, CHIP_ * and device pairing information for individual devices). Alternatively, this class can retrieve the * relevant information when the application tries to communicate with the device */ -class DLL_EXPORT DeviceController : public SessionReleaseDelegate, - public SessionRecoveryDelegate, +class DLL_EXPORT DeviceController : public SessionRecoveryDelegate, #if CHIP_DEVICE_CONFIG_ENABLE_DNSSD public AbstractDnssdDiscoveryController, #endif @@ -377,9 +376,6 @@ class DLL_EXPORT DeviceController : public SessionReleaseDelegate, ReliableMessageProtocolConfig mMRPConfig = gDefaultMRPConfig; - //////////// SessionReleaseDelegate Implementation /////////////// - void OnSessionReleased(const SessionHandle & session) override; - //////////// SessionRecoveryDelegate Implementation /////////////// void OnFirstMessageDeliveryFailed(const SessionHandle & session) override; @@ -687,9 +683,6 @@ class DLL_EXPORT DeviceCommissioner : public DeviceController, void OnSessionEstablishmentTimeout(); - //////////// SessionReleaseDelegate Implementation /////////////// - void OnSessionReleased(const SessionHandle & session) override; - static void OnSessionEstablishmentTimeoutCallback(System::Layer * aLayer, void * aAppState); /* This function sends a Device Attestation Certificate chain request to the device. diff --git a/src/controller/CommissioneeDeviceProxy.cpp b/src/controller/CommissioneeDeviceProxy.cpp index 5c33fb2834a62c..bebd673c3d2c22 100644 --- a/src/controller/CommissioneeDeviceProxy.cpp +++ b/src/controller/CommissioneeDeviceProxy.cpp @@ -58,9 +58,8 @@ CHIP_ERROR CommissioneeDeviceProxy::LoadSecureSessionParametersIfNeeded(bool & d { if (mSecureSession) { - Transport::SecureSession * secureSession = mSessionManager->GetSecureSession(mSecureSession.Get()); // Check if the connection state has the correct transport information - if (secureSession->GetPeerAddress().GetTransportType() == Transport::Type::kUndefined) + if (mSecureSession->AsSecureSession()->GetPeerAddress().GetTransportType() == Transport::Type::kUndefined) { mState = ConnectionState::NotConnected; ReturnErrorOnFailure(LoadSecureSessionParameters()); @@ -86,12 +85,9 @@ CHIP_ERROR CommissioneeDeviceProxy::SendCommands(app::CommandSender * commandObj return commandObj->SendCommandRequest(mSecureSession.Get()); } -void CommissioneeDeviceProxy::OnSessionReleased(const SessionHandle & session) +void CommissioneeDeviceProxy::OnSessionReleased() { - VerifyOrReturn(mSecureSession.Contains(session), - ChipLogDetail(Controller, "Connection expired, but it doesn't match the current session")); mState = ConnectionState::NotConnected; - mSecureSession.Release(); } CHIP_ERROR CommissioneeDeviceProxy::CloseSession() @@ -130,7 +126,7 @@ CHIP_ERROR CommissioneeDeviceProxy::UpdateDeviceData(const Transport::PeerAddres return CHIP_NO_ERROR; } - Transport::SecureSession * secureSession = mSessionManager->GetSecureSession(mSecureSession.Get()); + Transport::SecureSession * secureSession = mSecureSession.Get()->AsSecureSession(); secureSession->SetPeerAddress(addr); return CHIP_NO_ERROR; @@ -155,6 +151,7 @@ void CommissioneeDeviceProxy::Reset() CHIP_ERROR CommissioneeDeviceProxy::LoadSecureSessionParameters() { CHIP_ERROR err = CHIP_NO_ERROR; + SessionHolder sessionHolder; if (mSessionManager == nullptr || mState == ConnectionState::SecureConnected) { diff --git a/src/controller/CommissioneeDeviceProxy.h b/src/controller/CommissioneeDeviceProxy.h index b8183bee02a11f..200e5913b6e11e 100644 --- a/src/controller/CommissioneeDeviceProxy.h +++ b/src/controller/CommissioneeDeviceProxy.h @@ -84,7 +84,7 @@ class CommissioneeDeviceProxy : public DeviceProxy, public SessionReleaseDelegat { public: ~CommissioneeDeviceProxy(); - CommissioneeDeviceProxy() {} + CommissioneeDeviceProxy() : mSecureSession(*this) {} CommissioneeDeviceProxy(const CommissioneeDeviceProxy &) = delete; /** @@ -164,7 +164,7 @@ class CommissioneeDeviceProxy : public DeviceProxy, public SessionReleaseDelegat * * @param session A handle to the secure session */ - void OnSessionReleased(const SessionHandle & session) override; + void OnSessionReleased() override; /** * In case there exists an open session to the device, mark it as expired. @@ -298,7 +298,7 @@ class CommissioneeDeviceProxy : public DeviceProxy, public SessionReleaseDelegat Messaging::ExchangeManager * mExchangeMgr = nullptr; - SessionHolder mSecureSession; + SessionHolderWithDelegate mSecureSession; Controller::DeviceControllerInteractionModelDelegate * mpIMDelegate = nullptr; diff --git a/src/controller/WriteInteraction.h b/src/controller/WriteInteraction.h index a79d820616f954..4f2472b5ca27ea 100644 --- a/src/controller/WriteInteraction.h +++ b/src/controller/WriteInteraction.h @@ -111,7 +111,7 @@ CHIP_ERROR WriteAttribute(const SessionHandle & sessionHandle, chip::EndpointId // called. callback.release(); - if (sessionHandle.IsGroupSession()) + if (sessionHandle->IsGroupSession()) { ReturnErrorOnFailure( handle.EncodeAttributeWritePayload(chip::app::AttributePathParams(clusterId, attributeId), requestData)); diff --git a/src/controller/tests/data_model/TestRead.cpp b/src/controller/tests/data_model/TestRead.cpp index 1c3556ca0961ad..113586d7ada6c5 100644 --- a/src/controller/tests/data_model/TestRead.cpp +++ b/src/controller/tests/data_model/TestRead.cpp @@ -276,7 +276,7 @@ void TestReadInteraction::TestReadAttributeTimeout(nlTestSuite * apSuite, void * NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 2); - ctx.GetExchangeManager().ExpireExchangesForSession(ctx.GetSessionBobToAlice()); + ctx.ExpireSessionBobToAlice(); ctx.DrainAndServiceIO(); @@ -291,7 +291,7 @@ void TestReadInteraction::TestReadAttributeTimeout(nlTestSuite * apSuite, void * chip::app::InteractionModelEngine::GetInstance()->GetReportingEngine().Run(); ctx.DrainAndServiceIO(); - ctx.GetExchangeManager().ExpireExchangesForSession(ctx.GetSessionAliceToBob()); + ctx.ExpireSessionAliceToBob(); NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveReadHandlers() == 0); diff --git a/src/lib/core/CHIPConfig.h b/src/lib/core/CHIPConfig.h index 40d7f48fae977e..aa6edf837eeb1a 100644 --- a/src/lib/core/CHIPConfig.h +++ b/src/lib/core/CHIPConfig.h @@ -2306,6 +2306,15 @@ #define CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE 4 #endif // CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE +/** + * @def CHIP_CONFIG_GROUP_CONNECTION_POOL_SIZE + * + * @brief Define the size of the pool used for tracking CHIP groups. + */ +#ifndef CHIP_CONFIG_GROUP_CONNECTION_POOL_SIZE +#define CHIP_CONFIG_GROUP_CONNECTION_POOL_SIZE 8 +#endif // CHIP_CONFIG_GROUP_CONNECTION_POOL_SIZE + /** * @def CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE * diff --git a/src/lib/support/ReferenceCountedHandle.h b/src/lib/support/ReferenceCountedHandle.h index c433a639db8d7a..2a54e17b3c9ba2 100644 --- a/src/lib/support/ReferenceCountedHandle.h +++ b/src/lib/support/ReferenceCountedHandle.h @@ -38,7 +38,7 @@ class ReferenceCountedHandle bool operator==(const ReferenceCountedHandle & that) const { return &mTarget == &that.mTarget; } bool operator!=(const ReferenceCountedHandle & that) const { return !(*this == that); } - Target * operator->() { return &mTarget; } + Target * operator->() const { return &mTarget; } Target & Get() const { return mTarget; } private: diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index 5dc287c1438aa5..3a5cfb477ee7c7 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -87,7 +87,7 @@ void ExchangeContext::SetResponseTimeout(Timeout timeout) #if CONFIG_DEVICE_LAYER && CHIP_DEVICE_CONFIG_ENABLE_SED void ExchangeContext::UpdateSEDPollingMode() { - if (GetSessionHandle().GetPeerAddress(mExchangeMgr->GetSessionManager())->GetTransportType() != Transport::Type::kBle) + if (GetSessionHandle()->AsSecureSession()->GetPeerAddress().GetTransportType() != Transport::Type::kBle) { if (!IsResponseExpected() && !IsSendExpected() && (mExchangeMgr->GetNumActiveExchanges() == 1)) { @@ -125,10 +125,8 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp // an error arising below. at the end, we have to close it. ExchangeHandle ref(*this); - bool isUDPTransport = IsUDPTransport(); - - // this check is ignored by the ExchangeMsgDispatch if !AutoRequestAck() - bool reliableTransmissionRequested = isUDPTransport && !sendFlags.Has(SendMessageFlags::kNoAutoRequestAck); + // If session requires MRP and NoAutoRequestAck send flag is not specificed, request reliable transmission. + bool reliableTransmissionRequested = GetSessionHandle()->RequireMRP() && !sendFlags.Has(SendMessageFlags::kNoAutoRequestAck); // If a response message is expected... if (sendFlags.Has(SendMessageFlags::kExpectResponse) && !IsGroupExchangeContext()) @@ -252,7 +250,8 @@ void ExchangeContextDeletor::Release(ExchangeContext * ec) ExchangeContext::ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, const SessionHandle & session, bool Initiator, ExchangeDelegate * delegate) : - mDispatch((delegate != nullptr) ? delegate->GetMessageDispatch() : ApplicationExchangeDispatch::Instance()) + mDispatch((delegate != nullptr) ? delegate->GetMessageDispatch() : ApplicationExchangeDispatch::Instance()), + mSession(*this) { VerifyOrDie(mExchangeMgr == nullptr); @@ -267,7 +266,7 @@ ExchangeContext::ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, cons SetMsgRcvdFromPeer(false); // Do not request Ack for multicast - SetAutoRequestAck(!session.IsGroupSession()); + SetAutoRequestAck(!session->IsGroupSession()); #if defined(CHIP_EXCHANGE_CONTEXT_DETAIL_LOGGING) ChipLogDetail(ExchangeManager, "ec++ id: " ChipLogFormatExchange, ChipLogValueExchange(this)); @@ -316,14 +315,8 @@ bool ExchangeContext::MatchExchange(const SessionHandle & session, const PacketH && (payloadHeader.IsInitiator() != IsInitiator()); } -void ExchangeContext::OnConnectionExpired() +void ExchangeContext::OnSessionReleased() { - // Reset our mSession to a default-initialized (hence not matching any - // connection state) value, because it's still referencing the now-expired - // connection. This will mean that no more messages can be sent via this - // exchange, which seems fine given the semantics of connection expiration. - mSession.Release(); - if (!IsResponseExpected()) { // Nothing to do in this case @@ -488,43 +481,5 @@ void ExchangeContext::MessageHandled() Close(); } -bool ExchangeContext::IsUDPTransport() -{ - const Transport::PeerAddress * peerAddress = GetSessionHandle().GetPeerAddress(mExchangeMgr->GetSessionManager()); - return peerAddress && peerAddress->GetTransportType() == Transport::Type::kUdp; -} - -bool ExchangeContext::IsTCPTransport() -{ - const Transport::PeerAddress * peerAddress = GetSessionHandle().GetPeerAddress(mExchangeMgr->GetSessionManager()); - return peerAddress && peerAddress->GetTransportType() == Transport::Type::kTcp; -} - -bool ExchangeContext::IsBLETransport() -{ - const Transport::PeerAddress * peerAddress = GetSessionHandle().GetPeerAddress(mExchangeMgr->GetSessionManager()); - return peerAddress && peerAddress->GetTransportType() == Transport::Type::kBle; -} - -System::Clock::Milliseconds32 ExchangeContext::GetAckTimeout() -{ - System::Clock::Timeout timeout; - if (IsUDPTransport()) - { - timeout = GetMRPConfig().mIdleRetransTimeout * (CHIP_CONFIG_RMP_DEFAULT_MAX_RETRANS + 1); - } - else if (IsTCPTransport()) - { - // TODO: issue 12009, need actual tcp margin value considering restransmission - timeout = System::Clock::Seconds16(30); - } - return timeout; -} - -const ReliableMessageProtocolConfig & ExchangeContext::GetMRPConfig() const -{ - return GetSessionHandle().GetMRPConfig(GetExchangeMgr()->GetSessionManager()); -} - } // namespace Messaging } // namespace chip diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h index 57c704ce12185e..c5ec80a44b7200 100644 --- a/src/messaging/ExchangeContext.h +++ b/src/messaging/ExchangeContext.h @@ -55,7 +55,9 @@ class ExchangeContextDeletor * It defines methods for encoding and communicating CHIP messages within an ExchangeContext * over various transport mechanisms, for example, TCP, UDP, or CHIP Reliable Messaging. */ -class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public ReferenceCounted +class DLL_EXPORT ExchangeContext : public ReliableMessageContext, + public ReferenceCounted, + public SessionReleaseDelegate { friend class ExchangeManager; friend class ExchangeContextDeletor; @@ -77,7 +79,13 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public Referen bool IsEncryptionRequired() const { return mDispatch.IsEncryptionRequired(); } - bool IsGroupExchangeContext() const { return (mSession && mSession.Get().IsGroupSession()); } + bool IsGroupExchangeContext() const + { + return (mSession && mSession->GetSessionType() == Transport::Session::SessionType::kGroup); + } + + // Implement SessionReleaseDelegate + void OnSessionReleased() override; /** * Send a CHIP message on this exchange. @@ -167,18 +175,6 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public Referen void SetResponseTimeout(Timeout timeout); - // TODO: move following 5 functions into SessionHandle once we can access session vars w/o using a SessionManager - /* - * Get the overall acknowledge timeout period for the underneath transport(MRP+UDP/TCP) - */ - System::Clock::Milliseconds32 GetAckTimeout(); - - bool IsUDPTransport(); - bool IsTCPTransport(); - bool IsBLETransport(); - // Helper function for easily accessing MRP config - const ReliableMessageProtocolConfig & GetMRPConfig() const; - private: Timeout mResponseTimeout{ 0 }; // Maximum time to wait for response (in milliseconds); 0 disables response timeout. ExchangeDelegate * mDelegate = nullptr; @@ -186,8 +182,8 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public Referen ExchangeMessageDispatch & mDispatch; - SessionHolder mSession; // The connection state - uint16_t mExchangeId; // Assigned exchange ID. + SessionHolderWithDelegate mSession; // The connection state + uint16_t mExchangeId; // Assigned exchange ID. /** * Determine whether a response is currently expected for a message that was sent over @@ -229,11 +225,6 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public Referen */ bool MatchExchange(const SessionHandle & session, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader); - /** - * Notify the exchange that its connection has expired. - */ - void OnConnectionExpired(); - /** * Notify our delegate, if any, that we have timed out waiting for a * response. diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index 41963fe6c1512d..a6c8421f08847d 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -83,7 +83,6 @@ CHIP_ERROR ExchangeManager::Init(SessionManager * sessionManager) handler.Reset(); } - sessionManager->RegisterReleaseDelegate(*this); sessionManager->SetMessageDelegate(this); mReliableMessageMgr.Init(sessionManager->SystemLayer()); @@ -106,7 +105,6 @@ CHIP_ERROR ExchangeManager::Shutdown() if (mSessionManager != nullptr) { mSessionManager->SetMessageDelegate(nullptr); - mSessionManager->UnregisterReleaseDelegate(*this); mSessionManager = nullptr; } @@ -313,24 +311,6 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const } } -void ExchangeManager::OnSessionReleased(const SessionHandle & session) -{ - ExpireExchangesForSession(session); -} - -void ExchangeManager::ExpireExchangesForSession(const SessionHandle & session) -{ - mContextPool.ForEachActiveObject([&](auto * ec) { - if (ec->mSession.Contains(session)) - { - ec->OnConnectionExpired(); - // Continue to iterate because there can be multiple exchanges - // associated with the connection. - } - return Loop::Continue; - }); -} - void ExchangeManager::CloseAllContextsForDelegate(const ExchangeDelegate * delegate) { mContextPool.ForEachActiveObject([&](auto * ec) { diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h index 5cd2939e7d7b24..e4e5d0f723f77a 100644 --- a/src/messaging/ExchangeMgr.h +++ b/src/messaging/ExchangeMgr.h @@ -49,7 +49,7 @@ static constexpr int16_t kAnyMessageType = -1; * It works on be behalf of higher layers, creating ExchangeContexts and * handling the registration/unregistration of unsolicited message handlers. */ -class DLL_EXPORT ExchangeManager : public SessionMessageDelegate, public SessionReleaseDelegate +class DLL_EXPORT ExchangeManager : public SessionMessageDelegate { friend class ExchangeContext; @@ -193,10 +193,6 @@ class DLL_EXPORT ExchangeManager : public SessionMessageDelegate, public Session size_t GetNumActiveExchanges() { return mContextPool.Allocated(); } - // TODO: this should be test only, after OnSessionReleased is move to SessionHandle within the exchange context - // Expire all exchanges associated with the given session - void ExpireExchangesForSession(const SessionHandle & session); - private: enum class State { @@ -244,8 +240,6 @@ class DLL_EXPORT ExchangeManager : public SessionMessageDelegate, public Session void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const SessionHandle & session, const Transport::PeerAddress & source, DuplicateMessage isDuplicate, System::PacketBufferHandle && msgBuf) override; - - void OnSessionReleased(const SessionHandle & session) override; }; } // namespace Messaging diff --git a/src/messaging/ReliableMessageMgr.cpp b/src/messaging/ReliableMessageMgr.cpp index bc92ac837abd3e..4c696e891d7cd1 100644 --- a/src/messaging/ReliableMessageMgr.cpp +++ b/src/messaging/ReliableMessageMgr.cpp @@ -141,7 +141,8 @@ void ReliableMessageMgr::ExecuteActions() " Send Cnt %d", messageCounter, ChipLogValueExchange(&entry->ec.Get()), entry->sendCount); // TODO: Choose active/idle timeout corresponding to the activity of exchanges of the session. - entry->nextRetransTime = System::SystemClock().GetMonotonicTimestamp() + entry->ec->GetMRPConfig().mActiveRetransTimeout; + entry->nextRetransTime = + System::SystemClock().GetMonotonicTimestamp() + entry->ec->GetSessionHandle()->GetMRPConfig().mActiveRetransTimeout; SendFromRetransTable(entry); // For test not using async IO loop, the entry may have been removed after send, do not use entry below @@ -185,7 +186,8 @@ CHIP_ERROR ReliableMessageMgr::AddToRetransTable(ReliableMessageContext * rc, Re void ReliableMessageMgr::StartRetransmision(RetransTableEntry * entry) { // TODO: Choose active/idle timeout corresponding to the activity of exchanges of the session. - entry->nextRetransTime = System::SystemClock().GetMonotonicTimestamp() + entry->ec->GetMRPConfig().mIdleRetransTimeout; + entry->nextRetransTime = + System::SystemClock().GetMonotonicTimestamp() + entry->ec->GetSessionHandle()->GetMRPConfig().mIdleRetransTimeout; StartTimer(); } diff --git a/src/messaging/tests/MessagingContext.cpp b/src/messaging/tests/MessagingContext.cpp index 267cc9359750d8..39d1c51070e533 100644 --- a/src/messaging/tests/MessagingContext.cpp +++ b/src/messaging/tests/MessagingContext.cpp @@ -36,14 +36,11 @@ CHIP_ERROR MessagingContext::Init(TransportMgrBase * transport, IOContext * ioCo ReturnErrorOnFailure(mExchangeManager.Init(&mSessionManager)); ReturnErrorOnFailure(mMessageCounterManager.Init(&mExchangeManager)); - mSessionBobToFriends.Grab(mSessionManager.CreateGroupSession(GetBobKeyId(), GetFriendsGroupId(), GetFabricIndex()).Value()); + ReturnErrorOnFailure(CreateSessionBobToAlice()); + ReturnErrorOnFailure(CreateSessionAliceToBob()); + ReturnErrorOnFailure(CreateSessionBobToFriends()); - ReturnErrorOnFailure(mSessionManager.NewPairing(mSessionBobToAlice, Optional::Value(mAliceAddress), - GetAliceNodeId(), &mPairingBobToAlice, CryptoContext::SessionRole::kInitiator, - mSrcFabricIndex)); - - return mSessionManager.NewPairing(mSessionAliceToBob, Optional::Value(mBobAddress), GetBobNodeId(), - &mPairingAliceToBob, CryptoContext::SessionRole::kResponder, mDestFabricIndex); + return CHIP_NO_ERROR; } // Shutdown all layers, finalize operations @@ -71,6 +68,24 @@ CHIP_ERROR MessagingContext::ShutdownAndRestoreExisting(MessagingContext & exist return err; } +CHIP_ERROR MessagingContext::CreateSessionBobToAlice() +{ + return mSessionManager.NewPairing(mSessionBobToAlice, Optional::Value(mAliceAddress), GetAliceNodeId(), + &mPairingBobToAlice, CryptoContext::SessionRole::kInitiator, mSrcFabricIndex); +} + +CHIP_ERROR MessagingContext::CreateSessionAliceToBob() +{ + return mSessionManager.NewPairing(mSessionAliceToBob, Optional::Value(mBobAddress), GetBobNodeId(), + &mPairingAliceToBob, CryptoContext::SessionRole::kResponder, mDestFabricIndex); +} + +CHIP_ERROR MessagingContext::CreateSessionBobToFriends() +{ + mSessionBobToFriends.Grab(mSessionManager.CreateGroupSession(GetFriendsGroupId()).Value()); + return CHIP_NO_ERROR; +} + SessionHandle MessagingContext::GetSessionBobToAlice() { return mSessionBobToAlice.Get(); @@ -86,6 +101,21 @@ SessionHandle MessagingContext::GetSessionBobToFriends() return mSessionBobToFriends.Get(); } +void MessagingContext::ExpireSessionBobToAlice() +{ + mSessionManager.ExpirePairing(mSessionBobToAlice.Get()); +} + +void MessagingContext::ExpireSessionAliceToBob() +{ + mSessionManager.ExpirePairing(mSessionAliceToBob.Get()); +} + +void MessagingContext::ExpireSessionBobToFriends() +{ + // TODO: expire the group session +} + Messaging::ExchangeContext * MessagingContext::NewUnauthenticatedExchangeToAlice(Messaging::ExchangeDelegate * delegate) { return mExchangeManager.NewContext(mSessionManager.CreateUnauthenticatedSession(mAliceAddress, gDefaultMRPConfig).Value(), @@ -100,13 +130,11 @@ Messaging::ExchangeContext * MessagingContext::NewUnauthenticatedExchangeToBob(M Messaging::ExchangeContext * MessagingContext::NewExchangeToAlice(Messaging::ExchangeDelegate * delegate) { - // TODO: temporary create a SessionHandle from node id, will be fix in PR 3602 return mExchangeManager.NewContext(GetSessionBobToAlice(), delegate); } Messaging::ExchangeContext * MessagingContext::NewExchangeToBob(Messaging::ExchangeDelegate * delegate) { - // TODO: temporary create a SessionHandle from node id, will be fix in PR 3602 return mExchangeManager.NewContext(GetSessionAliceToBob(), delegate); } diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h index ef3a261b47a067..be14c0a3f172fc 100644 --- a/src/messaging/tests/MessagingContext.h +++ b/src/messaging/tests/MessagingContext.h @@ -88,6 +88,14 @@ class MessagingContext Messaging::ExchangeManager & GetExchangeManager() { return mExchangeManager; } secure_channel::MessageCounterManager & GetMessageCounterManager() { return mMessageCounterManager; } + CHIP_ERROR CreateSessionBobToAlice(); + CHIP_ERROR CreateSessionAliceToBob(); + CHIP_ERROR CreateSessionBobToFriends(); + + void ExpireSessionBobToAlice(); + void ExpireSessionAliceToBob(); + void ExpireSessionBobToFriends(); + SessionHandle GetSessionBobToAlice(); SessionHandle GetSessionAliceToBob(); SessionHandle GetSessionBobToFriends(); diff --git a/src/messaging/tests/TestExchangeMgr.cpp b/src/messaging/tests/TestExchangeMgr.cpp index 965cd8fbd15073..9ee9be6954db8d 100644 --- a/src/messaging/tests/TestExchangeMgr.cpp +++ b/src/messaging/tests/TestExchangeMgr.cpp @@ -97,7 +97,7 @@ void CheckNewContextTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, ec1 != nullptr); NL_TEST_ASSERT(inSuite, ec1->IsInitiator() == true); NL_TEST_ASSERT(inSuite, ec1->GetExchangeId() != 0); - auto sessionPeerToLocal = ctx.GetSecureSessionManager().GetSecureSession(ec1->GetSessionHandle()); + auto sessionPeerToLocal = ec1->GetSessionHandle()->AsSecureSession(); NL_TEST_ASSERT(inSuite, sessionPeerToLocal->GetPeerNodeId() == ctx.GetBobNodeId()); NL_TEST_ASSERT(inSuite, sessionPeerToLocal->GetPeerSessionId() == ctx.GetBobKeyId()); NL_TEST_ASSERT(inSuite, ec1->GetDelegate() == &mockAppDelegate); @@ -105,7 +105,7 @@ void CheckNewContextTest(nlTestSuite * inSuite, void * inContext) ExchangeContext * ec2 = ctx.NewExchangeToAlice(&mockAppDelegate); NL_TEST_ASSERT(inSuite, ec2 != nullptr); NL_TEST_ASSERT(inSuite, ec2->GetExchangeId() > ec1->GetExchangeId()); - auto sessionLocalToPeer = ctx.GetSecureSessionManager().GetSecureSession(ec2->GetSessionHandle()); + auto sessionLocalToPeer = ec2->GetSessionHandle()->AsSecureSession(); NL_TEST_ASSERT(inSuite, sessionLocalToPeer->GetPeerNodeId() == ctx.GetAliceNodeId()); NL_TEST_ASSERT(inSuite, sessionLocalToPeer->GetPeerSessionId() == ctx.GetAliceKeyId()); @@ -121,7 +121,7 @@ void CheckSessionExpirationBasics(nlTestSuite * inSuite, void * inContext) ExchangeContext * ec1 = ctx.NewExchangeToBob(&sendDelegate); // Expire the session this exchange is supposedly on. - ctx.GetExchangeManager().ExpireExchangesForSession(ec1->GetSessionHandle()); + ctx.GetSecureSessionManager().ExpirePairing(ec1->GetSessionHandle()); MockAppDelegate receiveDelegate; CHIP_ERROR err = @@ -138,6 +138,9 @@ void CheckSessionExpirationBasics(nlTestSuite * inSuite, void * inContext) err = ctx.GetExchangeManager().UnregisterUnsolicitedMessageHandlerForType(Protocols::BDX::Id, kMsgType_TEST1); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + // recreate closed session. + NL_TEST_ASSERT(inSuite, ctx.CreateSessionAliceToBob() == CHIP_NO_ERROR); } void CheckSessionExpirationTimeout(nlTestSuite * inSuite, void * inContext) @@ -153,10 +156,12 @@ void CheckSessionExpirationTimeout(nlTestSuite * inSuite, void * inContext) ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, !sendDelegate.IsOnResponseTimeoutCalled); - // Expire the session this exchange is supposedly on. This should close the - // exchange. - ctx.GetExchangeManager().ExpireExchangesForSession(ec1->GetSessionHandle()); + // Expire the session this exchange is supposedly on. This should close the exchange. + ctx.GetSecureSessionManager().ExpirePairing(ec1->GetSessionHandle()); NL_TEST_ASSERT(inSuite, sendDelegate.IsOnResponseTimeoutCalled); + + // recreate closed session. + NL_TEST_ASSERT(inSuite, ctx.CreateSessionAliceToBob() == CHIP_NO_ERROR); } void CheckUmhRegistrationTest(nlTestSuite * inSuite, void * inContext) diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index d78d38fe750521..981c019880ed72 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -203,11 +203,10 @@ void CheckResendApplicationMessage(nlTestSuite * inSuite, void * inContext) ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); NL_TEST_ASSERT(inSuite, rm != nullptr); - exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(), - { - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL - }); + exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({ + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); // Let's drop the initial message gLoopback.mSentMessageCount = 0; @@ -269,11 +268,10 @@ void CheckCloseExchangeAndResendApplicationMessage(nlTestSuite * inSuite, void * ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); NL_TEST_ASSERT(inSuite, rm != nullptr); - exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(), - { - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL - }); + exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({ + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); // Let's drop the initial message gLoopback.mSentMessageCount = 0; @@ -330,11 +328,10 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext) ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); NL_TEST_ASSERT(inSuite, rm != nullptr); - exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(), - { - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL - }); + exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({ + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); mockSender.mMessageDispatch.mRetainMessageOnSend = false; @@ -421,11 +418,10 @@ void CheckResendApplicationMessageWithPeerExchange(nlTestSuite * inSuite, void * ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); NL_TEST_ASSERT(inSuite, rm != nullptr); - exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(), - { - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL - }); + exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({ + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); // Let's drop the initial message gLoopback.mSentMessageCount = 0; @@ -484,11 +480,10 @@ void CheckDuplicateMessageClosedExchange(nlTestSuite * inSuite, void * inContext ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); NL_TEST_ASSERT(inSuite, rm != nullptr); - exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(), - { - 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL - 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL - }); + exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({ + 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL + 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); // Let's not drop the message. Expectation is that it is received by the peer, but the ack is dropped gLoopback.mSentMessageCount = 0; @@ -556,11 +551,10 @@ void CheckResendSessionEstablishmentMessageWithPeerExchange(nlTestSuite * inSuit ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); NL_TEST_ASSERT(inSuite, rm != nullptr); - exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(), - { - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL - }); + exchange->GetSessionHandle()->AsUnauthenticatedSession()->SetMRPConfig({ + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); // Let's drop the initial message gLoopback.mSentMessageCount = 0; @@ -621,11 +615,10 @@ void CheckDuplicateMessage(nlTestSuite * inSuite, void * inContext) ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); NL_TEST_ASSERT(inSuite, rm != nullptr); - exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(), - { - 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL - 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL - }); + exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({ + 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL + 64_ms32, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); // Let's not drop the message. Expectation is that it is received by the peer, but the ack is dropped gLoopback.mSentMessageCount = 0; @@ -1120,11 +1113,10 @@ void CheckLostResponseWithPiggyback(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0); // Make sure that we resend our message before the other side does. - exchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(), - { - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL - }); + exchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({ + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); // We send a message, the other side sends an application-level response // (which is lost), then we do a retransmit that is acked, then the other @@ -1158,11 +1150,10 @@ void CheckLostResponseWithPiggyback(nlTestSuite * inSuite, void * inContext) // Make sure receiver resends after sender does, and there's enough of a gap // that we are very unlikely to actually trigger the resends on the receiver // when we trigger the resends on the sender. - mockReceiver.mExchange->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(), - { - 256_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL - 256_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL - }); + mockReceiver.mExchange->GetSessionHandle()->AsSecureSession()->SetMRPConfig({ + 256_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL + 256_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); // Now send a message from the other side, but drop it. gLoopback.mNumMessagesToDrop = 1; diff --git a/src/protocols/echo/Echo.h b/src/protocols/echo/Echo.h index 1f3429b85abba7..cd00d684f5e5bb 100644 --- a/src/protocols/echo/Echo.h +++ b/src/protocols/echo/Echo.h @@ -103,7 +103,7 @@ class DLL_EXPORT EchoClient : public Messaging::ExchangeDelegate Messaging::ExchangeManager * mExchangeMgr = nullptr; Messaging::ExchangeContext * mExchangeCtx = nullptr; EchoFunct OnEchoResponseReceived = nullptr; - Optional mSecureSession = Optional(); + SessionHolder mSecureSession; CHIP_ERROR OnMessageReceived(Messaging::ExchangeContext * ec, const PayloadHeader & payloadHeader, System::PacketBufferHandle && payload) override; diff --git a/src/protocols/echo/EchoClient.cpp b/src/protocols/echo/EchoClient.cpp index e8bf65cf28ef68..6805dabb67a995 100644 --- a/src/protocols/echo/EchoClient.cpp +++ b/src/protocols/echo/EchoClient.cpp @@ -39,7 +39,7 @@ CHIP_ERROR EchoClient::Init(Messaging::ExchangeManager * exchangeMgr, const Sess return CHIP_ERROR_INCORRECT_STATE; mExchangeMgr = exchangeMgr; - mSecureSession.SetValue(session); + mSecureSession.Grab(session); OnEchoResponseReceived = nullptr; mExchangeCtx = nullptr; @@ -71,7 +71,7 @@ CHIP_ERROR EchoClient::SendEchoRequest(System::PacketBufferHandle && payload, Me } // Create a new exchange context. - mExchangeCtx = mExchangeMgr->NewContext(mSecureSession.Value(), this); + mExchangeCtx = mExchangeMgr->NewContext(mSecureSession.Get(), this); if (mExchangeCtx == nullptr) { return CHIP_ERROR_NO_MEMORY; diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 6555232ccb863b..ff9b25b9ca2750 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -231,7 +231,7 @@ CHIP_ERROR CASESession::EstablishSession(const Transport::PeerAddress peerAddres mFabricInfo = fabric; mLocalMRPConfig = mrpConfig; - mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout + mExchangeCtxt->GetAckTimeout()); + mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout()); SetPeerAddress(peerAddress); SetPeerNodeId(peerNodeId); @@ -1438,7 +1438,7 @@ CHIP_ERROR CASESession::ValidateReceivedMessage(ExchangeContext * ec, const Payl else { mExchangeCtxt = ec; - mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout + mExchangeCtxt->GetAckTimeout()); + mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout()); } VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); diff --git a/src/protocols/secure_channel/MessageCounterManager.cpp b/src/protocols/secure_channel/MessageCounterManager.cpp index 9c3a4d99143518..4ccf8a0960a10d 100644 --- a/src/protocols/secure_channel/MessageCounterManager.cpp +++ b/src/protocols/secure_channel/MessageCounterManager.cpp @@ -104,15 +104,13 @@ CHIP_ERROR MessageCounterManager::OnMessageReceived(Messaging::ExchangeContext * void MessageCounterManager::OnResponseTimeout(Messaging::ExchangeContext * exchangeContext) { - Transport::SecureSession * state = mExchangeMgr->GetSessionManager()->GetSecureSession(exchangeContext->GetSessionHandle()); - - if (state != nullptr) + if (exchangeContext->HasSessionHandle()) { - state->GetSessionMessageCounter().GetPeerMessageCounter().SyncFailed(); + exchangeContext->GetSessionHandle()->AsSecureSession()->GetSessionMessageCounter().GetPeerMessageCounter().SyncFailed(); } else { - ChipLogError(SecureChannel, "Timed out! Failed to clear message counter synchronization status."); + ChipLogError(SecureChannel, "MCSP Timeout! On a already released session."); } } @@ -223,39 +221,27 @@ CHIP_ERROR MessageCounterManager::SendMsgCounterSyncReq(const SessionHandle & se CHIP_ERROR MessageCounterManager::SendMsgCounterSyncResp(Messaging::ExchangeContext * exchangeContext, FixedByteSpan challenge) { - CHIP_ERROR err = CHIP_NO_ERROR; - Transport::SecureSession * state = nullptr; System::PacketBufferHandle msgBuf; - uint8_t * msg = nullptr; - state = mExchangeMgr->GetSessionManager()->GetSecureSession(exchangeContext->GetSessionHandle()); - VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED); + VerifyOrDie(exchangeContext->HasSessionHandle()); // Allocate new buffer. msgBuf = MessagePacketBuffer::New(kSyncRespMsgSize); - VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_NO_MEMORY); - - msg = msgBuf->Start(); + VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_NO_MEMORY); { + uint8_t * msg = msgBuf->Start(); Encoding::LittleEndian::BufferWriter bbuf(msg, kSyncRespMsgSize); - bbuf.Put32(state->GetSessionMessageCounter().GetLocalMessageCounter().Value()); + bbuf.Put32( + exchangeContext->GetSessionHandle()->AsSecureSession()->GetSessionMessageCounter().GetLocalMessageCounter().Value()); bbuf.Put(challenge.data(), kChallengeSize); - VerifyOrExit(bbuf.Fit(), err = CHIP_ERROR_NO_MEMORY); + VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_NO_MEMORY); } msgBuf->SetDataLength(kSyncRespMsgSize); - err = exchangeContext->SendMessage(Protocols::SecureChannel::MsgType::MsgCounterSyncRsp, std::move(msgBuf), - Messaging::SendFlags(Messaging::SendMessageFlags::kNoAutoRequestAck)); - -exit: - if (err != CHIP_NO_ERROR) - { - ChipLogError(SecureChannel, "Failed to send message counter synchronization response with error:%s", ErrorStr(err)); - } - - return err; + return exchangeContext->SendMessage(Protocols::SecureChannel::MsgType::MsgCounterSyncRsp, std::move(msgBuf), + Messaging::SendFlags(Messaging::SendMessageFlags::kNoAutoRequestAck)); } CHIP_ERROR MessageCounterManager::HandleMsgCounterSyncReq(Messaging::ExchangeContext * exchangeContext, @@ -288,17 +274,14 @@ CHIP_ERROR MessageCounterManager::HandleMsgCounterSyncResp(Messaging::ExchangeCo { CHIP_ERROR err = CHIP_NO_ERROR; - Transport::SecureSession * state = nullptr; - uint32_t syncCounter = 0; + uint32_t syncCounter = 0; const uint8_t * resp = msgBuf->Start(); size_t resplen = msgBuf->DataLength(); ChipLogDetail(SecureChannel, "Received MsgCounterSyncResp response"); - // Find an active connection to the specified peer node - state = mExchangeMgr->GetSessionManager()->GetSecureSession(exchangeContext->GetSessionHandle()); - VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED); + VerifyOrDie(exchangeContext->HasSessionHandle()); VerifyOrExit(msgBuf->DataLength() == kSyncRespMsgSize, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); @@ -310,11 +293,12 @@ CHIP_ERROR MessageCounterManager::HandleMsgCounterSyncResp(Messaging::ExchangeCo // Verify that the response field matches the expected Challenge field for the exchange. err = - state->GetSessionMessageCounter().GetPeerMessageCounter().VerifyChallenge(syncCounter, FixedByteSpan(resp)); + exchangeContext->GetSessionHandle()->AsSecureSession()->GetSessionMessageCounter().GetPeerMessageCounter().VerifyChallenge( + syncCounter, FixedByteSpan(resp)); SuccessOrExit(err); // Process all queued incoming messages after message counter synchronization is completed. - ProcessPendingMessages(exchangeContext->GetSessionHandle().GetPeerNodeId()); + ProcessPendingMessages(exchangeContext->GetSessionHandle()->AsSecureSession()->GetPeerNodeId()); exit: if (err != CHIP_NO_ERROR) diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 5d1128badc01e0..4056ee1a44da31 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -321,7 +321,7 @@ CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t SuccessOrExit(err); mExchangeCtxt = exchangeCtxt; - mExchangeCtxt->SetResponseTimeout(kSpake2p_Response_Timeout + mExchangeCtxt->GetAckTimeout()); + mExchangeCtxt->SetResponseTimeout(kSpake2p_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout()); SetPeerAddress(peerAddress); SetPeerNodeId(NodeIdFromPAKEKeyId(mPasscodeID)); @@ -890,7 +890,7 @@ CHIP_ERROR PASESession::ValidateReceivedMessage(ExchangeContext * exchange, cons else { mExchangeCtxt = exchange; - mExchangeCtxt->SetResponseTimeout(kSpake2p_Response_Timeout + mExchangeCtxt->GetAckTimeout()); + mExchangeCtxt->SetResponseTimeout(kSpake2p_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout()); } VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index 3f50a98f2b5a95..1f2a6729f4db90 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -181,11 +181,10 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P NL_TEST_ASSERT(inSuite, rm != nullptr); NL_TEST_ASSERT(inSuite, rc != nullptr); - contextCommissioner->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(), - { - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL - }); + contextCommissioner->GetSessionHandle()->AsUnauthenticatedSession()->SetMRPConfig({ + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); } NL_TEST_ASSERT(inSuite, @@ -309,11 +308,10 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, rm != nullptr); NL_TEST_ASSERT(inSuite, rc != nullptr); - contextCommissioner->GetSessionHandle().SetMRPConfig(&ctx.GetSecureSessionManager(), - { - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL - 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL - }); + contextCommissioner->GetSessionHandle()->AsUnauthenticatedSession()->SetMRPConfig({ + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_IDLE_RETRY_INTERVAL + 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); NL_TEST_ASSERT(inSuite, ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType( diff --git a/src/transport/BUILD.gn b/src/transport/BUILD.gn index 106ac723082075..7ea65702fe84dd 100644 --- a/src/transport/BUILD.gn +++ b/src/transport/BUILD.gn @@ -24,6 +24,7 @@ static_library("transport") { sources = [ "CryptoContext.cpp", "CryptoContext.h", + "GroupSession.h", "MessageCounter.cpp", "MessageCounter.h", "MessageCounterManagerInterface.h", @@ -31,11 +32,16 @@ static_library("transport") { "PeerMessageCounter.h", "SecureMessageCodec.cpp", "SecureMessageCodec.h", + "SecureSession.cpp", "SecureSession.h", "SecureSessionTable.h", + "Session.cpp", + "Session.h", "SessionDelegate.h", "SessionHandle.cpp", "SessionHandle.h", + "SessionHolder.cpp", + "SessionHolder.h", "SessionManager.cpp", "SessionManager.h", "SessionMessageCounter.h", diff --git a/src/transport/GroupSession.h b/src/transport/GroupSession.h new file mode 100644 index 00000000000000..f0fd85a9b66a8e --- /dev/null +++ b/src/transport/GroupSession.h @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace chip { +namespace Transport { + +class GroupSession : public Session +{ +public: + GroupSession(GroupId group, FabricIndex fabricIndex) : mGroupId(group), mFabricIndex(fabricIndex) {} + ~GroupSession() { NotifySessionReleased(); } + + Session::SessionType GetSessionType() const override { return Session::SessionType::kGroup; } +#if CHIP_PROGRESS_LOGGING + const char * GetSessionTypeString() const override { return "secure"; }; +#endif + + Access::SubjectDescriptor GetSubjectDescriptor() const override + { + Access::SubjectDescriptor isd; + isd.authMode = Access::AuthMode::kGroup; + // TODO: fill other group subjects fields + return isd; // return an empty ISD for unauthenticated session. + } + + bool RequireMRP() const override { return false; } + + const ReliableMessageProtocolConfig & GetMRPConfig() const override + { + VerifyOrDie(false); + return gDefaultMRPConfig; + } + + System::Clock::Milliseconds32 GetAckTimeout() const override + { + VerifyOrDie(false); + return System::Clock::Timeout(); + } + + GroupId GetGroupId() const { return mGroupId; } + FabricIndex GetFabricIndex() const { return mFabricIndex; } + +private: + const GroupId mGroupId; + const FabricIndex mFabricIndex; +}; + +/* + * @brief + * An table which manages GroupSessions + */ +template +class GroupSessionTable +{ +public: + ~GroupSessionTable() { mEntries.ReleaseAll(); } + + /** + * Get a session given the peer address. If the session doesn't exist in the cache, allocate a new entry for it. + * + * @return the session found or allocated, nullptr if not found and allocation failed. + */ + CHECK_RETURN_VALUE + Optional AllocEntry(GroupId group, FabricIndex fabricIndex) + { + GroupSession * entry = mEntries.CreateObject(group, fabricIndex); + if (entry != nullptr) + { + return MakeOptional(*entry); + } + else + { + return Optional::Missing(); + } + } + + /** + * Get a session using given GroupId + */ + CHECK_RETURN_VALUE + Optional FindEntry(GroupId group, FabricIndex fabricIndex) + { + GroupSession * result = nullptr; + mEntries.ForEachActiveObject([&](GroupSession * entry) { + if (entry->GetGroupId() == group && entry->GetFabricIndex() == fabricIndex) + { + result = entry; + return Loop::Break; + } + return Loop::Continue; + }); + if (result != nullptr) + { + return MakeOptional(*result); + } + else + { + return Optional::Missing(); + } + } + +private: + BitMapObjectPool mEntries; +}; + +} // namespace Transport +} // namespace chip diff --git a/src/transport/SecureSession.cpp b/src/transport/SecureSession.cpp new file mode 100644 index 00000000000000..d9e27361cbc6ce --- /dev/null +++ b/src/transport/SecureSession.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace chip { +namespace Transport { + +Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const +{ + Access::SubjectDescriptor subjectDescriptor; + if (IsOperationalNodeId(mPeerNodeId)) + { + subjectDescriptor.authMode = Access::AuthMode::kCase; + subjectDescriptor.subject = mPeerNodeId; + subjectDescriptor.fabricIndex = mFabric; + // TODO(#10243): add CATs + } + else if (IsPAKEKeyId(mPeerNodeId)) + { + subjectDescriptor.authMode = Access::AuthMode::kPase; + subjectDescriptor.subject = mPeerNodeId; + // TODO(#10242): PASE *can* have fabric in some situations + } + else + { + VerifyOrDie(false); + } + return subjectDescriptor; +} + +} // namespace Transport +} // namespace chip diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h index 84b0dbe85195fa..05efecc190ae32 100644 --- a/src/transport/SecureSession.h +++ b/src/transport/SecureSession.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -47,10 +48,8 @@ static constexpr uint32_t kUndefinedMessageIndex = UINT32_MAX; * - LastActivityTime is a monotonic timestamp of when this connection was * last used. Inactive connections can expire. * - CryptoContext contains the encryption context of a connection - * - * TODO: to add any message ACK information */ -class SecureSession +class SecureSession : public Session { public: /** @@ -70,14 +69,37 @@ class SecureSession mPeerNodeId(peerNodeId), mPeerCATs(peerCATs), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId), mFabric(fabric), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config) {} + ~SecureSession() { NotifySessionReleased(); } SecureSession(SecureSession &&) = delete; SecureSession(const SecureSession &) = delete; SecureSession & operator=(const SecureSession &) = delete; SecureSession & operator=(SecureSession &&) = delete; + Session::SessionType GetSessionType() const override { return Session::SessionType::kSecure; } +#if CHIP_PROGRESS_LOGGING + const char * GetSessionTypeString() const override { return "secure"; }; +#endif + + Access::SubjectDescriptor GetSubjectDescriptor() const override; + + bool RequireMRP() const override { return GetPeerAddress().GetTransportType() == Transport::Type::kUdp; } + + System::Clock::Milliseconds32 GetAckTimeout() const override + { + switch (mPeerAddress.GetTransportType()) + { + case Transport::Type::kUdp: + return GetMRPConfig().mIdleRetransTimeout * (CHIP_CONFIG_RMP_DEFAULT_MAX_RETRANS + 1); + case Transport::Type::kTcp: + return System::Clock::Seconds16(30); + default: + break; + } + return System::Clock::Timeout(); + } + const PeerAddress & GetPeerAddress() const { return mPeerAddress; } - PeerAddress & GetPeerAddress() { return mPeerAddress; } void SetPeerAddress(const PeerAddress & address) { mPeerAddress = address; } Type GetSecureSessionType() const { return mSecureSessionType; } @@ -86,7 +108,7 @@ class SecureSession void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; } - const ReliableMessageProtocolConfig & GetMRPConfig() const { return mMRPConfig; } + const ReliableMessageProtocolConfig & GetMRPConfig() const override { return mMRPConfig; } uint16_t GetLocalSessionId() const { return mLocalSessionId; } uint16_t GetPeerSessionId() const { return mPeerSessionId; } diff --git a/src/transport/SecureSessionTable.h b/src/transport/SecureSessionTable.h index a680ae050a0641..9dce2198c3bbc4 100644 --- a/src/transport/SecureSessionTable.h +++ b/src/transport/SecureSessionTable.h @@ -60,11 +60,13 @@ class SecureSessionTable * has been reached (with CHIP_ERROR_NO_MEMORY). */ CHECK_RETURN_VALUE - SecureSession * CreateNewSecureSession(SecureSession::Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId, - CATValues peerCATs, uint16_t peerSessionId, FabricIndex fabric, - const ReliableMessageProtocolConfig & config) + Optional CreateNewSecureSession(SecureSession::Type secureSessionType, uint16_t localSessionId, + NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId, + FabricIndex fabric, const ReliableMessageProtocolConfig & config) { - return mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config); + SecureSession * result = + mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config); + return result != nullptr ? MakeOptional(*result) : Optional::Missing(); } void ReleaseSession(SecureSession * session) { mEntries.ReleaseObject(session); } @@ -83,7 +85,7 @@ class SecureSessionTable * @return the state found, nullptr if not found */ CHECK_RETURN_VALUE - SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId) + Optional FindSecureSessionByLocalKey(uint16_t localSessionId) { SecureSession * result = nullptr; mEntries.ForEachActiveObject([&](auto session) { @@ -94,7 +96,7 @@ class SecureSessionTable } return Loop::Continue; }); - return result; + return result != nullptr ? MakeOptional(*result) : Optional::Missing(); } /** diff --git a/src/transport/Session.cpp b/src/transport/Session.cpp new file mode 100644 index 00000000000000..768cc462f273c9 --- /dev/null +++ b/src/transport/Session.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +namespace chip { +namespace Transport { + +SecureSession * Session::AsSecureSession() +{ + VerifyOrDie(GetSessionType() == SessionType::kSecure); + return static_cast(this); +} + +UnauthenticatedSession * Session::AsUnauthenticatedSession() +{ + VerifyOrDie(GetSessionType() == SessionType::kUnauthenticated); + return static_cast(this); +} + +GroupSession * Session::AsGroupSession() +{ + VerifyOrDie(GetSessionType() == SessionType::kGroup); + return static_cast(this); +} + +} // namespace Transport +} // namespace chip diff --git a/src/transport/Session.h b/src/transport/Session.h new file mode 100644 index 00000000000000..4e343cbef26034 --- /dev/null +++ b/src/transport/Session.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace chip { +namespace Transport { + +class SecureSession; +class UnauthenticatedSession; +class GroupSession; + +class Session +{ +public: + virtual ~Session() {} + + enum class SessionType : uint8_t + { + kUndefined = 0, + kUnauthenticated = 1, + kSecure = 2, + kGroup = 3, + }; + + virtual SessionType GetSessionType() const = 0; +#if CHIP_PROGRESS_LOGGING + virtual const char * GetSessionTypeString() const = 0; +#endif + + void AddHolder(SessionHolder & holder) + { + VerifyOrDie(!holder.IsInList()); + mHolders.PushBack(&holder); + } + + void RemoveHolder(SessionHolder & holder) + { + VerifyOrDie(mHolders.Contains(&holder)); + mHolders.Remove(&holder); + } + + // For types of sessions using reference counter, override these functions, otherwise leave it empty. + virtual void Retain() {} + virtual void Release() {} + + virtual Access::SubjectDescriptor GetSubjectDescriptor() const = 0; + virtual bool RequireMRP() const = 0; + virtual const ReliableMessageProtocolConfig & GetMRPConfig() const = 0; + virtual System::Clock::Milliseconds32 GetAckTimeout() const = 0; + + SecureSession * AsSecureSession(); + UnauthenticatedSession * AsUnauthenticatedSession(); + GroupSession * AsGroupSession(); + + bool IsGroupSession() const { return GetSessionType() == SessionType::kGroup; } + +protected: + // This should be called by sub-classes at the very beginning of the destructor, before any data field is disposed, such that + // the session is still functional during the callback. + void NotifySessionReleased() + { + SessionHandle session(*this); + while (!mHolders.Empty()) + { + mHolders.begin()->OnSessionReleased(); // OnSessionReleased must remove the item from the linked list + } + } + +private: + IntrusiveList mHolders; +}; + +} // namespace Transport +} // namespace chip diff --git a/src/transport/SessionDelegate.h b/src/transport/SessionDelegate.h index bee1ab6970ea20..97e1d66faa3ba1 100644 --- a/src/transport/SessionDelegate.h +++ b/src/transport/SessionDelegate.h @@ -31,7 +31,7 @@ class DLL_EXPORT SessionReleaseDelegate * * @param session The handle to the secure session */ - virtual void OnSessionReleased(const SessionHandle & session) = 0; + virtual void OnSessionReleased() = 0; }; class DLL_EXPORT SessionRecoveryDelegate diff --git a/src/transport/SessionHandle.cpp b/src/transport/SessionHandle.cpp index eebed4f41e6e98..3fe0ebfc1ff4b0 100644 --- a/src/transport/SessionHandle.cpp +++ b/src/transport/SessionHandle.cpp @@ -23,83 +23,4 @@ namespace chip { using namespace Transport; -using AuthMode = Access::AuthMode; -using SubjectDescriptor = Access::SubjectDescriptor; - -SubjectDescriptor SessionHandle::GetSubjectDescriptor() const -{ - SubjectDescriptor subjectDescriptor; - if (IsSecure()) - { - if (IsOperationalNodeId(mPeerNodeId)) - { - subjectDescriptor.authMode = AuthMode::kCase; - subjectDescriptor.subject = mPeerNodeId; - subjectDescriptor.fabricIndex = mFabric; - // TODO(#10243): add CATs - } - else if (IsPAKEKeyId(mPeerNodeId)) - { - subjectDescriptor.authMode = AuthMode::kPase; - subjectDescriptor.subject = mPeerNodeId; - // TODO(#10242): PASE *can* have fabric in some situations - } - else if (mGroupId.HasValue()) - { - subjectDescriptor.authMode = AuthMode::kGroup; - subjectDescriptor.subject = NodeIdFromGroupId(mGroupId.Value()); - } - } - return subjectDescriptor; -} - -const PeerAddress * SessionHandle::GetPeerAddress(SessionManager * sessionManager) const -{ - if (IsSecure()) - { - SecureSession * state = sessionManager->GetSecureSession(*this); - if (state == nullptr) - { - return nullptr; - } - - return &state->GetPeerAddress(); - } - - return &GetUnauthenticatedSession()->GetPeerAddress(); -} - -const ReliableMessageProtocolConfig & SessionHandle::GetMRPConfig(SessionManager * sessionManager) const -{ - if (IsSecure()) - { - SecureSession * secureSession = sessionManager->GetSecureSession(*this); - if (secureSession == nullptr) - { - return gDefaultMRPConfig; - } - return secureSession->GetMRPConfig(); - } - else - { - return GetUnauthenticatedSession()->GetMRPConfig(); - } -} - -void SessionHandle::SetMRPConfig(SessionManager * sessionManager, const ReliableMessageProtocolConfig & config) -{ - if (IsSecure()) - { - SecureSession * secureSession = sessionManager->GetSecureSession(*this); - if (secureSession != nullptr) - { - secureSession->SetMRPConfig(config); - } - } - else - { - return GetUnauthenticatedSession()->SetMRPConfig(config); - } -} - } // namespace chip diff --git a/src/transport/SessionHandle.h b/src/transport/SessionHandle.h index fb9dc31d21ebe4..344dab3580b0ec 100644 --- a/src/transport/SessionHandle.h +++ b/src/transport/SessionHandle.h @@ -17,92 +17,39 @@ #pragma once -#include -#include -#include -#include -#include -#include +#include +#include namespace chip { -class SessionManager; +namespace Transport { +class Session; +} // namespace Transport +class SessionHolder; + +/** @brief + * Non-copyable session reference. All SessionHandles are created within SessionManager. SessionHandle is not + * reference *counted, hence it is not allowed to store SessionHandle anywhere except for function arguments and + * return values. SessionHandle is short-lived as it is only available as stack variable, so it is never dangling. */ class SessionHandle { public: - using SubjectDescriptor = Access::SubjectDescriptor; - - SessionHandle(NodeId peerNodeId, FabricIndex fabric) : mPeerNodeId(peerNodeId), mFabric(fabric) {} - - SessionHandle(Transport::UnauthenticatedSessionHandle session) : - mPeerNodeId(kPlaceholderNodeId), mFabric(kUndefinedFabricIndex), mUnauthenticatedSessionHandle(session) - {} - - SessionHandle(Transport::SecureSession & session) : mPeerNodeId(session.GetPeerNodeId()), mFabric(session.GetFabricIndex()) - { - mLocalSessionId.SetValue(session.GetLocalSessionId()); - mPeerSessionId.SetValue(session.GetPeerSessionId()); - } - - SessionHandle(NodeId peerNodeId, GroupId groupId, FabricIndex fabric) : mPeerNodeId(peerNodeId), mFabric(fabric) - { - mGroupId.SetValue(groupId); - } - - bool IsSecure() const { return !mUnauthenticatedSessionHandle.HasValue(); } + SessionHandle(Transport::Session & session) : mSession(session) {} + ~SessionHandle() {} - bool HasFabricIndex() const { return (mFabric != kUndefinedFabricIndex); } - FabricIndex GetFabricIndex() const { return mFabric; } - void SetFabricIndex(FabricIndex fabricId) { mFabric = fabricId; } - void SetGroupId(GroupId groupId) { mGroupId.SetValue(groupId); } + SessionHandle(const SessionHandle &) = delete; + SessionHandle operator=(const SessionHandle &) = delete; + SessionHandle(SessionHandle &&) = default; + SessionHandle & operator=(SessionHandle &&) = delete; - SubjectDescriptor GetSubjectDescriptor() const; + bool operator==(const SessionHandle & that) const { return &mSession.Get() == &that.mSession.Get(); } - bool operator==(const SessionHandle & that) const - { - if (IsSecure()) - { - return that.IsSecure() && mLocalSessionId.Value() == that.mLocalSessionId.Value(); - } - else - { - return !that.IsSecure() && mUnauthenticatedSessionHandle.Value() == that.mUnauthenticatedSessionHandle.Value(); - } - } - - NodeId GetPeerNodeId() const { return mPeerNodeId; } - bool IsGroupSession() const { return mGroupId.HasValue(); } - const Optional & GetGroupId() const { return mGroupId; } - const Optional & GetPeerSessionId() const { return mPeerSessionId; } - const Optional & GetLocalSessionId() const { return mLocalSessionId; } - - // Return the peer address for this session. May return null if the peer - // address is not known. This can happen for secure sessions that have been - // torn down, at the very least. - const Transport::PeerAddress * GetPeerAddress(SessionManager * sessionManager) const; - - const ReliableMessageProtocolConfig & GetMRPConfig(SessionManager * sessionManager) const; - void SetMRPConfig(SessionManager * sessionManager, const ReliableMessageProtocolConfig & config); - - Transport::UnauthenticatedSessionHandle GetUnauthenticatedSession() const { return mUnauthenticatedSessionHandle.Value(); } + Transport::Session * operator->() const { return mSession.operator->(); } private: - friend class SessionManager; - - // Fields for secure session - NodeId mPeerNodeId; - Optional mLocalSessionId; - Optional mPeerSessionId; - Optional mGroupId; - // TODO: Re-evaluate the storing of Fabric ID in SessionHandle - // The Fabric ID will not be available for PASE and group sessions. So need - // to identify an approach that'll allow looking up the corresponding information for - // such sessions. - FabricIndex mFabric; - - // Fields for unauthenticated session - Optional mUnauthenticatedSessionHandle; + friend class SessionHolder; + ReferenceCountedHandle mSession; }; } // namespace chip diff --git a/src/transport/SessionHolder.cpp b/src/transport/SessionHolder.cpp new file mode 100644 index 00000000000000..290cc1803fafcb --- /dev/null +++ b/src/transport/SessionHolder.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace chip { + +SessionHolder::~SessionHolder() +{ + Release(); +} + +SessionHolder::SessionHolder(const SessionHolder & that) : IntrusiveListNodeBase() +{ + mSession = that.mSession; + if (mSession.HasValue()) + { + mSession.Value()->AddHolder(*this); + } +} + +SessionHolder::SessionHolder(SessionHolder && that) : IntrusiveListNodeBase() +{ + mSession = that.mSession; + if (mSession.HasValue()) + { + mSession.Value()->AddHolder(*this); + } + + that.Release(); +} + +SessionHolder & SessionHolder::operator=(const SessionHolder & that) +{ + Release(); + + mSession = that.mSession; + if (mSession.HasValue()) + { + mSession.Value()->AddHolder(*this); + } + + return *this; +} + +SessionHolder & SessionHolder::operator=(SessionHolder && that) +{ + Release(); + + mSession = that.mSession; + if (mSession.HasValue()) + { + mSession.Value()->AddHolder(*this); + } + + that.Release(); + + return *this; +} + +void SessionHolder::Grab(const SessionHandle & session) +{ + Release(); + mSession.Emplace(session.mSession); + session->AddHolder(*this); +} + +void SessionHolder::Release() +{ + if (mSession.HasValue()) + { + mSession.Value()->RemoveHolder(*this); + mSession.ClearValue(); + } +} + +} // namespace chip diff --git a/src/transport/SessionHolder.h b/src/transport/SessionHolder.h index 8210f8e5241d81..897004f59f6dcc 100644 --- a/src/transport/SessionHolder.h +++ b/src/transport/SessionHolder.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -26,42 +27,59 @@ namespace chip { * Managed session reference. The object is used to store a session, the stored session will be automatically * released when the underlying session is released. One must verify it is available before use. The object can be * created using SessionHandle.Grab() - * - * TODO: release holding session when the session is released. This will be implemented by following PRs */ -class SessionHolder : public SessionReleaseDelegate +class SessionHolder : public SessionReleaseDelegate, public IntrusiveListNodeBase { public: SessionHolder() {} - SessionHolder(const SessionHandle & session) : mSession(session) {} - ~SessionHolder() { Release(); } + ~SessionHolder(); SessionHolder(const SessionHolder &); - SessionHolder operator=(const SessionHolder &); SessionHolder(SessionHolder && that); - SessionHolder operator=(SessionHolder && that); + SessionHolder & operator=(const SessionHolder &); + SessionHolder & operator=(SessionHolder && that); + + // Implement SessionReleaseDelegate + void OnSessionReleased() override { Release(); } - void Grab(const SessionHandle & sessionHandle) + bool Contains(const SessionHandle & session) const { - Release(); - mSession.SetValue(sessionHandle); + return mSession.HasValue() && &mSession.Value().Get() == &session.mSession.Get(); } - void Release() { mSession.ClearValue(); } + void Grab(const SessionHandle & session); + void Release(); - // TODO: call this function when the underlying session is released - // Implement SessionReleaseDelegate - void OnSessionReleased(const SessionHandle & session) override { Release(); } + operator bool() const { return mSession.HasValue(); } + SessionHandle Get() const { return SessionHandle{ mSession.Value().Get() }; } + Optional ToOptional() const + { + return mSession.HasValue() ? chip::MakeOptional(Get()) : chip::Optional::Missing(); + } - // Check whether the SessionHolder contains a session matching given session - bool Contains(const SessionHandle & session) const { return mSession.HasValue() && mSession.Value() == session; } + Transport::Session * operator->() const { return &mSession.Value().Get(); } - operator bool() const { return mSession.HasValue(); } - const SessionHandle & Get() const { return mSession.Value(); } - Optional ToOptional() const { return mSession; } +private: + Optional> mSession; +}; + +// @brief Extends SessionHolder to allow propagate OnSessionReleased event to an extra given destination +class SessionHolderWithDelegate : public SessionHolder +{ +public: + SessionHolderWithDelegate(SessionReleaseDelegate & delegate) : mDelegate(delegate) {} + operator bool() const { return SessionHolder::operator bool(); } + + void OnSessionReleased() override + { + Release(); + + // Note, the session is already cleared during mDelegate.OnSessionReleased + mDelegate.OnSessionReleased(); + } private: - Optional mSession; + SessionReleaseDelegate & mDelegate; }; } // namespace chip diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 0e6452efe755ff..1edf5c24e364a3 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -95,7 +96,6 @@ void SessionManager::Shutdown() { CancelExpiryTimer(); - mSessionReleaseDelegates.ReleaseAll(); mSessionRecoveryDelegates.ReleaseAll(); mMessageCounterManager = nullptr; @@ -119,48 +119,47 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P NodeId destination; FabricIndex fabricIndex; #endif // CHIP_PROGRESS_LOGGING - if (sessionHandle.IsSecure()) + + switch (sessionHandle->GetSessionType()) { - if (sessionHandle.IsGroupSession()) + case Transport::Session::SessionType::kGroup: { + // TODO : #11911 + // For now, just set the packetHeader with the correct data. + packetHeader.SetDestinationGroupId(sessionHandle->AsGroupSession()->GetGroupId()); + packetHeader.SetFlags(Header::SecFlagValues::kPrivacyFlag); + packetHeader.SetSessionType(Header::SessionType::kGroupSession); + // TODO : Replace the PeerNodeId with Our nodeId + packetHeader.SetSourceNodeId(kUndefinedNodeId); + + if (!packetHeader.IsValidGroupMsg()) { - // TODO : #11911 - // For now, just set the packetHeader with the correct data. - packetHeader.SetDestinationGroupId(sessionHandle.GetGroupId()); - packetHeader.SetFlags(Header::SecFlagValues::kPrivacyFlag); - packetHeader.SetSessionType(Header::SessionType::kGroupSession); - // TODO : Replace the PeerNodeId with Our nodeId - packetHeader.SetSourceNodeId(sessionHandle.GetPeerNodeId()); - - if (!packetHeader.IsValidGroupMsg()) - { - return CHIP_ERROR_INTERNAL; - } - // TODO #11911 Update SecureMessageCodec::Encrypt for Group - ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); + return CHIP_ERROR_INTERNAL; + } + // TODO #11911 Update SecureMessageCodec::Encrypt for Group + ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); #if CHIP_PROGRESS_LOGGING - destination = sessionHandle.GetPeerNodeId(); - fabricIndex = sessionHandle.GetFabricIndex(); + destination = kUndefinedNodeId; + fabricIndex = kUndefinedFabricIndex; #endif // CHIP_PROGRESS_LOGGING - } - else + } + break; + case Transport::Session::SessionType::kSecure: { + SecureSession * session = sessionHandle->AsSecureSession(); + if (session == nullptr) { - SecureSession * session = GetSecureSession(sessionHandle); - if (session == nullptr) - { - return CHIP_ERROR_NOT_CONNECTED; - } - MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *session); - ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session, payloadHeader, packetHeader, message, counter)); + return CHIP_ERROR_NOT_CONNECTED; + } + MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *session); + ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session, payloadHeader, packetHeader, message, counter)); #if CHIP_PROGRESS_LOGGING - destination = session->GetPeerNodeId(); - fabricIndex = session->GetFabricIndex(); + destination = session->GetPeerNodeId(); + fabricIndex = session->GetFabricIndex(); #endif // CHIP_PROGRESS_LOGGING - } } - else - { + break; + case Transport::Session::SessionType::kUnauthenticated: { ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); MessageCounter & counter = mGlobalUnencryptedMessageCounter; @@ -174,13 +173,17 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P fabricIndex = kUndefinedFabricIndex; #endif // CHIP_PROGRESS_LOGGING } + break; + default: + return CHIP_ERROR_INTERNAL; + } ChipLogProgress(Inet, "Prepared %s message %p to 0x" ChipLogFormatX64 " (%u) of type " ChipLogFormatMessageType " and protocolId " ChipLogFormatProtocolId " on exchange " ChipLogFormatExchangeId " with MessageCounter:" ChipLogFormatMessageCounter ".", - sessionHandle.IsSecure() ? "encrypted" : "plaintext", &preparedMessage, ChipLogValueX64(destination), - fabricIndex, payloadHeader.GetMessageType(), ChipLogValueProtocolId(payloadHeader.GetProtocolID()), + sessionHandle->GetSessionTypeString(), &preparedMessage, ChipLogValueX64(destination), fabricIndex, + payloadHeader.GetMessageType(), ChipLogValueProtocolId(payloadHeader.GetProtocolID()), ChipLogValueExchangeIdFromSentHeader(payloadHeader), packetHeader.GetMessageCounter()); ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message)); @@ -197,58 +200,56 @@ CHIP_ERROR SessionManager::SendPreparedMessage(const SessionHandle & sessionHand const Transport::PeerAddress * destination; - if (sessionHandle.IsSecure()) + switch (sessionHandle->GetSessionType()) { - if (sessionHandle.IsGroupSession()) - { - chip::Transport::PeerAddress multicastAddress = - Transport::PeerAddress::Multicast(sessionHandle.GetFabricIndex(), sessionHandle.GetGroupId().Value()); - destination = static_cast(&multicastAddress); - char addressStr[Transport::PeerAddress::kMaxToStringSize]; - multicastAddress.ToString(addressStr, Transport::PeerAddress::kMaxToStringSize); - - ChipLogProgress(Inet, - "Sending %s msg %p with MessageCounter:" ChipLogFormatMessageCounter " to %d" - " at monotonic time: %" PRId64 - " msec to Multicast IPV6 address : %s with GroupID of %d and fabric Id of %d", - "encrypted", &preparedMessage, preparedMessage.GetMessageCounter(), sessionHandle.GetGroupId().Value(), - System::SystemClock().GetMonotonicMilliseconds64().count(), addressStr, - sessionHandle.GetGroupId().Value(), sessionHandle.GetFabricIndex()); - } - else - { - // Find an active connection to the specified peer node - SecureSession * session = GetSecureSession(sessionHandle); - if (session == nullptr) - { - ChipLogError(Inet, "Secure transport could not find a valid PeerConnection"); - return CHIP_ERROR_NOT_CONNECTED; - } - - // This marks any connection where we send data to as 'active' - session->MarkActive(); - - destination = &session->GetPeerAddress(); - - ChipLogProgress(Inet, - "Sending %s msg %p with MessageCounter:" ChipLogFormatMessageCounter " to 0x" ChipLogFormatX64 - " (%u) at monotonic time: %" PRId64 " msec", - "encrypted", &preparedMessage, preparedMessage.GetMessageCounter(), - ChipLogValueX64(session->GetPeerNodeId()), session->GetFabricIndex(), - System::SystemClock().GetMonotonicMilliseconds64().count()); - } + case Transport::Session::SessionType::kGroup: { + auto groupSession = sessionHandle->AsGroupSession(); + Transport::PeerAddress multicastAddress = + Transport::PeerAddress::Multicast(groupSession->GetFabricIndex(), groupSession->GetGroupId()); + destination = &multicastAddress; // XXX: this is dangling pointer, must be fixed + char addressStr[Transport::PeerAddress::kMaxToStringSize]; + multicastAddress.ToString(addressStr, Transport::PeerAddress::kMaxToStringSize); + + ChipLogProgress(Inet, + "Sending %s msg %p with MessageCounter:" ChipLogFormatMessageCounter " to %d" + " at monotonic time: %" PRId64 + " msec to Multicast IPV6 address : %s with GroupID of %d and fabric Id of %d", + "encrypted", &preparedMessage, preparedMessage.GetMessageCounter(), groupSession->GetGroupId(), + System::SystemClock().GetMonotonicMilliseconds64().count(), addressStr, groupSession->GetGroupId(), + groupSession->GetFabricIndex()); } - else - { - auto unauthenticated = sessionHandle.GetUnauthenticatedSession(); + break; + case Transport::Session::SessionType::kSecure: { + // Find an active connection to the specified peer node + SecureSession * secure = sessionHandle->AsSecureSession(); + + // This marks any connection where we send data to as 'active' + secure->MarkActive(); + + destination = &secure->GetPeerAddress(); + + ChipLogProgress(Inet, + "Sending %s msg %p with MessageCounter:" ChipLogFormatMessageCounter " to 0x" ChipLogFormatX64 + " (%u) at monotonic time: %" PRId64 " msec", + "encrypted", &preparedMessage, preparedMessage.GetMessageCounter(), + ChipLogValueX64(secure->GetPeerNodeId()), secure->GetFabricIndex(), + System::SystemClock().GetMonotonicMilliseconds64().count()); + } + break; + case Transport::Session::SessionType::kUnauthenticated: { + auto unauthenticated = sessionHandle->AsUnauthenticatedSession(); unauthenticated->MarkActive(); destination = &unauthenticated->GetPeerAddress(); ChipLogProgress(Inet, "Sending %s msg %p with MessageCounter:" ChipLogFormatMessageCounter " to 0x" ChipLogFormatX64 " at monotonic time: %" PRId64 " msec", - "plaintext", &preparedMessage, preparedMessage.GetMessageCounter(), ChipLogValueX64(kUndefinedNodeId), - System::SystemClock().GetMonotonicMilliseconds64().count()); + sessionHandle->GetSessionTypeString(), &preparedMessage, preparedMessage.GetMessageCounter(), + ChipLogValueX64(kUndefinedNodeId), System::SystemClock().GetMonotonicMilliseconds64().count()); + } + break; + default: + return CHIP_ERROR_INTERNAL; } PacketBufferHandle msgBuf = preparedMessage.CastToWritable(); @@ -268,12 +269,7 @@ CHIP_ERROR SessionManager::SendPreparedMessage(const SessionHandle & sessionHand void SessionManager::ExpirePairing(const SessionHandle & sessionHandle) { - SecureSession * session = GetSecureSession(sessionHandle); - if (session != nullptr) - { - HandleConnectionExpired(*session); - mSecureSessions.ReleaseSession(session); - } + mSecureSessions.ReleaseSession(sessionHandle->AsSecureSession()); } void SessionManager::ExpireAllPairings(NodeId peerNodeId, FabricIndex fabric) @@ -281,7 +277,6 @@ void SessionManager::ExpireAllPairings(NodeId peerNodeId, FabricIndex fabric) mSecureSessions.ForEachSession([&](auto session) { if (session->GetPeerNodeId() == peerNodeId && session->GetFabricIndex() == fabric) { - HandleConnectionExpired(*session); mSecureSessions.ReleaseSession(session); } return Loop::Continue; @@ -294,7 +289,6 @@ void SessionManager::ExpireAllPairingsForFabric(FabricIndex fabric) mSecureSessions.ForEachSession([&](auto session) { if (session->GetFabricIndex() == fabric) { - HandleConnectionExpired(*session); mSecureSessions.ReleaseSession(session); } return Loop::Continue; @@ -305,30 +299,30 @@ CHIP_ERROR SessionManager::NewPairing(SessionHolder & sessionHolder, const Optio NodeId peerNodeId, PairingSession * pairing, CryptoContext::SessionRole direction, FabricIndex fabric) { - uint16_t peerSessionId = pairing->GetPeerSessionId(); - uint16_t localSessionId = pairing->GetLocalSessionId(); - SecureSession * session = mSecureSessions.FindSecureSessionByLocalKey(localSessionId); + uint16_t peerSessionId = pairing->GetPeerSessionId(); + uint16_t localSessionId = pairing->GetLocalSessionId(); + Optional session = mSecureSessions.FindSecureSessionByLocalKey(localSessionId); // Find any existing connection with the same local key ID - if (session) + if (session.HasValue()) { - HandleConnectionExpired(*session); - mSecureSessions.ReleaseSession(session); + mSecureSessions.ReleaseSession(session.Value()->AsSecureSession()); } ChipLogDetail(Inet, "New secure session created for device 0x" ChipLogFormatX64 ", key %d!!", ChipLogValueX64(peerNodeId), peerSessionId); session = mSecureSessions.CreateNewSecureSession(pairing->GetSecureSessionType(), localSessionId, peerNodeId, pairing->GetPeerCATs(), peerSessionId, fabric, pairing->GetMRPConfig()); - ReturnErrorCodeIf(session == nullptr, CHIP_ERROR_NO_MEMORY); + ReturnErrorCodeIf(!session.HasValue(), CHIP_ERROR_NO_MEMORY); + Transport::SecureSession * secureSession = session.Value()->AsSecureSession(); if (peerAddr.HasValue() && peerAddr.Value().GetIPAddress() != Inet::IPAddress::Any) { - session->SetPeerAddress(peerAddr.Value()); + secureSession->SetPeerAddress(peerAddr.Value()); } else if (peerAddr.HasValue() && peerAddr.Value().GetTransportType() == Transport::Type::kBle) { - session->SetPeerAddress(peerAddr.Value()); + secureSession->SetPeerAddress(peerAddr.Value()); } else if (peerAddr.HasValue() && (peerAddr.Value().GetTransportType() == Transport::Type::kTcp || @@ -337,11 +331,10 @@ CHIP_ERROR SessionManager::NewPairing(SessionHolder & sessionHolder, const Optio return CHIP_ERROR_INVALID_ARGUMENT; } - ReturnErrorOnFailure(pairing->DeriveSecureSession(session->GetCryptoContext(), direction)); - - session->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(pairing->GetPeerCounter()); - sessionHolder.Grab(SessionHandle(*session)); + ReturnErrorOnFailure(pairing->DeriveSecureSession(secureSession->GetCryptoContext(), direction)); + secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(pairing->GetPeerCounter()); + sessionHolder.Grab(session.Value()); return CHIP_NO_ERROR; } @@ -419,19 +412,19 @@ void SessionManager::RefreshSessionOperationalData(const SessionHandle & session void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) { - Optional optionalSession = - mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, gDefaultMRPConfig); + Optional optionalSession = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, gDefaultMRPConfig); if (!optionalSession.HasValue()) { ChipLogError(Inet, "UnauthenticatedSession exhausted"); return; } - Transport::UnauthenticatedSessionHandle session = optionalSession.Value(); + const SessionHandle & session = optionalSession.Value(); SessionMessageDelegate::DuplicateMessage isDuplicate = SessionMessageDelegate::DuplicateMessage::No; // Verify message counter - CHIP_ERROR err = session->GetPeerMessageCounter().VerifyOrTrustFirst(packetHeader.GetMessageCounter()); + CHIP_ERROR err = + session->AsUnauthenticatedSession()->GetPeerMessageCounter().VerifyOrTrustFirst(packetHeader.GetMessageCounter()); if (err == CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED) { isDuplicate = SessionMessageDelegate::DuplicateMessage::Yes; @@ -439,7 +432,7 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr } VerifyOrDie(err == CHIP_NO_ERROR); - session->MarkActive(); + session->AsUnauthenticatedSession()->MarkActive(); PayloadHeader payloadHeader; ReturnOnFailure(payloadHeader.DecodeAndConsume(msg)); @@ -452,11 +445,11 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr packetHeader.GetMessageCounter(), ChipLogValueExchangeIdFromReceivedHeader(payloadHeader)); } - session->GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter()); + session->AsUnauthenticatedSession()->GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter()); if (mCB != nullptr) { - mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(session), peerAddress, isDuplicate, std::move(msg)); + mCB->OnMessageReceived(packetHeader, payloadHeader, optionalSession.Value(), peerAddress, isDuplicate, std::move(msg)); } } @@ -465,7 +458,7 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea { CHIP_ERROR err = CHIP_NO_ERROR; - SecureSession * session = mSecureSessions.FindSecureSessionByLocalKey(packetHeader.GetSessionId()); + Optional session = mSecureSessions.FindSecureSessionByLocalKey(packetHeader.GetSessionId()); PayloadHeader payloadHeader; @@ -477,20 +470,21 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea return; } - if (session == nullptr) + if (!session.HasValue()) { ChipLogError(Inet, "Data received on an unknown connection (%d). Dropping it!!", packetHeader.GetSessionId()); return; } + Transport::SecureSession * secureSession = session.Value()->AsSecureSession(); // Decrypt and verify the message before message counter verification or any further processing. - if (SecureMessageCodec::Decrypt(session, payloadHeader, packetHeader, msg) != CHIP_NO_ERROR) + if (SecureMessageCodec::Decrypt(secureSession, payloadHeader, packetHeader, msg) != CHIP_NO_ERROR) { ChipLogError(Inet, "Secure transport received message, but failed to decode/authenticate it, discarding"); return; } - err = session->GetSessionMessageCounter().GetPeerMessageCounter().Verify(packetHeader.GetMessageCounter()); + err = secureSession->GetSessionMessageCounter().GetPeerMessageCounter().Verify(packetHeader.GetMessageCounter()); if (err == CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED) { isDuplicate = SessionMessageDelegate::DuplicateMessage::Yes; @@ -502,7 +496,7 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea return; } - session->MarkActive(); + secureSession->MarkActive(); if (isDuplicate == SessionMessageDelegate::DuplicateMessage::Yes && !payloadHeader.NeedsAck()) { @@ -518,20 +512,19 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea } } - session->GetSessionMessageCounter().GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter()); + secureSession->GetSessionMessageCounter().GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter()); // TODO: once mDNS address resolution is available reconsider if this is required // This updates the peer address once a packet is received from a new address // and serves as a way to auto-detect peer changing IPs. - if (session->GetPeerAddress() != peerAddress) + if (secureSession->GetPeerAddress() != peerAddress) { - session->SetPeerAddress(peerAddress); + secureSession->SetPeerAddress(peerAddress); } if (mCB != nullptr) { - SessionHandle sessionHandle(*session); - mCB->OnMessageReceived(packetHeader, payloadHeader, sessionHandle, peerAddress, isDuplicate, std::move(msg)); + mCB->OnMessageReceived(packetHeader, payloadHeader, session.Value(), peerAddress, isDuplicate, std::move(msg)); } } @@ -540,9 +533,19 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade { PayloadHeader payloadHeader; SessionMessageDelegate::DuplicateMessage isDuplicate = SessionMessageDelegate::DuplicateMessage::No; - FabricIndex fabricIndex = 0; // TODO : remove initialization once GroupDataProvider->Decrypt is implemented // Credentials::GroupDataProvider * groups = Credentials::GetGroupDataProvider(); + if (!packetHeader.GetDestinationGroupId().HasValue()) + { + return; // malformed packet + } + + Optional session = FindGroupSession(packetHeader.GetDestinationGroupId().Value()); + if (!session.HasValue()) + { + return; + } + if (msg.IsNull()) { ChipLogError(Inet, "Secure transport received Groupcast NULL packet, discarding"); @@ -600,50 +603,22 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade if (mCB != nullptr) { - SessionHandle session(packetHeader.GetSourceNodeId().Value(), packetHeader.GetDestinationGroupId().Value(), fabricIndex); - mCB->OnMessageReceived(packetHeader, payloadHeader, session, peerAddress, isDuplicate, std::move(msg)); + mCB->OnMessageReceived(packetHeader, payloadHeader, session.Value(), peerAddress, isDuplicate, std::move(msg)); } } -void SessionManager::HandleConnectionExpired(Transport::SecureSession & session) -{ - ChipLogDetail(Inet, "Marking old secure session for device 0x" ChipLogFormatX64 " as expired", - ChipLogValueX64(session.GetPeerNodeId())); - - SessionHandle sessionHandle(session); - mSessionReleaseDelegates.ForEachActiveObject([&](std::reference_wrapper * cb) { - cb->get().OnSessionReleased(sessionHandle); - return Loop::Continue; - }); - - mTransportMgr->Disconnect(session.GetPeerAddress()); -} - void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param) { SessionManager * mgr = reinterpret_cast(param); #if CHIP_CONFIG_SESSION_REKEYING // TODO(#2279): session expiration is currently disabled until rekeying is supported // the #ifdef should be removed after that. - mgr->mSecureSessions.ExpireInactiveSessions( - System::SystemClock().GetMonotonicTimestamp(), System::Clock::Milliseconds32(CHIP_PEER_CONNECTION_TIMEOUT_MS), - [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); + mgr->mSecureSessions.ExpireInactiveSessions(System::SystemClock().GetMonotonicTimestamp(), + System::Clock::Milliseconds32(CHIP_PEER_CONNECTION_TIMEOUT_MS)); #endif mgr->ScheduleExpiryTimer(); // re-schedule the oneshot timer } -SecureSession * SessionManager::GetSecureSession(const SessionHandle & session) -{ - if (session.mLocalSessionId.HasValue()) - { - return mSecureSessions.FindSecureSessionByLocalKey(session.mLocalSessionId.Value()); - } - else - { - return nullptr; - } -} - SessionHandle SessionManager::FindSecureSessionForNode(NodeId peerNodeId) { SecureSession * found = nullptr; diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 70add956a8178f..62d40e3ea96045 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -143,37 +144,10 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate */ CHIP_ERROR SendPreparedMessage(const SessionHandle & session, const EncryptedPacketBufferHandle & preparedMessage); - Transport::SecureSession * GetSecureSession(const SessionHandle & session); - /// @brief Set the delegate for handling incoming messages. There can be only one message delegate (probably the /// ExchangeManager) void SetMessageDelegate(SessionMessageDelegate * cb) { mCB = cb; } - /// @brief Set the delegate for handling session release. - void RegisterReleaseDelegate(SessionReleaseDelegate & cb) - { -#ifndef NDEBUG - mSessionReleaseDelegates.ForEachActiveObject([&](std::reference_wrapper * i) { - VerifyOrDie(std::addressof(cb) != std::addressof(i->get())); - return Loop::Continue; - }); -#endif - std::reference_wrapper * slot = mSessionReleaseDelegates.CreateObject(cb); - VerifyOrDie(slot != nullptr); - } - - void UnregisterReleaseDelegate(SessionReleaseDelegate & cb) - { - mSessionReleaseDelegates.ForEachActiveObject([&](std::reference_wrapper * i) { - if (std::addressof(cb) == std::addressof(i->get())) - { - mSessionReleaseDelegates.ReleaseObject(i); - return Loop::Break; - } - return Loop::Continue; - }); - } - void RegisterRecoveryDelegate(SessionRecoveryDelegate & cb); void UnregisterRecoveryDelegate(SessionRecoveryDelegate & cb); void RefreshSessionOperationalData(const SessionHandle & sessionHandle); @@ -232,16 +206,12 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate Optional CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & config) { - Optional session = - mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, config); - return session.HasValue() ? MakeOptional(session.Value()) : NullOptional; + return mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, config); } - // TODO: placeholder function for creating GroupSession. Implements a GroupSession class in the future - Optional CreateGroupSession(NodeId peerNodeId, GroupId groupId, FabricIndex fabricIndex) - { - return MakeOptional(SessionHandle(peerNodeId, groupId, fabricIndex)); - } + // TODO: implements group sessions + Optional CreateGroupSession(GroupId group) { return mGroupSessions.AllocEntry(group, kUndefinedFabricIndex); } + Optional FindGroupSession(GroupId group) { return mGroupSessions.FindEntry(group, kUndefinedFabricIndex); } // TODO: this is a temporary solution for legacy tests which use nodeId to send packets // and tv-casting-app that uses the TV's node ID to find the associated secure session @@ -265,17 +235,12 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate System::Layer * mSystemLayer = nullptr; Transport::UnauthenticatedSessionTable mUnauthenticatedSessions; - Transport::SecureSessionTable mSecureSessions; // < Active connections to other peers - State mState; // < Initialization state of the object + Transport::SecureSessionTable mSecureSessions; + Transport::GroupSessionTable mGroupSessions; + State mState; // < Initialization state of the object SessionMessageDelegate * mCB = nullptr; - // TODO: This is a temporary solution to release sessions, in the near future, SessionReleaseDelegate will be - // directly associated with the every SessionHolder. Then the callback function is called on over the handle - // delegate directly, in order to prevent dangling handles. - BitMapObjectPool, CHIP_CONFIG_MAX_SESSION_RELEASE_DELEGATES> - mSessionReleaseDelegates; - BitMapObjectPool, CHIP_CONFIG_MAX_SESSION_RECOVERY_DELEGATES> mSessionRecoveryDelegates; @@ -285,17 +250,14 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate GlobalUnencryptedMessageCounter mGlobalUnencryptedMessageCounter; GlobalEncryptedMessageCounter mGlobalEncryptedMessageCounter; + friend class SessionHandle; + /** Schedules a new oneshot timer for checking connection expiry. */ void ScheduleExpiryTimer(); /** Cancels any active timers for connection expiry checks. */ void CancelExpiryTimer(); - /** - * Called when a specific connection expires. - */ - void HandleConnectionExpired(Transport::SecureSession & state); - /** * Callback for timer expiry check */ diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index c4ca79f1396cf1..0b645d011a79e3 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -26,14 +26,12 @@ #include #include #include +#include #include namespace chip { namespace Transport { -class UnauthenticatedSession; -using UnauthenticatedSessionHandle = ReferenceCountedHandle; - class UnauthenticatedSessionDeleter { public: @@ -45,12 +43,13 @@ class UnauthenticatedSessionDeleter * @brief * An UnauthenticatedSession stores the binding of TransportAddress, and message counters. */ -class UnauthenticatedSession : public ReferenceCounted +class UnauthenticatedSession : public Session, public ReferenceCounted { public: UnauthenticatedSession(const PeerAddress & address, const ReliableMessageProtocolConfig & config) : mPeerAddress(address), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config) {} + ~UnauthenticatedSession() { NotifySessionReleased(); } UnauthenticatedSession(const UnauthenticatedSession &) = delete; UnauthenticatedSession & operator=(const UnauthenticatedSession &) = delete; @@ -60,11 +59,41 @@ class UnauthenticatedSession : public ReferenceCounted::Retain(); } + void Release() override { ReferenceCounted::Release(); } + + Access::SubjectDescriptor GetSubjectDescriptor() const override + { + return Access::SubjectDescriptor(); // return an empty ISD for unauthenticated session. + } + + bool RequireMRP() const override { return GetPeerAddress().GetTransportType() == Transport::Type::kUdp; } + + System::Clock::Milliseconds32 GetAckTimeout() const override + { + switch (mPeerAddress.GetTransportType()) + { + case Transport::Type::kUdp: + return GetMRPConfig().mIdleRetransTimeout * (CHIP_CONFIG_RMP_DEFAULT_MAX_RETRANS + 1); + case Transport::Type::kTcp: + return System::Clock::Seconds16(30); + default: + break; + } + return System::Clock::Timeout(); + } + + NodeId GetPeerNodeId() const { return kUndefinedNodeId; } const PeerAddress & GetPeerAddress() const { return mPeerAddress; } void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; } - const ReliableMessageProtocolConfig & GetMRPConfig() const { return mMRPConfig; } + const ReliableMessageProtocolConfig & GetMRPConfig() const override { return mMRPConfig; } PeerMessageCounter & GetPeerMessageCounter() { return mPeerMessageCounter; } @@ -79,9 +108,8 @@ class UnauthenticatedSession : public ReferenceCounted class UnauthenticatedSessionTable @@ -95,21 +123,20 @@ class UnauthenticatedSessionTable * @return the session found or allocated, nullptr if not found and allocation failed. */ CHECK_RETURN_VALUE - Optional FindOrAllocateEntry(const PeerAddress & address, - const ReliableMessageProtocolConfig & config) + Optional FindOrAllocateEntry(const PeerAddress & address, const ReliableMessageProtocolConfig & config) { UnauthenticatedSession * result = FindEntry(address); if (result != nullptr) - return MakeOptional(*result); + return MakeOptional(*result); CHIP_ERROR err = AllocEntry(address, config, result); if (err == CHIP_NO_ERROR) { - return MakeOptional(*result); + return MakeOptional(*result); } else { - return Optional::Missing(); + return Optional::Missing(); } } diff --git a/src/transport/tests/TestPeerConnections.cpp b/src/transport/tests/TestPeerConnections.cpp index 259defeb627574..e716f8604778db 100644 --- a/src/transport/tests/TestPeerConnections.cpp +++ b/src/transport/tests/TestPeerConnections.cpp @@ -62,7 +62,6 @@ const CATValues kPeer3CATs; void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext) { - SecureSession * statePtr; SecureSessionTable<2> connections; System::Clock::Internal::MockClock clock; System::Clock::ClockBase * realClock = &System::SystemClock(); @@ -71,54 +70,53 @@ void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext) CATValues peerCATs; // Node ID 1, peer key 1, local key 2 - statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */, - gDefaultMRPConfig); - NL_TEST_ASSERT(inSuite, statePtr != nullptr); - NL_TEST_ASSERT(inSuite, statePtr->GetSecureSessionType() == kPeer1SessionType); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer1NodeId); - peerCATs = statePtr->GetPeerCATs(); + auto optionalSession = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, + 0 /* fabricIndex */, gDefaultMRPConfig); + NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetSecureSessionType() == kPeer1SessionType); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetPeerNodeId() == kPeer1NodeId); + peerCATs = optionalSession.Value()->AsSecureSession()->GetPeerCATs(); NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kPeer1CATs, sizeof(CATValues)) == 0); // Node ID 2, peer key 3, local key 4 - statePtr = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */, - gDefaultMRPConfig); - NL_TEST_ASSERT(inSuite, statePtr != nullptr); - NL_TEST_ASSERT(inSuite, statePtr->GetSecureSessionType() == kPeer2SessionType); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer2NodeId); - NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTime() == 100_ms64); - peerCATs = statePtr->GetPeerCATs(); + optionalSession = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */, + gDefaultMRPConfig); + NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetSecureSessionType() == kPeer2SessionType); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetPeerNodeId() == kPeer2NodeId); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetLastActivityTime() == 100_ms64); + peerCATs = optionalSession.Value()->AsSecureSession()->GetPeerCATs(); NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kPeer2CATs, sizeof(CATValues)) == 0); // Insufficient space for new connections. Object is max size 2 - statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */, - gDefaultMRPConfig); - NL_TEST_ASSERT(inSuite, statePtr == nullptr); + optionalSession = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */, + gDefaultMRPConfig); + NL_TEST_ASSERT(inSuite, !optionalSession.HasValue()); System::Clock::Internal::SetSystemClockForTesting(realClock); } void TestFindByKeyId(nlTestSuite * inSuite, void * inContext) { - SecureSession * statePtr; SecureSessionTable<2> connections; System::Clock::Internal::MockClock clock; System::Clock::ClockBase * realClock = &System::SystemClock(); System::Clock::Internal::SetSystemClockForTesting(&clock); // Node ID 1, peer key 1, local key 2 - statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */, - gDefaultMRPConfig); - NL_TEST_ASSERT(inSuite, statePtr != nullptr); + auto optionalSession = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, + 0 /* fabricIndex */, gDefaultMRPConfig); + NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(1)); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(1).HasValue()); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2).HasValue()); // Node ID 2, peer key 3, local key 4 - statePtr = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */, - gDefaultMRPConfig); - NL_TEST_ASSERT(inSuite, statePtr != nullptr); + optionalSession = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */, + gDefaultMRPConfig); + NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(3)); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(3).HasValue()); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4).HasValue()); System::Clock::Internal::SetSystemClockForTesting(realClock); } @@ -133,7 +131,6 @@ struct ExpiredCallInfo void TestExpireConnections(nlTestSuite * inSuite, void * inContext) { ExpiredCallInfo callInfo; - SecureSession * statePtr; SecureSessionTable<2> connections; System::Clock::Internal::MockClock clock; @@ -143,23 +140,23 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) clock.SetMonotonic(100_ms64); // Node ID 1, peer key 1, local key 2 - statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */, - gDefaultMRPConfig); - NL_TEST_ASSERT(inSuite, statePtr != nullptr); - statePtr->SetPeerAddress(kPeer1Addr); + auto optionalSession = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, + 0 /* fabricIndex */, gDefaultMRPConfig); + NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); + optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer1Addr); clock.SetMonotonic(200_ms64); // Node ID 2, peer key 3, local key 4 - statePtr = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */, - gDefaultMRPConfig); - NL_TEST_ASSERT(inSuite, statePtr != nullptr); - statePtr->SetPeerAddress(kPeer2Addr); + optionalSession = connections.CreateNewSecureSession(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */, + gDefaultMRPConfig); + NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); + optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer2Addr); // cannot add before expiry clock.SetMonotonic(300_ms64); - statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */, - gDefaultMRPConfig); - NL_TEST_ASSERT(inSuite, statePtr == nullptr); + optionalSession = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */, + gDefaultMRPConfig); + NL_TEST_ASSERT(inSuite, !optionalSession.HasValue()); // at time 300, this expires ip addr 1 connections.ExpireInactiveSessions(150_ms64, [&callInfo](const SecureSession & state) { @@ -170,21 +167,22 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, callInfo.callCount == 1); NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer1NodeId); NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer1Addr); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2).HasValue()); // now that the connections were expired, we can add peer3 clock.SetMonotonic(300_ms64); // Node ID 3, peer key 5, local key 6 - statePtr = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */, - gDefaultMRPConfig); - NL_TEST_ASSERT(inSuite, statePtr != nullptr); - statePtr->SetPeerAddress(kPeer3Addr); + optionalSession = connections.CreateNewSecureSession(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */, + gDefaultMRPConfig); + NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); + optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer3Addr); clock.SetMonotonic(400_ms64); - NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSessionByLocalKey(4)); + optionalSession = connections.FindSecureSessionByLocalKey(4); + NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - statePtr->MarkActive(); - NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTime() == clock.GetMonotonicTimestamp()); + optionalSession.Value()->AsSecureSession()->MarkActive(); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetLastActivityTime() == clock.GetMonotonicTimestamp()); // At this time: // Peer 3 active at time 300 @@ -202,17 +200,17 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, callInfo.callCount == 1); NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer3NodeId); NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer3Addr); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2)); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4)); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2).HasValue()); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4).HasValue()); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6).HasValue()); // Node ID 1, peer key 1, local key 2 - statePtr = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */, - gDefaultMRPConfig); - NL_TEST_ASSERT(inSuite, statePtr != nullptr); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2)); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4)); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6)); + optionalSession = connections.CreateNewSecureSession(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */, + gDefaultMRPConfig); + NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2).HasValue()); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4).HasValue()); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6).HasValue()); // peer 1 and 2 are active clock.SetMonotonic(1000_ms64); @@ -223,9 +221,9 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) callInfo.lastCallPeerAddress = state.GetPeerAddress(); }); NL_TEST_ASSERT(inSuite, callInfo.callCount == 2); // everything expired - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2)); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(4)); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2).HasValue()); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(4).HasValue()); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6).HasValue()); System::Clock::Internal::SetSystemClockForTesting(realClock); } diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index b39cd4aa5d737c..acdc76e542afed 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -59,15 +59,20 @@ constexpr NodeId kDestinationNodeId = 111222333; const char LARGE_PAYLOAD[kMaxAppMessageLen + 1] = "test message"; -class TestSessMgrCallback : public SessionReleaseDelegate, public SessionMessageDelegate +class TestSessionReleaseCallback : public SessionReleaseDelegate +{ +public: + void OnSessionReleased() override { mOldConnectionDropped = true; } + bool mOldConnectionDropped = false; +}; + +class TestSessMgrCallback : public SessionMessageDelegate { public: void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, const SessionHandle & session, const Transport::PeerAddress & source, DuplicateMessage isDuplicate, System::PacketBufferHandle && msgBuf) override { - NL_TEST_ASSERT(mSuite, mRemoteToLocalSession.Contains(session)); // Packet received by remote peer - size_t data_len = msgBuf->DataLength(); if (LargeMessageSent) @@ -84,17 +89,9 @@ class TestSessMgrCallback : public SessionReleaseDelegate, public SessionMessage ReceiveHandlerCallCount++; } - void OnSessionReleased(const SessionHandle & session) override { mOldConnectionDropped = true; } - - bool mOldConnectionDropped = false; - - nlTestSuite * mSuite = nullptr; - SessionHolder mRemoteToLocalSession; - SessionHolder mLocalToRemoteSession; - int ReceiveHandlerCallCount = 0; - int NewConnectionHandlerCallCount = 0; - - bool LargeMessageSent = false; + nlTestSuite * mSuite = nullptr; + int ReceiveHandlerCallCount = 0; + bool LargeMessageSent = false; }; void CheckSimpleInitTest(nlTestSuite * inSuite, void * inContext) @@ -145,19 +142,19 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) sessionManager.SetMessageDelegate(&callback); Optional peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); + SessionHolder localToRemoteSession; + SessionHolder remoteToLocalSession; SecurePairingUsingTestSecret pairing1(1, 2); - err = sessionManager.NewPairing(callback.mRemoteToLocalSession, peer, kSourceNodeId, &pairing1, - CryptoContext::SessionRole::kInitiator, 1); + err = + sessionManager.NewPairing(localToRemoteSession, peer, kSourceNodeId, &pairing1, CryptoContext::SessionRole::kInitiator, 1); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); SecurePairingUsingTestSecret pairing2(2, 1); - err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kDestinationNodeId, &pairing2, + err = sessionManager.NewPairing(remoteToLocalSession, peer, kDestinationNodeId, &pairing2, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Get(); - // Should be able to send a message to itself by just calling send. callback.ReceiveHandlerCallCount = 0; @@ -170,10 +167,10 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) payloadHeader.SetMessageType(chip::Protocols::Echo::MsgType::EchoRequest); EncryptedPacketBufferHandle preparedMessage; - err = sessionManager.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + err = sessionManager.PrepareMessage(localToRemoteSession.Get(), payloadHeader, std::move(buffer), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage); + err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -184,10 +181,10 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) callback.LargeMessageSent = true; - err = sessionManager.PrepareMessage(localToRemoteSession, payloadHeader, std::move(large_buffer), preparedMessage); + err = sessionManager.PrepareMessage(localToRemoteSession.Get(), payloadHeader, std::move(large_buffer), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage); + err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2); @@ -200,7 +197,7 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) callback.LargeMessageSent = true; - err = sessionManager.PrepareMessage(localToRemoteSession, payloadHeader, std::move(extra_large_buffer), preparedMessage); + err = sessionManager.PrepareMessage(localToRemoteSession.Get(), payloadHeader, std::move(extra_large_buffer), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_MESSAGE_TOO_LONG); sessionManager.Shutdown(); @@ -237,19 +234,19 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) sessionManager.SetMessageDelegate(&callback); Optional peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); + SessionHolder localToRemoteSession; + SessionHolder remoteToLocalSession; SecurePairingUsingTestSecret pairing1(1, 2); - err = sessionManager.NewPairing(callback.mRemoteToLocalSession, peer, kSourceNodeId, &pairing1, - CryptoContext::SessionRole::kInitiator, 1); + err = + sessionManager.NewPairing(localToRemoteSession, peer, kSourceNodeId, &pairing1, CryptoContext::SessionRole::kInitiator, 1); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); SecurePairingUsingTestSecret pairing2(2, 1); - err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kDestinationNodeId, &pairing2, + err = sessionManager.NewPairing(remoteToLocalSession, peer, kDestinationNodeId, &pairing2, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Get(); - // Should be able to send a message to itself by just calling send. callback.ReceiveHandlerCallCount = 0; @@ -264,19 +261,19 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) payloadHeader.SetInitiator(true); - err = sessionManager.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + err = sessionManager.PrepareMessage(localToRemoteSession.Get(), payloadHeader, std::move(buffer), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage); + err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); // Reset receive side message counter, or duplicated message will be denied. - Transport::SecureSession * state = sessionManager.GetSecureSession(callback.mRemoteToLocalSession.Get()); - state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); + Transport::SecureSession * session = remoteToLocalSession.Get()->AsSecureSession(); + session->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); - err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage); + err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2); @@ -315,19 +312,19 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) sessionManager.SetMessageDelegate(&callback); Optional peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); + SessionHolder localToRemoteSession; + SessionHolder remoteToLocalSession; SecurePairingUsingTestSecret pairing1(1, 2); - err = sessionManager.NewPairing(callback.mRemoteToLocalSession, peer, kSourceNodeId, &pairing1, - CryptoContext::SessionRole::kInitiator, 1); + err = + sessionManager.NewPairing(localToRemoteSession, peer, kSourceNodeId, &pairing1, CryptoContext::SessionRole::kInitiator, 1); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); SecurePairingUsingTestSecret pairing2(2, 1); - err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kDestinationNodeId, &pairing2, + err = sessionManager.NewPairing(remoteToLocalSession, peer, kDestinationNodeId, &pairing2, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Get(); - // Should be able to send a message to itself by just calling send. callback.ReceiveHandlerCallCount = 0; @@ -342,23 +339,21 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) payloadHeader.SetInitiator(true); - err = sessionManager.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + err = sessionManager.PrepareMessage(localToRemoteSession.Get(), payloadHeader, std::move(buffer), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage); + err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); /* -------------------------------------------------------------------------------------------*/ // Reset receive side message counter, or duplicated message will be denied. - Transport::SecureSession * state = sessionManager.GetSecureSession(callback.mRemoteToLocalSession.Get()); - state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); + Transport::SecureSession * session = remoteToLocalSession.Get()->AsSecureSession(); + session->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); PacketHeader packetHeader; - state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); - // Change Message ID EncryptedPacketBufferHandle badMessageCounterMsg = preparedMessage.CloneData(); NL_TEST_ASSERT(inSuite, badMessageCounterMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); @@ -367,13 +362,13 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) packetHeader.SetMessageCounter(messageCounter + 1); NL_TEST_ASSERT(inSuite, badMessageCounterMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = sessionManager.SendPreparedMessage(localToRemoteSession, badMessageCounterMsg); + err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), badMessageCounterMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); /* -------------------------------------------------------------------------------------------*/ - state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); + session->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); // Change Key ID EncryptedPacketBufferHandle badKeyIdMsg = preparedMessage.CloneData(); @@ -383,16 +378,16 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) packetHeader.SetSessionId(3); NL_TEST_ASSERT(inSuite, badKeyIdMsg.InsertPacketHeader(packetHeader) == CHIP_NO_ERROR); - err = sessionManager.SendPreparedMessage(localToRemoteSession, badKeyIdMsg); + err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), badKeyIdMsg); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); /* -------------------------------------------------------------------------------------------*/ - state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); + session->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); // Send the correct encrypted msg - err = sessionManager.SendPreparedMessage(localToRemoteSession, preparedMessage); + err = sessionManager.SendPreparedMessage(localToRemoteSession.Get(), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2); @@ -418,51 +413,46 @@ void StaleConnectionDropTest(nlTestSuite * inSuite, void * inContext) err = sessionManager.Init(&ctx.GetSystemLayer(), &transportMgr, &gMessageCounterManager); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - TestSessMgrCallback callback; - callback.mSuite = inSuite; - - sessionManager.RegisterReleaseDelegate(callback); - sessionManager.SetMessageDelegate(&callback); - Optional peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); + TestSessionReleaseCallback callback; + SessionHolderWithDelegate session1(callback); + SessionHolderWithDelegate session2(callback); + SessionHolderWithDelegate session3(callback); + SessionHolderWithDelegate session4(callback); + SessionHolderWithDelegate session5(callback); // First pairing SecurePairingUsingTestSecret pairing1(1, 1); callback.mOldConnectionDropped = false; - err = sessionManager.NewPairing(callback.mRemoteToLocalSession, peer, kSourceNodeId, &pairing1, - CryptoContext::SessionRole::kInitiator, 1); + err = sessionManager.NewPairing(session1, peer, kSourceNodeId, &pairing1, CryptoContext::SessionRole::kInitiator, 1); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); // New pairing with different peer node ID and different local key ID (same peer key ID) SecurePairingUsingTestSecret pairing2(1, 2); callback.mOldConnectionDropped = false; - err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kSourceNodeId, &pairing2, - CryptoContext::SessionRole::kResponder, 0); + err = sessionManager.NewPairing(session2, peer, kSourceNodeId, &pairing2, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); // New pairing with undefined node ID and different local key ID (same peer key ID) SecurePairingUsingTestSecret pairing3(1, 3); callback.mOldConnectionDropped = false; - err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kUndefinedNodeId, &pairing3, - CryptoContext::SessionRole::kResponder, 0); + err = sessionManager.NewPairing(session3, peer, kUndefinedNodeId, &pairing3, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); // New pairing with same local key ID, and a given node ID SecurePairingUsingTestSecret pairing4(1, 2); callback.mOldConnectionDropped = false; - err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kSourceNodeId, &pairing4, - CryptoContext::SessionRole::kResponder, 0); + err = sessionManager.NewPairing(session4, peer, kSourceNodeId, &pairing4, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.mOldConnectionDropped); // New pairing with same local key ID, and undefined node ID SecurePairingUsingTestSecret pairing5(1, 1); callback.mOldConnectionDropped = false; - err = sessionManager.NewPairing(callback.mLocalToRemoteSession, peer, kUndefinedNodeId, &pairing5, - CryptoContext::SessionRole::kResponder, 0); + err = sessionManager.NewPairing(session5, peer, kUndefinedNodeId, &pairing5, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.mOldConnectionDropped); diff --git a/third_party/nxp/k32w0_sdk/sdk_fixes/patch_ble_utils_h.patch b/third_party/nxp/k32w0_sdk/sdk_fixes/patch_ble_utils_h.patch new file mode 100644 index 00000000000000..bb2051b158ecf9 --- /dev/null +++ b/third_party/nxp/k32w0_sdk/sdk_fixes/patch_ble_utils_h.patch @@ -0,0 +1,14 @@ +--- a/ble_utils.h 2022-01-07 16:41:34.017433835 +0000 ++++ b/ble_utils.h 2022-01-07 16:42:09.797788620 +0000 +@@ -73,11 +73,6 @@ + #define ALIGN_64BIT #pragma pack(8) + #endif + +-/*! Marks that this variable is in the interface. */ +-#ifndef global +-#define global +-#endif +- + /*! Marks a function that never returns. */ + #if !defined(__IAR_SYSTEMS_ICC__) + #ifndef __noreturn diff --git a/third_party/nxp/k32w0_sdk/sdk_fixes/patch_k32w_sdk.sh b/third_party/nxp/k32w0_sdk/sdk_fixes/patch_k32w_sdk.sh index 378feb75f3ceba..d3f262f11595e6 100755 --- a/third_party/nxp/k32w0_sdk/sdk_fixes/patch_k32w_sdk.sh +++ b/third_party/nxp/k32w0_sdk/sdk_fixes/patch_k32w_sdk.sh @@ -5,6 +5,9 @@ if [[ ! -d $NXP_K32W061_SDK_ROOT ]]; then exit 1 fi +SOURCE=${BASH_SOURCE[0]} +SOURCE_DIR=$(cd "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd) + cp ./third_party/nxp/k32w0_sdk/sdk_fixes/gpio_pins.h "$NXP_K32W061_SDK_ROOT"/boards/k32w061dk6/wireless_examples/openthread/reed/bm/ cp ./third_party/nxp/k32w0_sdk/sdk_fixes/pin_mux.c "$NXP_K32W061_SDK_ROOT"/boards/k32w061dk6/wireless_examples/openthread/enablement/ @@ -17,5 +20,7 @@ cp -r ./third_party/nxp/k32w0_sdk/sdk_fixes/app_dual_mode_switch.h "$NXP_K32W061 cp -r ./third_party/nxp/k32w0_sdk/sdk_fixes/OtaUtils.c "$NXP_K32W061_SDK_ROOT"/middleware/wireless/framework/OtaSupport/Source/ +patch -d "$NXP_K32W061_SDK_ROOT"/middleware/wireless/bluetooth/host/interface -p1 <"$SOURCE_DIR/patch_ble_utils_h.patch" + echo "K32W SDK MR3 QP1 was patched!" exit 0