diff --git a/src/transport/raw/MessageHeader.cpp b/src/transport/raw/MessageHeader.cpp index 77474cf9941c01..1b8ef89680d75f 100644 --- a/src/transport/raw/MessageHeader.cpp +++ b/src/transport/raw/MessageHeader.cpp @@ -202,8 +202,17 @@ CHIP_ERROR PacketHeader::Decode(const uint8_t * const data, uint16_t size, uint1 mDestinationGroupId.ClearValue(); } + if (mSecFlags.Has(Header::SecFlagValues::kMsgExtensionFlag)) + { + // If present, skip over Message Extension block. + // Spec 4.4.1.8. Message Extensions (variable) + uint16_t mxLength; + SuccessOrExit(err = reader.Read16(&mxLength).StatusCode()); + VerifyOrExit(mxLength <= reader.Remaining(), err = CHIP_ERROR_INTERNAL); + reader.Skip(mxLength); + } + octets_read = static_cast(reader.OctetsRead()); - VerifyOrExit(octets_read == EncodeSizeBytes(), err = CHIP_ERROR_INTERNAL); *decode_len = octets_read; exit: @@ -258,8 +267,17 @@ CHIP_ERROR PayloadHeader::Decode(const uint8_t * const data, uint16_t size, uint mAckMessageCounter.ClearValue(); } + if (mExchangeFlags.Has(Header::ExFlagValues::kExchangeFlag_SecuredExtension)) + { + // If present, skip over Secured Extension block. + // Spec 4.4.3.7. Secured Extensions (variable) + uint16_t sxLength; + SuccessOrExit(err = reader.Read16(&sxLength).StatusCode()); + VerifyOrExit(sxLength <= reader.Remaining(), err = CHIP_ERROR_INTERNAL); + reader.Skip(sxLength); + } + octets_read = static_cast(reader.OctetsRead()); - VerifyOrExit(octets_read == EncodeSizeBytes(), err = CHIP_ERROR_INTERNAL); *decode_len = octets_read; exit: diff --git a/src/transport/raw/MessageHeader.h b/src/transport/raw/MessageHeader.h index 60255f78d19629..f253aa4f928751 100644 --- a/src/transport/raw/MessageHeader.h +++ b/src/transport/raw/MessageHeader.h @@ -71,6 +71,9 @@ enum class ExFlagValues : uint8_t /// Set when current message is requesting an acknowledgment from the recipient. kExchangeFlag_NeedsAck = 0x04, + /// Secured Extension block is present. + kExchangeFlag_SecuredExtension = 0x08, + /// Set when a vendor id is prepended to the Message Protocol Id field. kExchangeFlag_VendorIdPresent = 0x10, }; diff --git a/src/transport/raw/tests/TestMessageHeader.cpp b/src/transport/raw/tests/TestMessageHeader.cpp index f335b61e45911b..ff50e104f7579f 100644 --- a/src/transport/raw/tests/TestMessageHeader.cpp +++ b/src/transport/raw/tests/TestMessageHeader.cpp @@ -22,6 +22,8 @@ * the Message Header class within the transport layer * */ + +#include #include #include #include @@ -304,9 +306,21 @@ void TestPayloadHeaderEncodeDecodeBounds(nlTestSuite * inSuite, void * inContext } } +constexpr size_t HDR_LEN = 8; ///< Message header length +constexpr size_t SRC_LEN = 8; ///< Source Node ID length +constexpr size_t DST_LEN = 8; ///< Destination Node ID length +constexpr size_t GID_LEN = 2; ///< Group ID length +constexpr size_t MX_LEN = 6; ///< Message Exchange block length +constexpr size_t SX_LEN = 6; ///< Security Exchange block length +constexpr size_t PRO_LEN = 6; ///< Protocol header length +constexpr size_t APP_LEN = 2; ///< App payload length + +/// Size of fixed portion of message header + max source node id + max destination node id. +constexpr size_t MAX_FIXED_HEADER_SIZE = (HDR_LEN + SRC_LEN + DST_LEN); + struct SpecComplianceTestVector { - uint8_t encoded[8 + 8 + 8]; // Fixed header + max source id + max dest id + uint8_t encoded[MAX_FIXED_HEADER_SIZE]; // Fixed header + max source id + max dest id uint8_t messageFlags; uint16_t sessionId; uint8_t sessionType; @@ -363,12 +377,10 @@ struct SpecComplianceTestVector theSpecComplianceTestVector[] = { const unsigned theSpecComplianceTestVectorLength = sizeof(theSpecComplianceTestVector) / sizeof(struct SpecComplianceTestVector); -#define MAX_HEADER_SIZE (8 + 8 + 8) - void TestSpecComplianceEncode(nlTestSuite * inSuite, void * inContext) { struct SpecComplianceTestVector * testEntry; - uint8_t buffer[MAX_HEADER_SIZE]; + uint8_t buffer[MAX_FIXED_HEADER_SIZE]; uint16_t encodeSize; for (unsigned i = 0; i < theSpecComplianceTestVectorLength; i++) @@ -412,6 +424,126 @@ void TestSpecComplianceDecode(nlTestSuite * inSuite, void * inContext) } } +struct TestVectorMsgExtensions +{ + uint8_t payloadOffset; + uint8_t appPayloadOffset; + uint16_t msgLength; + const char * msg; +}; + +struct TestVectorMsgExtensions theTestVectorMsgExtensions[] = { + { + // SRC=none, DST=none, MX=0, SX=0 + .payloadOffset = HDR_LEN, + .appPayloadOffset = PRO_LEN, + .msgLength = HDR_LEN + PRO_LEN + APP_LEN, + .msg = "\x00\x00\x00\x00\xCC\xCC\xCC\xCC" + "\x01\xCC\xEE\xEE\x66\x66\xBB\xBB", + }, + // ================== Test MX ================== + { + // SRC=none, DST=none, MX=1, SX=0 + .payloadOffset = HDR_LEN + MX_LEN, + .appPayloadOffset = PRO_LEN, + .msgLength = HDR_LEN + MX_LEN + PRO_LEN + APP_LEN, + .msg = "\x00\x00\x00\x20\xCC\xCC\xCC\xCC\x04\x00\xE4\xE3\xE2\xE1" + "\x01\xCC\xEE\xEE\x66\x66\xBB\xBB", + }, + { + // SRC=1, DST=none, MX=1, SX=0 + .payloadOffset = HDR_LEN + MX_LEN + SRC_LEN, + .appPayloadOffset = PRO_LEN, + .msgLength = HDR_LEN + MX_LEN + SRC_LEN + PRO_LEN + APP_LEN, + .msg = "\x04\x00\x00\x20\xCC\xCC\xCC\xCC\x11\x11\x11\x11\x11\x11\x11\x11\x04\x00\xE4\xE3\xE2\xE1" + "\x01\xCC\xEE\xEE\x66\x66\xBB\xBB", + }, + { + // SRC=none, DST=1, MX=1, SX=0 + .payloadOffset = HDR_LEN + MX_LEN + DST_LEN, + .appPayloadOffset = PRO_LEN, + .msgLength = HDR_LEN + MX_LEN + DST_LEN + PRO_LEN + APP_LEN, + .msg = "\x01\x00\x00\x20\xCC\xCC\xCC\xCC\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\x04\x00\xE4\xE3\xE2\xE1" + "\x01\xCC\xEE\xEE\x66\x66\xBB\xBB", + }, + { + // SRC=1, DST=1, MX=1, SX=0 + .payloadOffset = HDR_LEN + MX_LEN + SRC_LEN + DST_LEN, + .appPayloadOffset = PRO_LEN, + .msgLength = HDR_LEN + MX_LEN + SRC_LEN + DST_LEN + PRO_LEN + APP_LEN, + .msg = "\x05\x00\x00\x20\xCC\xCC\xCC\xCC\x11\x11\x11\x11\x11\x11\x11\x11\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\x04\x00\xE4\xE3" + "\xE2\xE1" + "\x01\xCC\xEE\xEE\x66\x66\xBB\xBB", + }, + { + // SRC=none, DST=group, MX=1, SX=0 + .payloadOffset = HDR_LEN + MX_LEN + GID_LEN, + .appPayloadOffset = PRO_LEN, + .msgLength = HDR_LEN + MX_LEN + GID_LEN + PRO_LEN + APP_LEN, + .msg = "\x02\x00\x00\x21\xCC\xCC\xCC\xCC\xDD\xDD\x04\x00\xE4\xE3\xE2\xE1" + "\x01\xCC\xEE\xEE\x66\x66\xBB\xBB", + }, + { + // SRC=1, DST=group, MX=1, SX=0 + .payloadOffset = HDR_LEN + MX_LEN + SRC_LEN + GID_LEN, + .appPayloadOffset = PRO_LEN, + .msgLength = HDR_LEN + MX_LEN + SRC_LEN + GID_LEN + PRO_LEN + APP_LEN, + .msg = "\x06\x00\x00\x21\xCC\xCC\xCC\xCC\x11\x11\x11\x11\x11\x11\x11\x11\xDD\xDD\x04\x00\xE4\xE3\xE2\xE1" + "\x01\xCC\xEE\xEE\x66\x66\xBB\xBB", + }, + // ================== Test SX ================== + { + // SRC=none, DST=none, MX=0, SX=1 + .payloadOffset = HDR_LEN, + .appPayloadOffset = PRO_LEN + SX_LEN, + .msgLength = HDR_LEN + PRO_LEN + SX_LEN + APP_LEN, + .msg = "\x00\x00\x00\x00\xCC\xCC\xCC\xCC" + "\x08\xCC\xEE\xEE\x66\x66\x04\x00\xE4\xE3\xE2\xE1\xBB\xBB", + }, + { + // SRC=none, DST=none, MX=1, SX=1 + .payloadOffset = HDR_LEN + MX_LEN, + .appPayloadOffset = PRO_LEN + SX_LEN, + .msgLength = HDR_LEN + MX_LEN + PRO_LEN + SX_LEN + APP_LEN, + .msg = "\x00\x00\x00\x20\xCC\xCC\xCC\xCC\x04\x00\xE4\xE3\xE2\xE1" + "\x08\xCC\xEE\xEE\x66\x66\x04\x00\xE4\xE3\xE2\xE1\xBB\xBB", + }, + { + // SRC=1, DST=1, MX=1, SX=1 + .payloadOffset = HDR_LEN + MX_LEN + SRC_LEN + DST_LEN, + .appPayloadOffset = PRO_LEN + SX_LEN, + .msgLength = HDR_LEN + MX_LEN + SRC_LEN + DST_LEN + PRO_LEN + SX_LEN + APP_LEN, + .msg = "\x05\x00\x00\x20\xCC\xCC\xCC\xCC\x11\x11\x11\x11\x11\x11\x11\x11\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\x04\x00\xE4\xE3" + "\xE2\xE1" + "\x09\xCC\xEE\xEE\x66\x66\x04\x00\xE4\xE3\xE2\xE1\xBB\xBB", + }, +}; + +const unsigned theTestVectorMsgExtensionsLength = sizeof(theTestVectorMsgExtensions) / sizeof(struct TestVectorMsgExtensions); + +void TestMsgExtensionsDecode(nlTestSuite * inSuite, void * inContext) +{ + struct TestVectorMsgExtensions * testEntry; + PacketHeader packetHeader; + PayloadHeader payloadHeader; + uint16_t decodeSize; + + NL_TEST_ASSERT(inSuite, chip::Platform::MemoryInit() == CHIP_NO_ERROR); + + for (unsigned i = 0; i < theTestVectorMsgExtensionsLength; i++) + { + testEntry = &theTestVectorMsgExtensions[i]; + + System::PacketBufferHandle msg = System::PacketBufferHandle::NewWithData(testEntry->msg, testEntry->msgLength); + + NL_TEST_ASSERT(inSuite, packetHeader.Decode(msg->Start(), msg->DataLength(), &decodeSize) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, decodeSize == testEntry->payloadOffset); + + NL_TEST_ASSERT(inSuite, payloadHeader.Decode(msg->Start() + decodeSize, msg->DataLength(), &decodeSize) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, decodeSize == testEntry->appPayloadOffset); + } +} + } // namespace // clang-format off @@ -425,6 +557,7 @@ static const nlTest sTests[] = NL_TEST_DEF("PayloadEncodeDecodeBounds", TestPayloadHeaderEncodeDecodeBounds), NL_TEST_DEF("SpecComplianceEncode", TestSpecComplianceEncode), NL_TEST_DEF("SpecComplianceDecode", TestSpecComplianceDecode), + NL_TEST_DEF("TestMsgExtensionsDecode", TestMsgExtensionsDecode), NL_TEST_SENTINEL() }; // clang-format on