Skip to content

Commit

Permalink
look up aci when receiving pni
Browse files Browse the repository at this point in the history
  • Loading branch information
maltee1 committed Aug 8, 2024
1 parent 24ca68a commit 5ff230d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 13 deletions.
66 changes: 53 additions & 13 deletions pkg/connector/groupinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"time"

"github.com/google/uuid"
"github.com/rs/zerolog"
"go.mau.fi/util/ptr"
"maunium.net/go/mautrix/bridgev2"
Expand Down Expand Up @@ -120,11 +121,18 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden
}
}
for _, member := range groupInfo.PendingMembers {
if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI {
continue
var aci uuid.UUID
if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI {
device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID)
if err != nil {
continue
}
aci = device.ACI
} else {
aci = member.ServiceID.UUID
}
members.Members = append(members.Members, bridgev2.ChatMember{
EventSender: s.makeEventSender(member.ServiceID.UUID),
EventSender: s.makeEventSender(aci),
PowerLevel: roleToPL(member.Role),
Membership: event.MembershipInvite,
})
Expand All @@ -136,11 +144,18 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden
})
}
for _, member := range groupInfo.BannedMembers {
if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI {
continue
var aci uuid.UUID
if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI {
device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID)
if err != nil {
continue
}
aci = device.ACI
} else {
aci = member.ServiceID.UUID
}
members.Members = append(members.Members, bridgev2.ChatMember{
EventSender: s.makeEventSender(member.ServiceID.UUID),
EventSender: s.makeEventSender(aci),
Membership: event.MembershipBan,
})
}
Expand Down Expand Up @@ -246,18 +261,36 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint
if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI {
continue
}
var aci uuid.UUID
if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI {
device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID)
if err != nil {
continue
}
aci = device.ACI
} else {
aci = member.ServiceID.UUID
}

mc = append(mc, bridgev2.ChatMember{
EventSender: s.makeEventSender(member.ServiceID.UUID),
EventSender: s.makeEventSender(aci),
PowerLevel: roleToPL(member.Role),
Membership: event.MembershipInvite,
})
}
for _, memberServiceID := range groupChange.DeletePendingMembers {
if memberServiceID.Type != libsignalgo.ServiceIDTypeACI {
continue
var aci uuid.UUID
if memberServiceID.Type == libsignalgo.ServiceIDTypePNI {
device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, memberServiceID.UUID)
if err != nil {
continue
}
aci = device.ACI
} else {
aci = memberServiceID.UUID
}
mc = append(mc, bridgev2.ChatMember{
EventSender: s.makeEventSender(memberServiceID.UUID),
EventSender: s.makeEventSender(aci),
Membership: event.MembershipLeave,
PrevMembership: event.MembershipInvite,
})
Expand All @@ -276,11 +309,18 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint
})
}
for _, member := range groupChange.AddBannedMembers {
if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI {
continue
var aci uuid.UUID
if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI {
device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID)
if err != nil {
continue
}
aci = device.ACI
} else {
aci = member.ServiceID.UUID
}
mc = append(mc, bridgev2.ChatMember{
EventSender: s.makeEventSender(member.ServiceID.UUID),
EventSender: s.makeEventSender(aci),
Membership: event.MembershipBan,
})
}
Expand Down
10 changes: 10 additions & 0 deletions pkg/signalmeow/store/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ var _ DeviceStore = (*Container)(nil)
type DeviceStore interface {
PutDevice(ctx context.Context, dd *DeviceData) error
DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, error)
DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error)
}

// Container is a wrapper for a SQL database that can contain multiple signalmeow sessions.
Expand All @@ -39,6 +40,7 @@ FROM signalmeow_device
`

const getDeviceQuery = getAllDevicesQuery + " WHERE aci_uuid=$1"
const deviceByPNIQuery = getAllDevicesQuery + "Where pni_uuid=$1"

func (c *Container) Upgrade(ctx context.Context) error {
return c.db.Upgrade(ctx)
Expand Down Expand Up @@ -122,6 +124,14 @@ func (c *Container) DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, er
return sess, err
}

func (c *Container) DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error) {
sess, err := c.scanDevice(c.db.QueryRow(ctx, deviceByPNIQuery, pni))
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return sess, err
}

const (
insertDeviceQuery = `
INSERT INTO signalmeow_device (
Expand Down

0 comments on commit 5ff230d

Please sign in to comment.