Skip to content

Commit

Permalink
Add support for CASE session caching for session resume use cases (#1…
Browse files Browse the repository at this point in the history
…1937)

- Add tests for the CASE session cache

- Update the CASESessionCachable struct to have only necessary members
  and rename the struct and APIs appropriately

- Remove the tests that serilaize and deserilaize the CASE Session as its outdated
  • Loading branch information
nivi-apple authored and pull[bot] committed Feb 20, 2024
1 parent dccd066 commit 1070088
Show file tree
Hide file tree
Showing 9 changed files with 450 additions and 163 deletions.
10 changes: 10 additions & 0 deletions src/lib/core/CHIPConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -2722,6 +2722,16 @@ extern const char CHIP_NON_PRODUCTION_MARKER[];
#define CHIP_CONFIG_MAX_SESSION_RECOVERY_DELEGATES 3
#endif

/**
* @def CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE
*
* @brief
* Maximum number of CASE sessions that a device caches, that can be resumed
*/
#ifndef CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE
#define CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE 4
#endif

/**
* @}
*/
2 changes: 2 additions & 0 deletions src/protocols/secure_channel/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ static_library("secure_channel") {
"CASEServer.h",
"CASESession.cpp",
"CASESession.h",
"CASESessionCache.cpp",
"CASESessionCache.h",
"PASESession.cpp",
"PASESession.h",
"RendezvousParameters.h",
Expand Down
92 changes: 21 additions & 71 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ constexpr uint8_t kTBEData3_Nonce[] =
constexpr size_t kTBEDataNonceLength = sizeof(kTBEData2_Nonce);
static_assert(sizeof(kTBEData2_Nonce) == sizeof(kTBEData3_Nonce), "TBEData2_Nonce and TBEData3_Nonce must be same size");

constexpr uint8_t kCASESessionVersion = 1;

enum
{
kTag_TBEData_SenderNOC = 1,
Expand Down Expand Up @@ -124,96 +122,48 @@ void CASESession::CloseExchange()
}
}

CHIP_ERROR CASESession::Serialize(CASESessionSerialized & output)
{
uint16_t serializedLen = 0;
CASESessionSerializable serializable;

VerifyOrReturnError(BASE64_ENCODED_LEN(sizeof(serializable)) <= sizeof(output.inner), CHIP_ERROR_INVALID_ARGUMENT);

ReturnErrorOnFailure(ToSerializable(serializable));

serializedLen = chip::Base64Encode(Uint8::to_const_uchar(reinterpret_cast<uint8_t *>(&serializable)),
static_cast<uint16_t>(sizeof(serializable)), Uint8::to_char(output.inner));
VerifyOrReturnError(serializedLen > 0, CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(serializedLen < sizeof(output.inner), CHIP_ERROR_INVALID_ARGUMENT);
output.inner[serializedLen] = '\0';

return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::Deserialize(CASESessionSerialized & input)
{
CASESessionSerializable serializable;
size_t maxlen = BASE64_ENCODED_LEN(sizeof(serializable));
size_t len = strnlen(Uint8::to_char(input.inner), maxlen);
uint16_t deserializedLen = 0;

VerifyOrReturnError(len < sizeof(CASESessionSerialized), CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(CanCastTo<uint16_t>(len), CHIP_ERROR_INVALID_ARGUMENT);

memset(&serializable, 0, sizeof(serializable));
deserializedLen =
Base64Decode(Uint8::to_const_char(input.inner), static_cast<uint16_t>(len), Uint8::to_uchar((uint8_t *) &serializable));

VerifyOrReturnError(deserializedLen > 0, CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(deserializedLen <= sizeof(serializable), CHIP_ERROR_INVALID_ARGUMENT);

ReturnErrorOnFailure(FromSerializable(serializable));

return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::ToSerializable(CASESessionSerializable & serializable)
CHIP_ERROR CASESession::ToCachable(CASESessionCachable & cachableSession)
{
const NodeId peerNodeId = GetPeerNodeId();
VerifyOrReturnError(CanCastTo<uint16_t>(mSharedSecret.Length()), CHIP_ERROR_INTERNAL);
VerifyOrReturnError(CanCastTo<uint16_t>(sizeof(mMessageDigest)), CHIP_ERROR_INTERNAL);
VerifyOrReturnError(CanCastTo<uint64_t>(peerNodeId), CHIP_ERROR_INTERNAL);

memset(&serializable, 0, sizeof(serializable));
serializable.mSharedSecretLen = LittleEndian::HostSwap16(static_cast<uint16_t>(mSharedSecret.Length()));
serializable.mMessageDigestLen = LittleEndian::HostSwap16(static_cast<uint16_t>(sizeof(mMessageDigest)));
serializable.mVersion = kCASESessionVersion;
serializable.mPeerNodeId = LittleEndian::HostSwap64(peerNodeId);
for (size_t i = 0; i < serializable.mPeerCATs.size(); i++)
memset(&cachableSession, 0, sizeof(cachableSession));
cachableSession.mSharedSecretLen = LittleEndian::HostSwap16(static_cast<uint16_t>(mSharedSecret.Length()));
cachableSession.mPeerNodeId = LittleEndian::HostSwap64(peerNodeId);
for (size_t i = 0; i < cachableSession.mPeerCATs.size(); i++)
{
serializable.mPeerCATs.val[i] = LittleEndian::HostSwap32(GetPeerCATs().val[i]);
cachableSession.mPeerCATs.val[i] = LittleEndian::HostSwap32(GetPeerCATs().val[i]);
}
serializable.mLocalSessionId = LittleEndian::HostSwap16(GetLocalSessionId());
serializable.mPeerSessionId = LittleEndian::HostSwap16(GetPeerSessionId());
// TODO: Get the fabric index
cachableSession.mLocalFabricIndex = 0;
cachableSession.mSessionSetupTimeStamp = LittleEndian::HostSwap64(mSessionSetupTimeStamp);

memcpy(serializable.mResumptionId, mResumptionId, sizeof(mResumptionId));
memcpy(serializable.mSharedSecret, mSharedSecret, mSharedSecret.Length());
memcpy(serializable.mMessageDigest, mMessageDigest, sizeof(mMessageDigest));
memcpy(cachableSession.mResumptionId, mResumptionId, sizeof(mResumptionId));
memcpy(cachableSession.mSharedSecret, mSharedSecret, mSharedSecret.Length());

return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::FromSerializable(const CASESessionSerializable & serializable)
CHIP_ERROR CASESession::FromCachable(const CASESessionCachable & cachableSession)
{
VerifyOrReturnError(serializable.mVersion == kCASESessionVersion, CHIP_ERROR_VERSION_MISMATCH);

uint16_t length = LittleEndian::HostSwap16(serializable.mSharedSecretLen);
uint16_t length = LittleEndian::HostSwap16(cachableSession.mSharedSecretLen);
ReturnErrorOnFailure(mSharedSecret.SetLength(static_cast<size_t>(length)));
memset(mSharedSecret, 0, sizeof(mSharedSecret.Capacity()));
memcpy(mSharedSecret, serializable.mSharedSecret, length);

length = LittleEndian::HostSwap16(serializable.mMessageDigestLen);
VerifyOrReturnError(length <= sizeof(mMessageDigest), CHIP_ERROR_INVALID_ARGUMENT);
memcpy(mMessageDigest, serializable.mMessageDigest, length);
memcpy(mSharedSecret, cachableSession.mSharedSecret, length);

SetPeerNodeId(LittleEndian::HostSwap64(serializable.mPeerNodeId));
SetPeerNodeId(LittleEndian::HostSwap64(cachableSession.mPeerNodeId));
Credentials::CATValues peerCATs;
for (size_t i = 0; i < serializable.mPeerCATs.size(); i++)
for (size_t i = 0; i < cachableSession.mPeerCATs.size(); i++)
{
peerCATs.val[i] = LittleEndian::HostSwap32(serializable.mPeerCATs.val[i]);
peerCATs.val[i] = LittleEndian::HostSwap32(cachableSession.mPeerCATs.val[i]);
}
SetPeerCATs(peerCATs);
SetLocalSessionId(LittleEndian::HostSwap16(serializable.mLocalSessionId));
SetPeerSessionId(LittleEndian::HostSwap16(serializable.mPeerSessionId));
SetSessionTimeStamp(LittleEndian::HostSwap64(cachableSession.mSessionSetupTimeStamp));
// TODO: Set the fabric index correctly
mLocalFabricIndex = cachableSession.mLocalFabricIndex;

memcpy(mResumptionId, serializable.mResumptionId, sizeof(mResumptionId));
memcpy(mResumptionId, cachableSession.mResumptionId, sizeof(mResumptionId));

const ByteSpan * ipkListSpan = GetIPKList();
VerifyOrReturnError(ipkListSpan->size() == sizeof(mIPK), CHIP_ERROR_INVALID_ARGUMENT);
Expand Down
40 changes: 12 additions & 28 deletions src/protocols/secure_channel/CASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,15 @@ constexpr size_t kCASEResumptionIDSize = 16;
#define CASE_EPHEMERAL_KEY 0xCA5EECD0
#endif

struct CASESessionSerialized;

struct CASESessionSerializable
struct CASESessionCachable
{
uint8_t mVersion;
uint16_t mSharedSecretLen;
uint8_t mSharedSecret[Crypto::kMax_ECDH_Secret_Length];
uint16_t mMessageDigestLen;
uint8_t mMessageDigest[Crypto::kSHA256_Hash_Length];
FabricIndex mLocalFabricIndex;
NodeId mPeerNodeId;
Credentials::CATValues mPeerCATs;
uint16_t mLocalSessionId;
uint16_t mPeerSessionId;
uint8_t mResumptionId[kCASEResumptionIDSize];
uint64_t mSessionSetupTimeStamp;
};

class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public PairingSession
Expand Down Expand Up @@ -154,24 +149,14 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin
const char * GetR2ISessionInfo() const override { return "Sigma R2I Key"; }

/**
* @brief Serialize the Pairing Session to a string.
**/
CHIP_ERROR Serialize(CASESessionSerialized & output);

/**
* @brief Deserialize the Pairing Session from the string.
**/
CHIP_ERROR Deserialize(CASESessionSerialized & input);

/**
* @brief Serialize the CASESession to the given serializable data structure for secure pairing
* @brief Serialize the CASESession to the given cachableSession data structure for secure pairing
**/
CHIP_ERROR ToSerializable(CASESessionSerializable & output);
CHIP_ERROR ToCachable(CASESessionCachable & output);

/**
* @brief Reconstruct secure pairing class from the serializable data structure.
* @brief Reconstruct secure pairing class from the cachableSession data structure.
**/
CHIP_ERROR FromSerializable(const CASESessionSerializable & output);
CHIP_ERROR FromCachable(const CASESessionCachable & output);

SessionEstablishmentExchangeDispatch & MessageDispatch() { return mMessageDispatch; }

Expand Down Expand Up @@ -277,6 +262,9 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin

State mState;

uint8_t mLocalFabricIndex = 0;
uint64_t mSessionSetupTimeStamp = 0;

protected:
bool mCASESessionEstablished = false;

Expand All @@ -290,12 +278,8 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin
return ipkListSpan;
}
virtual size_t GetIPKListEntries() const { return 1; }
};

typedef struct CASESessionSerialized
{
// Extra uint64_t to account for padding bytes (NULL termination, and some decoding overheads)
uint8_t inner[BASE64_ENCODED_LEN(sizeof(CASESessionSerializable) + sizeof(uint64_t))];
} CASESessionSerialized;
void SetSessionTimeStamp(uint64_t timestamp) { mSessionSetupTimeStamp = timestamp; }
};

} // namespace chip
105 changes: 105 additions & 0 deletions src/protocols/secure_channel/CASESessionCache.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
*
* Copyright (c) 2021 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <protocols/secure_channel/CASESessionCache.h>

namespace chip {

CASESessionCache::CASESessionCache() {}

CASESessionCache::~CASESessionCache()
{
mCachePool.ForEachActiveObject([&](auto * ec) {
mCachePool.ReleaseObject(ec);
return true;
});
}

CASESessionCachable * CASESessionCache::GetLRUSession()
{
uint64_t minTimeStamp = UINT64_MAX;
CASESessionCachable * lruSession = nullptr;
mCachePool.ForEachActiveObject([&](auto * ec) {
if (minTimeStamp > ec->mSessionSetupTimeStamp)
{
minTimeStamp = ec->mSessionSetupTimeStamp;
lruSession = ec;
}
return true;
});
return lruSession;
}

CHIP_ERROR CASESessionCache::Add(CASESessionCachable & cachableSession)
{
// It's not an error if a device doesn't have cache for storing the sessions.
VerifyOrReturnError(mCachePool.Capacity() > 0, CHIP_NO_ERROR);

// If the cache is full, get the least recently used session index and release that.
if (mCachePool.Exhausted())
{
mCachePool.ReleaseObject(GetLRUSession());
}

mCachePool.CreateObject(cachableSession);
return CHIP_NO_ERROR;
}

CHIP_ERROR CASESessionCache::Remove(ResumptionID resumptionID)
{
CHIP_ERROR err = CHIP_NO_ERROR;
CASESession session;
mCachePool.ForEachActiveObject([&](auto * ec) {
if (resumptionID.data_equal(ResumptionID(ec->mResumptionId)))
{
mCachePool.ReleaseObject(ec);
}
return true;
});

return err;
}

CHIP_ERROR CASESessionCache::Get(ResumptionID resumptionID, CASESessionCachable & outSessionCachable)
{
CHIP_ERROR err = CHIP_NO_ERROR;
bool found = false;
mCachePool.ForEachActiveObject([&](auto * ec) {
if (resumptionID.data_equal(ResumptionID(ec->mResumptionId)))
{
found = true;
outSessionCachable = *ec;
return false;
}
return true;
});

if (!found)
{
err = CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND;
}

return err;
}

CHIP_ERROR CASESessionCache::Get(const PeerId & peer, CASESessionCachable & outSessionCachable)
{
// TODO: Implement this based on peer id
return CHIP_NO_ERROR;
}

} // namespace chip
44 changes: 44 additions & 0 deletions src/protocols/secure_channel/CASESessionCache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
*
* Copyright (c) 2021 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <lib/core/CHIPError.h>
#include <lib/core/PeerId.h>
#include <protocols/secure_channel/CASESession.h>

namespace chip {

using ResumptionID = FixedByteSpan<kCASEResumptionIDSize>;

class CASESessionCache
{
public:
CASESessionCache();
virtual ~CASESessionCache();

CHIP_ERROR Add(CASESessionCachable & cachableSession);
CHIP_ERROR Remove(ResumptionID resumptionID);
CHIP_ERROR Get(ResumptionID resumptionID, CASESessionCachable & outCachableSession);
CHIP_ERROR Get(const PeerId & peer, CASESessionCachable & outCachableSession);

private:
BitMapObjectPool<CASESessionCachable, CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE> mCachePool;
CASESessionCachable * GetLRUSession();
};

} // namespace chip
1 change: 1 addition & 0 deletions src/protocols/secure_channel/tests/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ chip_test_suite("tests") {

test_sources = [
"TestCASESession.cpp",
"TestCASESessionCache.cpp",

# TODO - Fix Message Counter Sync to use group key
# "TestMessageCounterManager.cpp",
Expand Down
Loading

0 comments on commit 1070088

Please sign in to comment.