diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp index 8b64654e94c84c..181f9d71e8df26 100644 --- a/src/app/tests/TestCommandInteraction.cpp +++ b/src/app/tests/TestCommandInteraction.cpp @@ -322,8 +322,9 @@ void TestCommandInteraction::TestCommandSenderWithSendCommand(nlTestSuite * apSu System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize); AddCommandDataIB(apSuite, apContext, &commandSender, false); - err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, Optional::Missing()); - NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_NOT_CONNECTED); + err = + commandSender.SendCommandRequest(0 /* nodeid */, 0 /* fabricindex */, Optional(ctx.GetSessionBobToAlice())); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); GenerateReceivedCommand(apSuite, apContext, buf, true /*aNeedCommandData*/); err = commandSender.ProcessCommandMessage(std::move(buf), Command::CommandRoleId::SenderId); diff --git a/src/transport/SecureSessionTable.h b/src/transport/SecureSessionTable.h index 4bc1eea9042a4d..00b126b5fb5765 100644 --- a/src/transport/SecureSessionTable.h +++ b/src/transport/SecureSessionTable.h @@ -30,73 +30,31 @@ namespace Transport { constexpr const uint16_t kAnyKeyId = 0xffff; /** - * Handles a set of peer connection states. + * Handles a set of sessions. * * Intended for: - * - handle connection active time and expiration - * - allocate and free space for connection states. + * - handle session active time and expiration + * - allocate and free space for sessions. */ -template +template class SecureSessionTable { public: /** - * Allocates a new peer connection state state object out of the internal resource pool. + * Allocates a new secure session out of the internal resource pool. * - * @param address represents the connection state address - * @param state [out] will contain the connection state if one was available. May be null if no return value is desired. - * - * @note the newly created state will have an 'active' time set based on the current time source. - * - * @returns CHIP_NO_ERROR if state could be initialized. May fail if maximum connection count - * has been reached (with CHIP_ERROR_NO_MEMORY). - */ - CHECK_RETURN_VALUE - CHIP_ERROR CreateNewPeerConnectionState(const PeerAddress & address, SecureSession ** state) - { - CHIP_ERROR err = CHIP_ERROR_NO_MEMORY; - - if (state) - { - *state = nullptr; - } - - for (size_t i = 0; i < kMaxConnectionCount; i++) - { - if (!mStates[i].IsInitialized()) - { - mStates[i] = SecureSession(address); - mStates[i].SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); - - if (state) - { - *state = &mStates[i]; - } - - err = CHIP_NO_ERROR; - break; - } - } - - return err; - } - - /** - * Allocates a new peer connection state state object out of the internal resource pool. - * - * @param peerNode represents optional peer Node's ID + * @param peerNode represents peer Node's ID * @param peerSessionId represents the encryption key ID assigned by peer node * @param localSessionId represents the encryption key ID assigned by local node - * @param state [out] will contain the connection state if one was available. May be null if no return value is desired. + * @param state [out] will contain the session if one was available. May be null if no return value is desired. * * @note the newly created state will have an 'active' time set based on the current time source. * - * @returns CHIP_NO_ERROR if state could be initialized. May fail if maximum connection count + * @returns CHIP_NO_ERROR if state could be initialized. May fail if maximum session count * has been reached (with CHIP_ERROR_NO_MEMORY). */ CHECK_RETURN_VALUE - CHIP_ERROR CreateNewPeerConnectionState(const Optional & peerNode, uint16_t peerSessionId, uint16_t localSessionId, - SecureSession ** state) + CHIP_ERROR CreateNewSecureSession(NodeId peerNode, uint16_t peerSessionId, uint16_t localSessionId, SecureSession ** state) { CHIP_ERROR err = CHIP_ERROR_NO_MEMORY; @@ -105,20 +63,16 @@ class SecureSessionTable *state = nullptr; } - for (size_t i = 0; i < kMaxConnectionCount; i++) + for (size_t i = 0; i < kMaxSessionCount; i++) { if (!mStates[i].IsInitialized()) { mStates[i] = SecureSession(); + mStates[i].SetPeerNodeId(peerNode); mStates[i].SetPeerSessionId(peerSessionId); mStates[i].SetLocalSessionId(localSessionId); mStates[i].SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); - if (peerNode.ValueOr(kUndefinedNodeId) != kUndefinedNodeId) - { - mStates[i].SetPeerNodeId(peerNode.Value()); - } - if (state) { *state = &mStates[i]; @@ -133,56 +87,25 @@ class SecureSessionTable } /** - * Get a peer connection state given a Peer address. + * Get a secure session given a Node Id. * - * @param address is the connection to find (based on address) + * @param nodeId is the session to find (based on nodeId). * @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start. * * @return the state found, nullptr if not found */ CHECK_RETURN_VALUE - SecureSession * FindPeerConnectionState(const PeerAddress & address, SecureSession * begin) + SecureSession * FindSecureSession(NodeId nodeId, SecureSession * begin) { SecureSession * state = nullptr; SecureSession * iter = &mStates[0]; - if (begin >= iter && begin < &mStates[kMaxConnectionCount]) + if (begin >= iter && begin < &mStates[kMaxSessionCount]) { iter = begin + 1; } - for (; iter < &mStates[kMaxConnectionCount]; iter++) - { - if (iter->GetPeerAddress() == address) - { - state = iter; - break; - } - } - return state; - } - - /** - * Get a peer connection state given a Node Id. - * - * @param nodeId is the connection to find (based on nodeId). Note that initial connections - * do not have a node id set. Use this if you know the node id should be set. - * @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start. - * - * @return the state found, nullptr if not found - */ - CHECK_RETURN_VALUE - SecureSession * FindPeerConnectionState(NodeId nodeId, SecureSession * begin) - { - SecureSession * state = nullptr; - SecureSession * iter = &mStates[0]; - - if (begin >= iter && begin < &mStates[kMaxConnectionCount]) - { - iter = begin + 1; - } - - for (; iter < &mStates[kMaxConnectionCount]; iter++) + for (; iter < &mStates[kMaxSessionCount]; iter++) { if (!iter->IsInitialized()) { @@ -198,104 +121,25 @@ class SecureSessionTable } /** - * Get a peer connection state given a Node Id and Peer's Encryption Key Id. - * - * @param nodeId is the connection to find (based on nodeId). Note that initial connections - * do not have a node id set. Use this if you know the node id should be set. - * @param peerSessionId Encryption key ID used by the peer node. - * @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start. + * Get a secure session given a Node Id and Peer's Encryption Key Id. * - * @return the state found, nullptr if not found - */ - CHECK_RETURN_VALUE - SecureSession * FindPeerConnectionState(Optional nodeId, uint16_t peerSessionId, SecureSession * begin) - { - SecureSession * state = nullptr; - SecureSession * iter = &mStates[0]; - - if (begin >= iter && begin < &mStates[kMaxConnectionCount]) - { - iter = begin + 1; - } - - for (; iter < &mStates[kMaxConnectionCount]; iter++) - { - if (!iter->IsInitialized()) - { - continue; - } - if (peerSessionId == kAnyKeyId || iter->GetPeerSessionId() == peerSessionId) - { - if (nodeId.ValueOr(kUndefinedNodeId) == kUndefinedNodeId || iter->GetPeerNodeId() == kUndefinedNodeId || - iter->GetPeerNodeId() == nodeId.Value()) - { - state = iter; - break; - } - } - } - return state; - } - - /** - * Get a peer connection state given the local Encryption Key Id. - * - * @param keyId Encryption key ID assigned by the local node. - * @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start. - * - * @return the state found, nullptr if not found - */ - CHECK_RETURN_VALUE - SecureSession * FindPeerConnectionState(uint16_t keyId, SecureSession * begin) - { - SecureSession * state = nullptr; - SecureSession * iter = &mStates[0]; - - VerifyOrDie(begin == nullptr || (begin >= iter && begin < &mStates[kMaxConnectionCount])); - - if (begin != nullptr) - { - iter = begin + 1; - } - - for (; iter < &mStates[kMaxConnectionCount]; iter++) - { - if (!iter->IsInitialized()) - { - continue; - } - - if (iter->GetLocalSessionId() == keyId) - { - state = iter; - break; - } - } - return state; - } - - /** - * Get a peer connection state given a Node Id and Peer's Encryption Key Id. - * - * @param nodeId is the connection to find (based on peer nodeId). Note that initial connections - * do not have a node id set. Use this if you know the node id should be set. * @param localSessionId Encryption key ID used by the local node. * @param begin If a member of the pool, will start search from the next item. Can be nullptr to search from start. * * @return the state found, nullptr if not found */ CHECK_RETURN_VALUE - SecureSession * FindPeerConnectionStateByLocalKey(Optional nodeId, uint16_t localSessionId, SecureSession * begin) + SecureSession * FindSecureSessionByLocalKey(uint16_t localSessionId, SecureSession * begin) { SecureSession * state = nullptr; SecureSession * iter = &mStates[0]; - if (begin >= iter && begin < &mStates[kMaxConnectionCount]) + if (begin >= iter && begin < &mStates[kMaxSessionCount]) { iter = begin + 1; } - for (; iter < &mStates[kMaxConnectionCount]; iter++) + for (; iter < &mStates[kMaxSessionCount]; iter++) { if (!iter->IsInitialized()) { @@ -303,26 +147,22 @@ class SecureSessionTable } if (iter->GetLocalSessionId() == localSessionId) { - if (nodeId.ValueOr(kUndefinedNodeId) == kUndefinedNodeId || iter->GetPeerNodeId() == kUndefinedNodeId || - iter->GetPeerNodeId() == nodeId.Value()) - { - state = iter; - break; - } + state = iter; + break; } } return state; } /** - * Get the first peer connection state that matches the given fabric index. + * Get the first session that matches the given fabric index. * * @param fabric The fabric index to match * - * @return the state found, nullptr if not found + * @return the session found, nullptr if not found */ CHECK_RETURN_VALUE - SecureSession * FindPeerConnectionStateByFabric(FabricIndex fabric) + SecureSession * FindSecureSessionByFabric(FabricIndex fabric) { for (auto & state : mStates) { @@ -338,51 +178,51 @@ class SecureSessionTable return nullptr; } - /// Convenience method to mark a peer connection state as active - void MarkConnectionActive(SecureSession * state) { state->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); } + /// Convenience method to mark a session as active + void MarkSessionActive(SecureSession * state) { state->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); } - /// Convenience method to expired a peer connection state and fired the related callback + /// Convenience method to expired a session and fired the related callback template - void MarkConnectionExpired(SecureSession * state, Callback callback) + void MarkSessionExpired(SecureSession * state, Callback callback) { callback(*state); *state = SecureSession(PeerAddress::Uninitialized()); } /** - * Iterates through all active connections and expires any connection with an idle time + * Iterates through all active sessions and expires any sessions with an idle time * larger than the given amount. * - * Expiring a connection involves callback execution and then clearing the internal state. + * Expiring a session involves callback execution and then clearing the internal state. */ template - void ExpireInactiveConnections(uint64_t maxIdleTimeMs, Callback callback) + void ExpireInactiveSessions(uint64_t maxIdleTimeMs, Callback callback) { const uint64_t currentTime = mTimeSource.GetCurrentMonotonicTimeMs(); - for (size_t i = 0; i < kMaxConnectionCount; i++) + for (size_t i = 0; i < kMaxSessionCount; i++) { - if (!mStates[i].GetPeerAddress().IsInitialized()) + if (!mStates[i].IsInitialized()) { - continue; // not an active connection + continue; // not an active session } - uint64_t connectionActiveTime = mStates[i].GetLastActivityTimeMs(); - if (connectionActiveTime + maxIdleTimeMs >= currentTime) + uint64_t sessionActiveTime = mStates[i].GetLastActivityTimeMs(); + if (sessionActiveTime + maxIdleTimeMs >= currentTime) { continue; // not expired } - MarkConnectionExpired(&mStates[i], callback); + MarkSessionExpired(&mStates[i], callback); } } - /// Allows access to the underlying time source used for keeping track of connection active time + /// Allows access to the underlying time source used for keeping track of session active time Time::TimeSource & GetTimeSource() { return mTimeSource; } private: Time::TimeSource mTimeSource; - SecureSession mStates[kMaxConnectionCount]; + SecureSession mStates[kMaxSessionCount]; }; } // namespace Transport diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 5881002bf3b1fd..fb6508c0c30d32 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -177,7 +177,7 @@ CHIP_ERROR SessionManager::SendPreparedMessage(SessionHandle session, const Encr } // This marks any connection where we send data to as 'active' - mPeerConnections.MarkConnectionActive(state); + mPeerConnections.MarkSessionActive(state); destination = &state->GetPeerAddress(); @@ -220,25 +220,25 @@ void SessionManager::ExpirePairing(SessionHandle session) SecureSession * state = GetSecureSession(session); if (state != nullptr) { - mPeerConnections.MarkConnectionExpired( - state, [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); + mPeerConnections.MarkSessionExpired(state, + [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); } } void SessionManager::ExpireAllPairings(NodeId peerNodeId, FabricIndex fabric) { - SecureSession * state = mPeerConnections.FindPeerConnectionState(peerNodeId, nullptr); + SecureSession * state = mPeerConnections.FindSecureSession(peerNodeId, nullptr); while (state != nullptr) { if (fabric == state->GetFabricIndex()) { - mPeerConnections.MarkConnectionExpired( + mPeerConnections.MarkSessionExpired( state, [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); - state = mPeerConnections.FindPeerConnectionState(peerNodeId, nullptr); + state = mPeerConnections.FindSecureSession(peerNodeId, nullptr); } else { - state = mPeerConnections.FindPeerConnectionState(peerNodeId, state); + state = mPeerConnections.FindSecureSession(peerNodeId, state); } } } @@ -246,12 +246,12 @@ void SessionManager::ExpireAllPairings(NodeId peerNodeId, FabricIndex fabric) void SessionManager::ExpireAllPairingsForFabric(FabricIndex fabric) { ChipLogDetail(Inet, "Expiring all connections for fabric %d!!", fabric); - SecureSession * state = mPeerConnections.FindPeerConnectionStateByFabric(fabric); + SecureSession * state = mPeerConnections.FindSecureSessionByFabric(fabric); while (state != nullptr) { - mPeerConnections.MarkConnectionExpired( - state, [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); - state = mPeerConnections.FindPeerConnectionStateByFabric(fabric); + mPeerConnections.MarkSessionExpired(state, + [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); + state = mPeerConnections.FindSecureSessionByFabric(fabric); } } @@ -260,21 +260,19 @@ CHIP_ERROR SessionManager::NewPairing(const Optional & p { uint16_t peerSessionId = pairing->GetPeerSessionId(); uint16_t localSessionId = pairing->GetLocalSessionId(); - SecureSession * state = - mPeerConnections.FindPeerConnectionStateByLocalKey(Optional::Value(peerNodeId), localSessionId, nullptr); + SecureSession * state = mPeerConnections.FindSecureSessionByLocalKey(localSessionId, nullptr); // Find any existing connection with the same local key ID if (state) { - mPeerConnections.MarkConnectionExpired( - state, [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); + mPeerConnections.MarkSessionExpired(state, + [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); } ChipLogDetail(Inet, "New secure session created for device 0x" ChipLogFormatX64 ", key %d!!", ChipLogValueX64(peerNodeId), peerSessionId); state = nullptr; - ReturnErrorOnFailure( - mPeerConnections.CreateNewPeerConnectionState(Optional::Value(peerNodeId), peerSessionId, localSessionId, &state)); + ReturnErrorOnFailure(mPeerConnections.CreateNewSecureSession(peerNodeId, peerSessionId, localSessionId, &state)); ReturnErrorCodeIf(state == nullptr, CHIP_ERROR_NO_MEMORY); state->SetFabricIndex(fabric); @@ -392,7 +390,7 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea { CHIP_ERROR err = CHIP_NO_ERROR; - SecureSession * state = mPeerConnections.FindPeerConnectionState(packetHeader.GetSessionId(), nullptr); + SecureSession * state = mPeerConnections.FindSecureSessionByLocalKey(packetHeader.GetSessionId(), nullptr); PayloadHeader payloadHeader; @@ -454,7 +452,7 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea SuccessOrExit(err); } - mPeerConnections.MarkConnectionActive(state); + mPeerConnections.MarkSessionActive(state); if (isDuplicate == SessionManagerDelegate::DuplicateMessage::Yes && !payloadHeader.NeedsAck()) { @@ -576,7 +574,7 @@ void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param) #if CHIP_CONFIG_SESSION_REKEYING // TODO(#2279): session expiration is currently disabled until rekeying is supported // the #ifdef should be removed after that. - mgr->mPeerConnections.ExpireInactiveConnections( + mgr->mPeerConnections.ExpireInactiveSessions( CHIP_PEER_CONNECTION_TIMEOUT_MS, [this](const Transport::SecureSession & state1) { HandleConnectionExpired(state1); }); #endif mgr->ScheduleExpiryTimer(); // re-schedule the oneshot timer @@ -584,8 +582,14 @@ void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param) SecureSession * SessionManager::GetSecureSession(SessionHandle session) { - return mPeerConnections.FindPeerConnectionStateByLocalKey(Optional::Value(session.mPeerNodeId), - session.mLocalSessionId.ValueOr(0), nullptr); + if (session.mLocalSessionId.HasValue()) + { + return mPeerConnections.FindSecureSessionByLocalKey(session.mLocalSessionId.Value(), nullptr); + } + else + { + return nullptr; + } } } // namespace chip diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index 70b46f26bf56b2..eb063494e0dc05 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -78,7 +78,7 @@ class UnauthenticatedSession : public ReferenceCounted +template class UnauthenticatedSessionTable { public: @@ -111,14 +111,14 @@ class UnauthenticatedSessionTable session->SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); } - /// Allows access to the underlying time source used for keeping track of connection active time + /// Allows access to the underlying time source used for keeping track of session active time Time::TimeSource & GetTimeSource() { return mTimeSource; } private: /** * Allocates a new session out of the internal resource pool. * - * @returns CHIP_NO_ERROR if new session created. May fail if maximum connection count has been reached (with + * @returns CHIP_NO_ERROR if new session created. May fail if maximum session count has been reached (with * CHIP_ERROR_NO_MEMORY). */ CHECK_RETURN_VALUE @@ -198,7 +198,7 @@ class UnauthenticatedSessionTable } Time::TimeSource mTimeSource; - BitMapObjectPool mEntries; + BitMapObjectPool mEntries; }; } // namespace Transport diff --git a/src/transport/tests/TestPeerConnections.cpp b/src/transport/tests/TestPeerConnections.cpp index 60b94eba794c8e..55ee2b2665b5d3 100644 --- a/src/transport/tests/TestPeerConnections.cpp +++ b/src/transport/tests/TestPeerConnections.cpp @@ -58,92 +58,61 @@ void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext) SecureSessionTable<2, Time::Source::kTest> connections; connections.GetTimeSource().SetCurrentMonotonicTimeMs(100); - err = connections.CreateNewPeerConnectionState(kPeer1Addr, nullptr); + // Node ID 1, peer key 1, local key 2 + err = connections.CreateNewSecureSession(kPeer1NodeId, 1, 2, nullptr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - err = connections.CreateNewPeerConnectionState(kPeer2Addr, &statePtr); + // Node ID 2, peer key 3, local key 4 + err = connections.CreateNewSecureSession(kPeer2NodeId, 3, 4, &statePtr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, statePtr != nullptr); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerAddress() == kPeer2Addr); + NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer2NodeId); NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTimeMs() == 100); // Insufficient space for new connections. Object is max size 2 - err = connections.CreateNewPeerConnectionState(kPeer3Addr, &statePtr); + err = connections.CreateNewSecureSession(kPeer3NodeId, 5, 6, &statePtr); NL_TEST_ASSERT(inSuite, err != CHIP_NO_ERROR); } -void TestFindByAddress(nlTestSuite * inSuite, void * inContext) -{ - CHIP_ERROR err; - SecureSession * statePtr; - SecureSessionTable<3, Time::Source::kTest> connections; - - SecureSession * state1 = nullptr; - SecureSession * state2 = nullptr; - SecureSession * state3 = nullptr; - - err = connections.CreateNewPeerConnectionState(kPeer1Addr, &state1); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - err = connections.CreateNewPeerConnectionState(kPeer1Addr, &state2); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - err = connections.CreateNewPeerConnectionState(kPeer2Addr, &state3); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - NL_TEST_ASSERT(inSuite, state1 != state2); - NL_TEST_ASSERT(inSuite, state1 != state3); - NL_TEST_ASSERT(inSuite, state2 != state3); - - NL_TEST_ASSERT(inSuite, statePtr = connections.FindPeerConnectionState(kPeer1Addr, nullptr)); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerAddress() == kPeer1Addr); - - NL_TEST_ASSERT(inSuite, statePtr = connections.FindPeerConnectionState(kPeer1Addr, statePtr)); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerAddress() == kPeer1Addr); - - NL_TEST_ASSERT(inSuite, (statePtr = connections.FindPeerConnectionState(kPeer1Addr, statePtr)) == nullptr); - - NL_TEST_ASSERT(inSuite, statePtr = connections.FindPeerConnectionState(kPeer2Addr, nullptr)); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerAddress() == kPeer2Addr); - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(kPeer3Addr, nullptr)); -} - void TestFindByNodeId(nlTestSuite * inSuite, void * inContext) { CHIP_ERROR err; SecureSession * statePtr; SecureSessionTable<3, Time::Source::kTest> connections; - err = connections.CreateNewPeerConnectionState(kPeer1Addr, &statePtr); + // Node ID 1, peer key 1, local key 2 + err = connections.CreateNewSecureSession(kPeer1NodeId, 1, 2, &statePtr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - statePtr->SetPeerNodeId(kPeer1NodeId); + statePtr->SetPeerAddress(kPeer1Addr); - err = connections.CreateNewPeerConnectionState(kPeer2Addr, &statePtr); + // Node ID 2, peer key 3, local key 4 + err = connections.CreateNewSecureSession(kPeer2NodeId, 3, 4, &statePtr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - statePtr->SetPeerNodeId(kPeer2NodeId); + statePtr->SetPeerAddress(kPeer2Addr); - err = connections.CreateNewPeerConnectionState(kPeer2Addr, &statePtr); + // Same Node ID 1, peer key 5, local key 6 + err = connections.CreateNewSecureSession(kPeer1NodeId, 5, 6, &statePtr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - statePtr->SetPeerNodeId(kPeer1NodeId); + statePtr->SetPeerAddress(kPeer3Addr); - NL_TEST_ASSERT(inSuite, statePtr = connections.FindPeerConnectionState(kPeer1NodeId, nullptr)); + NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSession(kPeer1NodeId, nullptr)); char buf[100]; statePtr->GetPeerAddress().ToString(buf); NL_TEST_ASSERT(inSuite, statePtr->GetPeerAddress() == kPeer1Addr); NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer1NodeId); - NL_TEST_ASSERT(inSuite, statePtr = connections.FindPeerConnectionState(kPeer1NodeId, statePtr)); + NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSession(kPeer1NodeId, statePtr)); statePtr->GetPeerAddress().ToString(buf); - NL_TEST_ASSERT(inSuite, statePtr->GetPeerAddress() == kPeer2Addr); + NL_TEST_ASSERT(inSuite, statePtr->GetPeerAddress() == kPeer3Addr); NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer1NodeId); - NL_TEST_ASSERT(inSuite, (statePtr = connections.FindPeerConnectionState(kPeer1NodeId, statePtr)) == nullptr); + NL_TEST_ASSERT(inSuite, (statePtr = connections.FindSecureSession(kPeer1NodeId, statePtr)) == nullptr); - NL_TEST_ASSERT(inSuite, statePtr = connections.FindPeerConnectionState(kPeer2NodeId, nullptr)); + NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSession(kPeer2NodeId, nullptr)); NL_TEST_ASSERT(inSuite, statePtr->GetPeerAddress() == kPeer2Addr); NL_TEST_ASSERT(inSuite, statePtr->GetPeerNodeId() == kPeer2NodeId); - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(kPeer3NodeId, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSession(kPeer3NodeId, nullptr)); } void TestFindByKeyId(nlTestSuite * inSuite, void * inContext) @@ -152,44 +121,19 @@ void TestFindByKeyId(nlTestSuite * inSuite, void * inContext) SecureSession * statePtr; SecureSessionTable<2, Time::Source::kTest> connections; - // No Node ID, peer key 1, local key 2 - err = connections.CreateNewPeerConnectionState(Optional::Missing(), 1, 2, &statePtr); + // Node ID 1, peer key 1, local key 2 + err = connections.CreateNewSecureSession(kPeer1NodeId, 1, 2, &statePtr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - // Lookup using no node, and peer key - NL_TEST_ASSERT(inSuite, connections.FindPeerConnectionState(Optional::Missing(), 1, nullptr)); - // Lookup using no node, and local key - NL_TEST_ASSERT(inSuite, connections.FindPeerConnectionStateByLocalKey(Optional::Missing(), 2, nullptr)); - - // Lookup using no node, and incorrect peer key - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(Optional::Missing(), 2, nullptr)); - - // Lookup using no node, and incorrect local key - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionStateByLocalKey(Optional::Missing(), 1, nullptr)); - - // Lookup using a node ID, and peer key - NL_TEST_ASSERT(inSuite, connections.FindPeerConnectionState(Optional::Value(kPeer1NodeId), 1, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(1, nullptr)); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2, nullptr)); - // Lookup using a node ID, and local key - NL_TEST_ASSERT(inSuite, connections.FindPeerConnectionStateByLocalKey(Optional::Value(kPeer1NodeId), 2, nullptr)); - - // Some Node ID, peer key 3, local key 4 - err = connections.CreateNewPeerConnectionState(Optional::Value(kPeer1NodeId), 3, 4, &statePtr); + // Node ID 2, peer key 3, local key 4 + err = connections.CreateNewSecureSession(kPeer2NodeId, 3, 4, &statePtr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - // Lookup using correct node (or no node), and correct keys - NL_TEST_ASSERT(inSuite, connections.FindPeerConnectionState(Optional::Value(kPeer1NodeId), 3, nullptr)); - NL_TEST_ASSERT(inSuite, connections.FindPeerConnectionStateByLocalKey(Optional::Value(kPeer1NodeId), 4, nullptr)); - NL_TEST_ASSERT(inSuite, connections.FindPeerConnectionState(Optional::Missing(), 3, nullptr)); - NL_TEST_ASSERT(inSuite, connections.FindPeerConnectionStateByLocalKey(Optional::Missing(), 4, nullptr)); - - // Lookup using incorrect keys - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(Optional::Value(kPeer1NodeId), 4, nullptr)); - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionStateByLocalKey(Optional::Value(kPeer1NodeId), 3, nullptr)); - - // Lookup using incorrect node, but correct keys - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(Optional::Value(kPeer2NodeId), 3, nullptr)); - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionStateByLocalKey(Optional::Value(kPeer2NodeId), 4, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(3, nullptr)); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4, nullptr)); } struct ExpiredCallInfo @@ -208,40 +152,44 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) connections.GetTimeSource().SetCurrentMonotonicTimeMs(100); - err = connections.CreateNewPeerConnectionState(kPeer1Addr, nullptr); + // Node ID 1, peer key 1, local key 2 + err = connections.CreateNewSecureSession(kPeer1NodeId, 1, 2, &statePtr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + statePtr->SetPeerAddress(kPeer1Addr); connections.GetTimeSource().SetCurrentMonotonicTimeMs(200); - err = connections.CreateNewPeerConnectionState(kPeer2Addr, &statePtr); + // Node ID 2, peer key 3, local key 4 + err = connections.CreateNewSecureSession(kPeer2NodeId, 3, 4, &statePtr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - statePtr->SetPeerNodeId(kPeer2NodeId); + statePtr->SetPeerAddress(kPeer2Addr); // cannot add before expiry connections.GetTimeSource().SetCurrentMonotonicTimeMs(300); - err = connections.CreateNewPeerConnectionState(kPeer3Addr, &statePtr); + err = connections.CreateNewSecureSession(kPeer3NodeId, 5, 6, &statePtr); NL_TEST_ASSERT(inSuite, err != CHIP_NO_ERROR); // at time 300, this expires ip addr 1 - connections.ExpireInactiveConnections(150, [&callInfo](const SecureSession & state) { + connections.ExpireInactiveSessions(150, [&callInfo](const SecureSession & state) { callInfo.callCount++; callInfo.lastCallNodeId = state.GetPeerNodeId(); callInfo.lastCallPeerAddress = state.GetPeerAddress(); }); NL_TEST_ASSERT(inSuite, callInfo.callCount == 1); - NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kUndefinedNodeId); + NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer1NodeId); NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer1Addr); - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(kPeer1NodeId, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2, nullptr)); // now that the connections were expired, we can add peer3 connections.GetTimeSource().SetCurrentMonotonicTimeMs(300); - err = connections.CreateNewPeerConnectionState(kPeer3Addr, &statePtr); + // Node ID 3, peer key 5, local key 6 + err = connections.CreateNewSecureSession(kPeer3NodeId, 5, 6, &statePtr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - statePtr->SetPeerNodeId(kPeer3NodeId); + statePtr->SetPeerAddress(kPeer3Addr); connections.GetTimeSource().SetCurrentMonotonicTimeMs(400); - NL_TEST_ASSERT(inSuite, statePtr = connections.FindPeerConnectionState(kPeer2NodeId, nullptr)); + NL_TEST_ASSERT(inSuite, statePtr = connections.FindSecureSessionByLocalKey(4, nullptr)); - connections.MarkConnectionActive(statePtr); + connections.MarkSessionActive(statePtr); NL_TEST_ASSERT(inSuite, statePtr->GetLastActivityTimeMs() == connections.GetTimeSource().GetCurrentMonotonicTimeMs()); // At this time: @@ -250,7 +198,7 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) connections.GetTimeSource().SetCurrentMonotonicTimeMs(500); callInfo.callCount = 0; - connections.ExpireInactiveConnections(150, [&callInfo](const SecureSession & state) { + connections.ExpireInactiveSessions(150, [&callInfo](const SecureSession & state) { callInfo.callCount++; callInfo.lastCallNodeId = state.GetPeerNodeId(); callInfo.lastCallPeerAddress = state.GetPeerAddress(); @@ -260,28 +208,29 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, callInfo.callCount == 1); NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer3NodeId); NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer3Addr); - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(kPeer1Addr, nullptr)); - NL_TEST_ASSERT(inSuite, connections.FindPeerConnectionState(kPeer2Addr, nullptr)); - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(kPeer3Addr, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2, nullptr)); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6, nullptr)); - err = connections.CreateNewPeerConnectionState(kPeer1Addr, nullptr); + // Node ID 1, peer key 1, local key 2 + err = connections.CreateNewSecureSession(kPeer1NodeId, 1, 2, &statePtr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, connections.FindPeerConnectionState(kPeer1Addr, nullptr)); - NL_TEST_ASSERT(inSuite, connections.FindPeerConnectionState(kPeer2Addr, nullptr)); - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(kPeer3Addr, nullptr)); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2, nullptr)); + NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6, nullptr)); // peer 1 and 2 are active connections.GetTimeSource().SetCurrentMonotonicTimeMs(1000); callInfo.callCount = 0; - connections.ExpireInactiveConnections(100, [&callInfo](const SecureSession & state) { + connections.ExpireInactiveSessions(100, [&callInfo](const SecureSession & state) { callInfo.callCount++; callInfo.lastCallNodeId = state.GetPeerNodeId(); callInfo.lastCallPeerAddress = state.GetPeerAddress(); }); NL_TEST_ASSERT(inSuite, callInfo.callCount == 2); // everything expired - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(kPeer1Addr, nullptr)); - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(kPeer2Addr, nullptr)); - NL_TEST_ASSERT(inSuite, !connections.FindPeerConnectionState(kPeer3Addr, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(4, nullptr)); + NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6, nullptr)); } } // namespace @@ -290,7 +239,6 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) static const nlTest sTests[] = { NL_TEST_DEF("BasicFunctionality", TestBasicFunctionality), - NL_TEST_DEF("FindByPeerAddress", TestFindByAddress), NL_TEST_DEF("FindByNodeId", TestFindByNodeId), NL_TEST_DEF("FindByKeyId", TestFindByKeyId), NL_TEST_DEF("ExpireConnections", TestExpireConnections),