Skip to content

Commit

Permalink
groupinfo: look up ACI from local devices when receiving PNI group me…
Browse files Browse the repository at this point in the history
…mber (#528)
  • Loading branch information
maltee1 authored Aug 9, 2024
1 parent f3a286f commit 77cf7e7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
42 changes: 30 additions & 12 deletions pkg/connector/groupinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"time"

"github.com/google/uuid"
"github.com/rs/zerolog"
"go.mau.fi/util/ptr"
"maunium.net/go/mautrix/bridgev2"
Expand Down Expand Up @@ -120,11 +121,12 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden
}
}
for _, member := range groupInfo.PendingMembers {
if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI {
aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID)
if aci == nil {
continue
}
members.Members = append(members.Members, bridgev2.ChatMember{
EventSender: s.makeEventSender(member.ServiceID.UUID),
EventSender: s.makeEventSender(*aci),
PowerLevel: roleToPL(member.Role),
Membership: event.MembershipInvite,
})
Expand All @@ -136,11 +138,12 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden
})
}
for _, member := range groupInfo.BannedMembers {
if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI {
aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID)
if aci == nil {
continue
}
members.Members = append(members.Members, bridgev2.ChatMember{
EventSender: s.makeEventSender(member.ServiceID.UUID),
EventSender: s.makeEventSender(*aci),
Membership: event.MembershipBan,
})
}
Expand Down Expand Up @@ -243,21 +246,23 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint
})
}
for _, member := range groupChange.AddPendingMembers {
if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI {
aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID)
if aci == nil {
continue
}
mc = append(mc, bridgev2.ChatMember{
EventSender: s.makeEventSender(member.ServiceID.UUID),
EventSender: s.makeEventSender(*aci),
PowerLevel: roleToPL(member.Role),
Membership: event.MembershipInvite,
})
}
for _, memberServiceID := range groupChange.DeletePendingMembers {
if memberServiceID.Type != libsignalgo.ServiceIDTypeACI {
aci := s.maybeResolvePNItoACI(ctx, memberServiceID)
if aci == nil {
continue
}
mc = append(mc, bridgev2.ChatMember{
EventSender: s.makeEventSender(memberServiceID.UUID),
EventSender: s.makeEventSender(*aci),
Membership: event.MembershipLeave,
PrevMembership: event.MembershipInvite,
})
Expand All @@ -276,20 +281,22 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint
})
}
for _, member := range groupChange.AddBannedMembers {
if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI {
aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID)
if aci == nil {
continue
}
mc = append(mc, bridgev2.ChatMember{
EventSender: s.makeEventSender(member.ServiceID.UUID),
EventSender: s.makeEventSender(*aci),
Membership: event.MembershipBan,
})
}
for _, memberServiceID := range groupChange.DeleteBannedMembers {
if memberServiceID.Type != libsignalgo.ServiceIDTypeACI {
aci := s.maybeResolvePNItoACI(ctx, memberServiceID)
if aci == nil {
continue
}
mc = append(mc, bridgev2.ChatMember{
EventSender: s.makeEventSender(memberServiceID.UUID),
EventSender: s.makeEventSender(*aci),
Membership: event.MembershipLeave,
PrevMembership: event.MembershipBan,
})
Expand All @@ -300,6 +307,17 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint
return ic
}

func (s *SignalClient) maybeResolvePNItoACI(ctx context.Context, serviceID *libsignalgo.ServiceID) *uuid.UUID {
if serviceID.Type == libsignalgo.ServiceIDTypeACI {
return &serviceID.UUID
}
device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, serviceID.UUID)
if err != nil || device == nil {
return nil
}
return &device.ACI
}

func (s *SignalClient) catchUpGroup(ctx context.Context, portal *bridgev2.Portal, fromRevision, toRevision uint32, ts uint64) {
if fromRevision >= toRevision {
return
Expand Down
10 changes: 10 additions & 0 deletions pkg/signalmeow/store/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ var _ DeviceStore = (*Container)(nil)
type DeviceStore interface {
PutDevice(ctx context.Context, dd *DeviceData) error
DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, error)
DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error)
}

// Container is a wrapper for a SQL database that can contain multiple signalmeow sessions.
Expand All @@ -39,6 +40,7 @@ FROM signalmeow_device
`

const getDeviceQuery = getAllDevicesQuery + " WHERE aci_uuid=$1"
const deviceByPNIQuery = getAllDevicesQuery + "WHERE pni_uuid=$1"

func (c *Container) Upgrade(ctx context.Context) error {
return c.db.Upgrade(ctx)
Expand Down Expand Up @@ -122,6 +124,14 @@ func (c *Container) DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, er
return sess, err
}

func (c *Container) DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error) {
sess, err := c.scanDevice(c.db.QueryRow(ctx, deviceByPNIQuery, pni))
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return sess, err
}

const (
insertDeviceQuery = `
INSERT INTO signalmeow_device (
Expand Down

0 comments on commit 77cf7e7

Please sign in to comment.