diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 07684077880c93..ec6404727290ef 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -190,6 +190,16 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P uint32_t messageCounter = counter.Value(); ReturnErrorOnFailure(counter.Advance()); packetHeader.SetMessageCounter(messageCounter); + Transport::UnauthenticatedSession * session = sessionHandle->AsUnauthenticatedSession(); + switch (session->GetSessionRole()) + { + case Transport::UnauthenticatedSession::SessionRole::kInitiator: + packetHeader.SetSourceNodeId(session->GetEphemeralInitiatorNodeID()); + break; + case Transport::UnauthenticatedSession::SessionRole::kResponder: + packetHeader.SetDestinationNodeId(session->GetEphemeralInitiatorNodeID()); + break; + } // Trace after all headers are settled. CHIP_TRACE_MESSAGE_SENT(payloadHeader, packetHeader, message->Start(), message->TotalLength()); @@ -401,7 +411,7 @@ void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System:: } else { - MessageDispatch(packetHeader, peerAddress, std::move(msg)); + UnauthenticatedMessageDispatch(packetHeader, peerAddress, std::move(msg)); } } @@ -437,18 +447,45 @@ void SessionManager::RefreshSessionOperationalData(const SessionHandle & session }); } -void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, - System::PacketBufferHandle && msg) +void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBufferHandle && msg) { - Optional optionalSession = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, GetLocalMRPConfig()); - if (!optionalSession.HasValue()) + Optional source = packetHeader.GetSourceNodeId(); + Optional destination = packetHeader.GetDestinationNodeId(); + if ((source.HasValue() && destination.HasValue()) || (!source.HasValue() && !destination.HasValue())) { - ChipLogError(Inet, "UnauthenticatedSession exhausted"); - return; + ChipLogProgress(Inet, + "Received malformed unsecure packet with source 0x" ChipLogFormatX64 " destination 0x" ChipLogFormatX64, + ChipLogValueX64(source.ValueOr(kUndefinedNodeId)), ChipLogValueX64(destination.ValueOr(kUndefinedNodeId))); + return; // ephemeral node id is only assigned to the initiator, there should be one and only one node id exists. + } + + Optional optionalSession; + if (source.HasValue()) + { + // Assume peer is the initiator, we are the responder. + optionalSession = mUnauthenticatedSessions.FindOrAllocateResponder(source.Value(), GetLocalMRPConfig()); + if (!optionalSession.HasValue()) + { + ChipLogError(Inet, "UnauthenticatedSession exhausted"); + return; + } + } + else + { + // Assume peer is the responder, we are the initiator. + optionalSession = mUnauthenticatedSessions.FindInitiator(destination.Value()); + if (!optionalSession.HasValue()) + { + ChipLogProgress(Inet, "Received unknown unsecure packet for initiator 0x" ChipLogFormatX64, + ChipLogValueX64(destination.Value())); + return; + } } const SessionHandle & session = optionalSession.Value(); Transport::UnauthenticatedSession * unsecuredSession = session->AsUnauthenticatedSession(); + unsecuredSession->SetPeerAddress(peerAddress); SessionMessageDelegate::DuplicateMessage isDuplicate = SessionMessageDelegate::DuplicateMessage::No; // Verify message counter diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 5b4cda34960783..4adfdd90015fe5 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -27,6 +27,7 @@ #include +#include #include #include #include @@ -206,7 +207,13 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate Optional CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & config) { - return mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress, config); + // Allocate ephemeralInitiatorNodeID in Operational Node ID range + NodeId ephemeralInitiatorNodeID; + do + { + ephemeralInitiatorNodeID = static_cast(Crypto::GetRandU64()); + } while (!IsOperationalNodeId(ephemeralInitiatorNodeID)); + return mUnauthenticatedSessions.AllocInitiator(ephemeralInitiatorNodeID, peerAddress, config); } // TODO: implements group sessions @@ -279,8 +286,8 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate void SecureGroupMessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg); - void MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, - System::PacketBufferHandle && msg); + void UnauthenticatedMessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, + System::PacketBufferHandle && msg); void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source); diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index 8e7cbd1ec371fb..213afb1661235a 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -46,8 +46,15 @@ class UnauthenticatedSessionDeleter class UnauthenticatedSession : public Session, public ReferenceCounted { public: - UnauthenticatedSession(const PeerAddress & address, const ReliableMessageProtocolConfig & config) : - mPeerAddress(address), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config) + enum class SessionRole + { + kInitiator, + kResponder, + }; + + UnauthenticatedSession(SessionRole sessionRole, NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config) : + mEphemeralInitiatorNodeId(ephemeralInitiatorNodeID), mSessionRole(sessionRole), + mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config) {} ~UnauthenticatedSession() { NotifySessionReleased(); } @@ -88,8 +95,22 @@ class UnauthenticatedSession : public Session, public ReferenceCounted FindOrAllocateEntry(const PeerAddress & address, const ReliableMessageProtocolConfig & config) + Optional FindOrAllocateResponder(NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config) { - UnauthenticatedSession * result = FindEntry(address); + UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID); if (result != nullptr) return MakeOptional(*result); - CHIP_ERROR err = AllocEntry(address, config, result); + CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, config, result); + if (err == CHIP_NO_ERROR) + { + return MakeOptional(*result); + } + else + { + return Optional::Missing(); + } + } + + CHECK_RETURN_VALUE Optional FindInitiator(NodeId ephemeralInitiatorNodeID) + { + UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID); + if (result != nullptr) + { + return MakeOptional(*result); + } + else + { + return Optional::Missing(); + } + } + + CHECK_RETURN_VALUE Optional AllocInitiator(NodeId ephemeralInitiatorNodeID, const PeerAddress & peerAddress, + const ReliableMessageProtocolConfig & config) + { + UnauthenticatedSession * result = nullptr; + CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, config, result); if (err == CHIP_NO_ERROR) { + result->SetPeerAddress(peerAddress); return MakeOptional(*result); } else @@ -148,10 +201,10 @@ class UnauthenticatedSessionTable * CHIP_ERROR_NO_MEMORY). */ CHECK_RETURN_VALUE - CHIP_ERROR AllocEntry(const PeerAddress & address, const ReliableMessageProtocolConfig & config, - UnauthenticatedSession *& entry) + CHIP_ERROR AllocEntry(UnauthenticatedSession::SessionRole sessionRole, NodeId ephemeralInitiatorNodeID, + const ReliableMessageProtocolConfig & config, UnauthenticatedSession *& entry) { - entry = mEntries.CreateObject(address, config); + entry = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, config); if (entry != nullptr) return CHIP_NO_ERROR; @@ -161,21 +214,16 @@ class UnauthenticatedSessionTable return CHIP_ERROR_NO_MEMORY; } - mEntries.ResetObject(entry, address, config); + mEntries.ResetObject(entry, sessionRole, ephemeralInitiatorNodeID, config); return CHIP_NO_ERROR; } - /** - * Get a session using given address - * - * @return the peer found, nullptr if not found - */ - CHECK_RETURN_VALUE - UnauthenticatedSession * FindEntry(const PeerAddress & address) + CHECK_RETURN_VALUE UnauthenticatedSession * FindEntry(UnauthenticatedSession::SessionRole sessionRole, + NodeId ephemeralInitiatorNodeID) { UnauthenticatedSession * result = nullptr; mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) { - if (MatchPeerAddress(entry->GetPeerAddress(), address)) + if (entry->GetSessionRole() == sessionRole && entry->GetEphemeralInitiatorNodeID() == ephemeralInitiatorNodeID) { result = entry; return Loop::Break; @@ -202,45 +250,6 @@ class UnauthenticatedSessionTable return result; } - // A temporary solution for #11120 - // Enforce interface match if not null - static bool MatchInterface(Inet::InterfaceId i1, Inet::InterfaceId i2) - { - if (i1.IsPresent() && i2.IsPresent()) - { - return i1 == i2; - } - else - { - // One of the interfaces is null. - return true; - } - } - - static bool MatchPeerAddress(const PeerAddress & a1, const PeerAddress & a2) - { - if (a1.GetTransportType() != a2.GetTransportType()) - return false; - - switch (a1.GetTransportType()) - { - case Transport::Type::kUndefined: - return false; - case Transport::Type::kUdp: - case Transport::Type::kTcp: - return a1.GetIPAddress() == a2.GetIPAddress() && a1.GetPort() == a2.GetPort() && - // Enforce interface equal-ness if the address is link-local, otherwise ignore interface - // Use MatchInterface for a temporary solution for #11120 - (a1.GetIPAddress().IsIPv6LinkLocal() ? a1.GetInterface() == a2.GetInterface() - : MatchInterface(a1.GetInterface(), a2.GetInterface())); - case Transport::Type::kBle: - // TODO: complete BLE address comparation - return true; - } - - return false; - } - ObjectPool mEntries; };