Skip to content

Commit

Permalink
Use refcounter for secure session
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost committed May 18, 2022
1 parent a777a80 commit c15a17e
Show file tree
Hide file tree
Showing 16 changed files with 230 additions and 290 deletions.
17 changes: 11 additions & 6 deletions src/lib/core/ReferenceCounted.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,26 @@ class DeleteDeletor
static void Release(T * obj) { chip::Platform::Delete(obj); }
};

template <class T>
class NoopDeletor
{
public:
static void Release(T * obj) {}
};

/**
* A reference counted object maintains a count of usages and when the usage
* count drops to 0, it deletes itself.
*/
template <class Subclass, class Deletor = DeleteDeletor<Subclass>, int kInitRefCount = 1>
template <class Subclass, class Deletor = DeleteDeletor<Subclass>, int kInitRefCount = 1, typename CounterType = uint32_t>
class ReferenceCounted
{
public:
using count_type = uint32_t;

/** Adds one to the usage count of this class */
Subclass * Retain()
{
VerifyOrDie(!kInitRefCount || mRefCount > 0);
VerifyOrDie(mRefCount < std::numeric_limits<count_type>::max());
VerifyOrDie(mRefCount < std::numeric_limits<CounterType>::max());
++mRefCount;

return static_cast<Subclass *>(this);
Expand All @@ -71,10 +76,10 @@ class ReferenceCounted
}

/** Get the current reference counter value */
count_type GetReferenceCount() const { return mRefCount; }
CounterType GetReferenceCount() const { return mRefCount; }

private:
count_type mRefCount = kInitRefCount;
CounterType mRefCount = kInitRefCount;
};

} // namespace chip
5 changes: 4 additions & 1 deletion src/messaging/ExchangeMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ CHIP_ERROR ExchangeManager::Shutdown()

ExchangeContext * ExchangeManager::NewContext(const SessionHandle & session, ExchangeDelegate * delegate)
{
// Disallow creating exchange on an inactive session
VerifyOrReturnError(session->IsActiveSession(), nullptr);
return mContextPool.CreateObject(this, mNextExchangeId++, session, true, delegate);
}

Expand Down Expand Up @@ -230,10 +232,11 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const
packetHeader.GetDestinationGroupId().Value());
}

// Do not handle unsolicited messages on a inactive session.
// If it's not a duplicate message, search for an unsolicited message handler if it is marked as being sent by an initiator.
// Since we didn't find an existing exchange that matches the message, it must be an unsolicited message. However all
// unsolicited messages must be marked as being from an initiator.
if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsInitiator())
if (session->IsActiveSession() && !msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsInitiator())
{
// Search for an unsolicited message handler that can handle the message. Prefer handlers that can explicitly
// handle the message type over handlers that handle all messages for a profile.
Expand Down
21 changes: 4 additions & 17 deletions src/protocols/secure_channel/PairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ namespace chip {

CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager)
{
auto handle = sessionManager.AllocateSession();
auto handle = sessionManager.AllocateSession(GetSecureSessionType());
VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_NO_MEMORY);
mSecureSessionHolder.Grab(handle.Value());
mSecureSessionHolder.GrabPairing(handle.Value());
mSessionManager = &sessionManager;
return CHIP_NO_ERROR;
}
Expand All @@ -48,8 +48,7 @@ CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress &

// Call Activate last, otherwise errors on anything after would lead to
// a partially valid session.
secureSession->Activate(GetSecureSessionType(), GetLocalScopedNodeId(), GetPeer(), GetPeerCATs(), peerSessionId,
mRemoteMRPConfig);
secureSession->Activate(GetLocalScopedNodeId(), GetPeer(), GetPeerCATs(), peerSessionId, mRemoteMRPConfig);

ChipLogDetail(Inet, "New secure session created for device " ChipLogFormatScopedNodeId ", LSID:%d PSID:%d!",
ChipLogValueScopedNodeId(GetPeer()), secureSession->GetLocalSessionId(), peerSessionId);
Expand Down Expand Up @@ -153,19 +152,7 @@ void PairingSession::Clear()
mExchangeCtxt = nullptr;
}

if (mSecureSessionHolder)
{
SessionHandle session = mSecureSessionHolder.Get();
// Call Release before ExpirePairing because we don't want to receive OnSessionReleased() event here
mSecureSessionHolder.Release();
if (!session->AsSecureSession()->IsActiveSession() && mSessionManager != nullptr)
{
// Make sure to clean up our pending session, since we're the only
// ones who have access to it do do so.
mSessionManager->ExpirePairing(session);
}
}

mSecureSessionHolder.Release();
mPeerSessionId.ClearValue();
mSessionManager = nullptr;
}
Expand Down
23 changes: 19 additions & 4 deletions src/transport/GroupSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,28 @@

#include <app/util/basic-types.h>
#include <lib/core/GroupId.h>
#include <lib/core/ReferenceCounted.h>
#include <lib/support/Pool.h>
#include <transport/Session.h>

namespace chip {
namespace Transport {

class IncomingGroupSession : public Session
class IncomingGroupSession : public Session, public ReferenceCounted<IncomingGroupSession, NoopDeletor<IncomingGroupSession>, 0>
{
public:
IncomingGroupSession(GroupId group, FabricIndex fabricIndex, NodeId peerNodeId) : mGroupId(group), mPeerNodeId(peerNodeId)
{
SetFabricIndex(fabricIndex);
}
~IncomingGroupSession() override { NotifySessionReleased(); }
~IncomingGroupSession() override
{
NotifySessionReleased();
VerifyOrDie(GetReferenceCount() == 0);
}

void Retain() override { ReferenceCounted<IncomingGroupSession, NoopDeletor<IncomingGroupSession>, 0>::Retain(); }
void Release() override { ReferenceCounted<IncomingGroupSession, NoopDeletor<IncomingGroupSession>, 0>::Release(); }

Session::SessionType GetSessionType() const override { return Session::SessionType::kGroupIncoming; }
#if CHIP_PROGRESS_LOGGING
Expand Down Expand Up @@ -74,11 +82,18 @@ class IncomingGroupSession : public Session
const NodeId mPeerNodeId;
};

class OutgoingGroupSession : public Session
class OutgoingGroupSession : public Session, public ReferenceCounted<OutgoingGroupSession, NoopDeletor<OutgoingGroupSession>, 0>
{
public:
OutgoingGroupSession(GroupId group, FabricIndex fabricIndex) : mGroupId(group) { SetFabricIndex(fabricIndex); }
~OutgoingGroupSession() override { NotifySessionReleased(); }
~OutgoingGroupSession() override
{
NotifySessionReleased();
VerifyOrDie(GetReferenceCount() == 0);
}

void Retain() override { ReferenceCounted<OutgoingGroupSession, NoopDeletor<OutgoingGroupSession>, 0>::Retain(); }
void Release() override { ReferenceCounted<OutgoingGroupSession, NoopDeletor<OutgoingGroupSession>, 0>::Release(); }

Session::SessionType GetSessionType() const override { return Session::SessionType::kGroupOutgoing; }
#if CHIP_PROGRESS_LOGGING
Expand Down
29 changes: 29 additions & 0 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,39 @@

#include <access/AuthMode.h>
#include <transport/SecureSession.h>
#include <transport/SecureSessionTable.h>

namespace chip {
namespace Transport {

void SecureSessionDeleter::Release(SecureSession * entry)
{
entry->mTable.ReleaseSession(entry);
}

void SecureSession::MarkForRemoval()
{
ChipLogDetail(Inet, "SecureSession MarkForRemoval %p Type:%d LSID:%d", this, to_underlying(mSecureSessionType),
mLocalSessionId);
ReferenceCountedHandle<Transport::Session> ref(*this);
switch (mState)
{
case State::kPairing:
mState = State::kPendingRemoval;
// Interrupt the pairing
NotifySessionReleased();
return;
case State::kActive:
Release(); // Decrease the ref which is retained at Activate
mState = State::kPendingRemoval;
NotifySessionReleased();
return;
case State::kPendingRemoval:
// Do nothing
return;
}
}

Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const
{
Access::SubjectDescriptor subjectDescriptor;
Expand Down
Loading

0 comments on commit c15a17e

Please sign in to comment.