Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the missing part of Exchange Header in Transport layer #4017

Merged
merged 2 commits into from
Dec 1, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/messaging/tests/TestExchangeMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ class LoopbackTransport : public Transport::Base
/// Transports are required to have a constructor that takes exactly one argument
CHIP_ERROR Init(const char * unused) { return CHIP_NO_ERROR; }

CHIP_ERROR SendMessage(const PacketHeader & header, Header::Flags payloadFlags, const PeerAddress & address,
System::PacketBuffer * msgBuf) override
CHIP_ERROR SendMessage(const PacketHeader & header, const PeerAddress & address, System::PacketBuffer * msgBuf) override
{
System::PacketBufferHandle msg_ForNow;
msg_ForNow.Adopt(msgBuf);
Expand Down
5 changes: 2 additions & 3 deletions src/transport/BLE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ CHIP_ERROR BLE::DelegateConnection(const uint16_t connDiscriminator)
return err;
}

CHIP_ERROR BLE::SendMessage(const PacketHeader & header, Header::Flags payloadFlags, const Transport::PeerAddress & address,
System::PacketBuffer * msgIn)
CHIP_ERROR BLE::SendMessage(const PacketHeader & header, const Transport::PeerAddress & address, System::PacketBuffer * msgIn)
{
CHIP_ERROR err = CHIP_NO_ERROR;
const uint16_t headerSize = header.EncodeSizeBytes();
Expand All @@ -150,7 +149,7 @@ CHIP_ERROR BLE::SendMessage(const PacketHeader & header, Header::Flags payloadFl

msgBuf->SetStart(msgBuf->Start() - headerSize);

err = header.Encode(msgBuf->Start(), msgBuf->DataLength(), &actualEncodedHeaderSize, payloadFlags);
err = header.Encode(msgBuf->Start(), msgBuf->DataLength(), &actualEncodedHeaderSize);
SuccessOrExit(err);

VerifyOrExit(headerSize == actualEncodedHeaderSize, err = CHIP_ERROR_INTERNAL);
Expand Down
2 changes: 1 addition & 1 deletion src/transport/BLE.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class DLL_EXPORT BLE : public Base
*/
CHIP_ERROR Init(RendezvousSessionDelegate * delegate, const RendezvousParameters & params);

CHIP_ERROR SendMessage(const PacketHeader & header, Header::Flags payloadFlags, const Transport::PeerAddress & address,
CHIP_ERROR SendMessage(const PacketHeader & header, const Transport::PeerAddress & address,
System::PacketBuffer * msgBuf) override;

bool CanSendToPeer(const Transport::PeerAddress & address) override
Expand Down
17 changes: 8 additions & 9 deletions src/transport/RendezvousSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ RendezvousSession::~RendezvousSession()
mDelegate = nullptr;
}

CHIP_ERROR RendezvousSession::SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags,
const Transport::PeerAddress & peerAddress, System::PacketBuffer * msgIn)
CHIP_ERROR RendezvousSession::SendPairingMessage(const PacketHeader & header, const Transport::PeerAddress & peerAddress,
System::PacketBuffer * msgIn)
{
if (mCurrentState != State::kSecurePairing)
{
Expand All @@ -102,11 +102,11 @@ CHIP_ERROR RendezvousSession::SendPairingMessage(const PacketHeader & header, He

if (peerAddress.GetTransportType() == Transport::Type::kBle)
{
return mTransport->SendMessage(header, payloadFlags, peerAddress, msgIn);
return mTransport->SendMessage(header, peerAddress, msgIn);
}
else if (mTransportMgr != nullptr)
{
return mTransportMgr->SendMessage(header, payloadFlags, peerAddress, msgIn);
return mTransportMgr->SendMessage(header, peerAddress, msgIn);
}
else
{
Expand Down Expand Up @@ -150,7 +150,7 @@ CHIP_ERROR RendezvousSession::SendSecureMessage(Protocols::CHIPProtocolId protoc
uint16_t totalLen = msgBuf->TotalLength();

ReturnErrorOnFailure(payloadHeader.Encode(data, totalLen, &actualEncodedHeaderSize));
ReturnErrorOnFailure(mSecureSession.Encrypt(data, totalLen, data, packetHeader, payloadHeader.GetEncodePacketFlags(), mac));
ReturnErrorOnFailure(mSecureSession.Encrypt(data, totalLen, data, packetHeader, mac));

uint16_t taglen = 0;
ReturnErrorOnFailure(mac.Encode(packetHeader, &data[totalLen], kMaxTagLen, &taglen));
Expand All @@ -159,8 +159,7 @@ CHIP_ERROR RendezvousSession::SendSecureMessage(Protocols::CHIPProtocolId protoc

msgBuf->SetDataLength(static_cast<uint16_t>(totalLen + taglen));

ReturnErrorOnFailure(mTransport->SendMessage(packetHeader, payloadHeader.GetEncodePacketFlags(), Transport::PeerAddress::BLE(),
msgBuf.Release_ForNow()));
ReturnErrorOnFailure(mTransport->SendMessage(packetHeader, Transport::PeerAddress::BLE(), msgBuf.Release_ForNow()));

mSecureMessageIndex++;

Expand Down Expand Up @@ -378,7 +377,7 @@ CHIP_ERROR RendezvousSession::HandleSecureMessage(const PacketHeader & packetHea
len = static_cast<uint16_t>(len - taglen);
msgBuf->SetDataLength(len);

ReturnErrorOnFailure(mSecureSession.Decrypt(data, len, plainText, packetHeader, payloadHeader.GetEncodePacketFlags(), mac));
ReturnErrorOnFailure(mSecureSession.Decrypt(data, len, plainText, packetHeader, mac));

// Use the node IDs from the packet header only after it's successfully decrypted
if (packetHeader.GetDestinationNodeId().HasValue() && !mParams.HasLocalNodeId())
Expand All @@ -394,7 +393,7 @@ CHIP_ERROR RendezvousSession::HandleSecureMessage(const PacketHeader & packetHea
}

uint16_t decodedSize = 0;
ReturnErrorOnFailure(payloadHeader.Decode(packetHeader.GetFlags(), plainText, len, &decodedSize));
ReturnErrorOnFailure(payloadHeader.Decode(plainText, len, &decodedSize));

ReturnErrorCodeIf(headerSize != decodedSize, CHIP_ERROR_INCORRECT_STATE);

Expand Down
4 changes: 2 additions & 2 deletions src/transport/RendezvousSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ class RendezvousSession : public SecurePairingSessionDelegate,
Optional<NodeId> GetRemoteNodeId() const { return mParams.GetRemoteNodeId(); }

//////////// SecurePairingSessionDelegate Implementation ///////////////
CHIP_ERROR SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags,
const Transport::PeerAddress & peerAddress, System::PacketBuffer * msgBuf) override;
CHIP_ERROR SendPairingMessage(const PacketHeader & header, const Transport::PeerAddress & peerAddress,
System::PacketBuffer * msgBuf) override;
void OnPairingError(CHIP_ERROR err) override;
void OnPairingComplete() override;

Expand Down
6 changes: 3 additions & 3 deletions src/transport/SecurePairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ CHIP_ERROR SecurePairingSession::AttachHeaderAndSend(uint8_t msgType, System::Pa
SuccessOrExit(err);
VerifyOrExit(headerSize == actualEncodedHeaderSize, err = CHIP_ERROR_INTERNAL);

err = mDelegate->SendPairingMessage(PacketHeader().SetSourceNodeId(mLocalNodeId).SetEncryptionKeyID(mLocalKeyId),
payloadHeader.GetEncodePacketFlags(), mPeerAddress, msgBuf.Release_ForNow());
err = mDelegate->SendPairingMessage(PacketHeader().SetSourceNodeId(mLocalNodeId).SetEncryptionKeyID(mLocalKeyId), mPeerAddress,
msgBuf.Release_ForNow());
SuccessOrExit(err);

exit:
Expand Down Expand Up @@ -437,7 +437,7 @@ CHIP_ERROR SecurePairingSession::HandlePeerMessage(const PacketHeader & packetHe

VerifyOrExit(!msg.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT);

err = payloadHeader.Decode(packetHeader.GetFlags(), msg->Start(), msg->DataLength(), &headerSize);
err = payloadHeader.Decode(msg->Start(), msg->DataLength(), &headerSize);
SuccessOrExit(err);

msg->ConsumeHead(headerSize);
Expand Down
5 changes: 2 additions & 3 deletions src/transport/SecurePairingSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@ class DLL_EXPORT SecurePairingSessionDelegate
* Called when pairing session generates a new message that should be sent to peer.
*
* @param header the message header for the sent message
* @param payloadFlags payload encoding flags
* @param peerAddress the destination of the message
* @param msgBuf the raw data for the message being sent
* @return CHIP_ERROR Error thrown when sending the message
*/
virtual CHIP_ERROR SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags,
const Transport::PeerAddress & peerAddress, System::PacketBuffer * msgBuf)
virtual CHIP_ERROR SendPairingMessage(const PacketHeader & header, const Transport::PeerAddress & peerAddress,
System::PacketBuffer * msgBuf)
{
return CHIP_ERROR_NOT_IMPLEMENTED;
}
Expand Down
13 changes: 6 additions & 7 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,15 @@ CHIP_ERROR SecureSession::GetIV(const PacketHeader & header, uint8_t * iv, size_
return bbuf.Fit() ? CHIP_NO_ERROR : CHIP_ERROR_NO_MEMORY;
}

CHIP_ERROR SecureSession::GetAdditionalAuthData(const PacketHeader & header, const Header::Flags payloadEncodeFlags, uint8_t * aad,
uint16_t & len)
CHIP_ERROR SecureSession::GetAdditionalAuthData(const PacketHeader & header, uint8_t * aad, uint16_t & len)
{
VerifyOrReturnError(len >= header.EncodeSizeBytes(), CHIP_ERROR_INVALID_ARGUMENT);

// Use unencrypted part of header as AAD. This will help
// integrity protect the whole message
uint16_t actualEncodedHeaderSize;

ReturnErrorOnFailure(header.Encode(aad, len, &actualEncodedHeaderSize, payloadEncodeFlags));
ReturnErrorOnFailure(header.Encode(aad, len, &actualEncodedHeaderSize));
VerifyOrReturnError(len >= actualEncodedHeaderSize, CHIP_ERROR_INVALID_ARGUMENT);

len = actualEncodedHeaderSize;
Expand All @@ -114,7 +113,7 @@ CHIP_ERROR SecureSession::GetAdditionalAuthData(const PacketHeader & header, con
}

CHIP_ERROR SecureSession::Encrypt(const uint8_t * input, size_t input_length, uint8_t * output, PacketHeader & header,
Header::Flags payloadFlags, MessageAuthenticationCode & mac)
MessageAuthenticationCode & mac)
{

constexpr Header::EncryptionType encType = Header::EncryptionType::kAESCCMTagLen16;
Expand All @@ -133,7 +132,7 @@ CHIP_ERROR SecureSession::Encrypt(const uint8_t * input, size_t input_length, ui
uint8_t tag[kMaxTagLen];

ReturnErrorOnFailure(GetIV(header, IV, sizeof(IV)));
ReturnErrorOnFailure(GetAdditionalAuthData(header, payloadFlags, AAD, aadLen));
ReturnErrorOnFailure(GetAdditionalAuthData(header, AAD, aadLen));
ReturnErrorOnFailure(
AES_CCM_encrypt(input, input_length, AAD, aadLen, mKey, sizeof(mKey), IV, sizeof(IV), output, tag, taglen));

Expand All @@ -143,7 +142,7 @@ CHIP_ERROR SecureSession::Encrypt(const uint8_t * input, size_t input_length, ui
}

CHIP_ERROR SecureSession::Decrypt(const uint8_t * input, size_t input_length, uint8_t * output, const PacketHeader & header,
Header::Flags payloadFlags, const MessageAuthenticationCode & mac)
const MessageAuthenticationCode & mac)
{
const size_t taglen = MessageAuthenticationCode::TagLenForEncryptionType(header.GetEncryptionType());
const uint8_t * tag = mac.GetTag();
Expand All @@ -157,7 +156,7 @@ CHIP_ERROR SecureSession::Decrypt(const uint8_t * input, size_t input_length, ui
VerifyOrReturnError(output != nullptr, CHIP_ERROR_INVALID_ARGUMENT);

ReturnErrorOnFailure(GetIV(header, IV, sizeof(IV)));
ReturnErrorOnFailure(GetAdditionalAuthData(header, payloadFlags, AAD, aadLen));
ReturnErrorOnFailure(GetAdditionalAuthData(header, AAD, aadLen));

return AES_CCM_decrypt(input, input_length, AAD, aadLen, tag, taglen, mKey, sizeof(mKey), IV, sizeof(IV), output);
}
Expand Down
9 changes: 3 additions & 6 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,12 @@ class DLL_EXPORT SecureSession
* @param input_length Length of the input data
* @param output Output buffer for encrypted data
* @param header message header structure. Encryption type will be set on the header.
* @param payloadFlags extra flags for packet header encryption
* @param mac - output the resulting mac
*
* @return CHIP_ERROR The result of encryption
*/
CHIP_ERROR Encrypt(const uint8_t * input, size_t input_length, uint8_t * output, PacketHeader & header,
Header::Flags payloadFlags, MessageAuthenticationCode & mac);
MessageAuthenticationCode & mac);

/**
* @brief
Expand All @@ -96,12 +95,11 @@ class DLL_EXPORT SecureSession
* @param input_length Length of the input data
* @param output Output buffer for decrypted data
* @param header message header structure
* @param payloadFlags extra flags for packet header decryption
* @return CHIP_ERROR The result of decryption
* @param mac Input mac
*/
CHIP_ERROR Decrypt(const uint8_t * input, size_t input_length, uint8_t * output, const PacketHeader & header,
Header::Flags payloadFlags, const MessageAuthenticationCode & mac);
const MessageAuthenticationCode & mac);

/**
* @brief
Expand All @@ -128,8 +126,7 @@ class DLL_EXPORT SecureSession
// Use unencrypted header as additional authenticated data (AAD) during encryption and decryption.
// The encryption operations includes AAD when message authentication tag is generated. This tag
// is used at the time of decryption to integrity check the received data.
static CHIP_ERROR GetAdditionalAuthData(const PacketHeader & header, Header::Flags payloadEncodeFlags, uint8_t * aad,
uint16_t & len);
static CHIP_ERROR GetAdditionalAuthData(const PacketHeader & header, uint8_t * aad, uint16_t & len);
};

} // namespace chip
24 changes: 12 additions & 12 deletions src/transport/SecureSessionMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId p
err = payloadHeader.Encode(data, totalLen, &actualEncodedHeaderSize);
SuccessOrExit(err);

err = state->GetSecureSession().Encrypt(data, totalLen, data, packetHeader, payloadHeader.GetEncodePacketFlags(), mac);
err = state->GetSecureSession().Encrypt(data, totalLen, data, packetHeader, mac);
SuccessOrExit(err);

err = mac.Encode(packetHeader, &data[totalLen], kMaxTagLen, &taglen);
Expand All @@ -157,8 +157,7 @@ CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId p

ChipLogDetail(Inet, "Secure transport transmitting msg %u after encryption", state->GetSendMessageIndex());

err = mTransportMgr->SendMessage(packetHeader, payloadHeader.GetEncodePacketFlags(), state->GetPeerAddress(),
msgBuf.Release_ForNow());
err = mTransportMgr->SendMessage(packetHeader, state->GetPeerAddress(), msgBuf.Release_ForNow());
}
SuccessOrExit(err);
state->IncrementSendMessageIndex();
Expand Down Expand Up @@ -306,13 +305,13 @@ void SecureSessionMgr::OnMessageReceived(const PacketHeader & packetHeader, cons
PayloadHeader payloadHeader;
MessageAuthenticationCode mac;

uint8_t * data = msg->Start();
uint8_t * plainText = nullptr;
uint16_t len = msg->TotalLength();
const uint16_t headerSize = payloadHeader.EncodeSizeBytes();
uint16_t decodedSize = 0;
uint16_t taglen = 0;
uint16_t payloadlen = 0;
uint8_t * data = msg->Start();
uint8_t * plainText = nullptr;
uint16_t len = msg->TotalLength();
uint16_t headerSize = 0;
uint16_t decodedSize = 0;
uint16_t taglen = 0;
uint16_t payloadlen = 0;

#if CHIP_SYSTEM_CONFIG_USE_LWIP
/* This is a workaround for the case where PacketBuffer payload is not
Expand All @@ -333,10 +332,11 @@ void SecureSessionMgr::OnMessageReceived(const PacketHeader & packetHeader, cons
len = static_cast<uint16_t>(len - taglen);
msg->SetDataLength(len, nullptr);

err = state->GetSecureSession().Decrypt(data, len, plainText, packetHeader, payloadHeader.GetEncodePacketFlags(), mac);
err = state->GetSecureSession().Decrypt(data, len, plainText, packetHeader, mac);
VerifyOrExit(err == CHIP_NO_ERROR, ChipLogError(Inet, "Secure transport failed to decrypt msg: err %d", err));

err = payloadHeader.Decode(packetHeader.GetFlags(), plainText, len, &decodedSize);
err = payloadHeader.Decode(plainText, len, &decodedSize);
headerSize = payloadHeader.EncodeSizeBytes();
VerifyOrExit(err == CHIP_NO_ERROR, ChipLogError(Inet, "Secure transport failed to decode encrypted header: err %d", err));
VerifyOrExit(headerSize == decodedSize, ChipLogError(Inet, "Secure transport decode encrypted header length mismatched"));

Expand Down
5 changes: 2 additions & 3 deletions src/transport/TransportMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ class TransportMgrBase
public:
CHIP_ERROR Init(Transport::Base * transport);

CHIP_ERROR SendMessage(const PacketHeader & header, Header::Flags payloadFlags, const Transport::PeerAddress & address,
System::PacketBuffer * msgBuf)
CHIP_ERROR SendMessage(const PacketHeader & header, const Transport::PeerAddress & address, System::PacketBuffer * msgBuf)
{
return mTransport->SendMessage(header, payloadFlags, address, msgBuf);
return mTransport->SendMessage(header, address, msgBuf);
}

void Disconnect(const Transport::PeerAddress & address) { mTransport->Disconnect(address); }
Expand Down
3 changes: 1 addition & 2 deletions src/transport/raw/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ class Base
* This method calls <tt>chip::System::PacketBuffer::Free</tt> on
* behalf of the caller regardless of the return status.
*/
virtual CHIP_ERROR SendMessage(const PacketHeader & header, Header::Flags payloadFlags, const PeerAddress & address,
System::PacketBuffer * msgBuf) = 0;
virtual CHIP_ERROR SendMessage(const PacketHeader & header, const PeerAddress & address, System::PacketBuffer * msgBuf) = 0;

/**
* Determine if this transport can SendMessage to the specified peer address.
Expand Down
Loading