Skip to content

Commit

Permalink
Use exchange to send CASE session establishment messages (#6271)
Browse files Browse the repository at this point in the history
* Send CASE messages using exchange

* cleanup

* enable CASESession unit test

* fix test

* remove local node ID

* Address review comments

* Fix test failures

* Update src/protocols/secure_channel/CASESession.cpp

Co-authored-by: Boris Zbarsky <bzbarsky@apple.com>

Co-authored-by: Boris Zbarsky <bzbarsky@apple.com>
  • Loading branch information
2 people authored and pull[bot] committed Sep 16, 2021
1 parent 446598d commit 21e81b2
Show file tree
Hide file tree
Showing 10 changed files with 289 additions and 179 deletions.
11 changes: 6 additions & 5 deletions src/channel/ChannelContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,22 +261,23 @@ CHIP_ERROR ChannelContext::SendSessionEstablishmentMessage(const PacketHeader &
CHIP_ERROR ChannelContext::HandlePairingMessage(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress,
System::PacketBufferHandle && msg)
{
if (IsCasePairing())
return mStateVars.mPreparing.mCasePairingSession->HandlePeerMessage(packetHeader, peerAddress, std::move(msg));
return CHIP_ERROR_INCORRECT_STATE;
}

void ChannelContext::EnterCasePairingState()
{
mStateVars.mPreparing.mState = PrepareState::kCasePairing;
mStateVars.mPreparing.mCasePairingSession = Platform::New<CASESession>();

ExchangeContext * ctxt = mExchangeManager->NewContext(SecureSessionHandle(), mStateVars.mPreparing.mCasePairingSession);
VerifyOrReturn(ctxt != nullptr);

// TODO: currently only supports IP/UDP paring
Transport::PeerAddress addr;
addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(mStateVars.mPreparing.mAddress);
CHIP_ERROR err = mStateVars.mPreparing.mCasePairingSession->EstablishSession(
addr, &mStateVars.mPreparing.mBuilder.GetOperationalCredentialSet(),
Optional<NodeId>::Value(mExchangeManager->GetSessionMgr()->GetLocalNodeId()),
mStateVars.mPreparing.mBuilder.GetPeerNodeId(), mExchangeManager->GetNextKeyId(), this);
addr, &mStateVars.mPreparing.mBuilder.GetOperationalCredentialSet(), mStateVars.mPreparing.mBuilder.GetPeerNodeId(),
mExchangeManager->GetNextKeyId(), ctxt, this);
if (err != CHIP_NO_ERROR)
{
ExitCasePairingState();
Expand Down
2 changes: 1 addition & 1 deletion src/channel/ChannelContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <channel/Channel.h>
#include <lib/core/ReferenceCounted.h>
#include <lib/mdns/platform/Mdns.h>
#include <transport/CASESession.h>
#include <protocols/secure_channel/CASESession.h>
#include <transport/PeerConnectionState.h>
#include <transport/SecureSessionMgr.h>

Expand Down
2 changes: 2 additions & 0 deletions src/protocols/secure_channel/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ static_library("secure_channel") {
output_name = "libSecureChannel"

sources = [
"CASESession.cpp",
"CASESession.h",
"NetworkProvisioning.cpp",
"NetworkProvisioning.h",
"PASESession.cpp",
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@
#include <credentials/CHIPCert.h>
#include <credentials/CHIPOperationalCredentials.h>
#include <crypto/CHIPCryptoPAL.h>
#include <messaging/ExchangeContext.h>
#include <messaging/ExchangeDelegate.h>
#include <protocols/secure_channel/Constants.h>
#include <protocols/secure_channel/SessionEstablishmentExchangeDispatch.h>
#include <support/Base64.h>
#include <system/SystemPacketBuffer.h>
#include <transport/PairingSession.h>
#include <transport/PeerConnectionState.h>
#include <transport/SecureSession.h>
#include <transport/SessionEstablishmentDelegate.h>
Expand Down Expand Up @@ -59,13 +63,12 @@ struct CASESessionSerializable
uint16_t mMessageDigestLen;
uint8_t mMessageDigest[kSHA256_Hash_Length];
uint8_t mPairingComplete;
NodeId mLocalNodeId;
NodeId mPeerNodeId;
uint16_t mLocalKeyId;
uint16_t mPeerKeyId;
};

class DLL_EXPORT CASESession
class DLL_EXPORT CASESession : public Messaging::ExchangeDelegateBase, public PairingSession
{
public:
CASESession();
Expand All @@ -82,14 +85,13 @@ class DLL_EXPORT CASESession
*
* @param operationalCredentialSet CHIP Certificate Set used to store the chain root of trust an validate peer node
* certificates
* @param myNodeId Node id of local node
* @param myKeyId Key ID to be assigned to the secure session on the peer node
* @param delegate Callback object
*
* @return CHIP_ERROR The result of initialization
*/
CHIP_ERROR WaitForSessionEstablishment(OperationalCredentialSet * operationalCredentialSet, Optional<NodeId> myNodeId,
uint16_t myKeyId, SessionEstablishmentDelegate * delegate);
CHIP_ERROR WaitForSessionEstablishment(OperationalCredentialSet * operationalCredentialSet, uint16_t myKeyId,
SessionEstablishmentDelegate * delegate);

/**
* @brief
Expand All @@ -98,15 +100,15 @@ class DLL_EXPORT CASESession
* @param peerAddress Address of peer with which to establish a session.
* @param operationalCredentialSet CHIP Certificate Set used to store the chain root of trust an validate peer node
* certificates
* @param myNodeId Node id of local node
* @param peerNodeId Node id of the peer node
* @param myKeyId Key ID to be assigned to the secure session on the peer node
* @param exchangeCtxt The exchange context to send and receive messages with the peer
* @param delegate Callback object
*
* @return CHIP_ERROR The result of initialization
*/
CHIP_ERROR EstablishSession(const Transport::PeerAddress peerAddress, OperationalCredentialSet * operationalCredentialSet,
Optional<NodeId> myNodeId, NodeId peerNodeId, uint16_t myKeyId,
NodeId peerNodeId, uint16_t myKeyId, Messaging::ExchangeContext * exchangeCtxt,
SessionEstablishmentDelegate * delegate);

/**
Expand All @@ -120,19 +122,7 @@ class DLL_EXPORT CASESession
* initialized once session establishment is complete
* @return CHIP_ERROR The result of session derivation
*/
virtual CHIP_ERROR DeriveSecureSession(const uint8_t * info, size_t info_len, SecureSession & session);

/**
* @brief
* Handler for peer's messages, exchanged during pairing handshake.
*
* @param packetHeader Message header for the received message
* @param peerAddress Source of the message
* @param msg Message sent by the peer
* @return CHIP_ERROR The result of message processing
*/
virtual CHIP_ERROR HandlePeerMessage(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress,
System::PacketBufferHandle msg);
virtual CHIP_ERROR DeriveSecureSession(const uint8_t * info, size_t info_len, SecureSession & session) override;

/**
* @brief
Expand All @@ -148,15 +138,19 @@ class DLL_EXPORT CASESession
*
* @return uint16_t The associated peer key id
*/
uint16_t GetPeerKeyId() { return mConnectionState.GetPeerKeyID(); }
uint16_t GetPeerKeyId() override { return mConnectionState.GetPeerKeyID(); }

/**
* @brief
* Return the associated local key id
*
* @return uint16_t The assocated local key id
*/
uint16_t GetLocalKeyId() { return mConnectionState.GetLocalKeyID(); }
uint16_t GetLocalKeyId() override { return mConnectionState.GetLocalKeyID(); }

const char * GetI2RSessionInfo() const override { return "Sigma I2R Key"; }

const char * GetR2ISessionInfo() const override { return "Sigma R2I Key"; }

Transport::PeerConnectionState & PeerConnection() { return mConnectionState; }

Expand All @@ -180,6 +174,18 @@ class DLL_EXPORT CASESession
**/
CHIP_ERROR FromSerializable(const CASESessionSerializable & output);

SessionEstablishmentExchangeDispatch & MessageDispatch() { return mMessageDispatch; }

//// ExchangeDelegate Implementation ////
void OnMessageReceived(Messaging::ExchangeContext * ec, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader,
System::PacketBufferHandle payload) override;
void OnResponseTimeout(Messaging::ExchangeContext * ec) override;
Messaging::ExchangeMessageDispatch * GetMessageDispatch(Messaging::ReliableMessageMgr * rmMgr,
SecureSessionMgr * sessionMgr) override
{
return &mMessageDispatch;
}

private:
enum SigmaErrorType : uint8_t
{
Expand All @@ -190,17 +196,16 @@ class DLL_EXPORT CASESession
kUnexpected = 0xff,
};

CHIP_ERROR Init(OperationalCredentialSet * operationalCredentialSet, Optional<NodeId> myNodeId, uint16_t myKeyId,
SessionEstablishmentDelegate * delegate);
CHIP_ERROR Init(OperationalCredentialSet * operationalCredentialSet, uint16_t myKeyId, SessionEstablishmentDelegate * delegate);

CHIP_ERROR SendSigmaR1();
CHIP_ERROR HandleSigmaR1_and_SendSigmaR2(const PacketHeader & header, const System::PacketBufferHandle & msg);
CHIP_ERROR HandleSigmaR1(const PacketHeader & header, const System::PacketBufferHandle & msg);
CHIP_ERROR HandleSigmaR1_and_SendSigmaR2(const System::PacketBufferHandle & msg);
CHIP_ERROR HandleSigmaR1(const System::PacketBufferHandle & msg);
CHIP_ERROR SendSigmaR2();
CHIP_ERROR HandleSigmaR2_and_SendSigmaR3(const PacketHeader & header, const System::PacketBufferHandle & msg);
CHIP_ERROR HandleSigmaR2(const PacketHeader & header, const System::PacketBufferHandle & msg);
CHIP_ERROR HandleSigmaR2_and_SendSigmaR3(const System::PacketBufferHandle & msg);
CHIP_ERROR HandleSigmaR2(const System::PacketBufferHandle & msg);
CHIP_ERROR SendSigmaR3();
CHIP_ERROR HandleSigmaR3(const PacketHeader & header, const System::PacketBufferHandle & msg);
CHIP_ERROR HandleSigmaR3(const System::PacketBufferHandle & msg);

CHIP_ERROR SendSigmaR1Resume();
CHIP_ERROR HandleSigmaR1Resume_and_SendSigmaR2Resume(const PacketHeader & header, const System::PacketBufferHandle & msg);
Expand All @@ -217,7 +222,7 @@ class DLL_EXPORT CASESession
CHIP_ERROR ComputeIPK(const uint16_t sessionID, uint8_t * ipk, size_t ipkLen);

void SendErrorMsg(SigmaErrorType errorCode);
void HandleErrorMsg(const PacketHeader & header, const System::PacketBufferHandle & msg);
void HandleErrorMsg(const System::PacketBufferHandle & msg);

// TODO: Remove this and replace with system method to retrieve current time
CHIP_ERROR SetEffectiveTime(void);
Expand All @@ -226,6 +231,9 @@ class DLL_EXPORT CASESession

void Clear();

CHIP_ERROR ValidateReceivedMessage(Messaging::ExchangeContext * ec, const PacketHeader & packetHeader,
const PayloadHeader & payloadHeader, System::PacketBufferHandle & msg);

SessionEstablishmentDelegate * mDelegate = nullptr;

Protocols::SecureChannel::MsgType mNextExpectedMsg = Protocols::SecureChannel::MsgType::CASE_SigmaErr;
Expand All @@ -242,14 +250,15 @@ class DLL_EXPORT CASESession
uint8_t mIPK[kIPKSize];
uint8_t mRemoteIPK[kIPKSize];

Messaging::ExchangeContext * mExchangeCtxt = nullptr;
SessionEstablishmentExchangeDispatch mMessageDispatch;

struct SigmaErrorMsg
{
SigmaErrorType error;
};

protected:
NodeId mLocalNodeId = kUndefinedNodeId;

bool mPairingComplete = false;

Transport::PeerConnectionState mConnectionState;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ CHIP_ERROR SessionEstablishmentExchangeDispatch::SendMessageImpl(SecureSessionHa
System::PacketBufferHandle && message,
EncryptedPacketBufferHandle * retainedMessage)
{
ChipLogProgress(ExchangeManager, "SessionEstablishmentExchangeDispatch::SendMessageImpl mTransportMgr %p", mTransportMgr);
ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message));
if (mTransportMgr != nullptr)
{
Expand Down
2 changes: 2 additions & 0 deletions src/protocols/secure_channel/tests/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ chip_test_suite("tests") {
output_name = "libSecureChannelTests"

test_sources = [
"TestCASESession.cpp",
"TestPASESession.cpp",
"TestStatusReport.cpp",
]

public_deps = [
"${chip_root}/src/credentials/tests:cert_test_vectors",
"${chip_root}/src/lib/core",
"${chip_root}/src/lib/support",
"${chip_root}/src/messaging/tests:helpers",
Expand Down
Loading

0 comments on commit 21e81b2

Please sign in to comment.