Skip to content

Commit

Permalink
Add CRMP test for retransmission of session establishment messages (#…
Browse files Browse the repository at this point in the history
…7234)

* Add CRMP test for retransmission of session establishment messages

* Fix test

* Add packet loss test to TestPASESession

* add some comments

* reduce stack usage due to nrfconnect platform limits

* clear transport state before running test

* fixes for testing on nrfconnect
  • Loading branch information
pan-apple authored Jun 1, 2021
1 parent d6ec067 commit 0f94665
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/app/server/RendezvousServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ CHIP_ERROR RendezvousServer::WaitForPairing(const RendezvousParameters & params,
strlen(kSpake2pKeyExchangeSalt), mNextKeyId++, this));
}

ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr));
ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(mExchangeManager->GetReliableMessageMgr(), transportMgr));
mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress());

return CHIP_NO_ERROR;
Expand Down
3 changes: 2 additions & 1 deletion src/controller/CHIPDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,8 @@ CHIP_ERROR Device::EstablishCASESession()
Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(SecureSessionHandle(), &mCASESession);
VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL);

ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager->GetTransportManager()));
ReturnErrorOnFailure(
mCASESession.MessageDispatch().Init(mExchangeMgr->GetReliableMessageMgr(), mSessionManager->GetTransportManager()));
mCASESession.MessageDispatch().SetPeerAddress(mDeviceAddress);

ReturnErrorOnFailure(mCASESession.EstablishSession(mDeviceAddress, mCredentials, mDeviceId, 0, exchange, this));
Expand Down
2 changes: 1 addition & 1 deletion src/controller/CHIPDeviceController.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam

mIsIPRendezvous = (params.GetPeerAddress().GetTransportType() != Transport::Type::kBle);

err = mPairingSession.MessageDispatch().Init(mTransportMgr);
err = mPairingSession.MessageDispatch().Init(mExchangeMgr->GetReliableMessageMgr(), mTransportMgr);
SuccessOrExit(err);
mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress());

Expand Down
3 changes: 2 additions & 1 deletion src/messaging/ReliableMessageMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ CHIP_ERROR ReliableMessageMgr::AddToRetransTable(ReliableMessageContext * rc, Re

void ReliableMessageMgr::StartRetransmision(RetransTableEntry * entry)
{
VerifyOrDie(entry != nullptr && entry->rc != nullptr);
VerifyOrReturn(entry != nullptr && entry->rc != nullptr,
ChipLogError(ExchangeManager, "StartRetransmission was called for invalid entry"));

entry->nextRetransTimeTick = static_cast<uint16_t>(entry->rc->GetInitialRetransmitTimeoutTick() +
GetTickCounterFromTimeDelta(System::Timer::GetCurrentEpoch()));
Expand Down
106 changes: 106 additions & 0 deletions src/messaging/tests/TestReliableMessageProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessa
return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(message));
}

CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message,
EncryptedPacketBufferHandle * retainedMessage) const override
{
if (retainedMessage != nullptr && mRetainMessageOnSend)
{
*retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain());
}
return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(message));
}

bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; }

bool mRetainMessageOnSend = true;
Expand All @@ -140,6 +150,12 @@ class MockSessionEstablishmentDelegate : public ExchangeDelegate
System::PacketBufferHandle && buffer) override
{
IsOnMessageReceivedCalled = true;
ec->Close();
if (mTestSuite != nullptr)
{
NL_TEST_ASSERT(mTestSuite, buffer->TotalLength() == sizeof(PAYLOAD));
NL_TEST_ASSERT(mTestSuite, memcmp(buffer->Start(), PAYLOAD, buffer->TotalLength()) == 0);
}
}

void OnResponseTimeout(ExchangeContext * ec) override {}
Expand All @@ -151,6 +167,7 @@ class MockSessionEstablishmentDelegate : public ExchangeDelegate

bool IsOnMessageReceivedCalled = false;
MockSessionEstablishmentExchangeDispatch mMessageDispatch;
nlTestSuite * mTestSuite = nullptr;
};

void test_os_sleep_ms(uint64_t millisecs)
Expand Down Expand Up @@ -464,6 +481,94 @@ void CheckResendApplicationMessageWithPeerExchange(nlTestSuite * inSuite, void *
rm->ClearRetransTable(rc);
}

void CheckResendSessionEstablishmentMessageWithPeerExchange(nlTestSuite * inSuite, void * inContext)
{
// Making this static to reduce stack usage, as some platforms have limits on stack size.
static TestContext ctx;

CHIP_ERROR err = ctx.Init(inSuite, &gTransportMgr);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

ctx.SetSourceNodeId(kAnyNodeId);
ctx.SetDestinationNodeId(kAnyNodeId);
ctx.SetLocalKeyId(0);
ctx.SetPeerKeyId(0);
ctx.SetAdminId(kUndefinedAdminId);

ctx.GetInetLayer().SystemLayer()->Init(nullptr);

chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD));
NL_TEST_ASSERT(inSuite, !buffer.IsNull());

MockSessionEstablishmentDelegate mockReceiver;
err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

mockReceiver.mTestSuite = inSuite;

MockSessionEstablishmentDelegate mockSender;
ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender);
NL_TEST_ASSERT(inSuite, exchange != nullptr);

ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
ReliableMessageContext * rc = exchange->GetReliableMessageContext();
NL_TEST_ASSERT(inSuite, rm != nullptr);
NL_TEST_ASSERT(inSuite, rc != nullptr);

rc->SetConfig({
1, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL
1, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL
});

err = mockSender.mMessageDispatch.Init(rm);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

// Let's drop the initial message
gLoopback.mSendMessageCount = 0;
gLoopback.mNumMessagesToDrop = 1;
gLoopback.mDroppedMessageCount = 0;

// Ensure the retransmit table is empty right now
NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0);

err = exchange->SendMessage(Echo::MsgType::EchoRequest, std::move(buffer));
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
exchange->Close();

// Ensure the message was dropped, and was added to retransmit table
NL_TEST_ASSERT(inSuite, gLoopback.mNumMessagesToDrop == 0);
NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 1);
NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 1);
NL_TEST_ASSERT(inSuite, !mockReceiver.IsOnMessageReceivedCalled);

// 1 tick is 64 ms, sleep 65 ms to trigger first re-transmit
test_os_sleep_ms(65);
ReliableMessageMgr::Timeout(&ctx.GetSystemLayer(), rm, CHIP_SYSTEM_NO_ERROR);

// Ensure the retransmit message was not dropped, and is no longer in the retransmit table
NL_TEST_ASSERT(inSuite, gLoopback.mSendMessageCount >= 2);
NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 1);
NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0);
NL_TEST_ASSERT(inSuite, mockReceiver.IsOnMessageReceivedCalled);

mockReceiver.mTestSuite = nullptr;

err = ctx.GetExchangeManager().UnregisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

rm->ClearRetransTable(rc);
ctx.Shutdown();

// This test didn't use the global test context because the session establishment messages
// do not carry encryption key IDs (as the messages are not encrypted), or node IDs (as these
// are not assigned yet). A temporary context is created with default values for these
// parameters.
// Let's reset the state of transport manager so that other tests are not impacted
// as those could be using the global test context.
TestContext & inctx = *static_cast<TestContext *>(inContext);
gTransportMgr.SetSecureSessionMgr(&inctx.GetSecureSessionManager());
}

void CheckSendStandaloneAckMessage(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
Expand Down Expand Up @@ -498,6 +603,7 @@ const nlTest sTests[] =
NL_TEST_DEF("Test ReliableMessageMgr::CheckCloseExchangeAndResendApplicationMessage", CheckCloseExchangeAndResendApplicationMessage),
NL_TEST_DEF("Test ReliableMessageMgr::CheckFailedMessageRetainOnSend", CheckFailedMessageRetainOnSend),
NL_TEST_DEF("Test ReliableMessageMgr::CheckResendApplicationMessageWithPeerExchange", CheckResendApplicationMessageWithPeerExchange),
NL_TEST_DEF("Test ReliableMessageMgr::CheckResendSessionEstablishmentMessageWithPeerExchange", CheckResendSessionEstablishmentMessageWithPeerExchange),
NL_TEST_DEF("Test ReliableMessageMgr::CheckSendStandaloneAckMessage", CheckSendStandaloneAckMessage),

NL_TEST_SENTINEL()
Expand Down
2 changes: 1 addition & 1 deletion src/protocols/secure_channel/CASEServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ CHIP_ERROR CASEServer::ListenForSessionEstablishment(Messaging::ExchangeManager
mAdmins = admins;
mExchangeManager = exchangeManager;

ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr));
ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(mExchangeManager->GetReliableMessageMgr(), transportMgr));

ExchangeDelegate * delegate = this;
ReturnErrorOnFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,29 @@ CHIP_ERROR SessionEstablishmentExchangeDispatch::SendMessageImpl(SecureSessionHa
System::PacketBufferHandle && message,
EncryptedPacketBufferHandle * retainedMessage)
{
ReturnErrorCodeIf(mTransportMgr == nullptr, CHIP_ERROR_INCORRECT_STATE);
PacketHeader packetHeader;

ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message));
ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message));

if (mTransportMgr != nullptr)
if (retainedMessage != nullptr)
{
return mTransportMgr->SendMessage(mPeerAddress, std::move(message));
*retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain());
}
return mTransportMgr->SendMessage(mPeerAddress, std::move(message));
}

CHIP_ERROR SessionEstablishmentExchangeDispatch::ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message,
EncryptedPacketBufferHandle * retainedMessage) const
{
ReturnErrorCodeIf(mTransportMgr == nullptr, CHIP_ERROR_INCORRECT_STATE);

return CHIP_ERROR_INCORRECT_STATE;
if (retainedMessage != nullptr)
{
*retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain());
}
return mTransportMgr->SendMessage(mPeerAddress, std::move(message));
}

CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId,
Expand All @@ -60,6 +72,7 @@ bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, u
case Protocols::SecureChannel::Id.GetProtocolId():
switch (type)
{
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::StandaloneAck):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PBKDFParamRequest):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PBKDFParamResponse):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PASE_Spake2p1):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi

virtual ~SessionEstablishmentExchangeDispatch() {}

CHIP_ERROR Init(TransportMgrBase * transportMgr)
CHIP_ERROR Init(Messaging::ReliableMessageMgr * reliableMessageMgr, TransportMgrBase * transportMgr)
{
ReturnErrorCodeIf(transportMgr == nullptr, CHIP_ERROR_INVALID_ARGUMENT);
mTransportMgr = transportMgr;
return CHIP_NO_ERROR;
return ExchangeMessageDispatch::Init(reliableMessageMgr);
}

CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message,
EncryptedPacketBufferHandle * retainedMessage) const override;

CHIP_ERROR OnMessageReceived(const PayloadHeader & payloadHeader, uint32_t messageId,
const Transport::PeerAddress & peerAddress,
Messaging::ReliableMessageContext * reliableMessageContext) override;
Expand Down
15 changes: 11 additions & 4 deletions src/protocols/secure_channel/tests/TestCASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
TestCASESecurePairingDelegate delegate;
CASESession pairing;

NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(
inSuite, pairing.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == CHIP_NO_ERROR);
ExchangeContext * context = ctx.NewExchangeToLocal(&pairing);

NL_TEST_ASSERT(inSuite,
Expand All @@ -135,7 +136,9 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST;

CASESession pairing1;
NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite,
pairing1.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) ==
CHIP_NO_ERROR);

gLoopback.mSentMessageCount = 0;
gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST;
Expand All @@ -159,8 +162,12 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte
CASESessionSerializable serializableAccessory;

gLoopback.mSentMessageCount = 0;
NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite,
pairingCommissioner.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) ==
CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite,
pairingAccessory.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) ==
CHIP_NO_ERROR);

NL_TEST_ASSERT(inSuite,
ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(
Expand Down
Loading

0 comments on commit 0f94665

Please sign in to comment.