Skip to content

Commit

Permalink
feat(MLS): allow creating mls conversations with partial success [WPB…
Browse files Browse the repository at this point in the history
…-3694] (#2623)

When calling `establishMLSGroup` inside `MLSConversationRepository`, an additional parameter called `allowSkippingUsersWithoutKeyPackages` can be passed.
It's `false` by default, and the only place where it is using `true` is when creating a group conversation.

(cherry picked from commit 79a7a57)
  • Loading branch information
vitorhugods committed Mar 13, 2024
1 parent ded4598 commit dea98f1
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,24 +165,33 @@ internal class ConversationGroupRepositoryImpl(
}.flatMap {
newGroupConversationSystemMessagesCreator.value.conversationStarted(conversationEntity)
}.flatMap {
when (protocol) {
is Conversation.ProtocolInfo.Proteus -> Either.Right(setOf())
is Conversation.ProtocolInfo.MLSCapable -> mlsConversationRepository.establishMLSGroup(
groupID = protocol.groupId,
members = usersList + selfUserId,
allowSkippingUsersWithoutKeyPackages = true
).map { it.notAddedUsers }
}
}.flatMap { protocolSpecificAdditionFailures ->
newConversationMembersRepository.persistMembersAdditionToTheConversation(
conversationEntity.id, conversationResponse
).flatMap {
if (protocolSpecificAdditionFailures.isEmpty()) {
Either.Right(Unit)
} else {
newGroupConversationSystemMessagesCreator.value.conversationFailedToAddMembers(
conversationEntity.id.toModel(), protocolSpecificAdditionFailures.toList(), FailedToAdd.Type.Unknown
)
}
}.flatMap {
when (lastUsersAttempt) {
is LastUsersAttempt.None -> Either.Right(Unit)
is LastUsersAttempt.Failed ->
newGroupConversationSystemMessagesCreator.value.conversationFailedToAddMembers(
conversationEntity.id.toModel(), lastUsersAttempt.failedUsers, lastUsersAttempt.failType
)
}
}.flatMap {
when (protocol) {
is Conversation.ProtocolInfo.Proteus -> Either.Right(Unit)
is Conversation.ProtocolInfo.MLSCapable -> mlsConversationRepository.establishMLSGroup(
groupID = protocol.groupId,
members = usersList + selfUserId
)
}
}
}.flatMap {
wrapStorageRequest {
Expand Down Expand Up @@ -325,6 +334,7 @@ internal class ConversationGroupRepositoryImpl(
is LastUsersAttempt.Failed -> newGroupConversationSystemMessagesCreator.value.conversationFailedToAddMembers(
conversationId, lastUsersAttempt.failedUsers, lastUsersAttempt.failType
)

is LastUsersAttempt.None -> Either.Right(Unit)
}

Expand Down Expand Up @@ -579,8 +589,10 @@ internal class ConversationGroupRepositoryImpl(
): Either<CoreFailure, ValidToInvalidUsers> = when {
failure is NetworkFailure.FederatedBackendFailure.RetryableFailure ->
Either.Right(extractValidUsersForRetryableFederationError(userIdList, failure))

failure.isMissingLegalHoldConsentError ->
fetchAndExtractValidUsersForRetryableLegalHoldError(userIdList)

else ->
Either.Right(ValidToInvalidUsers(userIdList, emptyList(), FailedToAdd.Type.Unknown))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.flatMapLeft
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.functional.onSuccess
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.logStructuredJson
Expand Down Expand Up @@ -194,7 +195,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
"protocolInfo" to conversation.protocol.toLogMap(),
)
)
}
}.map { Unit }
}

type == Conversation.Type.ONE_ON_ONE -> {
Expand All @@ -214,7 +215,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
"protocolInfo" to conversation.protocol.toLogMap(),
)
)
}
}.map { Unit }
}

else -> Either.Right(Unit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.configuration.server.ServerConfig
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.mls.MLSAdditionResult
import com.wire.kalium.logic.data.e2ei.CertificateRevocationListRepository
import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.event.EventDeliveryInfo
Expand Down Expand Up @@ -118,7 +119,28 @@ data class E2EIdentity(
@Suppress("TooManyFunctions", "LongParameterList")
interface MLSConversationRepository {
suspend fun decryptMessage(message: ByteArray, groupID: GroupID): Either<CoreFailure, List<DecryptedMessageBundle>>
suspend fun establishMLSGroup(groupID: GroupID, members: List<UserId>): Either<CoreFailure, Unit>

/**
* Establishes an MLS (Messaging Layer Security) group with the specified group ID and members.
*
* Allows partial addition of members through the [allowSkippingUsersWithoutKeyPackages] parameter.
* If this parameter is set to true, users without key packages will be ignored and the rest will be added to the group.
*
* @param groupID The ID of the group to be established. Must be of type [GroupID].
* @param members The list of user IDs (of type [UserId]) to be added as members to the group.
* @param allowSkippingUsersWithoutKeyPackages Flag indicating whether to allow a partial member list in case of some users
* not having key packages available. Default value is false. If false, will return [Either.Left] containing
* [CoreFailure.MissingKeyPackages] for the missing users.
* @return An instance of [Either] indicating the result of the operation. It can be either [Either.Right] if the
* group was successfully established, or [Either.Left] if an error occurred. If successful, returns [Unit].
* Possible types of [Either.Left] are defined in the sealed interface [CoreFailure].
*/
suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
allowSkippingUsersWithoutKeyPackages: Boolean = false
): Either<CoreFailure, MLSAdditionResult>

suspend fun establishMLSSubConversationGroup(groupID: GroupID, parentId: ConversationId): Either<CoreFailure, Unit>
suspend fun hasEstablishedMLSGroup(groupID: GroupID): Either<CoreFailure, Boolean>
suspend fun addMemberToMLSGroup(groupID: GroupID, userIdList: List<UserId>): Either<CoreFailure, Unit>
Expand Down Expand Up @@ -428,25 +450,30 @@ internal class MLSConversationDataSource(
)

override suspend fun addMemberToMLSGroup(groupID: GroupID, userIdList: List<UserId>): Either<CoreFailure, Unit> =
internalAddMemberToMLSGroup(groupID, userIdList, retryOnStaleMessage = true)
internalAddMemberToMLSGroup(
groupID = groupID,
userIdList = userIdList,
retryOnStaleMessage = true,
allowPartialMemberList = false
).map { Unit }

private suspend fun internalAddMemberToMLSGroup(
groupID: GroupID,
userIdList: List<UserId>,
retryOnStaleMessage: Boolean
): Either<CoreFailure, Unit> = withContext(serialDispatcher) {
retryOnStaleMessage: Boolean,
allowPartialMemberList: Boolean = false,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
commitPendingProposals(groupID).flatMap {
produceAndSendCommitWithRetry(groupID, retryOnStaleMessage = retryOnStaleMessage) {
produceAndSendCommitWithRetryAndResult(groupID, retryOnStaleMessage = retryOnStaleMessage) {
keyPackageRepository.claimKeyPackages(userIdList).flatMap { result ->
if (result.usersWithoutKeyPackagesAvailable.isNotEmpty()) {
if (result.usersWithoutKeyPackagesAvailable.isNotEmpty() && !allowPartialMemberList) {
Either.Left(CoreFailure.MissingKeyPackages(result.usersWithoutKeyPackagesAvailable))
} else {
Either.Right(result)
}
}.flatMap { result ->
val keyPackages = result.successfullyFetchedKeyPackages
val clientKeyPackageList = keyPackages.map { it.keyPackage.decodeBase64Bytes() }

wrapMLSRequest {
if (userIdList.isEmpty()) {
// We are creating a group with only our self client which technically
Expand All @@ -460,6 +487,12 @@ internal class MLSConversationDataSource(
commitBundle?.crlNewDistributionPoints?.let { revocationList ->
checkRevocationList(revocationList)
}
}.map {
val additionResult = MLSAdditionResult(
result.successfullyFetchedKeyPackages.map { user -> UserId(user.userId, user.domain) }.toSet(),
result.usersWithoutKeyPackagesAvailable.toSet()
)
CommitOperationResult(it, additionResult)
}
}
}
Expand Down Expand Up @@ -517,11 +550,17 @@ internal class MLSConversationDataSource(

override suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>
): Either<CoreFailure, Unit> = withContext(serialDispatcher) {
members: List<UserId>,
allowSkippingUsersWithoutKeyPackages: Boolean,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsPublicKeysRepository.getKeys().flatMap { publicKeys ->
val keys = publicKeys.map { mlsPublicKeysMapper.toCrypto(it) }
establishMLSGroup(groupID, members, keys)
establishMLSGroup(
groupID = groupID,
members = members,
keys = keys,
allowPartialMemberList = allowSkippingUsersWithoutKeyPackages
)
}
}

Expand All @@ -532,16 +571,22 @@ internal class MLSConversationDataSource(
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
conversationDAO.getMLSGroupIdByConversationId(parentId.toDao())?.let { parentGroupId ->
val externalSenderKey = mlsClient.getExternalSenders(GroupID(parentGroupId).toCrypto())
establishMLSGroup(groupID, emptyList(), listOf(mlsPublicKeysMapper.toCrypto(externalSenderKey)))
establishMLSGroup(
groupID = groupID,
members = emptyList(),
keys = listOf(mlsPublicKeysMapper.toCrypto(externalSenderKey)),
allowPartialMemberList = false
).map { Unit }
} ?: Either.Left(StorageFailure.DataNotFound)
}
}

private suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
keys: List<Ed22519Key>
): Either<CoreFailure, Unit> = withContext(serialDispatcher) {
keys: List<Ed22519Key>,
allowPartialMemberList: Boolean = false,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.createConversation(
Expand All @@ -555,18 +600,23 @@ internal class MLSConversationDataSource(
Either.Left(it)
}
}.flatMap {
internalAddMemberToMLSGroup(groupID, members, retryOnStaleMessage = false).onFailure {
internalAddMemberToMLSGroup(
groupID = groupID,
userIdList = members,
retryOnStaleMessage = false,
allowPartialMemberList = allowPartialMemberList
).onFailure {
wrapMLSRequest {
mlsClient.wipeConversation(groupID.toCrypto())
}
}
}.flatMap {
}.flatMap { additionResult ->
wrapStorageRequest {
conversationDAO.updateConversationGroupState(
ConversationEntity.GroupState.ESTABLISHED,
idMapper.toGroupIDEntity(groupID)
)
}
}.map { additionResult }
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import com.wire.kalium.persistence.dao.member.MemberDAO
* Either all users are added or some of them could fail to be added.
*/
internal interface NewConversationMembersRepository {
// TODO(refactor): Use Set<UserId> instead of List to avoid duplications
suspend fun persistMembersAdditionToTheConversation(
conversationId: ConversationIDEntity,
conversationResponse: ConversationResponse,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,8 @@ data class KeyPackageClaimResult(
val successfullyFetchedKeyPackages: List<KeyPackageDTO>,
val usersWithoutKeyPackagesAvailable: Set<UserId>
)

data class MLSAdditionResult(
val successfullyAddedUsers: Set<UserId>,
val notAddedUsers: Set<UserId>
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.conversation.mls.MLSAdditionResult
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.SelfTeamIdProvider
Expand Down Expand Up @@ -209,7 +210,7 @@ class ConversationGroupRepositoryTest {

verify(newGroupConversationSystemMessagesCreator)
.suspendFunction(newGroupConversationSystemMessagesCreator::conversationFailedToAddMembers)
.with(anything(), eq(listOf(unreachableUserId)), eq(MessageContent.MemberChange.FailedToAdd.Type.Federation))
.with(anything(), eq(listOf(unreachableUserId)), eq(MessageContent.MemberChange.FailedToAdd.Type.Federation))
.wasInvoked(once)
}
}
Expand Down Expand Up @@ -269,7 +270,7 @@ class ConversationGroupRepositoryTest {
.withCreateNewConversationAPIResponses(arrayOf(NetworkResponse.Success(conversationResponse, emptyMap(), 201)))
.withSelfTeamId(Either.Right(TestUser.SELF.teamId))
.withInsertConversationSuccess()
.withMlsConversationEstablished()
.withMlsConversationEstablished(MLSAdditionResult(setOf(TestUser.USER_ID), emptySet()))
.withConversationDetailsById(TestConversation.GROUP_VIEW_ENTITY(PROTEUS_PROTOCOL_INFO))
.withSuccessfulNewConversationGroupStartedHandled()
.withSuccessfulNewConversationMemberHandled()
Expand All @@ -292,7 +293,7 @@ class ConversationGroupRepositoryTest {

verify(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::establishMLSGroup)
.with(anything(), anything())
.with(anything(), anything(), eq(true))
.wasInvoked(once)

verify(newConversationMembersRepository)
Expand All @@ -302,6 +303,56 @@ class ConversationGroupRepositoryTest {
}
}

@Test
fun givenMLSProtocolIsUsedAndSomeUsersAreNotAddedToMLSGroup_whenCallingCreateGroupConversation_thenMissingMembersArePersisted() =
runTest {
val conversationResponse = CONVERSATION_RESPONSE.copy(protocol = MLS)
val missingMembersFromMLSGroup = setOf(TestUser.OTHER_USER_ID, TestUser.OTHER_USER_ID_2)
val successfullyAddedUsers = setOf(TestUser.USER_ID)
val allWantedMembers = successfullyAddedUsers + missingMembersFromMLSGroup
val (arrangement, conversationGroupRepository) = Arrangement()
.withCreateNewConversationAPIResponses(arrayOf(NetworkResponse.Success(conversationResponse, emptyMap(), 201)))
.withSelfTeamId(Either.Right(TestUser.SELF.teamId))
.withInsertConversationSuccess()
.withMlsConversationEstablished(MLSAdditionResult(setOf(TestUser.USER_ID), notAddedUsers = missingMembersFromMLSGroup))
.withConversationDetailsById(TestConversation.GROUP_VIEW_ENTITY(PROTEUS_PROTOCOL_INFO))
.withSuccessfulNewConversationGroupStartedHandled()
.withSuccessfulNewConversationMemberHandled()
.withSuccessfulNewConversationGroupStartedUnverifiedWarningHandled()
.withInsertFailedToAddSystemMessageSuccess()
.arrange()

val result = conversationGroupRepository.createGroupConversation(
GROUP_NAME,
allWantedMembers.toList(),
ConversationOptions(protocol = ConversationOptions.Protocol.MLS)
)

result.shouldSucceed()

with(arrangement) {
verify(conversationDAO)
.suspendFunction(conversationDAO::insertConversation)
.with(anything())
.wasInvoked(once)

verify(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::establishMLSGroup)
.with(anything(), anything(), eq(true))
.wasInvoked(once)

verify(newConversationMembersRepository)
.suspendFunction(newConversationMembersRepository::persistMembersAdditionToTheConversation)
.with(anything(), anything())
.wasInvoked(once)

verify(arrangement.newGroupConversationSystemMessagesCreator)
.suspendFunction(arrangement.newGroupConversationSystemMessagesCreator::conversationFailedToAddMembers)
.with(anything(), matching { it.containsAll(missingMembersFromMLSGroup) })
.wasInvoked(once)
}
}

@Test
fun givenProteusConversation_whenAddingMembersToConversation_thenShouldSucceed() = runTest {
val (arrangement, conversationGroupRepository) = Arrangement()
Expand Down Expand Up @@ -1489,11 +1540,11 @@ class ConversationGroupRepositoryTest {
selfTeamIdProvider
)

fun withMlsConversationEstablished(): Arrangement {
fun withMlsConversationEstablished(additionResult: MLSAdditionResult): Arrangement {
given(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::establishMLSGroup)
.whenInvokedWith(anything(), anything())
.thenReturn(Either.Right(Unit))
.thenReturn(Either.Right(additionResult))
return this
}

Expand Down
Loading

0 comments on commit dea98f1

Please sign in to comment.