Skip to content

Commit

Permalink
Implement the missing part of Exchange Header in Transport layer (#4017)
Browse files Browse the repository at this point in the history
* Implement the missing part of Exchange Header in Transport layer

* Revert comment 'if' back to 'iff'("if and only if")
  • Loading branch information
yufengwangca authored Dec 1, 2020
1 parent f382625 commit dedd15f
Show file tree
Hide file tree
Showing 25 changed files with 187 additions and 146 deletions.
3 changes: 1 addition & 2 deletions src/messaging/tests/TestExchangeMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ class LoopbackTransport : public Transport::Base
/// Transports are required to have a constructor that takes exactly one argument
CHIP_ERROR Init(const char * unused) { return CHIP_NO_ERROR; }

CHIP_ERROR SendMessage(const PacketHeader & header, Header::Flags payloadFlags, const PeerAddress & address,
System::PacketBuffer * msgBuf) override
CHIP_ERROR SendMessage(const PacketHeader & header, const PeerAddress & address, System::PacketBuffer * msgBuf) override
{
System::PacketBufferHandle msg_ForNow;
msg_ForNow.Adopt(msgBuf);
Expand Down
5 changes: 2 additions & 3 deletions src/transport/BLE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ CHIP_ERROR BLE::DelegateConnection(const uint16_t connDiscriminator)
return err;
}

CHIP_ERROR BLE::SendMessage(const PacketHeader & header, Header::Flags payloadFlags, const Transport::PeerAddress & address,
System::PacketBuffer * msgIn)
CHIP_ERROR BLE::SendMessage(const PacketHeader & header, const Transport::PeerAddress & address, System::PacketBuffer * msgIn)
{
CHIP_ERROR err = CHIP_NO_ERROR;
const uint16_t headerSize = header.EncodeSizeBytes();
Expand All @@ -150,7 +149,7 @@ CHIP_ERROR BLE::SendMessage(const PacketHeader & header, Header::Flags payloadFl

msgBuf->SetStart(msgBuf->Start() - headerSize);

err = header.Encode(msgBuf->Start(), msgBuf->DataLength(), &actualEncodedHeaderSize, payloadFlags);
err = header.Encode(msgBuf->Start(), msgBuf->DataLength(), &actualEncodedHeaderSize);
SuccessOrExit(err);

VerifyOrExit(headerSize == actualEncodedHeaderSize, err = CHIP_ERROR_INTERNAL);
Expand Down
2 changes: 1 addition & 1 deletion src/transport/BLE.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class DLL_EXPORT BLE : public Base
*/
CHIP_ERROR Init(RendezvousSessionDelegate * delegate, const RendezvousParameters & params);

CHIP_ERROR SendMessage(const PacketHeader & header, Header::Flags payloadFlags, const Transport::PeerAddress & address,
CHIP_ERROR SendMessage(const PacketHeader & header, const Transport::PeerAddress & address,
System::PacketBuffer * msgBuf) override;

bool CanSendToPeer(const Transport::PeerAddress & address) override
Expand Down
17 changes: 8 additions & 9 deletions src/transport/RendezvousSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ RendezvousSession::~RendezvousSession()
mDelegate = nullptr;
}

CHIP_ERROR RendezvousSession::SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags,
const Transport::PeerAddress & peerAddress, System::PacketBuffer * msgIn)
CHIP_ERROR RendezvousSession::SendPairingMessage(const PacketHeader & header, const Transport::PeerAddress & peerAddress,
System::PacketBuffer * msgIn)
{
if (mCurrentState != State::kSecurePairing)
{
Expand All @@ -102,11 +102,11 @@ CHIP_ERROR RendezvousSession::SendPairingMessage(const PacketHeader & header, He

if (peerAddress.GetTransportType() == Transport::Type::kBle)
{
return mTransport->SendMessage(header, payloadFlags, peerAddress, msgIn);
return mTransport->SendMessage(header, peerAddress, msgIn);
}
else if (mTransportMgr != nullptr)
{
return mTransportMgr->SendMessage(header, payloadFlags, peerAddress, msgIn);
return mTransportMgr->SendMessage(header, peerAddress, msgIn);
}
else
{
Expand Down Expand Up @@ -150,7 +150,7 @@ CHIP_ERROR RendezvousSession::SendSecureMessage(Protocols::CHIPProtocolId protoc
uint16_t totalLen = msgBuf->TotalLength();

ReturnErrorOnFailure(payloadHeader.Encode(data, totalLen, &actualEncodedHeaderSize));
ReturnErrorOnFailure(mSecureSession.Encrypt(data, totalLen, data, packetHeader, payloadHeader.GetEncodePacketFlags(), mac));
ReturnErrorOnFailure(mSecureSession.Encrypt(data, totalLen, data, packetHeader, mac));

uint16_t taglen = 0;
ReturnErrorOnFailure(mac.Encode(packetHeader, &data[totalLen], kMaxTagLen, &taglen));
Expand All @@ -159,8 +159,7 @@ CHIP_ERROR RendezvousSession::SendSecureMessage(Protocols::CHIPProtocolId protoc

msgBuf->SetDataLength(static_cast<uint16_t>(totalLen + taglen));

ReturnErrorOnFailure(mTransport->SendMessage(packetHeader, payloadHeader.GetEncodePacketFlags(), Transport::PeerAddress::BLE(),
msgBuf.Release_ForNow()));
ReturnErrorOnFailure(mTransport->SendMessage(packetHeader, Transport::PeerAddress::BLE(), msgBuf.Release_ForNow()));

mSecureMessageIndex++;

Expand Down Expand Up @@ -378,7 +377,7 @@ CHIP_ERROR RendezvousSession::HandleSecureMessage(const PacketHeader & packetHea
len = static_cast<uint16_t>(len - taglen);
msgBuf->SetDataLength(len);

ReturnErrorOnFailure(mSecureSession.Decrypt(data, len, plainText, packetHeader, payloadHeader.GetEncodePacketFlags(), mac));
ReturnErrorOnFailure(mSecureSession.Decrypt(data, len, plainText, packetHeader, mac));

// Use the node IDs from the packet header only after it's successfully decrypted
if (packetHeader.GetDestinationNodeId().HasValue() && !mParams.HasLocalNodeId())
Expand All @@ -394,7 +393,7 @@ CHIP_ERROR RendezvousSession::HandleSecureMessage(const PacketHeader & packetHea
}

uint16_t decodedSize = 0;
ReturnErrorOnFailure(payloadHeader.Decode(packetHeader.GetFlags(), plainText, len, &decodedSize));
ReturnErrorOnFailure(payloadHeader.Decode(plainText, len, &decodedSize));

ReturnErrorCodeIf(headerSize != decodedSize, CHIP_ERROR_INCORRECT_STATE);

Expand Down
4 changes: 2 additions & 2 deletions src/transport/RendezvousSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ class RendezvousSession : public SecurePairingSessionDelegate,
Optional<NodeId> GetRemoteNodeId() const { return mParams.GetRemoteNodeId(); }

//////////// SecurePairingSessionDelegate Implementation ///////////////
CHIP_ERROR SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags,
const Transport::PeerAddress & peerAddress, System::PacketBuffer * msgBuf) override;
CHIP_ERROR SendPairingMessage(const PacketHeader & header, const Transport::PeerAddress & peerAddress,
System::PacketBuffer * msgBuf) override;
void OnPairingError(CHIP_ERROR err) override;
void OnPairingComplete() override;

Expand Down
6 changes: 3 additions & 3 deletions src/transport/SecurePairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ CHIP_ERROR SecurePairingSession::AttachHeaderAndSend(uint8_t msgType, System::Pa
SuccessOrExit(err);
VerifyOrExit(headerSize == actualEncodedHeaderSize, err = CHIP_ERROR_INTERNAL);

err = mDelegate->SendPairingMessage(PacketHeader().SetSourceNodeId(mLocalNodeId).SetEncryptionKeyID(mLocalKeyId),
payloadHeader.GetEncodePacketFlags(), mPeerAddress, msgBuf.Release_ForNow());
err = mDelegate->SendPairingMessage(PacketHeader().SetSourceNodeId(mLocalNodeId).SetEncryptionKeyID(mLocalKeyId), mPeerAddress,
msgBuf.Release_ForNow());
SuccessOrExit(err);

exit:
Expand Down Expand Up @@ -437,7 +437,7 @@ CHIP_ERROR SecurePairingSession::HandlePeerMessage(const PacketHeader & packetHe

VerifyOrExit(!msg.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT);

err = payloadHeader.Decode(packetHeader.GetFlags(), msg->Start(), msg->DataLength(), &headerSize);
err = payloadHeader.Decode(msg->Start(), msg->DataLength(), &headerSize);
SuccessOrExit(err);

msg->ConsumeHead(headerSize);
Expand Down
5 changes: 2 additions & 3 deletions src/transport/SecurePairingSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@ class DLL_EXPORT SecurePairingSessionDelegate
* Called when pairing session generates a new message that should be sent to peer.
*
* @param header the message header for the sent message
* @param payloadFlags payload encoding flags
* @param peerAddress the destination of the message
* @param msgBuf the raw data for the message being sent
* @return CHIP_ERROR Error thrown when sending the message
*/
virtual CHIP_ERROR SendPairingMessage(const PacketHeader & header, Header::Flags payloadFlags,
const Transport::PeerAddress & peerAddress, System::PacketBuffer * msgBuf)
virtual CHIP_ERROR SendPairingMessage(const PacketHeader & header, const Transport::PeerAddress & peerAddress,
System::PacketBuffer * msgBuf)
{
return CHIP_ERROR_NOT_IMPLEMENTED;
}
Expand Down
13 changes: 6 additions & 7 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,15 @@ CHIP_ERROR SecureSession::GetIV(const PacketHeader & header, uint8_t * iv, size_
return bbuf.Fit() ? CHIP_NO_ERROR : CHIP_ERROR_NO_MEMORY;
}

CHIP_ERROR SecureSession::GetAdditionalAuthData(const PacketHeader & header, const Header::Flags payloadEncodeFlags, uint8_t * aad,
uint16_t & len)
CHIP_ERROR SecureSession::GetAdditionalAuthData(const PacketHeader & header, uint8_t * aad, uint16_t & len)
{
VerifyOrReturnError(len >= header.EncodeSizeBytes(), CHIP_ERROR_INVALID_ARGUMENT);

// Use unencrypted part of header as AAD. This will help
// integrity protect the whole message
uint16_t actualEncodedHeaderSize;

ReturnErrorOnFailure(header.Encode(aad, len, &actualEncodedHeaderSize, payloadEncodeFlags));
ReturnErrorOnFailure(header.Encode(aad, len, &actualEncodedHeaderSize));
VerifyOrReturnError(len >= actualEncodedHeaderSize, CHIP_ERROR_INVALID_ARGUMENT);

len = actualEncodedHeaderSize;
Expand All @@ -114,7 +113,7 @@ CHIP_ERROR SecureSession::GetAdditionalAuthData(const PacketHeader & header, con
}

CHIP_ERROR SecureSession::Encrypt(const uint8_t * input, size_t input_length, uint8_t * output, PacketHeader & header,
Header::Flags payloadFlags, MessageAuthenticationCode & mac)
MessageAuthenticationCode & mac)
{

constexpr Header::EncryptionType encType = Header::EncryptionType::kAESCCMTagLen16;
Expand All @@ -133,7 +132,7 @@ CHIP_ERROR SecureSession::Encrypt(const uint8_t * input, size_t input_length, ui
uint8_t tag[kMaxTagLen];

ReturnErrorOnFailure(GetIV(header, IV, sizeof(IV)));
ReturnErrorOnFailure(GetAdditionalAuthData(header, payloadFlags, AAD, aadLen));
ReturnErrorOnFailure(GetAdditionalAuthData(header, AAD, aadLen));
ReturnErrorOnFailure(
AES_CCM_encrypt(input, input_length, AAD, aadLen, mKey, sizeof(mKey), IV, sizeof(IV), output, tag, taglen));

Expand All @@ -143,7 +142,7 @@ CHIP_ERROR SecureSession::Encrypt(const uint8_t * input, size_t input_length, ui
}

CHIP_ERROR SecureSession::Decrypt(const uint8_t * input, size_t input_length, uint8_t * output, const PacketHeader & header,
Header::Flags payloadFlags, const MessageAuthenticationCode & mac)
const MessageAuthenticationCode & mac)
{
const size_t taglen = MessageAuthenticationCode::TagLenForEncryptionType(header.GetEncryptionType());
const uint8_t * tag = mac.GetTag();
Expand All @@ -157,7 +156,7 @@ CHIP_ERROR SecureSession::Decrypt(const uint8_t * input, size_t input_length, ui
VerifyOrReturnError(output != nullptr, CHIP_ERROR_INVALID_ARGUMENT);

ReturnErrorOnFailure(GetIV(header, IV, sizeof(IV)));
ReturnErrorOnFailure(GetAdditionalAuthData(header, payloadFlags, AAD, aadLen));
ReturnErrorOnFailure(GetAdditionalAuthData(header, AAD, aadLen));

return AES_CCM_decrypt(input, input_length, AAD, aadLen, tag, taglen, mKey, sizeof(mKey), IV, sizeof(IV), output);
}
Expand Down
9 changes: 3 additions & 6 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,12 @@ class DLL_EXPORT SecureSession
* @param input_length Length of the input data
* @param output Output buffer for encrypted data
* @param header message header structure. Encryption type will be set on the header.
* @param payloadFlags extra flags for packet header encryption
* @param mac - output the resulting mac
*
* @return CHIP_ERROR The result of encryption
*/
CHIP_ERROR Encrypt(const uint8_t * input, size_t input_length, uint8_t * output, PacketHeader & header,
Header::Flags payloadFlags, MessageAuthenticationCode & mac);
MessageAuthenticationCode & mac);

/**
* @brief
Expand All @@ -96,12 +95,11 @@ class DLL_EXPORT SecureSession
* @param input_length Length of the input data
* @param output Output buffer for decrypted data
* @param header message header structure
* @param payloadFlags extra flags for packet header decryption
* @return CHIP_ERROR The result of decryption
* @param mac Input mac
*/
CHIP_ERROR Decrypt(const uint8_t * input, size_t input_length, uint8_t * output, const PacketHeader & header,
Header::Flags payloadFlags, const MessageAuthenticationCode & mac);
const MessageAuthenticationCode & mac);

/**
* @brief
Expand All @@ -128,8 +126,7 @@ class DLL_EXPORT SecureSession
// Use unencrypted header as additional authenticated data (AAD) during encryption and decryption.
// The encryption operations includes AAD when message authentication tag is generated. This tag
// is used at the time of decryption to integrity check the received data.
static CHIP_ERROR GetAdditionalAuthData(const PacketHeader & header, Header::Flags payloadEncodeFlags, uint8_t * aad,
uint16_t & len);
static CHIP_ERROR GetAdditionalAuthData(const PacketHeader & header, uint8_t * aad, uint16_t & len);
};

} // namespace chip
26 changes: 13 additions & 13 deletions src/transport/SecureSessionMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId p
.SetPayloadLength(static_cast<uint16_t>(payloadLength));
packetHeader.GetFlags().Set(Header::FlagValues::kSecure);

ChipLogProgress(Inet, "Sending msg from %llu to %llu\n", mLocalNodeId, peerNodeId);
ChipLogProgress(Inet, "Sending msg from %llu to %llu", mLocalNodeId, peerNodeId);

VerifyOrExit(msgBuf->EnsureReservedSize(headerSize), err = CHIP_ERROR_NO_MEMORY);

Expand All @@ -146,7 +146,7 @@ CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId p
err = payloadHeader.Encode(data, totalLen, &actualEncodedHeaderSize);
SuccessOrExit(err);

err = state->GetSecureSession().Encrypt(data, totalLen, data, packetHeader, payloadHeader.GetEncodePacketFlags(), mac);
err = state->GetSecureSession().Encrypt(data, totalLen, data, packetHeader, mac);
SuccessOrExit(err);

err = mac.Encode(packetHeader, &data[totalLen], kMaxTagLen, &taglen);
Expand All @@ -157,8 +157,7 @@ CHIP_ERROR SecureSessionMgr::SendMessage(PayloadHeader & payloadHeader, NodeId p

ChipLogDetail(Inet, "Secure transport transmitting msg %u after encryption", state->GetSendMessageIndex());

err = mTransportMgr->SendMessage(packetHeader, payloadHeader.GetEncodePacketFlags(), state->GetPeerAddress(),
msgBuf.Release_ForNow());
err = mTransportMgr->SendMessage(packetHeader, state->GetPeerAddress(), msgBuf.Release_ForNow());
}
SuccessOrExit(err);
state->IncrementSendMessageIndex();
Expand Down Expand Up @@ -306,13 +305,13 @@ void SecureSessionMgr::OnMessageReceived(const PacketHeader & packetHeader, cons
PayloadHeader payloadHeader;
MessageAuthenticationCode mac;

uint8_t * data = msg->Start();
uint8_t * plainText = nullptr;
uint16_t len = msg->TotalLength();
const uint16_t headerSize = payloadHeader.EncodeSizeBytes();
uint16_t decodedSize = 0;
uint16_t taglen = 0;
uint16_t payloadlen = 0;
uint8_t * data = msg->Start();
uint8_t * plainText = nullptr;
uint16_t len = msg->TotalLength();
uint16_t headerSize = 0;
uint16_t decodedSize = 0;
uint16_t taglen = 0;
uint16_t payloadlen = 0;

#if CHIP_SYSTEM_CONFIG_USE_LWIP
/* This is a workaround for the case where PacketBuffer payload is not
Expand All @@ -333,10 +332,11 @@ void SecureSessionMgr::OnMessageReceived(const PacketHeader & packetHeader, cons
len = static_cast<uint16_t>(len - taglen);
msg->SetDataLength(len, nullptr);

err = state->GetSecureSession().Decrypt(data, len, plainText, packetHeader, payloadHeader.GetEncodePacketFlags(), mac);
err = state->GetSecureSession().Decrypt(data, len, plainText, packetHeader, mac);
VerifyOrExit(err == CHIP_NO_ERROR, ChipLogError(Inet, "Secure transport failed to decrypt msg: err %d", err));

err = payloadHeader.Decode(packetHeader.GetFlags(), plainText, len, &decodedSize);
err = payloadHeader.Decode(plainText, len, &decodedSize);
headerSize = payloadHeader.EncodeSizeBytes();
VerifyOrExit(err == CHIP_NO_ERROR, ChipLogError(Inet, "Secure transport failed to decode encrypted header: err %d", err));
VerifyOrExit(headerSize == decodedSize, ChipLogError(Inet, "Secure transport decode encrypted header length mismatched"));

Expand Down
5 changes: 2 additions & 3 deletions src/transport/TransportMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ class TransportMgrBase
public:
CHIP_ERROR Init(Transport::Base * transport);

CHIP_ERROR SendMessage(const PacketHeader & header, Header::Flags payloadFlags, const Transport::PeerAddress & address,
System::PacketBuffer * msgBuf)
CHIP_ERROR SendMessage(const PacketHeader & header, const Transport::PeerAddress & address, System::PacketBuffer * msgBuf)
{
return mTransport->SendMessage(header, payloadFlags, address, msgBuf);
return mTransport->SendMessage(header, address, msgBuf);
}

void Disconnect(const Transport::PeerAddress & address) { mTransport->Disconnect(address); }
Expand Down
3 changes: 1 addition & 2 deletions src/transport/raw/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ class Base
* This method calls <tt>chip::System::PacketBuffer::Free</tt> on
* behalf of the caller regardless of the return status.
*/
virtual CHIP_ERROR SendMessage(const PacketHeader & header, Header::Flags payloadFlags, const PeerAddress & address,
System::PacketBuffer * msgBuf) = 0;
virtual CHIP_ERROR SendMessage(const PacketHeader & header, const PeerAddress & address, System::PacketBuffer * msgBuf) = 0;

/**
* Determine if this transport can SendMessage to the specified peer address.
Expand Down
Loading

0 comments on commit dedd15f

Please sign in to comment.