diff --git a/src/lib/core/ReferenceCounted.h b/src/lib/core/ReferenceCounted.h index 2d59af288c842c..431899abed8afa 100644 --- a/src/lib/core/ReferenceCounted.h +++ b/src/lib/core/ReferenceCounted.h @@ -39,6 +39,13 @@ class DeleteDeletor static void Release(T * obj) { chip::Platform::Delete(obj); } }; +template +class NoopDeletor +{ +public: + static void Release(T * obj) {} +}; + /** * A reference counted object maintains a count of usages and when the usage * count drops to 0, it deletes itself. diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 3bd59609910af1..b1573e1859e000 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -129,6 +129,11 @@ CASESession::~CASESession() Clear(); } +void CASESession::OnSessionReleased() +{ + // TODO: interrupt pairing procedure, call OnSessionEstablishmentError, if the pairing is not finished +} + void CASESession::Finish() { mCASESessionEstablished = true; diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 2d15387b0d0b8a..82d7b1e6a9605b 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -154,6 +154,9 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, void OnResponseTimeout(Messaging::ExchangeContext * ec) override; Messaging::ExchangeMessageDispatch & GetMessageDispatch() override { return SessionEstablishmentExchangeDispatch::Instance(); } + //// SessionReleaseDelegate //// + void OnSessionReleased() override; + FabricIndex GetFabricIndex() const { return mFabricInfo != nullptr ? mFabricInfo->GetFabricIndex() : kUndefinedFabricIndex; } // TODO: remove Clear, we should create a new instance instead reset the old instance. diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index cbe67852f698d1..3abe6e6314ace7 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -69,6 +69,11 @@ PASESession::~PASESession() Clear(); } +void PASESession::OnSessionReleased() +{ + // TODO: interrupt pairing procedure, call OnSessionEstablishmentError, if the pairing is not finished +} + void PASESession::Finish() { mPairingComplete = true; diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index dfa3b826bdaeec..32fe93ddf102bb 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -172,6 +172,9 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler, Messaging::ExchangeMessageDispatch & GetMessageDispatch() override { return SessionEstablishmentExchangeDispatch::Instance(); } + //// SessionReleaseDelegate //// + void OnSessionReleased() override; + private: enum Spake2pErrorType : uint8_t { diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 8e214221695dea..f7e1934f024899 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -183,11 +183,12 @@ CHIP_ERROR InitCredentialSets() void CASE_SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) { + SessionManager sessionManager; + // Test all combinations of invalid parameters TestCASESecurePairingDelegate delegate; CASESession pairing; FabricTable fabrics; - SessionManager sessionManager; NL_TEST_ASSERT(inSuite, pairing.GetSecureSessionType() == SecureSession::Type::kCASE); @@ -202,6 +203,7 @@ void CASE_SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); + SessionManager sessionManager; // Test all combinations of invalid parameters TestCASESecurePairingDelegate delegate; @@ -210,7 +212,6 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) FabricInfo * fabric = gCommissionerFabrics.FindFabricWithIndex(gCommissionerFabricIndex); NL_TEST_ASSERT(inSuite, fabric != nullptr); - SessionManager sessionManager; ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing); @@ -250,15 +251,19 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) gLoopback.mMessageSendError = CHIP_NO_ERROR; } -void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, CASESession & pairingCommissioner, - TestCASESecurePairingDelegate & delegateCommissioner) + +void CASE_SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); + SessionManager sessionManager; + + TestCASESecurePairingDelegate delegateCommissioner; + CASESession pairingCommissioner; + pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider); // Test all combinations of invalid parameters TestCASESecurePairingDelegate delegateAccessory; CASESession pairingAccessory; - SessionManager sessionManager; gLoopback.mSentMessageCount = 0; @@ -285,14 +290,6 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 1); } -void CASE_SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) -{ - TestCASESecurePairingDelegate delegateCommissioner; - CASESession pairingCommissioner; - pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider); - CASE_SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, delegateCommissioner); -} - CASEServerForTest gPairingServer; void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inContext) diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index 598dee3c53dd5d..3142c29654778d 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -116,11 +116,11 @@ using namespace System::Clock::Literals; void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); + SessionManager sessionManager; // Test all combinations of invalid parameters TestSecurePairingDelegate delegate; PASESession pairing; - SessionManager sessionManager; NL_TEST_ASSERT(inSuite, pairing.GetSecureSessionType() == SecureSession::Type::kPASE); @@ -156,11 +156,11 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); + SessionManager sessionManager; // Test all combinations of invalid parameters TestSecurePairingDelegate delegate; PASESession pairing; - SessionManager sessionManager; gLoopback.Reset(); @@ -197,7 +197,8 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) gLoopback.mMessageSendError = CHIP_NO_ERROR; } -void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, PASESession & pairingCommissioner, +void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, SessionManager & sessionManager, + PASESession & pairingCommissioner, Optional mrpCommissionerConfig, Optional mrpAccessoryConfig, TestSecurePairingDelegate & delegateCommissioner) @@ -206,7 +207,6 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P TestSecurePairingDelegate delegateAccessory; PASESession pairingAccessory; - SessionManager sessionManager; gLoopback.mSentMessageCount = 0; @@ -279,53 +279,58 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) { + SessionManager sessionManager; TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; gLoopback.Reset(); - SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional::Missing(), + SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, Optional::Missing(), Optional::Missing(), delegateCommissioner); } void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void * inContext) { + SessionManager sessionManager; TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; gLoopback.Reset(); ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32); - SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, + SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, Optional::Value(config), Optional::Missing(), delegateCommissioner); } void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inContext) { + SessionManager sessionManager; TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; gLoopback.Reset(); ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32); - SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional::Missing(), + SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, Optional::Missing(), Optional::Value(config), delegateCommissioner); } void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContext) { + SessionManager sessionManager; TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; gLoopback.Reset(); ReliableMessageProtocolConfig commissionerConfig(1000_ms32, 10000_ms32); ReliableMessageProtocolConfig deviceConfig(2000_ms32, 7000_ms32); - SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, + SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, Optional::Value(commissionerConfig), Optional::Value(deviceConfig), delegateCommissioner); } void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inContext) { + SessionManager sessionManager; TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; gLoopback.Reset(); gLoopback.mNumMessagesToDrop = 2; - SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional::Missing(), + SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, Optional::Missing(), Optional::Missing(), delegateCommissioner); NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 2); NL_TEST_ASSERT(inSuite, gLoopback.mNumMessagesToDrop == 0); @@ -334,6 +339,7 @@ void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inCo void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); + SessionManager sessionManager; TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; @@ -341,8 +347,6 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) TestSecurePairingDelegate delegateAccessory; PASESession pairingAccessory; - SessionManager sessionManager; - gLoopback.Reset(); gLoopback.mSentMessageCount = 0; diff --git a/src/transport/GroupSession.h b/src/transport/GroupSession.h index c3118653b9e384..c989713d0e4ee7 100644 --- a/src/transport/GroupSession.h +++ b/src/transport/GroupSession.h @@ -18,20 +18,28 @@ #include #include +#include #include #include namespace chip { namespace Transport { -class IncomingGroupSession : public Session +class IncomingGroupSession : public Session, public ReferenceCounted, 0> { public: IncomingGroupSession(GroupId group, FabricIndex fabricIndex, NodeId sourceNodeId) : mGroupId(group), mSourceNodeId(sourceNodeId) { SetFabricIndex(fabricIndex); } - ~IncomingGroupSession() override { NotifySessionReleased(); } + ~IncomingGroupSession() override + { + NotifySessionReleased(); + VerifyOrDie(GetReferenceCount() == 0); + } + + void Retain() override { ReferenceCounted, 0>::Retain(); } + void Release() override { ReferenceCounted, 0>::Release(); } Session::SessionType GetSessionType() const override { return Session::SessionType::kGroupIncoming; } #if CHIP_PROGRESS_LOGGING @@ -75,11 +83,18 @@ class IncomingGroupSession : public Session const NodeId mSourceNodeId; }; -class OutgoingGroupSession : public Session +class OutgoingGroupSession : public Session, public ReferenceCounted, 0> { public: OutgoingGroupSession(GroupId group, FabricIndex fabricIndex) : mGroupId(group) { SetFabricIndex(fabricIndex); } - ~OutgoingGroupSession() override { NotifySessionReleased(); } + ~OutgoingGroupSession() override + { + NotifySessionReleased(); + VerifyOrDie(GetReferenceCount() == 0); + } + + void Retain() override { ReferenceCounted, 0>::Retain(); } + void Release() override { ReferenceCounted, 0>::Release(); } Session::SessionType GetSessionType() const override { return Session::SessionType::kGroupOutgoing; } #if CHIP_PROGRESS_LOGGING diff --git a/src/transport/PairingSession.cpp b/src/transport/PairingSession.cpp index cd9da86f2eb60a..ebd934a80b1138 100644 --- a/src/transport/PairingSession.cpp +++ b/src/transport/PairingSession.cpp @@ -25,10 +25,9 @@ namespace chip { CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager) { - auto handle = sessionManager.AllocateSession(); + auto handle = sessionManager.AllocateSession(GetSecureSessionType()); VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_NO_MEMORY); mSecureSessionHolder.Grab(handle.Value()); - mSessionManager = &sessionManager; return CHIP_NO_ERROR; } @@ -39,7 +38,7 @@ CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress & uint16_t peerSessionId = GetPeerSessionId(); ChipLogDetail(Inet, "New secure session created for device " ChipLogFormatScopedNodeId ", LSID:%d PSID:%d!", ChipLogValueScopedNodeId(GetPeer()), secureSession->GetLocalSessionId(), peerSessionId); - secureSession->Activate(GetSecureSessionType(), GetPeer(), GetPeerCATs(), peerSessionId, mRemoteMRPConfig); + secureSession->Activate(GetPeer(), GetPeerCATs(), peerSessionId, mRemoteMRPConfig); secureSession->SetPeerAddress(peerAddress); ReturnErrorOnFailure(DeriveSecureSession(secureSession->GetCryptoContext())); secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue); @@ -102,22 +101,8 @@ CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::Tag expectedTag, TL void PairingSession::Clear() { - if (mSessionManager != nullptr) - { - if (mSecureSessionHolder && !mSecureSessionHolder->AsSecureSession()->IsActiveSession()) - { - // Make sure to clean up our pending session, since we're the only - // ones who have access to it do do so. - mSessionManager->ExpirePairing(mSecureSessionHolder.Get()); - } - } - mPeerSessionId.ClearValue(); - // If we called ExpirePairing above, the holder has already released the - // session (due to it being destroyed). If not, we need to release it. - // Release is idempotent, so it's OK to just call it here. mSecureSessionHolder.Release(); - mSessionManager = nullptr; } } // namespace chip diff --git a/src/transport/PairingSession.h b/src/transport/PairingSession.h index c2d7576c20ada9..4d9b23d1f18209 100644 --- a/src/transport/PairingSession.h +++ b/src/transport/PairingSession.h @@ -37,9 +37,10 @@ namespace chip { class SessionManager; -class DLL_EXPORT PairingSession +class DLL_EXPORT PairingSession : public SessionReleaseDelegate { public: + PairingSession() : mSecureSessionHolder(*this) {} virtual ~PairingSession() { Clear(); } virtual Transport::SecureSession::Type GetSecureSessionType() const = 0; @@ -164,10 +165,7 @@ class DLL_EXPORT PairingSession protected: CryptoContext::SessionRole mRole; - SessionHolder mSecureSessionHolder; - // mSessionManager is set if we actually allocate a secure session, so we - // can clean it up later as needed. - SessionManager * mSessionManager = nullptr; + SessionHolderWithDelegate mSecureSessionHolder; // mLocalMRPConfig is our config which is sent to the other end and used by the peer session. // mRemoteMRPConfig is received from other end and set to our session. diff --git a/src/transport/SecureSession.cpp b/src/transport/SecureSession.cpp index 2a55e03e5bbd9a..34bd00c44d21dc 100644 --- a/src/transport/SecureSession.cpp +++ b/src/transport/SecureSession.cpp @@ -16,10 +16,16 @@ #include #include +#include namespace chip { namespace Transport { +void SecureSessionDeleter::Release(SecureSession * entry) +{ + entry->mTable.ReleaseSession(entry); +} + ScopedNodeId SecureSession::GetPeer() const { return ScopedNodeId(mPeerNodeId, GetFabricIndex()); diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h index 49fea973b87127..ef89689ddb381a 100644 --- a/src/transport/SecureSession.h +++ b/src/transport/SecureSession.h @@ -22,19 +22,22 @@ #pragma once #include -#include +#include #include #include #include #include -#include -#include #include namespace chip { namespace Transport { -static constexpr uint32_t kUndefinedMessageIndex = UINT32_MAX; +class SecureSessionTable; +class SecureSessionDeleter +{ +public: + static void Release(SecureSession * entry); +}; /** * Defines state of a peer connection at a transport layer. @@ -49,7 +52,7 @@ static constexpr uint32_t kUndefinedMessageIndex = UINT32_MAX; * last used. Inactive connections can expire. * - CryptoContext contains the encryption context of a connection */ -class SecureSession : public Session +class SecureSession : public Session, public ReferenceCounted { public: /** @@ -60,25 +63,17 @@ class SecureSession : public Session { kPASE = 1, kCASE = 2, - // kPending denotes a secure session object that is internally - // reserved by the stack before and during session establishment. - // - // Although the stack can tolerate eviction of these (releasing one - // out from under the holder would exhibit as CHIP_ERROR_INCORRECT_STATE - // during CASE or PASE), intent is that we should not and would leave - // these untouched until CASE or PASE complete. - kPending = 3, }; - // TODO: This constructor should be private. Tests should allocate a - // kPending session and then call Activate(), just like non-test code does. - SecureSession(Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId, + // Test-only: inject a session in Active state. + SecureSession(SecureSessionTable & table, Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId, FabricIndex fabric, const ReliableMessageProtocolConfig & config) : - mSecureSessionType(secureSessionType), + mTable(table), mState(State::kActive), mSecureSessionType(secureSessionType), mPeerNodeId(peerNodeId), mPeerCATs(peerCATs), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mLastPeerActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config) { + Retain(); // Put the test session in Active state SetFabricIndex(fabric); } @@ -88,8 +83,8 @@ class SecureSession : public Session * session establishment attempt. The object for the pending session * receives a local session ID, but no other state. */ - SecureSession(uint16_t localSessionId) : - SecureSession(Type::kPending, localSessionId, kUndefinedNodeId, CATValues{}, 0, kUndefinedFabricIndex, GetLocalMRPConfig()) + SecureSession(SecureSessionTable & table, Type secureSessionType, uint16_t localSessionId) : + mTable(table), mState(State::kPairing), mSecureSessionType(secureSessionType), mLocalSessionId(localSessionId) {} /** @@ -98,23 +93,48 @@ class SecureSession : public Session * PASE, setting internal state according to the parameters used and * discovered during session establishment. */ - void Activate(Type secureSessionType, const ScopedNodeId & peer, CATValues peerCATs, uint16_t peerSessionId, - const ReliableMessageProtocolConfig & config) + void Activate(const ScopedNodeId & peer, CATValues peerCATs, uint16_t peerSessionId, const ReliableMessageProtocolConfig & config) { - mSecureSessionType = secureSessionType; + VerifyOrDie(mState == State::kPairing); + Retain(); mPeerNodeId = peer.GetNodeId(); mPeerCATs = peerCATs; mPeerSessionId = peerSessionId; mMRPConfig = config; SetFabricIndex(peer.GetFabricIndex()); } - ~SecureSession() override { NotifySessionReleased(); } + ~SecureSession() override {} SecureSession(SecureSession &&) = delete; SecureSession(const SecureSession &) = delete; SecureSession & operator=(const SecureSession &) = delete; SecureSession & operator=(SecureSession &&) = delete; + void Retain() override { ReferenceCounted::Retain(); } + void Release() override { ReferenceCounted::Release(); } + + bool IsPendingRemoval() const override { return mState == State::kPendingRemoval; } + /// @brief Mark as pending removal, all holders to this session will be cleared, and disallow future grab + void MarkForRemoval() + { + switch (mState) + { + case State::kPairing: + mState = State::kPendingRemoval; + // Interrupt the pairing + NotifySessionReleased(); + return; + case State::kActive: + Release(); // Decrease the ref which is retained at Activate + mState = State::kPendingRemoval; + NotifySessionReleased(); + return; + case State::kPendingRemoval: + // Do nothing + return; + } + } + Session::SessionType GetSessionType() const override { return Session::SessionType::kSecure; } #if CHIP_PROGRESS_LOGGING const char * GetSessionTypeString() const override { return "secure"; }; @@ -145,7 +165,7 @@ class SecureSession : public Session Type GetSecureSessionType() const { return mSecureSessionType; } bool IsCASESession() const { return GetSecureSessionType() == Type::kCASE; } bool IsPASESession() const { return GetSecureSessionType() == Type::kPASE; } - bool IsActiveSession() const { return GetSecureSessionType() != Type::kPending; } + bool IsActiveSession() const { return mState == State::kActive; } NodeId GetPeerNodeId() const { return mPeerNodeId; } CATValues GetPeerCATs() const { return mPeerCATs; } @@ -190,16 +210,44 @@ class SecureSession : public Session SessionMessageCounter & GetSessionMessageCounter() { return mSessionMessageCounter; } private: - Type mSecureSessionType; - NodeId mPeerNodeId; - CATValues mPeerCATs; + enum class State : uint8_t + { + // kPending denotes a secure session object that is internally + // reserved by the stack before and during session establishment. + // + // Although the stack can tolerate eviction of these (releasing one + // out from under the holder would exhibit as CHIP_ERROR_INCORRECT_STATE + // during CASE or PASE), intent is that we should not and would leave + // these untouched until CASE or PASE complete. + // + // During this stage, the reference counter is hold by the PairingSession + kPairing = 1, + + // The session is active, ready for use. During this stage, the + // reference counter increased by 1 in Activate, and will be decreased + // by 1 when MarkForRemoval is called. + kActive = 2, + + // The session is pending for removal, all SessionHolders are already + // cleared during MarkForRemoval, no future SessionHolder is able grab + // this session, when all SessionHandles goes out of scope, the session + // object will be released automatically. + kPendingRemoval = 3, + }; + + friend class SecureSessionDeleter; + SecureSessionTable & mTable; + State mState; + const Type mSecureSessionType; + NodeId mPeerNodeId = kUndefinedNodeId; + CATValues mPeerCATs = CATValues{}; const uint16_t mLocalSessionId; - uint16_t mPeerSessionId; + uint16_t mPeerSessionId = 0; PeerAddress mPeerAddress; - System::Clock::Timestamp mLastActivityTime; ///< Timestamp of last tx or rx - System::Clock::Timestamp mLastPeerActivityTime; ///< Timestamp of last rx - ReliableMessageProtocolConfig mMRPConfig; + System::Clock::Timestamp mLastActivityTime = System::SystemClock().GetMonotonicTimestamp(); ///< Timestamp of last tx or rx + System::Clock::Timestamp mLastPeerActivityTime = System::SystemClock().GetMonotonicTimestamp(); ///< Timestamp of last rx + ReliableMessageProtocolConfig mMRPConfig = GetLocalMRPConfig(); CryptoContext mCryptoContext; SessionMessageCounter mSessionMessageCounter; }; diff --git a/src/transport/SecureSessionTable.h b/src/transport/SecureSessionTable.h index 1804ea412c88d9..0f95fc0871b37c 100644 --- a/src/transport/SecureSessionTable.h +++ b/src/transport/SecureSessionTable.h @@ -35,7 +35,6 @@ constexpr uint16_t kUnsecuredSessionId = 0; * - handle session active time and expiration * - allocate and free space for sessions. */ -template class SecureSessionTable { public: @@ -65,37 +64,10 @@ class SecureSessionTable FabricIndex fabric, const ReliableMessageProtocolConfig & config) { SecureSession * result = - mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config); + mEntries.CreateObject(*this, secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config); return result != nullptr ? MakeOptional(*result) : Optional::Missing(); } - /** - * Allocate a new secure session out of the internal resource pool with the - * specified session ID. The returned secure session will not become active - * until the call to SecureSession::Activate. If there is a resident - * session at the passed ID, an empty Optional will be returned to signal - * the error. - * - * This variant of the interface is primarily useful in testing, where - * session IDs may need to be predetermined. - * - * @param localSessionId unique identifier for the local node's secure unicast session context - * @returns allocated session, or NullOptional on failure - */ - CHECK_RETURN_VALUE - Optional CreateNewSecureSession(uint16_t localSessionId) - { - Optional rv = Optional::Missing(); - SecureSession * allocated = nullptr; - VerifyOrExit(localSessionId != kUnsecuredSessionId, rv = NullOptional); - VerifyOrExit(!FindSecureSessionByLocalKey(localSessionId).HasValue(), rv = NullOptional); - allocated = mEntries.CreateObject(localSessionId); - VerifyOrExit(allocated != nullptr, rv = Optional::Missing()); - rv = MakeOptional(*allocated); - exit: - return rv; - } - /** * Allocate a new secure session out of the internal resource pool with a * non-colliding session ID and increments mNextSessionId to give a clue to @@ -105,13 +77,13 @@ class SecureSessionTable * @returns allocated session, or NullOptional on failure */ CHECK_RETURN_VALUE - Optional CreateNewSecureSession() + Optional CreateNewSecureSession(SecureSession::Type secureSessionType) { Optional rv = Optional::Missing(); auto sessionId = FindUnusedSessionId(); SecureSession * allocated = nullptr; VerifyOrExit(sessionId.HasValue(), rv = Optional::Missing()); - allocated = mEntries.CreateObject(sessionId.Value()); + allocated = mEntries.CreateObject(*this, secureSessionType, sessionId.Value()); VerifyOrExit(allocated != nullptr, rv = Optional::Missing()); rv = MakeOptional(*allocated); mNextSessionId = sessionId.Value() == kMaxSessionID ? static_cast(kUnsecuredSessionId + 1) @@ -150,26 +122,6 @@ class SecureSessionTable return result != nullptr ? MakeOptional(*result) : Optional::Missing(); } - /** - * Iterates through all active sessions and expires any sessions with an idle time - * larger than the given amount. - * - * Expiring a session involves callback execution and then clearing the internal state. - */ - template - void ExpireInactiveSessions(System::Clock::Timestamp maxIdleTime, Callback callback) - { - mEntries.ForEachActiveObject([&](auto session) { - if (session->GetSecureSessionType() != SecureSession::Type::kPending && - session->GetLastActivityTime() + maxIdleTime < System::SystemClock().GetMonotonicTimestamp()) - { - callback(*session); - ReleaseSession(session); - } - return Loop::Continue; - }); - } - private: /** * Find an available session ID that is unused in the secure session table. @@ -179,7 +131,7 @@ class SecureSessionTable * from the starting mNextSessionId clue. * * The outer-loop considers 64 session IDs in each iteration to give a - * runtime complexity of O(kMaxSessionCount^2/64). Speed up could be + * runtime complexity of O(CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE^2/64). Speed up could be * achieved with a sorted session table or additional storage. * * @return an unused session ID if any is found, else NullOptional @@ -237,7 +189,7 @@ class SecureSessionTable return NullOptional; } - BitMapObjectPool mEntries; + BitMapObjectPool mEntries; uint16_t mNextSessionId = 0; }; diff --git a/src/transport/Session.h b/src/transport/Session.h index b51c78964cf3d2..df29c27b567cd9 100644 --- a/src/transport/Session.h +++ b/src/transport/Session.h @@ -65,9 +65,10 @@ class Session mHolders.Remove(&holder); } - // For types of sessions using reference counter, override these functions, otherwise leave it empty. - virtual void Retain() {} - virtual void Release() {} + virtual void Retain() = 0; + virtual void Release() = 0; + + virtual bool IsPendingRemoval() const { return false; } virtual ScopedNodeId GetPeer() const = 0; virtual Access::SubjectDescriptor GetSubjectDescriptor() const = 0; diff --git a/src/transport/SessionHandle.h b/src/transport/SessionHandle.h index 344dab3580b0ec..6d20a3713cc36f 100644 --- a/src/transport/SessionHandle.h +++ b/src/transport/SessionHandle.h @@ -26,12 +26,12 @@ namespace Transport { class Session; } // namespace Transport -class SessionHolder; - /** @brief - * Non-copyable session reference. All SessionHandles are created within SessionManager. SessionHandle is not - * reference *counted, hence it is not allowed to store SessionHandle anywhere except for function arguments and - * return values. SessionHandle is short-lived as it is only available as stack variable, so it is never dangling. */ + * Non-copyable session reference. All SessionHandles are created within SessionManager. It is not allowed to store SessionHandle anywhere except for function arguments and + * return values. + * + * SessionHandle is reference counted such that it never dangling, but there can be a gray period when the session is mark as pending removal but not actually removed yet. During this period, the handle is functional, but the underlying session won't be able to grabbed by any SessionHolder. SessionHandle->IsPendingRemoval can be used to check if the session is pending removal. + */ class SessionHandle { public: diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 5acaccf3cc8dd4..c1febe5863cbb0 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -346,7 +346,7 @@ CHIP_ERROR SessionManager::SendPreparedMessage(const SessionHandle & sessionHand void SessionManager::ExpirePairing(const SessionHandle & sessionHandle) { - mSecureSessions.ReleaseSession(sessionHandle->AsSecureSession()); + sessionHandle->AsSecureSession()->MarkForRemoval(); } void SessionManager::ExpireAllPairings(const ScopedNodeId & node) @@ -354,7 +354,7 @@ void SessionManager::ExpireAllPairings(const ScopedNodeId & node) mSecureSessions.ForEachSession([&](auto session) { if (session->GetPeer() == node) { - mSecureSessions.ReleaseSession(session); + session->MarkForRemoval(); } return Loop::Continue; }); @@ -366,7 +366,7 @@ void SessionManager::ExpireAllPairingsForFabric(FabricIndex fabric) mSecureSessions.ForEachSession([&](auto session) { if (session->GetFabricIndex() == fabric) { - mSecureSessions.ReleaseSession(session); + session->MarkForRemoval(); } return Loop::Continue; }); @@ -378,15 +378,15 @@ void SessionManager::ExpireAllPASEPairings() mSecureSessions.ForEachSession([&](auto session) { if (session->GetSecureSessionType() == Transport::SecureSession::Type::kPASE) { - mSecureSessions.ReleaseSession(session); + session->MarkForRemoval(); } return Loop::Continue; }); } -Optional SessionManager::AllocateSession() +Optional SessionManager::AllocateSession(SecureSession::Type secureSessionType) { - return mSecureSessions.CreateNewSecureSession(); + return mSecureSessions.CreateNewSecureSession(secureSessionType); } CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, NodeId peerNodeId, @@ -771,12 +771,6 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param) { SessionManager * mgr = reinterpret_cast(param); -#if CHIP_CONFIG_SESSION_REKEYING - // TODO(#14217): session expiration is currently disabled until rekeying is supported - // the #ifdef should be removed after that. - mgr->mSecureSessions.ExpireInactiveSessions(System::SystemClock().GetMonotonicTimestamp(), - System::Clock::Milliseconds32(CHIP_PEER_CONNECTION_TIMEOUT_MS)); -#endif mgr->ScheduleExpiryTimer(); // re-schedule the oneshot timer } diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 0a5293dbe143e6..5b044b510aeef7 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -53,8 +53,6 @@ namespace chip { -class PairingSession; - /** * @brief * Tracks ownership of a encrypted packet buffer. @@ -169,7 +167,7 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate * @return SessionHandle with a reference to a SecureSession, else NullOptional on failure */ CHECK_RETURN_VALUE - Optional AllocateSession(); + Optional AllocateSession(Transport::SecureSession::Type secureSessionType); void ExpirePairing(const SessionHandle & session); void ExpireAllPairings(const ScopedNodeId & node); @@ -262,7 +260,7 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate System::Layer * mSystemLayer = nullptr; FabricTable * mFabricTable = nullptr; Transport::UnauthenticatedSessionTable mUnauthenticatedSessions; - Transport::SecureSessionTable mSecureSessions; + Transport::SecureSessionTable mSecureSessions; State mState; // < Initialization state of the object chip::Transport::GroupOutgoingCounters mGroupClientCounter; @@ -276,8 +274,6 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate GlobalUnencryptedMessageCounter mGlobalUnencryptedMessageCounter; - friend class SessionHandle; - /** Schedules a new oneshot timer for checking connection expiry. */ void ScheduleExpiryTimer(); diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index 15a118d47ccdd1..517c7320300154 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -20,11 +20,8 @@ #include #include #include -#include -#include #include #include -#include #include #include #include @@ -32,18 +29,11 @@ namespace chip { namespace Transport { -class UnauthenticatedSessionDeleter -{ -public: - // This is a no-op because life-cycle of UnauthenticatedSessionTable is rotated by LRU - static void Release(UnauthenticatedSession * entry) {} -}; - /** * @brief * An UnauthenticatedSession stores the binding of TransportAddress, and message counters. */ -class UnauthenticatedSession : public Session, public ReferenceCounted +class UnauthenticatedSession : public Session, public ReferenceCounted, 0> { public: enum class SessionRole @@ -58,7 +48,7 @@ class UnauthenticatedSession : public Session, public ReferenceCounted::Retain(); } - void Release() override { ReferenceCounted::Release(); } + void Retain() override { ReferenceCounted, 0>::Retain(); } + void Release() override { ReferenceCounted, 0>::Release(); } ScopedNodeId GetPeer() const override { return ScopedNodeId(kUndefinedNodeId, GetFabricIndex()); } diff --git a/src/transport/tests/TestPairingSession.cpp b/src/transport/tests/TestPairingSession.cpp index 6ffd2037f02f06..60cc83694ced0a 100644 --- a/src/transport/tests/TestPairingSession.cpp +++ b/src/transport/tests/TestPairingSession.cpp @@ -45,6 +45,8 @@ class TestPairingSession : public PairingSession ScopedNodeId GetPeer() const override { return ScopedNodeId(); } CATValues GetPeerCATs() const override { return CATValues(); }; + void OnSessionReleased() override {} + const ReliableMessageProtocolConfig & GetRemoteMRPConfig() const { return mRemoteMRPConfig; } CHIP_ERROR DeriveSecureSession(CryptoContext & session) const override { return CHIP_NO_ERROR; } diff --git a/src/transport/tests/TestPeerConnections.cpp b/src/transport/tests/TestPeerConnections.cpp index 8aea395d003414..7cd0f489ca1da8 100644 --- a/src/transport/tests/TestPeerConnections.cpp +++ b/src/transport/tests/TestPeerConnections.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include @@ -46,58 +47,73 @@ PeerAddress AddressFromString(const char * str) const PeerAddress kPeer1Addr = AddressFromString("fe80::1"); const PeerAddress kPeer2Addr = AddressFromString("fe80::2"); -const PeerAddress kPeer3Addr = AddressFromString("fe80::3"); const NodeId kPeer1NodeId = 123; const NodeId kPeer2NodeId = 6; -const NodeId kPeer3NodeId = 81; const SecureSession::Type kPeer1SessionType = SecureSession::Type::kCASE; const SecureSession::Type kPeer2SessionType = SecureSession::Type::kCASE; -const SecureSession::Type kPeer3SessionType = SecureSession::Type::kPASE; const CATValues kPeer1CATs = { { 0xABCD0001, 0xABCE0100, 0xABCD0020 } }; const CATValues kPeer2CATs = { { 0xABCD0012, kUndefinedCAT, kUndefinedCAT } }; + +#if !CHIP_SYSTEM_CONFIG_POOL_USE_HEAP +const PeerAddress kPeer3Addr = AddressFromString("fe80::3"); +const NodeId kPeer3NodeId = 81; +const SecureSession::Type kPeer3SessionType = SecureSession::Type::kPASE; const CATValues kPeer3CATs; +#endif void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext) { - SecureSessionTable<2> connections; + SecureSessionTable connections; System::Clock::Internal::MockClock clock; System::Clock::ClockBase * realClock = &System::SystemClock(); System::Clock::Internal::SetSystemClockForTesting(&clock); clock.SetMonotonic(100_ms64); CATValues peerCATs; + Optional sessions[CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE]; // Node ID 1, peer key 1, local key 2 - auto optionalSession = connections.CreateNewSecureSessionForTest(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, + sessions[0] = connections.CreateNewSecureSessionForTest(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, 0 /* fabricIndex */, GetLocalMRPConfig()); - NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetSecureSessionType() == kPeer1SessionType); - NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetPeerNodeId() == kPeer1NodeId); - peerCATs = optionalSession.Value()->AsSecureSession()->GetPeerCATs(); + NL_TEST_ASSERT(inSuite, sessions[0].HasValue()); + NL_TEST_ASSERT(inSuite, sessions[0].Value()->AsSecureSession()->GetSecureSessionType() == kPeer1SessionType); + NL_TEST_ASSERT(inSuite, sessions[0].Value()->AsSecureSession()->GetPeerNodeId() == kPeer1NodeId); + peerCATs = sessions[0].Value()->AsSecureSession()->GetPeerCATs(); NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kPeer1CATs, sizeof(CATValues)) == 0); // Node ID 2, peer key 3, local key 4 - optionalSession = connections.CreateNewSecureSessionForTest(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, + sessions[1] = connections.CreateNewSecureSessionForTest(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, 0 /* fabricIndex */, GetLocalMRPConfig()); - NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetSecureSessionType() == kPeer2SessionType); - NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetPeerNodeId() == kPeer2NodeId); - NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetLastActivityTime() == 100_ms64); - peerCATs = optionalSession.Value()->AsSecureSession()->GetPeerCATs(); + NL_TEST_ASSERT(inSuite, sessions[1].HasValue()); + NL_TEST_ASSERT(inSuite, sessions[1].Value()->AsSecureSession()->GetSecureSessionType() == kPeer2SessionType); + NL_TEST_ASSERT(inSuite, sessions[1].Value()->AsSecureSession()->GetPeerNodeId() == kPeer2NodeId); + NL_TEST_ASSERT(inSuite, sessions[1].Value()->AsSecureSession()->GetLastActivityTime() == 100_ms64); + peerCATs = sessions[1].Value()->AsSecureSession()->GetPeerCATs(); NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kPeer2CATs, sizeof(CATValues)) == 0); - // Insufficient space for new connections. Object is max size 2 - optionalSession = connections.CreateNewSecureSessionForTest(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, +#if !CHIP_SYSTEM_CONFIG_POOL_USE_HEAP + // If not using a heap, we can fill the SecureSessionTable + for (uint16_t i = 2; i < CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE; ++i) + { + sessions[i] = connections.CreateNewSecureSessionForTest(kPeer2SessionType, i + 6, kPeer2NodeId, kPeer2CATs, 3, + 0 /* fabricIndex */, GetLocalMRPConfig()); + NL_TEST_ASSERT(inSuite, sessions[i].HasValue()); + } + + // Insufficient space for new connections. + auto optionalSession = connections.CreateNewSecureSessionForTest(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, 0 /* fabricIndex */, GetLocalMRPConfig()); NL_TEST_ASSERT(inSuite, !optionalSession.HasValue()); +#endif System::Clock::Internal::SetSystemClockForTesting(realClock); } void TestFindByKeyId(nlTestSuite * inSuite, void * inContext) { - SecureSessionTable<2> connections; + SessionManager sessionManager; + SecureSessionTable connections; System::Clock::Internal::MockClock clock; System::Clock::ClockBase * realClock = &System::SystemClock(); System::Clock::Internal::SetSystemClockForTesting(&clock); @@ -128,106 +144,6 @@ struct ExpiredCallInfo PeerAddress lastCallPeerAddress = PeerAddress::Uninitialized(); }; -void TestExpireConnections(nlTestSuite * inSuite, void * inContext) -{ - ExpiredCallInfo callInfo; - SecureSessionTable<2> connections; - - System::Clock::Internal::MockClock clock; - System::Clock::ClockBase * realClock = &System::SystemClock(); - System::Clock::Internal::SetSystemClockForTesting(&clock); - - clock.SetMonotonic(100_ms64); - - // Node ID 1, peer key 1, local key 2 - auto optionalSession = connections.CreateNewSecureSessionForTest(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, - 0 /* fabricIndex */, GetLocalMRPConfig()); - NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer1Addr); - - clock.SetMonotonic(200_ms64); - // Node ID 2, peer key 3, local key 4 - optionalSession = connections.CreateNewSecureSessionForTest(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, - 0 /* fabricIndex */, GetLocalMRPConfig()); - NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer2Addr); - - // cannot add before expiry - clock.SetMonotonic(300_ms64); - optionalSession = connections.CreateNewSecureSessionForTest(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, - 0 /* fabricIndex */, GetLocalMRPConfig()); - NL_TEST_ASSERT(inSuite, !optionalSession.HasValue()); - - // at time 300, this expires ip addr 1 - connections.ExpireInactiveSessions(150_ms64, [&callInfo](const SecureSession & state) { - callInfo.callCount++; - callInfo.lastCallNodeId = state.GetPeerNodeId(); - callInfo.lastCallPeerAddress = state.GetPeerAddress(); - }); - NL_TEST_ASSERT(inSuite, callInfo.callCount == 1); - NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer1NodeId); - NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer1Addr); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2).HasValue()); - - // now that the connections were expired, we can add peer3 - clock.SetMonotonic(300_ms64); - // Node ID 3, peer key 5, local key 6 - optionalSession = connections.CreateNewSecureSessionForTest(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, - 0 /* fabricIndex */, GetLocalMRPConfig()); - NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer3Addr); - - clock.SetMonotonic(400_ms64); - optionalSession = connections.FindSecureSessionByLocalKey(4); - NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - - optionalSession.Value()->AsSecureSession()->MarkActive(); - NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetLastActivityTime() == clock.GetMonotonicTimestamp()); - - // At this time: - // Peer 3 active at time 300 - // Peer 2 active at time 400 - - clock.SetMonotonic(500_ms64); - callInfo.callCount = 0; - connections.ExpireInactiveSessions(150_ms64, [&callInfo](const SecureSession & state) { - callInfo.callCount++; - callInfo.lastCallNodeId = state.GetPeerNodeId(); - callInfo.lastCallPeerAddress = state.GetPeerAddress(); - }); - - // peer 2 stays active - NL_TEST_ASSERT(inSuite, callInfo.callCount == 1); - NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer3NodeId); - NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer3Addr); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2).HasValue()); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4).HasValue()); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6).HasValue()); - - // Node ID 1, peer key 1, local key 2 - optionalSession = connections.CreateNewSecureSessionForTest(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, - 0 /* fabricIndex */, GetLocalMRPConfig()); - NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2).HasValue()); - NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4).HasValue()); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6).HasValue()); - - // peer 1 and 2 are active - clock.SetMonotonic(1000_ms64); - callInfo.callCount = 0; - connections.ExpireInactiveSessions(100_ms64, [&callInfo](const SecureSession & state) { - callInfo.callCount++; - callInfo.lastCallNodeId = state.GetPeerNodeId(); - callInfo.lastCallPeerAddress = state.GetPeerAddress(); - }); - NL_TEST_ASSERT(inSuite, callInfo.callCount == 2); // everything expired - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2).HasValue()); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(4).HasValue()); - NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6).HasValue()); - - System::Clock::Internal::SetSystemClockForTesting(realClock); -} - } // namespace // clang-format off @@ -235,7 +151,6 @@ static const nlTest sTests[] = { NL_TEST_DEF("BasicFunctionality", TestBasicFunctionality), NL_TEST_DEF("FindByKeyId", TestFindByKeyId), - NL_TEST_DEF("ExpireConnections", TestExpireConnections), NL_TEST_SENTINEL() }; // clang-format on diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index 9f5cb44ea0bd80..c4bc7a595f06e8 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -702,7 +702,7 @@ static void RandomSessionIdAllocatorOffset(nlTestSuite * inSuite, SessionManager const int bound = rand() % max; for (int i = 0; i < bound; ++i) { - auto handle = sessionManager.AllocateSession(); + auto handle = sessionManager.AllocateSession(Transport::SecureSession::Type::kPASE); NL_TEST_ASSERT(inSuite, handle.HasValue()); sessionManager.ExpirePairing(handle.Value()); } @@ -716,7 +716,7 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) // Allocate a session. uint16_t sessionId1; { - auto handle = sessionManager.AllocateSession(); + auto handle = sessionManager.AllocateSession(Transport::SecureSession::Type::kPASE); NL_TEST_ASSERT(inSuite, handle.HasValue()); SessionHolderWithDelegate session(handle.Value(), callback); sessionId1 = session->AsSecureSession()->GetLocalSessionId(); @@ -727,7 +727,7 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) auto prevSessionId = sessionId1; for (uint32_t i = 0; i < 10; ++i) { - auto handle = sessionManager.AllocateSession(); + auto handle = sessionManager.AllocateSession(Transport::SecureSession::Type::kPASE); if (!handle.HasValue()) { break; @@ -748,7 +748,7 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) // sessions are immediately freed. for (uint32_t i = 0; i < UINT16_MAX + 10; ++i) { - auto handle = sessionManager.AllocateSession(); + auto handle = sessionManager.AllocateSession(Transport::SecureSession::Type::kPASE); NL_TEST_ASSERT(inSuite, handle.HasValue()); auto sessionId = handle.Value()->AsSecureSession()->GetLocalSessionId(); NL_TEST_ASSERT(inSuite, sessionId - prevSessionId == 1 || (sessionId == 1 && prevSessionId == 65535)); @@ -769,7 +769,7 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) for (size_t h = 0; h < numHandles; ++h) { constexpr int maxOffset = 5000; - handles[h] = sessionManager.AllocateSession(); + handles[h] = sessionManager.AllocateSession(Transport::SecureSession::Type::kPASE); NL_TEST_ASSERT(inSuite, handles[h].HasValue()); sessionIds[h] = handles[h].Value()->AsSecureSession()->GetLocalSessionId(); RandomSessionIdAllocatorOffset(inSuite, sessionManager, maxOffset); @@ -785,7 +785,7 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) // these collide either. for (int j = 0; j < UINT16_MAX; ++j) { - auto handle = sessionManager.AllocateSession(); + auto handle = sessionManager.AllocateSession(Transport::SecureSession::Type::kPASE); NL_TEST_ASSERT(inSuite, handle.HasValue()); auto potentialCollision = handle.Value()->AsSecureSession()->GetLocalSessionId(); for (size_t h = 0; h < numHandles; ++h)