Skip to content

Commit 7a3f1e6

Browse files
committed
authmailbox: add clientRegistry for thread-safe client handling
Extract mutex-based client management into a new clientRegistry struct to ensure thread-safe access and modifications to the client subscriptions.
1 parent fd9fa23 commit 7a3f1e6

File tree

1 file changed

+123
-43
lines changed

1 file changed

+123
-43
lines changed

authmailbox/multi_subscription.go

Lines changed: 123 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,95 @@ type clientSubscriptions struct {
2626
cancels map[asset.SerializedKey]context.CancelFunc
2727
}
2828

29+
// clientRegistry is a thread-safe registry for managing mailbox clients.
30+
// It encapsulates the clients map and provides a safe API for accessing
31+
// and modifying client subscriptions.
32+
type clientRegistry struct {
33+
sync.RWMutex
34+
35+
// clients holds the active mailbox clients, keyed by their server URL.
36+
clients map[url.URL]*clientSubscriptions
37+
}
38+
39+
// newClientRegistry creates a new client registry instance.
40+
func newClientRegistry() *clientRegistry {
41+
return &clientRegistry{
42+
clients: make(map[url.URL]*clientSubscriptions),
43+
}
44+
}
45+
46+
// Get retrieves an existing client or creates a new one if it doesn't
47+
// exist. It returns the client and a boolean indicating whether the client
48+
// was newly created.
49+
func (r *clientRegistry) Get(serverURL url.URL,
50+
cfgCopy ClientConfig) (*clientSubscriptions, bool, error) {
51+
52+
r.Lock()
53+
defer r.Unlock()
54+
55+
client, ok := r.clients[serverURL]
56+
if ok {
57+
return client, false, nil
58+
}
59+
60+
// Create a new client connection.
61+
cfgCopy.ServerAddress = serverURL.Host
62+
mboxClient := NewClient(&cfgCopy)
63+
64+
client = &clientSubscriptions{
65+
client: mboxClient,
66+
subscriptions: make(
67+
map[asset.SerializedKey]ReceiveSubscription,
68+
),
69+
cancels: make(
70+
map[asset.SerializedKey]context.CancelFunc,
71+
),
72+
}
73+
r.clients[serverURL] = client
74+
75+
return client, true, nil
76+
}
77+
78+
// RemoveClient removes a client from the registry.
79+
func (r *clientRegistry) RemoveClient(serverURL url.URL) {
80+
r.Lock()
81+
defer r.Unlock()
82+
83+
delete(r.clients, serverURL)
84+
}
85+
86+
// AddSubscription adds a subscription and its cancel function to a client. If
87+
// the client does not exist, an error is returned.
88+
func (r *clientRegistry) AddSubscription(serverURL url.URL,
89+
key asset.SerializedKey, subscription ReceiveSubscription,
90+
cancel context.CancelFunc) error {
91+
92+
r.Lock()
93+
defer r.Unlock()
94+
95+
client, ok := r.clients[serverURL]
96+
if !ok {
97+
return fmt.Errorf("no client found for %s", serverURL.String())
98+
}
99+
100+
client.subscriptions[key] = subscription
101+
client.cancels[key] = cancel
102+
103+
return nil
104+
}
105+
106+
// ForEach executes a function for each client in the registry. The function
107+
// receives a copy of the client subscriptions to avoid holding the lock
108+
// during potentially long operations.
109+
func (r *clientRegistry) ForEach(fn func(*clientSubscriptions)) {
110+
r.RLock()
111+
defer r.RUnlock()
112+
113+
for _, client := range r.clients {
114+
fn(client)
115+
}
116+
}
117+
29118
// MultiSubscription is a subscription manager that can handle multiple mailbox
30119
// clients, allowing subscriptions to different accounts across different
31120
// mailbox servers. It manages subscriptions and message queues for each client
@@ -34,16 +123,14 @@ type MultiSubscription struct {
34123
// cfg holds the configuration for the MultiSubscription instance.
35124
cfg MultiSubscriptionConfig
36125

37-
// clients holds the active mailbox clients, keyed by their server URL.
38-
clients map[url.URL]*clientSubscriptions
126+
// registry manages the active mailbox clients in a thread-safe manner.
127+
registry *clientRegistry
39128

40129
// msgQueue is the concurrent queue that holds received messages from
41130
// all subscriptions across all clients. This allows for a unified
42131
// message channel that can be used to receive messages from any
43132
// subscribed account, regardless of which mailbox server it belongs to.
44133
msgQueue *lfn.ConcurrentQueue[*ReceivedMessages]
45-
46-
sync.RWMutex
47134
}
48135

49136
// MultiSubscriptionConfig holds the configuration parameters for creating a
@@ -65,7 +152,7 @@ func NewMultiSubscription(cfg MultiSubscriptionConfig) *MultiSubscription {
65152

66153
return &MultiSubscription{
67154
cfg: cfg,
68-
clients: make(map[url.URL]*clientSubscriptions),
155+
registry: newClientRegistry(),
69156
msgQueue: queue,
70157
}
71158
}
@@ -77,41 +164,33 @@ func NewMultiSubscription(cfg MultiSubscriptionConfig) *MultiSubscription {
77164
func (m *MultiSubscription) Subscribe(ctx context.Context, serverURL url.URL,
78165
receiverKey keychain.KeyDescriptor, filter MessageFilter) error {
79166

80-
// We hold the mutex for access to common resources.
81-
m.Lock()
167+
// Get or create a client for the given server URL. This call is
168+
// thread-safe and will handle locking internally.
82169
cfgCopy := m.cfg.BaseClientConfig
83-
client, ok := m.clients[serverURL]
170+
client, isNewClient, err := m.registry.Get(serverURL, cfgCopy)
171+
if err != nil {
172+
return err
173+
}
84174

85-
// If this is the first time we're seeing a server URL, we first create
86-
// a network connection to the mailbox server.
87-
if !ok {
88-
cfgCopy.ServerAddress = serverURL.Host
89-
90-
mboxClient := NewClient(&cfgCopy)
91-
client = &clientSubscriptions{
92-
client: mboxClient,
93-
subscriptions: make(
94-
map[asset.SerializedKey]ReceiveSubscription,
95-
),
96-
cancels: make(
97-
map[asset.SerializedKey]context.CancelFunc,
98-
),
99-
}
100-
m.clients[serverURL] = client
175+
// Start the mailbox client if it's not already started. This is safe to
176+
// do without holding any locks since the client itself manages its own
177+
// state.
178+
if isNewClient {
179+
log.Debugf("Starting new mailbox client for %s",
180+
serverURL.String())
101181

102-
err := mboxClient.Start()
182+
err = client.client.Start()
103183
if err != nil {
104-
m.Unlock()
105-
return fmt.Errorf("unable to create mailbox client: %w",
184+
// Remove the client from the map if we failed to start
185+
// it.
186+
m.registry.RemoveClient(serverURL)
187+
return fmt.Errorf("unable to start mailbox client: %w",
106188
err)
107189
}
108190
}
109191

110-
// We release the lock here again, because StartAccountSubscription
111-
// might block for a while, and we don't want to hold the lock
112-
// unnecessarily long.
113-
m.Unlock()
114-
192+
// Start the subscription. We don't hold any locks during this call
193+
// since StartAccountSubscription might block for a while.
115194
ctx, cancel := context.WithCancel(ctx)
116195
subscription, err := client.client.StartAccountSubscription(
117196
ctx, m.msgQueue.ChanIn(), receiverKey, filter,
@@ -122,13 +201,15 @@ func (m *MultiSubscription) Subscribe(ctx context.Context, serverURL url.URL,
122201
err)
123202
}
124203

125-
// We hold the lock again to safely add the subscription and cancel
126-
// function to the client's maps.
127-
m.Lock()
204+
// Add the subscription and cancel function to the client's maps.
205+
// This is thread-safe and handled internally by the registry.
128206
key := asset.ToSerialized(receiverKey.PubKey)
129-
client.subscriptions[key] = subscription
130-
client.cancels[key] = cancel
131-
m.Unlock()
207+
err = m.registry.AddSubscription(serverURL, key, subscription, cancel)
208+
if err != nil {
209+
cancel()
210+
return fmt.Errorf("unable to add subscription to registry: %w",
211+
err)
212+
}
132213

133214
return nil
134215
}
@@ -148,11 +229,10 @@ func (m *MultiSubscription) Stop() error {
148229

149230
log.Info("Stopping all mailbox clients and subscriptions...")
150231

151-
m.RLock()
152-
defer m.RUnlock()
153-
154232
var lastErr error
155-
for _, client := range m.clients {
233+
234+
// Iterate through all clients in a thread-safe manner and stop them.
235+
m.registry.ForEach(func(client *clientSubscriptions) {
156236
for _, cancel := range client.cancels {
157237
cancel()
158238
}
@@ -170,7 +250,7 @@ func (m *MultiSubscription) Stop() error {
170250
log.Errorf("Error stopping client: %v", err)
171251
lastErr = err
172252
}
173-
}
253+
})
174254

175255
return lastErr
176256
}

0 commit comments

Comments
 (0)