Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for re-using CASE sessions #17266

Merged
merged 4 commits into from
Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions examples/tv-casting-app/linux/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,21 +328,14 @@ class TargetVideoPlayerInfo
.clientPool = &gCASEClientPool,
};

PeerId peerID = fabric->GetPeerIdForNode(nodeId);
mOperationalDeviceProxy = chip::Platform::New<chip::OperationalDeviceProxy>(initParams, peerID);
PeerId peerID = fabric->GetPeerIdForNode(nodeId);

// TODO: figure out why this doesn't work so that we can remove OperationalDeviceProxy creation above,
// and remove the FindSecureSessionForNode and SetConnectedSession calls below
// mOperationalDeviceProxy = server->GetCASESessionManager()->FindExistingSession(nodeId);
mOperationalDeviceProxy = server->GetCASESessionManager()->FindExistingSession(peerID);
mrjerryjohns marked this conversation as resolved.
Show resolved Hide resolved
if (mOperationalDeviceProxy == nullptr)
{
ChipLogError(AppServer, "Failed in creating an instance of OperationalDeviceProxy");
ChipLogError(AppServer, "Failed to find an existing instance of OperationalDeviceProxy to the peer");
return CHIP_ERROR_INVALID_ARGUMENT;
}
ChipLogError(AppServer, "Created an instance of OperationalDeviceProxy");

SessionHandle handle = server->GetSecureSessionManager().FindSecureSessionForNode(nodeId);
mOperationalDeviceProxy->SetConnectedSession(handle);

mInitialized = true;
return CHIP_NO_ERROR;
Expand Down
2 changes: 1 addition & 1 deletion src/app/CASESessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ CHIP_ERROR CASESessionManager::FindOrEstablishSession(PeerId peerId, Callback::C
OperationalDeviceProxy * session = FindExistingSession(peerId);
if (session == nullptr)
{
ChipLogDetail(CASESessionManager, "FindOrEstablishSession: No existing session found");
ChipLogDetail(CASESessionManager, "FindOrEstablishSession: No existing OperationalDeviceProxy instance found");

session = mConfig.devicePool->Allocate(mConfig.sessionInitParams, peerId);

Expand Down
132 changes: 86 additions & 46 deletions src/app/OperationalDeviceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "CASEClient.h"
#include "CommandSender.h"
#include "ReadPrepareParams.h"
#include "transport/SecureSession.h"

#include <lib/address_resolve/AddressResolve.h>
#include <lib/core/CHIPCore.h>
Expand Down Expand Up @@ -57,10 +58,36 @@ void OperationalDeviceProxy::MoveToState(State aTargetState)
}
}

bool OperationalDeviceProxy::CheckAndLoadExistingSession()
mrjerryjohns marked this conversation as resolved.
Show resolved Hide resolved
mrjerryjohns marked this conversation as resolved.
Show resolved Hide resolved
{
VerifyOrReturnError(mState == State::NeedsAddress || mState == State::Initialized, false);

SessionHolder existingSession;
mrjerryjohns marked this conversation as resolved.
Show resolved Hide resolved
ScopedNodeId peerNodeId(mPeerId.GetNodeId(), mFabricInfo->GetFabricIndex());

mInitParams.sessionManager->FindSecureSessionForNode(mSecureSession, peerNodeId, Transport::SecureSession::Type::kCASE);
if (mSecureSession)
{
ChipLogProgress(Controller, "Found an existing secure session to [" ChipLogFormatX64 ":" ChipLogFormatX64 "]!",
mrjerryjohns marked this conversation as resolved.
Show resolved Hide resolved
ChipLogValueX64(mPeerId.GetCompressedFabricId()), ChipLogValueX64(mPeerId.GetNodeId()));
return true;
}

return false;
}

CHIP_ERROR OperationalDeviceProxy::Connect(Callback::Callback<OnDeviceConnected> * onConnection,
Callback::Callback<OnDeviceConnectionFailure> * onFailure)
{
CHIP_ERROR err = CHIP_NO_ERROR;
CHIP_ERROR err = CHIP_NO_ERROR;
bool isConnected = false;

//
// Always enqueue our user provided callbacks into our callback list.
// If anything goes wrong below, we'll trigger failures (including any queued from
// a previous iteration which in theory shouldn't happen, but this is written to be more defensive)
//
EnqueueConnectionCallbacks(onConnection, onFailure);

switch (mState)
{
Expand All @@ -69,35 +96,47 @@ CHIP_ERROR OperationalDeviceProxy::Connect(Callback::Callback<OnDeviceConnected>
break;

case State::NeedsAddress:
err = LookupPeerAddress();
EnqueueConnectionCallbacks(onConnection, onFailure);
isConnected = CheckAndLoadExistingSession();
if (!isConnected)
{
err = LookupPeerAddress();
}

break;

case State::Initialized:
err = EstablishConnection();
if (err == CHIP_NO_ERROR)
isConnected = CheckAndLoadExistingSession();
if (!isConnected)
{
EnqueueConnectionCallbacks(onConnection, onFailure);
err = EstablishConnection();
}

break;

case State::Connecting:
EnqueueConnectionCallbacks(onConnection, onFailure);
break;

case State::SecureConnected:
if (onConnection != nullptr)
{
onConnection->mCall(onConnection->mContext, this);
}
isConnected = true;
break;

default:
err = CHIP_ERROR_INCORRECT_STATE;
}

if (err != CHIP_NO_ERROR && onFailure != nullptr)
if (isConnected)
{
MoveToState(State::SecureConnected);
}

//
// Dequeue all our callbacks on either encountering an error
// or if we successfully connected. Both should not be set
// simultaneously.
//
if (err != CHIP_NO_ERROR || isConnected)
{
onFailure->mCall(onFailure->mContext, mPeerId, err);
DequeueConnectionCallbacks(err);
}

return err;
Expand Down Expand Up @@ -133,7 +172,7 @@ CHIP_ERROR OperationalDeviceProxy::UpdateDeviceData(const Transport::PeerAddress
err = EstablishConnection();
if (err != CHIP_NO_ERROR)
{
OnSessionEstablishmentError(err);
DequeueConnectionCallbacks(err);
}
}
else
Expand Down Expand Up @@ -194,35 +233,37 @@ void OperationalDeviceProxy::EnqueueConnectionCallbacks(Callback::Callback<OnDev
}
}

void OperationalDeviceProxy::DequeueConnectionSuccessCallbacks(bool executeCallback)
void OperationalDeviceProxy::DequeueConnectionCallbacks(CHIP_ERROR error)
{
Cancelable ready;
mConnectionSuccess.DequeueAll(ready);
mConnectionFailure.DequeueAll(ready);
mrjerryjohns marked this conversation as resolved.
Show resolved Hide resolved

//
// If we encountered no error, go ahead and call all success callbacks. Otherwise,
// call the failure callbacks.
//
while (ready.mNext != &ready)
{
Callback::Callback<OnDeviceConnected> * cb = Callback::Callback<OnDeviceConnected>::FromCancelable(ready.mNext);
Callback::Callback<OnDeviceConnectionFailure> * cb =
Callback::Callback<OnDeviceConnectionFailure>::FromCancelable(ready.mNext);

cb->Cancel();
if (executeCallback)

if (error != CHIP_NO_ERROR)
{
cb->mCall(cb->mContext, this);
cb->mCall(cb->mContext, mPeerId, error);
}
}
}

void OperationalDeviceProxy::DequeueConnectionFailureCallbacks(CHIP_ERROR error, bool executeCallback)
{
Cancelable ready;
mConnectionFailure.DequeueAll(ready);
mConnectionSuccess.DequeueAll(ready);
while (ready.mNext != &ready)
{
Callback::Callback<OnDeviceConnectionFailure> * cb =
Callback::Callback<OnDeviceConnectionFailure>::FromCancelable(ready.mNext);
Callback::Callback<OnDeviceConnected> * cb = Callback::Callback<OnDeviceConnected>::FromCancelable(ready.mNext);

cb->Cancel();
if (executeCallback)
if (error == CHIP_NO_ERROR)
{
cb->mCall(cb->mContext, mPeerId, error);
cb->mCall(cb->mContext, this);
}
}
}
Expand All @@ -234,13 +275,20 @@ void OperationalDeviceProxy::HandleCASEConnectionFailure(void * context, CASECli
ChipLogError(Controller, "HandleCASEConnectionFailure was called while the device was not initialized"));
VerifyOrReturn(client == device->mCASEClient, ChipLogError(Controller, "HandleCASEConnectionFailure for unknown CASEClient"));

//
// We don't need to reset the state all the way back to NeedsAddress since all that transpired
// was just CASE connection failure. So let's re-use the cached address to re-do CASE again
// if need-be.
//
device->MoveToState(State::Initialized);

device->CloseCASESession();
device->DequeueConnectionSuccessCallbacks(/* executeCallback */ false);
device->DequeueConnectionFailureCallbacks(error, /* executeCallback */ true);
// Do not touch device anymore; it might have been destroyed by a failure
device->DequeueConnectionCallbacks(error);

//
// Do not touch this instance anymore; it might have been destroyed by a failure
mrjerryjohns marked this conversation as resolved.
Show resolved Hide resolved
// callback.
//
}

void OperationalDeviceProxy::HandleCASEConnected(void * context, CASEClient * client)
Expand All @@ -254,19 +302,18 @@ void OperationalDeviceProxy::HandleCASEConnected(void * context, CASEClient * cl
if (err != CHIP_NO_ERROR)
{
device->HandleCASEConnectionFailure(context, client, err);
// Do not touch device anymore; it might have been destroyed by a
// HandleCASEConnectionFailure.
}
else
{
device->MoveToState(State::SecureConnected);

device->CloseCASESession();
device->DequeueConnectionFailureCallbacks(CHIP_NO_ERROR, /* executeCallback */ false);
device->DequeueConnectionSuccessCallbacks(/* executeCallback */ true);
// Do not touch device anymore; it might have been destroyed by a
// success callback.
device->DequeueConnectionCallbacks(CHIP_NO_ERROR);
}

//
// Do not touch this instance anymore; it might have been destroyed by a failure
// callback.
mrjerryjohns marked this conversation as resolved.
Show resolved Hide resolved
//
}

CHIP_ERROR OperationalDeviceProxy::Disconnect()
Expand All @@ -285,12 +332,6 @@ CHIP_ERROR OperationalDeviceProxy::Disconnect()
return CHIP_NO_ERROR;
}

void OperationalDeviceProxy::SetConnectedSession(const SessionHandle & handle)
{
mSecureSession.Grab(handle);
MoveToState(State::SecureConnected);
}

void OperationalDeviceProxy::Clear()
{
if (mCASEClient)
Expand Down Expand Up @@ -367,8 +408,7 @@ void OperationalDeviceProxy::OnNodeAddressResolutionFailed(const PeerId & peerId
ChipLogError(Discovery, "Operational discovery failed for 0x" ChipLogFormatX64 ": %" CHIP_ERROR_FORMAT,
ChipLogValueX64(peerId.GetNodeId()), reason.Format());

DequeueConnectionSuccessCallbacks(/* executeCallback */ false);
DequeueConnectionFailureCallbacks(reason, /* executeCallback */ true);
DequeueConnectionCallbacks(reason);
}

} // namespace chip
33 changes: 22 additions & 11 deletions src/app/OperationalDeviceProxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy,
{
public:
~OperationalDeviceProxy() override;

//
// TODO: Should not be PeerId, but rather, ScopedNodeId
//
OperationalDeviceProxy(DeviceProxyInitParams & params, PeerId peerId) : mSecureSession(*this)
{
mInitParams = params;
Expand Down Expand Up @@ -159,15 +163,6 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy,
*/
CHIP_ERROR Disconnect() override;

/**
* Use SetConnectedSession if 'this' object is a newly allocated device proxy.
* It will take an existing session, such as the one established
* during commissioning, and use it for this device proxy.
*
* Note: Avoid using this function generally as it is Deprecated
*/
void SetConnectedSession(const SessionHandle & handle);

NodeId GetDeviceId() const override { return mPeerId.GetNodeId(); }

/**
Expand Down Expand Up @@ -268,6 +263,15 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy,

CHIP_ERROR EstablishConnection();

/*
* This checks to see if an existing CASE session exists to the peer within the SessionManager
mrjerryjohns marked this conversation as resolved.
Show resolved Hide resolved
* and if one exists, to load that into mSecureSession.
*
* Returns true if a valid session was found, false otherwise.
*
*/
bool CheckAndLoadExistingSession();

bool IsSecureConnected() const override { return mState == State::SecureConnected; }

static void HandleCASEConnected(void * context, CASEClient * client);
Expand All @@ -280,8 +284,15 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy,
void EnqueueConnectionCallbacks(Callback::Callback<OnDeviceConnected> * onConnection,
Callback::Callback<OnDeviceConnectionFailure> * onFailure);

void DequeueConnectionSuccessCallbacks(bool executeCallback);
void DequeueConnectionFailureCallbacks(CHIP_ERROR error, bool executeCallback);
/*
* This dequeues all failure and success callbacks and appropriately
* invokes either set depending on the value of error.
*
* If error == CHIP_NO_ERROR, only success callbacks are invoked.
* Otherwise, only failure callbacks are invoked.
*
*/
void DequeueConnectionCallbacks(CHIP_ERROR error);
};

} // namespace chip
16 changes: 11 additions & 5 deletions src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,20 +814,26 @@ void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param)
mgr->ScheduleExpiryTimer(); // re-schedule the oneshot timer
}

SessionHandle SessionManager::FindSecureSessionForNode(NodeId peerNodeId)
void SessionManager::FindSecureSessionForNode(SessionHolder & sessionHolder, ScopedNodeId peerNodeId,
Transport::SecureSession::Type type)
{
SecureSession * found = nullptr;
mSecureSessions.ForEachSession([&](auto session) {
if (session->GetPeerNodeId() == peerNodeId)
mSecureSessions.ForEachSession([&peerNodeId, type, &found](auto session) {
if (session->GetPeer() == peerNodeId &&
(type == SecureSession::Type::kUndefined || type == session->GetSecureSessionType()))
{
found = session;
return Loop::Break;
}
return Loop::Continue;
});

VerifyOrDie(found != nullptr);
return SessionHandle(*found);
sessionHolder.Release();

if (found)
{
sessionHolder.Grab(SessionHandle(*found));
}
}

/**
Expand Down
9 changes: 6 additions & 3 deletions src/transport/SessionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,12 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate
return mUnauthenticatedSessions.AllocInitiator(ephemeralInitiatorNodeID, peerAddress, config);
}

// TODO: this is a temporary solution for legacy tests which use nodeId to send packets
// and tv-casting-app that uses the TV's node ID to find the associated secure session
SessionHandle FindSecureSessionForNode(NodeId peerNodeId);
//
// Find an existing secure session given a peer's scoped NodeId and a type of session to match against.
// If matching against all types of sessions is desired, kUnDefined should be passed into type.
mrjerryjohns marked this conversation as resolved.
Show resolved Hide resolved
//
void FindSecureSessionForNode(SessionHolder & sessionHolder, ScopedNodeId peerNodeId,
mrjerryjohns marked this conversation as resolved.
Show resolved Hide resolved
Transport::SecureSession::Type type = Transport::SecureSession::Type::kUndefined);

using SessionHandleCallback = bool (*)(void * context, SessionHandle & sessionHandle);
CHIP_ERROR ForEachSessionHandle(void * context, SessionHandleCallback callback);
Expand Down