diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 3c4d2572fb926b..9dfca878346929 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -628,6 +628,7 @@ CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg) ChipLogError(SecureChannel, "CASE failed to match destination ID with local fabrics"); ChipLogByteSpan(SecureChannel, destinationIdentifier); } + SuccessOrExit(err); // ParseSigma1 ensures that: // mRemotePubKey.Length() == initiatorPubKey.size() == kP256_PublicKey_Length. diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 9c62a48e046908..8af555796e5bec 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -310,6 +310,7 @@ class TestCASESession #if CONFIG_BUILD_FOR_HOST_UNIT_TEST static void SimulateUpdateNOCInvalidatePendingEstablishment(nlTestSuite * inSuite, void * inContext); #endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST + static void Sigma1BadDestinationIdTest(nlTestSuite * inSuite, void * inContext); }; void TestCASESession::SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) @@ -1005,6 +1006,88 @@ void TestCASESession::SimulateUpdateNOCInvalidatePendingEstablishment(nlTestSuit } #endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST +namespace { +class ExpectErrorExchangeDelegate : public ExchangeDelegate +{ +public: + ExpectErrorExchangeDelegate(nlTestSuite * suite, uint16_t expectedProtocolCode) : + mSuite(suite), mExpectedProtocolCode(expectedProtocolCode) + {} + +private: + CHIP_ERROR OnMessageReceived(ExchangeContext * ec, const PayloadHeader & payloadHeader, + System::PacketBufferHandle && buf) override + { + using namespace SecureChannel; + + NL_TEST_ASSERT(mSuite, payloadHeader.HasMessageType(MsgType::StatusReport)); + + SecureChannel::StatusReport statusReport; + CHIP_ERROR err = statusReport.Parse(std::move(buf)); + NL_TEST_ASSERT(mSuite, err == CHIP_NO_ERROR); + + NL_TEST_ASSERT(mSuite, statusReport.GetProtocolId() == SecureChannel::Id); + NL_TEST_ASSERT(mSuite, statusReport.GetGeneralCode() == GeneralStatusCode::kFailure); + NL_TEST_ASSERT(mSuite, statusReport.GetProtocolCode() == mExpectedProtocolCode); + return CHIP_NO_ERROR; + } + + void OnResponseTimeout(ExchangeContext * ec) override {} + + Messaging::ExchangeMessageDispatch & GetMessageDispatch() override { return SessionEstablishmentExchangeDispatch::Instance(); } + + nlTestSuite * mSuite; + uint16_t mExpectedProtocolCode; +}; +} // anonymous namespace + +void TestCASESession::Sigma1BadDestinationIdTest(nlTestSuite * inSuite, void * inContext) +{ + using SecureChannel::MsgType; + + TestContext & ctx = *reinterpret_cast(inContext); + + SessionManager & sessionManager = ctx.GetSecureSessionManager(); + + constexpr size_t bufferSize = 600; + System::PacketBufferHandle data = chip::System::PacketBufferHandle::New(bufferSize); + NL_TEST_ASSERT(inSuite, !data.IsNull()); + + MutableByteSpan buf(data->Start(), data->AvailableDataLength()); + // This uses a bogus destination id that is not going to match anything in practice. + CHIP_ERROR err = EncodeSigma1(buf); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + data->SetDataLength(static_cast(buf.size())); + + Optional session = sessionManager.CreateUnauthenticatedSession(ctx.GetAliceAddress(), GetDefaultMRPConfig()); + NL_TEST_ASSERT(inSuite, session.HasValue()); + + TestCASESecurePairingDelegate caseDelegate; + CASESession caseSession; + caseSession.SetGroupDataProvider(&gDeviceGroupDataProvider); + err = caseSession.PrepareForSessionEstablishment(sessionManager, &gDeviceFabrics, nullptr, nullptr, &caseDelegate, + ScopedNodeId(), NullOptional); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(MsgType::CASE_Sigma1, &caseSession); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + ExpectErrorExchangeDelegate delegate(inSuite, SecureChannel::kProtocolCodeNoSharedRoot); + ExchangeContext * exchange = ctx.GetExchangeManager().NewContext(session.Value(), &delegate); + NL_TEST_ASSERT(inSuite, exchange != nullptr); + + err = exchange->SendMessage(MsgType::CASE_Sigma1, std::move(data), SendMessageFlags::kExpectResponse); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + ctx.DrainAndServiceIO(); + + NL_TEST_ASSERT(inSuite, caseDelegate.mNumPairingErrors == 1); + NL_TEST_ASSERT(inSuite, caseDelegate.mNumPairingComplete == 0); + + ctx.GetExchangeManager().UnregisterUnsolicitedMessageHandlerForType(MsgType::CASE_Sigma1); + caseSession.Clear(); +} + } // namespace chip // Test Suite @@ -1027,6 +1110,7 @@ static const nlTest sTests[] = // CASESession that are in the process of establishing. NL_TEST_DEF("InvalidatePendingSessionEstablishment", chip::TestCASESession::SimulateUpdateNOCInvalidatePendingEstablishment), #endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST + NL_TEST_DEF("Sigma1BadDestinationId", chip::TestCASESession::Sigma1BadDestinationIdTest), NL_TEST_SENTINEL() };