diff --git a/src/app/WriteClient.cpp b/src/app/WriteClient.cpp index 211d8db96214c0..26699e402b76a3 100644 --- a/src/app/WriteClient.cpp +++ b/src/app/WriteClient.cpp @@ -244,12 +244,21 @@ CHIP_ERROR WriteClient::SendWriteRequest(SessionHandle session, System::Clock::T // Create a new exchange context. mpExchangeCtx = mpExchangeMgr->NewContext(session, this); VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); - mpExchangeCtx->SetResponseTimeout(timeout); + if (session.IsGroupSession()) + { + // Exchange will be closed by WriteClientHandle::SendWriteRequest for group messages + err = mpExchangeCtx->SendMessage(Protocols::InteractionModel::MsgType::WriteRequest, std::move(packet), + Messaging::SendFlags(Messaging::SendMessageFlags::kNoAutoRequestAck)); + } + else + { + mpExchangeCtx->SetResponseTimeout(timeout); - err = mpExchangeCtx->SendMessage(Protocols::InteractionModel::MsgType::WriteRequest, std::move(packet), - Messaging::SendFlags(Messaging::SendMessageFlags::kExpectResponse)); - SuccessOrExit(err); - MoveToState(State::AwaitingResponse); + err = mpExchangeCtx->SendMessage(Protocols::InteractionModel::MsgType::WriteRequest, std::move(packet), + Messaging::SendFlags(Messaging::SendMessageFlags::kExpectResponse)); + SuccessOrExit(err); + MoveToState(State::AwaitingResponse); + } exit: if (err != CHIP_NO_ERROR) @@ -350,10 +359,11 @@ CHIP_ERROR WriteClientHandle::SendWriteRequest(SessionHandle session, System::Cl { CHIP_ERROR err = mpWriteClient->SendWriteRequest(session, timeout); - if (err == CHIP_NO_ERROR) + // Transferring ownership of the underlying WriteClient to the IM layer. IM will manage its lifetime. + // For groupcast writes, there is no transfer of ownership since the interaction is done upon transmission of the action + if (err == CHIP_NO_ERROR && !session.IsGroupSession()) { - // On success, the InteractionModelEngine will be responible to take care of the lifecycle of the WriteClient, so we release - // the WriteClient without closing it. + // Release the WriteClient without closing it. mpWriteClient = nullptr; } else diff --git a/src/app/tests/TestWriteInteraction.cpp b/src/app/tests/TestWriteInteraction.cpp index 1a5067e5cc0976..fd83dd21491365 100644 --- a/src/app/tests/TestWriteInteraction.cpp +++ b/src/app/tests/TestWriteInteraction.cpp @@ -44,6 +44,7 @@ class TestWriteInteraction { public: static void TestWriteClient(nlTestSuite * apSuite, void * apContext); + static void TestWriteClientGroup(nlTestSuite * apSuite, void * apContext); static void TestWriteHandler(nlTestSuite * apSuite, void * apContext); static void TestWriteRoundtrip(nlTestSuite * apSuite, void * apContext); static void TestWriteRoundtripWithClusterObjects(nlTestSuite * apSuite, void * apContext); @@ -233,6 +234,33 @@ void TestWriteInteraction::TestWriteClient(nlTestSuite * apSuite, void * apConte NL_TEST_ASSERT(apSuite, rm->TestGetCountRetransTable() == 0); } +void TestWriteInteraction::TestWriteClientGroup(nlTestSuite * apSuite, void * apContext) +{ + TestContext & ctx = *static_cast(apContext); + + CHIP_ERROR err = CHIP_NO_ERROR; + + app::WriteClient writeClient; + app::WriteClientHandle writeClientHandle; + writeClientHandle.SetWriteClient(&writeClient); + + System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize); + TestWriteClientCallback callback; + err = writeClient.Init(&ctx.GetExchangeManager(), &callback); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + AddAttributeDataIB(apSuite, apContext, writeClientHandle); + + SessionHandle groupSession = ctx.GetSessionBobToFriends(); + NL_TEST_ASSERT(apSuite, groupSession.IsGroupSession()); + + err = writeClientHandle.SendWriteRequest(groupSession); + + // Write will fail until issue #11078 is completed + NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_NOT_CONNECTED); + // The internal WriteClient should be shutdown once we SendWriteRequest for group. + NL_TEST_ASSERT(apSuite, nullptr == writeClientHandle.mpWriteClient); +} + void TestWriteInteraction::TestWriteHandler(nlTestSuite * apSuite, void * apContext) { TestContext & ctx = *static_cast(apContext); @@ -384,6 +412,7 @@ namespace { const nlTest sTests[] = { NL_TEST_DEF("CheckWriteClient", chip::app::TestWriteInteraction::TestWriteClient), + NL_TEST_DEF("CheckWriteClientGroup", chip::app::TestWriteInteraction::TestWriteClientGroup), NL_TEST_DEF("CheckWriteHandler", chip::app::TestWriteInteraction::TestWriteHandler), NL_TEST_DEF("CheckWriteRoundtrip", chip::app::TestWriteInteraction::TestWriteRoundtrip), NL_TEST_DEF("TestWriteRoundtripWithClusterObjects", chip::app::TestWriteInteraction::TestWriteRoundtripWithClusterObjects), diff --git a/src/messaging/tests/MessagingContext.cpp b/src/messaging/tests/MessagingContext.cpp index f680c40586a0dc..fa09f93cefcd4a 100644 --- a/src/messaging/tests/MessagingContext.cpp +++ b/src/messaging/tests/MessagingContext.cpp @@ -65,6 +65,11 @@ SessionHandle MessagingContext::GetSessionAliceToBob() return SessionHandle(GetBobNodeId(), GetAliceKeyId(), GetBobKeyId(), mDestFabricIndex); } +SessionHandle MessagingContext::GetSessionBobToFriends() +{ + return SessionHandle(GetBobKeyId(), GetFriendsGroupId(), GetFabricIndex()); +} + Messaging::ExchangeContext * MessagingContext::NewUnauthenticatedExchangeToAlice(Messaging::ExchangeDelegate * delegate) { return mExchangeManager.NewContext(mSessionManager.CreateUnauthenticatedSession(mAliceAddress).Value(), delegate); diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h index 4c8e9755fb02b1..0e9159648fd135 100644 --- a/src/messaging/tests/MessagingContext.h +++ b/src/messaging/tests/MessagingContext.h @@ -61,6 +61,7 @@ class MessagingContext uint16_t GetBobKeyId() const { return mBobKeyId; } uint16_t GetAliceKeyId() const { return mAliceKeyId; } + GroupId GetFriendsGroupId() const { return mFriendsGroupId; } void SetBobKeyId(uint16_t id) { mBobKeyId = id; } void SetAliceKeyId(uint16_t id) { mAliceKeyId = id; } @@ -78,6 +79,7 @@ class MessagingContext SessionHandle GetSessionBobToAlice(); SessionHandle GetSessionAliceToBob(); + SessionHandle GetSessionBobToFriends(); Messaging::ExchangeContext * NewUnauthenticatedExchangeToAlice(Messaging::ExchangeDelegate * delegate); Messaging::ExchangeContext * NewUnauthenticatedExchangeToBob(Messaging::ExchangeDelegate * delegate); @@ -94,10 +96,11 @@ class MessagingContext secure_channel::MessageCounterManager mMessageCounterManager; IOContext * mIOContext; - NodeId mBobNodeId = 123654; - NodeId mAliceNodeId = 111222333; - uint16_t mBobKeyId = 1; - uint16_t mAliceKeyId = 2; + NodeId mBobNodeId = 123654; + NodeId mAliceNodeId = 111222333; + uint16_t mBobKeyId = 1; + uint16_t mAliceKeyId = 2; + GroupId mFriendsGroupId = 517; Transport::PeerAddress mAliceAddress; Transport::PeerAddress mBobAddress; SecurePairingUsingTestSecret mPairingAliceToBob;