diff --git a/src/app/server/RendezvousServer.cpp b/src/app/server/RendezvousServer.cpp index a1adb9f2a1f9a4..bedd9463b37fb7 100644 --- a/src/app/server/RendezvousServer.cpp +++ b/src/app/server/RendezvousServer.cpp @@ -111,7 +111,7 @@ CHIP_ERROR RendezvousServer::WaitForPairing(const RendezvousParameters & params, strlen(kSpake2pKeyExchangeSalt), mNextKeyId++, this)); } - ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr)); + ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(mExchangeManager->GetReliableMessageMgr(), transportMgr)); mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress()); return CHIP_NO_ERROR; diff --git a/src/controller/CHIPDevice.cpp b/src/controller/CHIPDevice.cpp index 785024c82dd600..a010515c0139ac 100644 --- a/src/controller/CHIPDevice.cpp +++ b/src/controller/CHIPDevice.cpp @@ -485,7 +485,8 @@ CHIP_ERROR Device::EstablishCASESession() Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(SecureSessionHandle(), &mCASESession); VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL); - ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager->GetTransportManager())); + ReturnErrorOnFailure( + mCASESession.MessageDispatch().Init(mExchangeMgr->GetReliableMessageMgr(), mSessionManager->GetTransportManager())); mCASESession.MessageDispatch().SetPeerAddress(mDeviceAddress); ReturnErrorOnFailure(mCASESession.EstablishSession(mDeviceAddress, mCredentials, mDeviceId, 0, exchange, this)); diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index f23e186f6fb2a9..e43bee247dd0a0 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -851,7 +851,7 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam mIsIPRendezvous = (params.GetPeerAddress().GetTransportType() != Transport::Type::kBle); - err = mPairingSession.MessageDispatch().Init(mTransportMgr); + err = mPairingSession.MessageDispatch().Init(mExchangeMgr->GetReliableMessageMgr(), mTransportMgr); SuccessOrExit(err); mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress()); diff --git a/src/messaging/ReliableMessageMgr.cpp b/src/messaging/ReliableMessageMgr.cpp index e5a8c2025839a1..ab84a91cb1031a 100644 --- a/src/messaging/ReliableMessageMgr.cpp +++ b/src/messaging/ReliableMessageMgr.cpp @@ -296,7 +296,8 @@ CHIP_ERROR ReliableMessageMgr::AddToRetransTable(ReliableMessageContext * rc, Re void ReliableMessageMgr::StartRetransmision(RetransTableEntry * entry) { - VerifyOrDie(entry != nullptr && entry->rc != nullptr); + VerifyOrReturn(entry != nullptr && entry->rc != nullptr, + ChipLogError(ExchangeManager, "StartRetransmission was called for invalid entry")); entry->nextRetransTimeTick = static_cast(entry->rc->GetInitialRetransmitTimeoutTick() + GetTickCounterFromTimeDelta(System::Timer::GetCurrentEpoch())); diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 32d1b855a28ead..e26d7dcfcd4080 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -128,6 +128,16 @@ class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessa return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(message)); } + CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, + EncryptedPacketBufferHandle * retainedMessage) const override + { + if (retainedMessage != nullptr && mRetainMessageOnSend) + { + *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain()); + } + return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(message)); + } + bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; } bool mRetainMessageOnSend = true; @@ -140,6 +150,12 @@ class MockSessionEstablishmentDelegate : public ExchangeDelegate System::PacketBufferHandle && buffer) override { IsOnMessageReceivedCalled = true; + ec->Close(); + if (mTestSuite != nullptr) + { + NL_TEST_ASSERT(mTestSuite, buffer->TotalLength() == sizeof(PAYLOAD)); + NL_TEST_ASSERT(mTestSuite, memcmp(buffer->Start(), PAYLOAD, buffer->TotalLength()) == 0); + } } void OnResponseTimeout(ExchangeContext * ec) override {} @@ -151,6 +167,7 @@ class MockSessionEstablishmentDelegate : public ExchangeDelegate bool IsOnMessageReceivedCalled = false; MockSessionEstablishmentExchangeDispatch mMessageDispatch; + nlTestSuite * mTestSuite = nullptr; }; void test_os_sleep_ms(uint64_t millisecs) @@ -464,6 +481,94 @@ void CheckResendApplicationMessageWithPeerExchange(nlTestSuite * inSuite, void * rm->ClearRetransTable(rc); } +void CheckResendSessionEstablishmentMessageWithPeerExchange(nlTestSuite * inSuite, void * inContext) +{ + // Making this static to reduce stack usage, as some platforms have limits on stack size. + static TestContext ctx; + + CHIP_ERROR err = ctx.Init(inSuite, &gTransportMgr); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + ctx.SetSourceNodeId(kAnyNodeId); + ctx.SetDestinationNodeId(kAnyNodeId); + ctx.SetLocalKeyId(0); + ctx.SetPeerKeyId(0); + ctx.SetAdminId(kUndefinedAdminId); + + ctx.GetInetLayer().SystemLayer()->Init(nullptr); + + chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD)); + NL_TEST_ASSERT(inSuite, !buffer.IsNull()); + + MockSessionEstablishmentDelegate mockReceiver; + err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + mockReceiver.mTestSuite = inSuite; + + MockSessionEstablishmentDelegate mockSender; + ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender); + NL_TEST_ASSERT(inSuite, exchange != nullptr); + + ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + ReliableMessageContext * rc = exchange->GetReliableMessageContext(); + NL_TEST_ASSERT(inSuite, rm != nullptr); + NL_TEST_ASSERT(inSuite, rc != nullptr); + + rc->SetConfig({ + 1, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL + 1, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); + + err = mockSender.mMessageDispatch.Init(rm); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + // Let's drop the initial message + gLoopback.mSendMessageCount = 0; + gLoopback.mNumMessagesToDrop = 1; + gLoopback.mDroppedMessageCount = 0; + + // Ensure the retransmit table is empty right now + NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0); + + err = exchange->SendMessage(Echo::MsgType::EchoRequest, std::move(buffer)); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + exchange->Close(); + + // Ensure the message was dropped, and was added to retransmit table + NL_TEST_ASSERT(inSuite, gLoopback.mNumMessagesToDrop == 0); + NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 1); + NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 1); + NL_TEST_ASSERT(inSuite, !mockReceiver.IsOnMessageReceivedCalled); + + // 1 tick is 64 ms, sleep 65 ms to trigger first re-transmit + test_os_sleep_ms(65); + ReliableMessageMgr::Timeout(&ctx.GetSystemLayer(), rm, CHIP_SYSTEM_NO_ERROR); + + // Ensure the retransmit message was not dropped, and is no longer in the retransmit table + NL_TEST_ASSERT(inSuite, gLoopback.mSendMessageCount >= 2); + NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 1); + NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0); + NL_TEST_ASSERT(inSuite, mockReceiver.IsOnMessageReceivedCalled); + + mockReceiver.mTestSuite = nullptr; + + err = ctx.GetExchangeManager().UnregisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + rm->ClearRetransTable(rc); + ctx.Shutdown(); + + // This test didn't use the global test context because the session establishment messages + // do not carry encryption key IDs (as the messages are not encrypted), or node IDs (as these + // are not assigned yet). A temporary context is created with default values for these + // parameters. + // Let's reset the state of transport manager so that other tests are not impacted + // as those could be using the global test context. + TestContext & inctx = *static_cast(inContext); + gTransportMgr.SetSecureSessionMgr(&inctx.GetSecureSessionManager()); +} + void CheckSendStandaloneAckMessage(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); @@ -498,6 +603,7 @@ const nlTest sTests[] = NL_TEST_DEF("Test ReliableMessageMgr::CheckCloseExchangeAndResendApplicationMessage", CheckCloseExchangeAndResendApplicationMessage), NL_TEST_DEF("Test ReliableMessageMgr::CheckFailedMessageRetainOnSend", CheckFailedMessageRetainOnSend), NL_TEST_DEF("Test ReliableMessageMgr::CheckResendApplicationMessageWithPeerExchange", CheckResendApplicationMessageWithPeerExchange), + NL_TEST_DEF("Test ReliableMessageMgr::CheckResendSessionEstablishmentMessageWithPeerExchange", CheckResendSessionEstablishmentMessageWithPeerExchange), NL_TEST_DEF("Test ReliableMessageMgr::CheckSendStandaloneAckMessage", CheckSendStandaloneAckMessage), NL_TEST_SENTINEL() diff --git a/src/protocols/secure_channel/CASEServer.cpp b/src/protocols/secure_channel/CASEServer.cpp index 25adca9cbf1794..0176bb250727e6 100644 --- a/src/protocols/secure_channel/CASEServer.cpp +++ b/src/protocols/secure_channel/CASEServer.cpp @@ -41,7 +41,7 @@ CHIP_ERROR CASEServer::ListenForSessionEstablishment(Messaging::ExchangeManager mAdmins = admins; mExchangeManager = exchangeManager; - ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr)); + ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(mExchangeManager->GetReliableMessageMgr(), transportMgr)); ExchangeDelegate * delegate = this; ReturnErrorOnFailure( diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index e0b38323325c2b..78fac185ec66f5 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -32,17 +32,29 @@ CHIP_ERROR SessionEstablishmentExchangeDispatch::SendMessageImpl(SecureSessionHa System::PacketBufferHandle && message, EncryptedPacketBufferHandle * retainedMessage) { + ReturnErrorCodeIf(mTransportMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); PacketHeader packetHeader; ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message)); - if (mTransportMgr != nullptr) + if (retainedMessage != nullptr) { - return mTransportMgr->SendMessage(mPeerAddress, std::move(message)); + *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain()); } + return mTransportMgr->SendMessage(mPeerAddress, std::move(message)); +} + +CHIP_ERROR SessionEstablishmentExchangeDispatch::ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, + EncryptedPacketBufferHandle * retainedMessage) const +{ + ReturnErrorCodeIf(mTransportMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); - return CHIP_ERROR_INCORRECT_STATE; + if (retainedMessage != nullptr) + { + *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain()); + } + return mTransportMgr->SendMessage(mPeerAddress, std::move(message)); } CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, @@ -60,6 +72,7 @@ bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, u case Protocols::SecureChannel::Id.GetProtocolId(): switch (type) { + case static_cast(Protocols::SecureChannel::MsgType::StandaloneAck): case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamRequest): case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamResponse): case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p1): diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h index b222d7b318ab7b..a6e9a669727fb1 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h @@ -36,13 +36,16 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi virtual ~SessionEstablishmentExchangeDispatch() {} - CHIP_ERROR Init(TransportMgrBase * transportMgr) + CHIP_ERROR Init(Messaging::ReliableMessageMgr * reliableMessageMgr, TransportMgrBase * transportMgr) { ReturnErrorCodeIf(transportMgr == nullptr, CHIP_ERROR_INVALID_ARGUMENT); mTransportMgr = transportMgr; - return CHIP_NO_ERROR; + return ExchangeMessageDispatch::Init(reliableMessageMgr); } + CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, + EncryptedPacketBufferHandle * retainedMessage) const override; + CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId, const Transport::PeerAddress & peerAddress, Messaging::ReliableMessageContext * reliableMessageContext) override; diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 57d5b6f41a4b48..642e8e81537dd9 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -120,7 +120,8 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) TestCASESecurePairingDelegate delegate; CASESession pairing; - NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT( + inSuite, pairing.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == CHIP_NO_ERROR); ExchangeContext * context = ctx.NewExchangeToLocal(&pairing); NL_TEST_ASSERT(inSuite, @@ -135,7 +136,9 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; CASESession pairing1; - NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, + pairing1.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == + CHIP_NO_ERROR); gLoopback.mSentMessageCount = 0; gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; @@ -159,8 +162,12 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte CASESessionSerializable serializableAccessory; gLoopback.mSentMessageCount = 0; - NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, + pairingCommissioner.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == + CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, + pairingAccessory.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == + CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType( diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index e58b57ad24a958..a7512f2d543110 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -41,26 +41,70 @@ using namespace chip::Protocols; using TestContext = chip::Test::MessagingContext; -class LoopbackTransport : public Transport::Base +static void test_os_sleep_ms(uint64_t millisecs) +{ + struct timespec sleep_time; + uint64_t s = millisecs / 1000; + + millisecs -= s * 1000; + sleep_time.tv_sec = static_cast(s); + sleep_time.tv_nsec = static_cast(millisecs * 1000000); + + nanosleep(&sleep_time, nullptr); +} + +class PASETestLoopbackTransport : public Transport::Base { public: CHIP_ERROR SendMessage(const PeerAddress & address, System::PacketBufferHandle && msgBuf) override { ReturnErrorOnFailure(mMessageSendError); mSentMessageCount++; - HandleMessageReceived(address, std::move(msgBuf)); + + if (mNumMessagesToDrop == 0) + { + // The msgBuf is also being used for retransmission. So we cannot hand over the same buffer + // to the receive handler. The receive handler modifies the buffer for extracting headers etc. + // So the buffer passed to receive handler cannot be used for retransmission afterwards. + // Let's clone the message, and provide cloned message to the receive handler. + System::PacketBufferHandle receivedMessage = msgBuf.CloneData(); + HandleMessageReceived(address, std::move(receivedMessage)); + } + else + { + mNumMessagesToDrop--; + mDroppedMessageCount++; + if (mContext != nullptr) + { + test_os_sleep_ms(65); + ReliableMessageMgr * rm = mContext->GetExchangeManager().GetReliableMessageMgr(); + ReliableMessageMgr::Timeout(&mContext->GetSystemLayer(), rm, CHIP_SYSTEM_NO_ERROR); + } + } return CHIP_NO_ERROR; } bool CanSendToPeer(const PeerAddress & address) override { return true; } - uint32_t mSentMessageCount = 0; - CHIP_ERROR mMessageSendError = CHIP_NO_ERROR; + void Reset() + { + mNumMessagesToDrop = 0; + mDroppedMessageCount = 0; + mSentMessageCount = 0; + mMessageSendError = CHIP_NO_ERROR; + mContext = nullptr; + } + + uint32_t mNumMessagesToDrop = 0; + uint32_t mDroppedMessageCount = 0; + uint32_t mSentMessageCount = 0; + CHIP_ERROR mMessageSendError = CHIP_NO_ERROR; + TestContext * mContext = nullptr; }; TransportMgrBase gTransportMgr; -LoopbackTransport gLoopback; +PASETestLoopbackTransport gLoopback; class TestSecurePairingDelegate : public SessionEstablishmentDelegate { @@ -91,6 +135,8 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) TestSecurePairingDelegate delegate; PASESession pairing; + gLoopback.Reset(); + NL_TEST_ASSERT(inSuite, pairing.WaitForPairing(1234, 500, nullptr, 0, 0, &delegate) == CHIP_ERROR_INVALID_ARGUMENT); NL_TEST_ASSERT(inSuite, pairing.WaitForPairing(1234, 500, (const uint8_t *) "saltSalt", 8, 0, nullptr) == CHIP_ERROR_INVALID_ARGUMENT); @@ -106,21 +152,29 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) PASESession pairing; - NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + gLoopback.Reset(); + + NL_TEST_ASSERT( + inSuite, pairing.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == CHIP_NO_ERROR); ExchangeContext * context = ctx.NewExchangeToLocal(&pairing); NL_TEST_ASSERT(inSuite, pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, nullptr, nullptr) != CHIP_NO_ERROR); + + gLoopback.Reset(); NL_TEST_ASSERT(inSuite, pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, context, &delegate) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); + gLoopback.Reset(); gLoopback.mSentMessageCount = 0; gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; PASESession pairing1; - NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, + pairing1.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == + CHIP_NO_ERROR); ExchangeContext * context1 = ctx.NewExchangeToLocal(&pairing1); NL_TEST_ASSERT(inSuite, pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, context1, &delegate) == @@ -138,15 +192,36 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P gLoopback.mSentMessageCount = 0; - NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, + pairingCommissioner.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == + CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, + pairingAccessory.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == + CHIP_NO_ERROR); + + ExchangeContext * contextCommissioner = ctx.NewExchangeToLocal(&pairingCommissioner); + + if (gLoopback.mNumMessagesToDrop != 0) + { + pairingCommissioner.MessageDispatch().SetPeerAddress(PeerAddress(Type::kUdp)); + pairingAccessory.MessageDispatch().SetPeerAddress(PeerAddress(Type::kUdp)); + + ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + ReliableMessageContext * rc = contextCommissioner->GetReliableMessageContext(); + NL_TEST_ASSERT(inSuite, rm != nullptr); + NL_TEST_ASSERT(inSuite, rc != nullptr); + + rc->SetConfig({ + 1, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL + 1, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); + gLoopback.mContext = &ctx; + } NL_TEST_ASSERT(inSuite, ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType( Protocols::SecureChannel::MsgType::PBKDFParamRequest, &pairingAccessory) == CHIP_NO_ERROR); - ExchangeContext * contextCommissioner = ctx.NewExchangeToLocal(&pairingCommissioner); - NL_TEST_ASSERT(inSuite, pairingAccessory.WaitForPairing(1234, 500, (const uint8_t *) "saltSALT", 8, 0, &delegateAccessory) == CHIP_NO_ERROR); @@ -154,22 +229,40 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 5); + // Standalone acks also increment the mSentMessageCount. But some messages could be acked + // via piggybacked acks. So we cannot check for a specific value of mSentMessageCount. + // Let's make sure atleast number is >= than the minimum messages required to complete the + // handshake. + NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount >= 5); NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingComplete == 1); NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 1); + gLoopback.mContext = nullptr; } void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) { TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; + gLoopback.Reset(); + SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, delegateCommissioner); +} + +void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inContext) +{ + TestSecurePairingDelegate delegateCommissioner; + PASESession pairingCommissioner; + gLoopback.Reset(); + gLoopback.mNumMessagesToDrop = 2; SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, delegateCommissioner); + NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 2); + NL_TEST_ASSERT(inSuite, gLoopback.mNumMessagesToDrop == 0); } void SecurePairingDeserialize(nlTestSuite * inSuite, void * inContext, PASESession & pairingCommissioner, PASESession & deserialized) { PASESessionSerialized serialized; + gLoopback.Reset(); NL_TEST_ASSERT(inSuite, pairingCommissioner.Serialize(serialized) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, deserialized.Deserialize(serialized) == CHIP_NO_ERROR); @@ -189,6 +282,8 @@ void SecurePairingSerializeTest(nlTestSuite * inSuite, void * inContext) auto * testPairingSession1 = chip::Platform::New(); auto * testPairingSession2 = chip::Platform::New(); + gLoopback.Reset(); + SecurePairingHandshakeTestCommon(inSuite, inContext, *testPairingSession1, delegateCommissioner); SecurePairingDeserialize(inSuite, inContext, *testPairingSession1, *testPairingSession2); @@ -234,6 +329,7 @@ static const nlTest sTests[] = NL_TEST_DEF("WaitInit", SecurePairingWaitTest), NL_TEST_DEF("Start", SecurePairingStartTest), NL_TEST_DEF("Handshake", SecurePairingHandshakeTest), + NL_TEST_DEF("Handshake with packet loss", SecurePairingHandshakeWithPacketLossTest), NL_TEST_DEF("Serialize", SecurePairingSerializeTest), NL_TEST_SENTINEL()