Skip to content

Commit

Permalink
Use refcounter for secure session
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost committed Apr 21, 2022
1 parent 4b4e035 commit 98a4066
Show file tree
Hide file tree
Showing 21 changed files with 225 additions and 299 deletions.
7 changes: 7 additions & 0 deletions src/lib/core/ReferenceCounted.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class DeleteDeletor
static void Release(T * obj) { chip::Platform::Delete(obj); }
};

template <class T>
class NoopDeletor
{
public:
static void Release(T * obj) {}
};

/**
* A reference counted object maintains a count of usages and when the usage
* count drops to 0, it deletes itself.
Expand Down
5 changes: 5 additions & 0 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ CASESession::~CASESession()
Clear();
}

void CASESession::OnSessionReleased()
{
// TODO: interrupt pairing procedure, call OnSessionEstablishmentError, if the pairing is not finished
}

void CASESession::Finish()
{
mCASESessionEstablished = true;
Expand Down
3 changes: 3 additions & 0 deletions src/protocols/secure_channel/CASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler,
void OnResponseTimeout(Messaging::ExchangeContext * ec) override;
Messaging::ExchangeMessageDispatch & GetMessageDispatch() override { return SessionEstablishmentExchangeDispatch::Instance(); }

//// SessionReleaseDelegate ////
void OnSessionReleased() override;

FabricIndex GetFabricIndex() const { return mFabricInfo != nullptr ? mFabricInfo->GetFabricIndex() : kUndefinedFabricIndex; }

// TODO: remove Clear, we should create a new instance instead reset the old instance.
Expand Down
5 changes: 5 additions & 0 deletions src/protocols/secure_channel/PASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ PASESession::~PASESession()
Clear();
}

void PASESession::OnSessionReleased()
{
// TODO: interrupt pairing procedure, call OnSessionEstablishmentError, if the pairing is not finished
}

void PASESession::Finish()
{
mPairingComplete = true;
Expand Down
3 changes: 3 additions & 0 deletions src/protocols/secure_channel/PASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler,

Messaging::ExchangeMessageDispatch & GetMessageDispatch() override { return SessionEstablishmentExchangeDispatch::Instance(); }

//// SessionReleaseDelegate ////
void OnSessionReleased() override;

private:
enum Spake2pErrorType : uint8_t
{
Expand Down
23 changes: 10 additions & 13 deletions src/protocols/secure_channel/tests/TestCASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,12 @@ CHIP_ERROR InitCredentialSets()

void CASE_SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;

// Test all combinations of invalid parameters
TestCASESecurePairingDelegate delegate;
CASESession pairing;
FabricTable fabrics;
SessionManager sessionManager;

NL_TEST_ASSERT(inSuite, pairing.GetSecureSessionType() == SecureSession::Type::kCASE);

Expand All @@ -202,6 +203,7 @@ void CASE_SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext)
void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;

// Test all combinations of invalid parameters
TestCASESecurePairingDelegate delegate;
Expand All @@ -210,7 +212,6 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)

FabricInfo * fabric = gCommissionerFabrics.FindFabricWithIndex(gCommissionerFabricIndex);
NL_TEST_ASSERT(inSuite, fabric != nullptr);
SessionManager sessionManager;

ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing);

Expand Down Expand Up @@ -250,15 +251,19 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
gLoopback.mMessageSendError = CHIP_NO_ERROR;
}

void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, CASESession & pairingCommissioner,
TestCASESecurePairingDelegate & delegateCommissioner)

void CASE_SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;

TestCASESecurePairingDelegate delegateCommissioner;
CASESession pairingCommissioner;
pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider);

// Test all combinations of invalid parameters
TestCASESecurePairingDelegate delegateAccessory;
CASESession pairingAccessory;
SessionManager sessionManager;

gLoopback.mSentMessageCount = 0;

Expand All @@ -285,14 +290,6 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte
NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 1);
}

void CASE_SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext)
{
TestCASESecurePairingDelegate delegateCommissioner;
CASESession pairingCommissioner;
pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider);
CASE_SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, delegateCommissioner);
}

CASEServerForTest gPairingServer;

void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inContext)
Expand Down
26 changes: 15 additions & 11 deletions src/protocols/secure_channel/tests/TestPASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ using namespace System::Clock::Literals;
void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;

// Test all combinations of invalid parameters
TestSecurePairingDelegate delegate;
PASESession pairing;
SessionManager sessionManager;

NL_TEST_ASSERT(inSuite, pairing.GetSecureSessionType() == SecureSession::Type::kPASE);

Expand Down Expand Up @@ -156,11 +156,11 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext)
void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;

// Test all combinations of invalid parameters
TestSecurePairingDelegate delegate;
PASESession pairing;
SessionManager sessionManager;

gLoopback.Reset();

Expand Down Expand Up @@ -197,7 +197,8 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
gLoopback.mMessageSendError = CHIP_NO_ERROR;
}

void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, PASESession & pairingCommissioner,
void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, SessionManager & sessionManager,
PASESession & pairingCommissioner,
Optional<ReliableMessageProtocolConfig> mrpCommissionerConfig,
Optional<ReliableMessageProtocolConfig> mrpAccessoryConfig,
TestSecurePairingDelegate & delegateCommissioner)
Expand All @@ -206,7 +207,6 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P

TestSecurePairingDelegate delegateAccessory;
PASESession pairingAccessory;
SessionManager sessionManager;

gLoopback.mSentMessageCount = 0;

Expand Down Expand Up @@ -279,53 +279,58 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P

void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
gLoopback.Reset();
SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional<ReliableMessageProtocolConfig>::Missing(),
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, Optional<ReliableMessageProtocolConfig>::Missing(),
Optional<ReliableMessageProtocolConfig>::Missing(), delegateCommissioner);
}

void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
gLoopback.Reset();
ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32);
SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner,
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner,
Optional<ReliableMessageProtocolConfig>::Value(config),
Optional<ReliableMessageProtocolConfig>::Missing(), delegateCommissioner);
}

void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
gLoopback.Reset();
ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32);
SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional<ReliableMessageProtocolConfig>::Missing(),
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, Optional<ReliableMessageProtocolConfig>::Missing(),
Optional<ReliableMessageProtocolConfig>::Value(config), delegateCommissioner);
}

void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
gLoopback.Reset();
ReliableMessageProtocolConfig commissionerConfig(1000_ms32, 10000_ms32);
ReliableMessageProtocolConfig deviceConfig(2000_ms32, 7000_ms32);
SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner,
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner,
Optional<ReliableMessageProtocolConfig>::Value(commissionerConfig),
Optional<ReliableMessageProtocolConfig>::Value(deviceConfig), delegateCommissioner);
}

void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
gLoopback.Reset();
gLoopback.mNumMessagesToDrop = 2;
SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional<ReliableMessageProtocolConfig>::Missing(),
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, Optional<ReliableMessageProtocolConfig>::Missing(),
Optional<ReliableMessageProtocolConfig>::Missing(), delegateCommissioner);
NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 2);
NL_TEST_ASSERT(inSuite, gLoopback.mNumMessagesToDrop == 0);
Expand All @@ -334,15 +339,14 @@ void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inCo
void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;

TestSecurePairingDelegate delegateAccessory;
PASESession pairingAccessory;

SessionManager sessionManager;

gLoopback.Reset();
gLoopback.mSentMessageCount = 0;

Expand Down
23 changes: 19 additions & 4 deletions src/transport/GroupSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,28 @@

#include <app/util/basic-types.h>
#include <lib/core/GroupId.h>
#include <lib/core/ReferenceCounted.h>
#include <lib/support/Pool.h>
#include <transport/Session.h>

namespace chip {
namespace Transport {

class IncomingGroupSession : public Session
class IncomingGroupSession : public Session, public ReferenceCounted<IncomingGroupSession, NoopDeletor<IncomingGroupSession>, 0>
{
public:
IncomingGroupSession(GroupId group, FabricIndex fabricIndex, NodeId sourceNodeId) : mGroupId(group), mSourceNodeId(sourceNodeId)
{
SetFabricIndex(fabricIndex);
}
~IncomingGroupSession() override { NotifySessionReleased(); }
~IncomingGroupSession() override
{
NotifySessionReleased();
VerifyOrDie(GetReferenceCount() == 0);
}

void Retain() override { ReferenceCounted<IncomingGroupSession, NoopDeletor<IncomingGroupSession>, 0>::Retain(); }
void Release() override { ReferenceCounted<IncomingGroupSession, NoopDeletor<IncomingGroupSession>, 0>::Release(); }

Session::SessionType GetSessionType() const override { return Session::SessionType::kGroupIncoming; }
#if CHIP_PROGRESS_LOGGING
Expand Down Expand Up @@ -75,11 +83,18 @@ class IncomingGroupSession : public Session
const NodeId mSourceNodeId;
};

class OutgoingGroupSession : public Session
class OutgoingGroupSession : public Session, public ReferenceCounted<OutgoingGroupSession, NoopDeletor<OutgoingGroupSession>, 0>
{
public:
OutgoingGroupSession(GroupId group, FabricIndex fabricIndex) : mGroupId(group) { SetFabricIndex(fabricIndex); }
~OutgoingGroupSession() override { NotifySessionReleased(); }
~OutgoingGroupSession() override
{
NotifySessionReleased();
VerifyOrDie(GetReferenceCount() == 0);
}

void Retain() override { ReferenceCounted<OutgoingGroupSession, NoopDeletor<OutgoingGroupSession>, 0>::Retain(); }
void Release() override { ReferenceCounted<OutgoingGroupSession, NoopDeletor<OutgoingGroupSession>, 0>::Release(); }

Session::SessionType GetSessionType() const override { return Session::SessionType::kGroupOutgoing; }
#if CHIP_PROGRESS_LOGGING
Expand Down
19 changes: 2 additions & 17 deletions src/transport/PairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ namespace chip {

CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager)
{
auto handle = sessionManager.AllocateSession();
auto handle = sessionManager.AllocateSession(GetSecureSessionType());
VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_NO_MEMORY);
mSecureSessionHolder.Grab(handle.Value());
mSessionManager = &sessionManager;
return CHIP_NO_ERROR;
}

Expand All @@ -39,7 +38,7 @@ CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress &
uint16_t peerSessionId = GetPeerSessionId();
ChipLogDetail(Inet, "New secure session created for device " ChipLogFormatScopedNodeId ", LSID:%d PSID:%d!",
ChipLogValueScopedNodeId(GetPeer()), secureSession->GetLocalSessionId(), peerSessionId);
secureSession->Activate(GetSecureSessionType(), GetPeer(), GetPeerCATs(), peerSessionId, mRemoteMRPConfig);
secureSession->Activate(GetPeer(), GetPeerCATs(), peerSessionId, mRemoteMRPConfig);
secureSession->SetPeerAddress(peerAddress);
ReturnErrorOnFailure(DeriveSecureSession(secureSession->GetCryptoContext()));
secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue);
Expand Down Expand Up @@ -102,22 +101,8 @@ CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::Tag expectedTag, TL

void PairingSession::Clear()
{
if (mSessionManager != nullptr)
{
if (mSecureSessionHolder && !mSecureSessionHolder->AsSecureSession()->IsActiveSession())
{
// Make sure to clean up our pending session, since we're the only
// ones who have access to it do do so.
mSessionManager->ExpirePairing(mSecureSessionHolder.Get());
}
}

mPeerSessionId.ClearValue();
// If we called ExpirePairing above, the holder has already released the
// session (due to it being destroyed). If not, we need to release it.
// Release is idempotent, so it's OK to just call it here.
mSecureSessionHolder.Release();
mSessionManager = nullptr;
}

} // namespace chip
8 changes: 3 additions & 5 deletions src/transport/PairingSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ namespace chip {

class SessionManager;

class DLL_EXPORT PairingSession
class DLL_EXPORT PairingSession : public SessionReleaseDelegate
{
public:
PairingSession() : mSecureSessionHolder(*this) {}
virtual ~PairingSession() { Clear(); }

virtual Transport::SecureSession::Type GetSecureSessionType() const = 0;
Expand Down Expand Up @@ -164,10 +165,7 @@ class DLL_EXPORT PairingSession

protected:
CryptoContext::SessionRole mRole;
SessionHolder mSecureSessionHolder;
// mSessionManager is set if we actually allocate a secure session, so we
// can clean it up later as needed.
SessionManager * mSessionManager = nullptr;
SessionHolderWithDelegate mSecureSessionHolder;

// mLocalMRPConfig is our config which is sent to the other end and used by the peer session.
// mRemoteMRPConfig is received from other end and set to our session.
Expand Down
6 changes: 6 additions & 0 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@

#include <access/AuthMode.h>
#include <transport/SecureSession.h>
#include <transport/SecureSessionTable.h>

namespace chip {
namespace Transport {

void SecureSessionDeleter::Release(SecureSession * entry)
{
entry->mTable.ReleaseSession(entry);
}

ScopedNodeId SecureSession::GetPeer() const
{
return ScopedNodeId(mPeerNodeId, GetFabricIndex());
Expand Down
Loading

0 comments on commit 98a4066

Please sign in to comment.