Skip to content

Commit

Permalink
Make sure we handle BUSY responses during CASE properly. (#28901)
Browse files Browse the repository at this point in the history
If we got a BUSY response during a CASE handshake and successfully read the wait
time, we would treat that as a success case, not a failure case, and not realize
that our exchange has been closed.  That could lead to use-after-free when we
later tried to abort an already-closed exchange.

The problem was introduced in 59a0b2f
(PR #28153).

The new unit test fails (both with ASAN failures and with incorrect state
because the client that got BUSY does not think the handshake failed) without
this fix.
  • Loading branch information
bzbarsky-apple authored and pull[bot] committed Feb 22, 2024
1 parent acb9a27 commit 1154284
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 25 deletions.
44 changes: 20 additions & 24 deletions src/protocols/secure_channel/PairingSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,44 +162,40 @@ class DLL_EXPORT PairingSession : public SessionDelegate
CHIP_ERROR HandleStatusReport(System::PacketBufferHandle && msg, bool successExpected)
{
Protocols::SecureChannel::StatusReport report;
CHIP_ERROR err = report.Parse(std::move(msg));
ReturnErrorOnFailure(err);
ReturnErrorOnFailure(report.Parse(std::move(msg)));
VerifyOrReturnError(report.GetProtocolId() == Protocols::SecureChannel::Id, CHIP_ERROR_INVALID_ARGUMENT);

if (report.GetGeneralCode() == Protocols::SecureChannel::GeneralStatusCode::kSuccess &&
report.GetProtocolCode() == Protocols::SecureChannel::kProtocolCodeSuccess && successExpected)
{
OnSuccessStatusReport();
return CHIP_NO_ERROR;
}
else
{
err = OnFailureStatusReport(report.GetGeneralCode(), report.GetProtocolCode());

if (report.GetGeneralCode() == Protocols::SecureChannel::GeneralStatusCode::kBusy &&
report.GetProtocolCode() == Protocols::SecureChannel::kProtocolCodeBusy)
if (report.GetGeneralCode() == Protocols::SecureChannel::GeneralStatusCode::kBusy &&
report.GetProtocolCode() == Protocols::SecureChannel::kProtocolCodeBusy)
{
if (!report.GetProtocolData().IsNull())
{
if (!report.GetProtocolData().IsNull())
Encoding::LittleEndian::Reader reader(report.GetProtocolData()->Start(), report.GetProtocolData()->DataLength());

uint16_t minimumWaitTime = 0;
CHIP_ERROR waitTimeErr = reader.Read16(&minimumWaitTime).StatusCode();
if (waitTimeErr != CHIP_NO_ERROR)
{
ChipLogError(SecureChannel, "Failed to read the minimum wait time: %" CHIP_ERROR_FORMAT, waitTimeErr.Format());
}
else
{
Encoding::LittleEndian::Reader reader(report.GetProtocolData()->Start(),
report.GetProtocolData()->DataLength());

uint16_t minimumWaitTime = 0;
err = reader.Read16(&minimumWaitTime).StatusCode();
if (err != CHIP_NO_ERROR)
{
ChipLogError(SecureChannel, "Failed to read the minimum wait time: %" CHIP_ERROR_FORMAT, err.Format());
}
else
{
// TODO: CASE: Notify minimum wait time to clients on receiving busy status report #28290
ChipLogProgress(SecureChannel, "Received busy status report with minimum wait time: %u ms",
minimumWaitTime);
}
// TODO: CASE: Notify minimum wait time to clients on receiving busy status report #28290
ChipLogProgress(SecureChannel, "Received busy status report with minimum wait time: %u ms", minimumWaitTime);
}
}
}

return err;
// It's very important that we propagate the return value from
// OnFailureStatusReport out to the caller. Make sure we return it directly.
return OnFailureStatusReport(report.GetGeneralCode(), report.GetProtocolCode());
}

/**
Expand Down
64 changes: 63 additions & 1 deletion src/protocols/secure_channel/tests/TestCASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,14 @@ CHIP_ERROR InitFabricTable(chip::FabricTable & fabricTable, chip::TestPersistent
class TestCASESecurePairingDelegate : public SessionEstablishmentDelegate
{
public:
void OnSessionEstablishmentError(CHIP_ERROR error) override { mNumPairingErrors++; }
void OnSessionEstablishmentError(CHIP_ERROR error) override
{
mNumPairingErrors++;
if (error == CHIP_ERROR_BUSY)
{
mNumBusyResponses++;
}
}

void OnSessionEstablished(const SessionHandle & session) override
{
Expand All @@ -137,6 +144,7 @@ class TestCASESecurePairingDelegate : public SessionEstablishmentDelegate
// TODO: Rename mNumPairing* to mNumEstablishment*
uint32_t mNumPairingErrors = 0;
uint32_t mNumPairingComplete = 0;
uint32_t mNumBusyResponses = 0;
};

class TestOperationalKeystore : public chip::Crypto::OperationalKeystore
Expand Down Expand Up @@ -314,6 +322,7 @@ class TestCASESession
static void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext);
static void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext);
static void SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inContext);
static void ClientReceivesBusyTest(nlTestSuite * inSuite, void * inContext);
static void Sigma1ParsingTest(nlTestSuite * inSuite, void * inContext);
static void DestinationIdTest(nlTestSuite * inSuite, void * inContext);
static void SessionResumptionStorage(nlTestSuite * inSuite, void * inContext);
Expand Down Expand Up @@ -536,6 +545,58 @@ void TestCASESession::SecurePairingHandshakeServerTest(nlTestSuite * inSuite, vo

chip::Platform::Delete(pairingCommissioner);
chip::Platform::Delete(pairingCommissioner1);

gPairingServer.Shutdown();
}

void TestCASESession::ClientReceivesBusyTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestCASESecurePairingDelegate delegateCommissioner1, delegateCommissioner2;
CASESession pairingCommissioner1, pairingCommissioner2;

pairingCommissioner1.SetGroupDataProvider(&gCommissionerGroupDataProvider);
pairingCommissioner2.SetGroupDataProvider(&gCommissionerGroupDataProvider);

auto & loopback = ctx.GetLoopback();
loopback.mSentMessageCount = 0;

NL_TEST_ASSERT(inSuite,
gPairingServer.ListenForSessionEstablishment(&ctx.GetExchangeManager(), &ctx.GetSecureSessionManager(),
&gDeviceFabrics, nullptr, nullptr,
&gDeviceGroupDataProvider) == CHIP_NO_ERROR);

ExchangeContext * contextCommissioner1 = ctx.NewUnauthenticatedExchangeToBob(&pairingCommissioner1);
ExchangeContext * contextCommissioner2 = ctx.NewUnauthenticatedExchangeToBob(&pairingCommissioner2);

NL_TEST_ASSERT(inSuite,
pairingCommissioner1.EstablishSession(sessionManager, &gCommissionerFabrics,
ScopedNodeId{ Node01_01, gCommissionerFabricIndex }, contextCommissioner1,
nullptr, nullptr, &delegateCommissioner1, NullOptional) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite,
pairingCommissioner2.EstablishSession(sessionManager, &gCommissionerFabrics,
ScopedNodeId{ Node01_01, gCommissionerFabricIndex }, contextCommissioner2,
nullptr, nullptr, &delegateCommissioner2, NullOptional) == CHIP_NO_ERROR);

ServiceEvents(ctx);

// We should have one full handshake and one Sigma1 + Busy + ack. If that
// ever changes (e.g. because our server starts supporting multiple parallel
// handshakes), this test needs to be fixed so that the server is still
// responding BUSY to the client.
NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == sTestCaseMessageCount + 3);
NL_TEST_ASSERT(inSuite, delegateCommissioner1.mNumPairingComplete == 1);
NL_TEST_ASSERT(inSuite, delegateCommissioner2.mNumPairingComplete == 0);

NL_TEST_ASSERT(inSuite, delegateCommissioner1.mNumPairingErrors == 0);
NL_TEST_ASSERT(inSuite, delegateCommissioner2.mNumPairingErrors == 1);

NL_TEST_ASSERT(inSuite, delegateCommissioner1.mNumBusyResponses == 0);
NL_TEST_ASSERT(inSuite, delegateCommissioner2.mNumBusyResponses == 1);

gPairingServer.Shutdown();
}

struct Sigma1Params
Expand Down Expand Up @@ -1115,6 +1176,7 @@ static const nlTest sTests[] =
NL_TEST_DEF("Start", chip::TestCASESession::SecurePairingStartTest),
NL_TEST_DEF("Handshake", chip::TestCASESession::SecurePairingHandshakeTest),
NL_TEST_DEF("ServerHandshake", chip::TestCASESession::SecurePairingHandshakeServerTest),
NL_TEST_DEF("ClientReceivesBusy", chip::TestCASESession::ClientReceivesBusyTest),
NL_TEST_DEF("Sigma1Parsing", chip::TestCASESession::Sigma1ParsingTest),
NL_TEST_DEF("DestinationId", chip::TestCASESession::DestinationIdTest),
NL_TEST_DEF("SessionResumptionStorage", chip::TestCASESession::SessionResumptionStorage),
Expand Down

0 comments on commit 1154284

Please sign in to comment.