@@ -176,76 +176,21 @@ func NewClient(ctx context.Context, cmd *cobra.Command, conf *Config) (*Client,
176
176
}
177
177
}
178
178
179
- pc := newPortConfig (conf .Port )
180
179
var mnts []* socketMount
180
+ pc := newPortConfig (conf .Port )
181
181
for _ , inst := range conf .Instances {
182
- var (
183
- // network is one of "tcp" or "unix"
184
- network string
185
- // address is either a TCP host port, or a Unix socket
186
- address string
187
- )
188
- // IF
189
- // a global Unix socket directory is NOT set AND
190
- // an instance-level Unix socket is NOT set
191
- // (e.g., I didn't set a Unix socket globally or for this instance)
192
- // OR
193
- // an instance-level TCP address or port IS set
194
- // (e.g., I'm overriding any global settings to use TCP for this
195
- // instance)
196
- // use a TCP listener.
197
- // Otherwise, use a Unix socket.
198
- if (conf .UnixSocket == "" && inst .UnixSocket == "" ) ||
199
- (inst .Addr != "" || inst .Port != 0 ) {
200
- network = "tcp"
201
-
202
- a := conf .Addr
203
- if inst .Addr != "" {
204
- a = inst .Addr
205
- }
206
-
207
- var np int
208
- switch {
209
- case inst .Port != 0 :
210
- np = inst .Port
211
- case conf .Port != 0 :
212
- np = pc .nextPort ()
213
- default :
214
- np = pc .nextPort ()
215
- }
216
-
217
- address = net .JoinHostPort (a , fmt .Sprint (np ))
218
- } else {
219
- network = "unix"
220
-
221
- dir := conf .UnixSocket
222
- if dir == "" {
223
- dir = inst .UnixSocket
224
- }
225
- ud , err := UnixSocketDir (dir , inst .Name )
226
- if err != nil {
227
- return nil , err
228
- }
229
- // Create the parent directory that will hold the socket.
230
- if _ , err := os .Stat (ud ); err != nil {
231
- if err = os .Mkdir (ud , 0777 ); err != nil {
232
- return nil , err
233
- }
234
- }
235
- // use the Postgres-specific socket name
236
- address = filepath .Join (ud , ".s.PGSQL.5432" )
237
- }
238
-
239
- m := & socketMount {inst : inst .Name }
240
- addr , err := m .listen (ctx , network , address )
182
+ m , err := newSocketMount (ctx , conf , pc , inst )
241
183
if err != nil {
242
184
for _ , m := range mnts {
243
- m .close ()
185
+ mErr := m .Close ()
186
+ if mErr != nil {
187
+ cmd .PrintErrf ("failed to close mount: %v" , mErr )
188
+ }
244
189
}
245
190
return nil , fmt .Errorf ("[%v] Unable to mount socket: %v" , inst .Name , err )
246
191
}
247
192
248
- cmd .Printf ("[%s] Listening on %s\n " , inst .Name , addr . String ())
193
+ cmd .Printf ("[%s] Listening on %s\n " , inst .Name , m . Addr ())
249
194
mnts = append (mnts , m )
250
195
}
251
196
@@ -277,22 +222,45 @@ func (c *Client) Serve(ctx context.Context) error {
277
222
return <- exitCh
278
223
}
279
224
280
- // Close triggers the proxyClient to shutdown.
281
- func (c * Client ) Close () {
282
- defer c .dialer .Close ()
225
+ // MultiErr is a group of errors wrapped into one.
226
+ type MultiErr []error
227
+
228
+ // Error returns a single string representing one or more errors.
229
+ func (m MultiErr ) Error () string {
230
+ l := len (m )
231
+ if l == 1 {
232
+ return m [0 ].Error ()
233
+ }
234
+ var errs []string
235
+ for _ , e := range m {
236
+ errs = append (errs , e .Error ())
237
+ }
238
+ return strings .Join (errs , ", " )
239
+ }
240
+
241
+ func (c * Client ) Close () error {
242
+ var mErr MultiErr
283
243
for _ , m := range c .mnts {
284
- m .close ()
244
+ err := m .Close ()
245
+ if err != nil {
246
+ mErr = append (mErr , err )
247
+ }
248
+ }
249
+ cErr := c .dialer .Close ()
250
+ if cErr != nil {
251
+ mErr = append (mErr , cErr )
252
+ }
253
+ if len (mErr ) > 0 {
254
+ return mErr
285
255
}
256
+ return nil
286
257
}
287
258
288
259
// serveSocketMount persistently listens to the socketMounts listener and proxies connections to a
289
260
// given AlloyDB instance.
290
261
func (c * Client ) serveSocketMount (ctx context.Context , s * socketMount ) error {
291
- if s .listener == nil {
292
- return fmt .Errorf ("[%s] mount doesn't have a listener set" , s .inst )
293
- }
294
262
for {
295
- cConn , err := s .listener . Accept ()
263
+ cConn , err := s .Accept ()
296
264
if err != nil {
297
265
if nerr , ok := err .(net.Error ); ok && nerr .Temporary () {
298
266
c .cmd .PrintErrf ("[%s] Error accepting connection: %v\n " , s .inst , err )
@@ -327,22 +295,82 @@ type socketMount struct {
327
295
listener net.Listener
328
296
}
329
297
330
- // listen causes a socketMount to create a Listener at the specified network address.
331
- func (s * socketMount ) listen (ctx context.Context , network string , address string ) (net.Addr , error ) {
298
+ func newSocketMount (ctx context.Context , conf * Config , pc * portConfig , inst InstanceConnConfig ) (* socketMount , error ) {
299
+ var (
300
+ // network is one of "tcp" or "unix"
301
+ network string
302
+ // address is either a TCP host port, or a Unix socket
303
+ address string
304
+ )
305
+ // IF
306
+ // a global Unix socket directory is NOT set AND
307
+ // an instance-level Unix socket is NOT set
308
+ // (e.g., I didn't set a Unix socket globally or for this instance)
309
+ // OR
310
+ // an instance-level TCP address or port IS set
311
+ // (e.g., I'm overriding any global settings to use TCP for this
312
+ // instance)
313
+ // use a TCP listener.
314
+ // Otherwise, use a Unix socket.
315
+ if (conf .UnixSocket == "" && inst .UnixSocket == "" ) ||
316
+ (inst .Addr != "" || inst .Port != 0 ) {
317
+ network = "tcp"
318
+
319
+ a := conf .Addr
320
+ if inst .Addr != "" {
321
+ a = inst .Addr
322
+ }
323
+
324
+ var np int
325
+ switch {
326
+ case inst .Port != 0 :
327
+ np = inst .Port
328
+ default :
329
+ np = pc .nextPort ()
330
+ }
331
+
332
+ address = net .JoinHostPort (a , fmt .Sprint (np ))
333
+ } else {
334
+ network = "unix"
335
+
336
+ dir := conf .UnixSocket
337
+ if dir == "" {
338
+ dir = inst .UnixSocket
339
+ }
340
+ ud , err := UnixSocketDir (dir , inst .Name )
341
+ if err != nil {
342
+ return nil , err
343
+ }
344
+ // Create the parent directory that will hold the socket.
345
+ if _ , err := os .Stat (ud ); err != nil {
346
+ if err = os .Mkdir (ud , 0777 ); err != nil {
347
+ return nil , err
348
+ }
349
+ }
350
+ // use the Postgres-specific socket name
351
+ address = filepath .Join (ud , ".s.PGSQL.5432" )
352
+ }
353
+
332
354
lc := net.ListenConfig {KeepAlive : 30 * time .Second }
333
- l , err := lc .Listen (ctx , network , address )
355
+ ln , err := lc .Listen (ctx , network , address )
334
356
if err != nil {
335
357
return nil , err
336
358
}
337
- s .listener = l
338
- return s .listener .Addr (), nil
359
+ m := & socketMount {inst : inst .Name , listener : ln }
360
+ return m , nil
361
+ }
362
+
363
+ func (s * socketMount ) Addr () net.Addr {
364
+ return s .listener .Addr ()
365
+ }
366
+
367
+ func (s * socketMount ) Accept () (net.Conn , error ) {
368
+ return s .listener .Accept ()
339
369
}
340
370
341
371
// close stops the mount from listening for any more connections
342
- func (s * socketMount ) close () error {
343
- err := s .listener .Close ()
344
- s .listener = nil
345
- return err
372
+ func (s * socketMount ) Close () error {
373
+ return s .listener .Close ()
346
374
}
347
375
348
376
// proxyConn sets up a bidirectional copy between two open connections
0 commit comments