Skip to content

Commit

Permalink
Make sure we don't use un-initialized reader when a command has no fi…
Browse files Browse the repository at this point in the history
…elds struct. (project-chip#20056)

* Make sure we don't use un-initialized reader when a command has no fields struct.

Fixes project-chip#10501

* Address review comments.
  • Loading branch information
bzbarsky-apple authored Jul 1, 2022
1 parent 3b92283 commit a295bec
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 28 deletions.
19 changes: 17 additions & 2 deletions src/app/CommandHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <app/RequiredPrivilege.h>
#include <app/util/MatterCallbacks.h>
#include <credentials/GroupDataProvider.h>
#include <lib/core/CHIPTLVData.hpp>
#include <lib/core/CHIPTLVUtilities.hpp>
#include <lib/support/TypeTraits.h>
#include <protocols/secure_channel/Constants.h>
Expand Down Expand Up @@ -244,6 +245,18 @@ CHIP_ERROR CommandHandler::SendCommandResponse()
return CHIP_NO_ERROR;
}

namespace {
// We use this when the sender did not actually provide a CommandFields struct,
// to avoid downstream consumers having to worry about cases when there is or is
// not a struct available. We use an empty struct with anonymous tag, since we
// can't use a context tag at top level, and consumers should not care about the
// tag here).
constexpr uint8_t sNoFields[] = {
CHIP_TLV_STRUCTURE(CHIP_TLV_TAG_ANONYMOUS),
CHIP_TLV_END_OF_CONTAINER,
};
} // anonymous namespace

CHIP_ERROR CommandHandler::ProcessCommandDataIB(CommandDataIB::Parser & aCommandElement)
{
CHIP_ERROR err = CHIP_NO_ERROR;
Expand Down Expand Up @@ -308,7 +321,8 @@ CHIP_ERROR CommandHandler::ProcessCommandDataIB(CommandDataIB::Parser & aCommand
ChipLogDetail(DataManagement,
"Received command without data for Endpoint=%u Cluster=" ChipLogFormatMEI " Command=" ChipLogFormatMEI,
concretePath.mEndpointId, ChipLogValueMEI(concretePath.mClusterId), ChipLogValueMEI(concretePath.mCommandId));
err = CHIP_NO_ERROR;
commandDataReader.Init(sNoFields);
err = commandDataReader.Next();
}
if (CHIP_NO_ERROR == err)
{
Expand Down Expand Up @@ -365,7 +379,8 @@ CHIP_ERROR CommandHandler::ProcessGroupCommandDataIB(CommandDataIB::Parser & aCo
ChipLogDetail(DataManagement,
"Received command without data for Group=%u Cluster=" ChipLogFormatMEI " Command=" ChipLogFormatMEI, groupId,
ChipLogValueMEI(clusterId), ChipLogValueMEI(commandId));
err = CHIP_NO_ERROR;
commandDataReader.Init(sNoFields);
err = commandDataReader.Next();
}
SuccessOrExit(err);

Expand Down
87 changes: 63 additions & 24 deletions src/app/tests/TestCommandInteraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <lib/core/CHIPTLV.h>
#include <lib/core/CHIPTLVDebug.hpp>
#include <lib/core/CHIPTLVUtilities.hpp>
#include <lib/core/Optional.h>
#include <lib/support/ErrorStr.h>
#include <lib/support/UnitTestContext.h>
#include <lib/support/UnitTestRegistration.h>
Expand All @@ -57,10 +58,14 @@ bool isCommandDispatched = false;
bool sendResponse = true;
bool asyncCommand = false;

// Allow us to do test asserts from arbitrary places.
nlTestSuite * gSuite = nullptr;

constexpr EndpointId kTestEndpointId = 1;
constexpr ClusterId kTestClusterId = 3;
constexpr CommandId kTestCommandId = 4;
constexpr CommandId kTestCommandIdCommandSpecificResponse = 5;
constexpr CommandId kTestCommandIdWithData = 4;
constexpr CommandId kTestCommandIdNoData = 5;
constexpr CommandId kTestCommandIdCommandSpecificResponse = 6;
constexpr CommandId kTestNonExistCommandId = 0;
} // namespace

Expand Down Expand Up @@ -97,6 +102,36 @@ void DispatchSingleClusterCommand(const ConcreteCommandPath & aCommandPath, chip
ChipLogDetail(Controller, "Received Cluster Command: Endpoint=%x Cluster=" ChipLogFormatMEI " Command=" ChipLogFormatMEI,
aCommandPath.mEndpointId, ChipLogValueMEI(aCommandPath.mClusterId), ChipLogValueMEI(aCommandPath.mCommandId));

// Duplicate what our normal command-field-decode code does, in terms of
// checking for a struct and then entering it before getting the fields.
if (aReader.GetType() != TLV::kTLVType_Structure)
{
apCommandObj->AddStatus(aCommandPath, Protocols::InteractionModel::Status::InvalidAction);
return;
}

TLV::TLVType outerContainerType;
CHIP_ERROR err = aReader.EnterContainer(outerContainerType);
NL_TEST_ASSERT(gSuite, err == CHIP_NO_ERROR);

err = aReader.Next();
if (aCommandPath.mCommandId == kTestCommandIdNoData)
{
NL_TEST_ASSERT(gSuite, err == CHIP_ERROR_END_OF_TLV);
}
else
{
NL_TEST_ASSERT(gSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(gSuite, aReader.GetTag() == TLV::ContextTag(1));
bool val;
err = aReader.Get(val);
NL_TEST_ASSERT(gSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(gSuite, val);
}

err = aReader.ExitContainer(outerContainerType);
NL_TEST_ASSERT(gSuite, err == CHIP_NO_ERROR);

if (asyncCommand)
{
asyncCommandHandle = apCommandObj;
Expand All @@ -105,7 +140,7 @@ void DispatchSingleClusterCommand(const ConcreteCommandPath & aCommandPath, chip

if (sendResponse)
{
if (aCommandPath.mCommandId == kTestCommandId)
if (aCommandPath.mCommandId == kTestCommandIdNoData || aCommandPath.mCommandId == kTestCommandIdWithData)
{
apCommandObj->AddStatus(aCommandPath, Protocols::InteractionModel::Status::Success);
}
Expand Down Expand Up @@ -200,16 +235,20 @@ class TestCommandInteraction
}

private:
// Generate an invoke request. If aCommandId is kTestCommandIdWithData, a
// payload will be included. Otherwise no payload will be included.
static void GenerateInvokeRequest(nlTestSuite * apSuite, void * apContext, System::PacketBufferHandle & aPayload,
bool aNeedCommandData, bool aIsTimedRequest, EndpointId aEndpointId = kTestEndpointId,
ClusterId aClusterId = kTestClusterId, CommandId aCommandId = kTestCommandId);
bool aIsTimedRequest, CommandId aCommandId, ClusterId aClusterId = kTestClusterId,
EndpointId aEndpointId = kTestEndpointId);
// Generate an invoke response. If aCommandId is kTestCommandIdWithData, a
// payload will be included. Otherwise no payload will be included.
static void GenerateInvokeResponse(nlTestSuite * apSuite, void * apContext, System::PacketBufferHandle & aPayload,
bool aNeedCommandData, EndpointId aEndpointId = kTestEndpointId,
ClusterId aClusterId = kTestClusterId, CommandId aCommandId = kTestCommandId);
CommandId aCommandId, ClusterId aClusterId = kTestClusterId,
EndpointId aEndpointId = kTestEndpointId);
static void AddInvokeRequestData(nlTestSuite * apSuite, void * apContext, CommandSender * apCommandSender,
CommandId aCommandId = kTestCommandId);
CommandId aCommandId = kTestCommandIdWithData);
static void AddInvokeResponseData(nlTestSuite * apSuite, void * apContext, CommandHandler * apCommandHandler,
bool aNeedStatusCode, CommandId aCommandId = kTestCommandId);
bool aNeedStatusCode, CommandId aCommandId = kTestCommandIdWithData);
static void ValidateCommandHandlerWithSendCommand(nlTestSuite * apSuite, void * apContext, bool aNeedStatusCode);
};

Expand All @@ -224,14 +263,14 @@ class TestExchangeDelegate : public Messaging::ExchangeDelegate
void OnResponseTimeout(Messaging::ExchangeContext * ec) override {}
};

CommandPathParams MakeTestCommandPath(CommandId aCommandId = kTestCommandId)
CommandPathParams MakeTestCommandPath(CommandId aCommandId = kTestCommandIdWithData)
{
return CommandPathParams(kTestEndpointId, 0, kTestClusterId, aCommandId, (chip::app::CommandPathFlags::kEndpointIdValid));
}

void TestCommandInteraction::GenerateInvokeRequest(nlTestSuite * apSuite, void * apContext, System::PacketBufferHandle & aPayload,
bool aNeedCommandData, bool aIsTimedRequest, EndpointId aEndpointId,
ClusterId aClusterId, CommandId aCommandId)
bool aIsTimedRequest, CommandId aCommandId, ClusterId aClusterId,
EndpointId aEndpointId)

{
CHIP_ERROR err = CHIP_NO_ERROR;
Expand All @@ -255,7 +294,7 @@ void TestCommandInteraction::GenerateInvokeRequest(nlTestSuite * apSuite, void *
commandPathBuilder.EndpointId(aEndpointId).ClusterId(aClusterId).CommandId(aCommandId).EndOfCommandPathIB();
NL_TEST_ASSERT(apSuite, commandPathBuilder.GetError() == CHIP_NO_ERROR);

if (aNeedCommandData)
if (aCommandId == kTestCommandIdWithData)
{
chip::TLV::TLVWriter * pWriter = commandDataIBBuilder.GetWriter();
chip::TLV::TLVType dummyType = chip::TLV::kTLVType_NotSpecified;
Expand Down Expand Up @@ -284,8 +323,7 @@ void TestCommandInteraction::GenerateInvokeRequest(nlTestSuite * apSuite, void *
}

void TestCommandInteraction::GenerateInvokeResponse(nlTestSuite * apSuite, void * apContext, System::PacketBufferHandle & aPayload,
bool aNeedCommandData, EndpointId aEndpointId, ClusterId aClusterId,
CommandId aCommandId)
CommandId aCommandId, ClusterId aClusterId, EndpointId aEndpointId)

{
CHIP_ERROR err = CHIP_NO_ERROR;
Expand All @@ -312,7 +350,7 @@ void TestCommandInteraction::GenerateInvokeResponse(nlTestSuite * apSuite, void
commandPathBuilder.EndpointId(aEndpointId).ClusterId(aClusterId).CommandId(aCommandId).EndOfCommandPathIB();
NL_TEST_ASSERT(apSuite, commandPathBuilder.GetError() == CHIP_NO_ERROR);

if (aNeedCommandData)
if (aCommandId == kTestCommandIdWithData)
{
chip::TLV::TLVWriter * pWriter = commandDataIBBuilder.GetWriter();
chip::TLV::TLVType dummyType = chip::TLV::kTLVType_NotSpecified;
Expand Down Expand Up @@ -406,7 +444,7 @@ void TestCommandInteraction::TestCommandHandlerWithWrongState(nlTestSuite * apSu
{
TestContext & ctx = *static_cast<TestContext *>(apContext);
CHIP_ERROR err = CHIP_NO_ERROR;
ConcreteCommandPath path = { kTestEndpointId, kTestClusterId, kTestCommandId };
ConcreteCommandPath path = { kTestEndpointId, kTestClusterId, kTestCommandIdNoData };

app::CommandHandler commandHandler(&mockCommandHandlerDelegate);

Expand Down Expand Up @@ -435,7 +473,7 @@ void TestCommandInteraction::TestCommandSenderWithSendCommand(nlTestSuite * apSu

ctx.DrainAndServiceIO();

GenerateInvokeResponse(apSuite, apContext, buf, true /*aNeedCommandData*/);
GenerateInvokeResponse(apSuite, apContext, buf, kTestCommandIdWithData);
err = commandSender.ProcessInvokeResponse(std::move(buf));
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
}
Expand All @@ -444,7 +482,7 @@ void TestCommandInteraction::TestCommandHandlerWithSendEmptyCommand(nlTestSuite
{
TestContext & ctx = *static_cast<TestContext *>(apContext);
CHIP_ERROR err = CHIP_NO_ERROR;
ConcreteCommandPath path = { kTestEndpointId, kTestClusterId, kTestCommandId };
ConcreteCommandPath path = { kTestEndpointId, kTestClusterId, kTestCommandIdNoData };

app::CommandHandler commandHandler(&mockCommandHandlerDelegate);
System::PacketBufferHandle commandDatabuf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize);
Expand All @@ -470,7 +508,7 @@ void TestCommandInteraction::TestCommandSenderWithProcessReceivedMsg(nlTestSuite

System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize);

GenerateInvokeResponse(apSuite, apContext, buf, true /*aNeedCommandData*/);
GenerateInvokeResponse(apSuite, apContext, buf, kTestCommandIdWithData);
err = commandSender.ProcessInvokeResponse(std::move(buf));
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
}
Expand Down Expand Up @@ -636,7 +674,7 @@ void TestCommandInteraction::TestCommandHandlerWithProcessReceivedMsg(nlTestSuit
TestExchangeDelegate delegate;
commandHandler.mpExchangeCtx = ctx.NewExchangeToAlice(&delegate);

GenerateInvokeRequest(apSuite, apContext, commandDatabuf, true /*aNeedCommandData*/, /* aIsTimedRequest = */ false);
GenerateInvokeRequest(apSuite, apContext, commandDatabuf, /* aIsTimedRequest = */ false, kTestCommandIdWithData);
err = commandHandler.ProcessInvokeRequest(std::move(commandDatabuf), false);

ChipLogDetail(DataManagement, "###################################### %s", err.AsString());
Expand All @@ -650,8 +688,8 @@ void TestCommandInteraction::TestCommandHandlerWithProcessReceivedNotExistComman
System::PacketBufferHandle commandDatabuf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize);

// Use some invalid endpoint / cluster / command.
GenerateInvokeRequest(apSuite, apContext, commandDatabuf, false /*aNeedCommandData*/, /* aIsTimedRequest = */ false,
0xDE /* endpoint */, 0xADBE /* cluster */, 0xEF /* command */);
GenerateInvokeRequest(apSuite, apContext, commandDatabuf, /* aIsTimedRequest = */ false, 0xEF /* command */,
0xADBE /* cluster */, 0xDE /* endpoint */);

// TODO: Need to find a way to get the response instead of only check if a function on key path is called.
// We should not reach CommandDispatch if requested command does not exist.
Expand All @@ -676,7 +714,7 @@ void TestCommandInteraction::TestCommandHandlerWithProcessReceivedEmptyDataMsg(n
commandHandler.mpExchangeCtx = ctx.NewExchangeToAlice(&delegate);

chip::isCommandDispatched = false;
GenerateInvokeRequest(apSuite, apContext, commandDatabuf, false /*aNeedCommandData*/, messageIsTimed);
GenerateInvokeRequest(apSuite, apContext, commandDatabuf, messageIsTimed, kTestCommandIdNoData);
err = commandHandler.ProcessInvokeRequest(std::move(commandDatabuf), transactionIsTimed);
NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(apSuite, chip::isCommandDispatched == (messageIsTimed == transactionIsTimed));
Expand Down Expand Up @@ -934,6 +972,7 @@ nlTestSuite sSuite =

int TestCommandInteraction()
{
chip::gSuite = &sSuite;
return chip::ExecuteTestsWithContext<TestContext>(&sSuite);
}

Expand Down
4 changes: 2 additions & 2 deletions src/lib/core/CHIPTLV.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@
namespace chip {
namespace TLV {

inline uint8_t operator|(TLVElementType lhs, TLVTagControl rhs)
constexpr inline uint8_t operator|(TLVElementType lhs, TLVTagControl rhs)
{
return static_cast<uint8_t>(lhs) | static_cast<uint8_t>(rhs);
}

inline uint8_t operator|(TLVTagControl lhs, TLVElementType rhs)
constexpr inline uint8_t operator|(TLVTagControl lhs, TLVElementType rhs)
{
return static_cast<uint8_t>(lhs) | static_cast<uint8_t>(rhs);
}
Expand Down

0 comments on commit a295bec

Please sign in to comment.