Skip to content

Commit 2c9864d

Browse files
authored
feat: add max-sigterm-delay flag (#64)
This is a port of GoogleCloudPlatform/cloud-sql-proxy#1256
1 parent 0fd062c commit 2c9864d

File tree

4 files changed

+140
-39
lines changed

4 files changed

+140
-39
lines changed

cmd/root.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ without having to manage any client SSL certificates.`,
139139
cmd.PersistentFlags().Uint64Var(&c.conf.MaxConnections, "max-connections", 0,
140140
`Limits the number of connections by refusing any additional connections.
141141
When this flag is not set, there is no limit.`)
142+
cmd.PersistentFlags().DurationVar(&c.conf.WaitOnClose, "max-sigterm-delay", 0,
143+
`Maximum amount of time to wait after for any open connections
144+
to close after receiving a TERM signal. The proxy will shut
145+
down when the number of open connections reaches 0 or when
146+
the maximum time has passed. Defaults to 0s.`)
147+
142148
cmd.PersistentFlags().StringVar(&c.telemetryProject, "telemetry-project", "",
143149
"Enable Cloud Monitoring and Cloud Trace integration with the provided project ID.")
144150
cmd.PersistentFlags().BoolVar(&c.disableTraces, "disable-traces", false,
@@ -389,7 +395,7 @@ func runSignalWrapper(cmd *Command) error {
389395
cmd.Println("The proxy has started successfully and is ready for new connections!")
390396
defer func() {
391397
if cErr := p.Close(); cErr != nil {
392-
cmd.PrintErrf("error during shutdown: %v\n", cErr)
398+
cmd.PrintErrf("The proxy failed to close cleanly: %v\n", cErr)
393399
}
394400
}()
395401

@@ -400,9 +406,9 @@ func runSignalWrapper(cmd *Command) error {
400406
err := <-shutdownCh
401407
switch {
402408
case errors.Is(err, errSigInt):
403-
cmd.PrintErrln("SIGINT signal received. Shuting down...")
409+
cmd.PrintErrln("SIGINT signal received. Shutting down...")
404410
case errors.Is(err, errSigTerm):
405-
cmd.PrintErrln("SIGTERM signal received. Shuting down...")
411+
cmd.PrintErrln("SIGTERM signal received. Shutting down...")
406412
default:
407413
cmd.PrintErrf("The proxy has encountered a terminal error: %v\n", err)
408414
}

cmd/root_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,13 @@ func TestNewCommandArguments(t *testing.T) {
170170
MaxConnections: 1,
171171
}),
172172
},
173+
{
174+
desc: "using wait after signterm flag",
175+
args: []string{"--max-sigterm-delay", "10s", "/projects/proj/locations/region/clusters/clust/instances/inst"},
176+
want: withDefaults(&proxy.Config{
177+
WaitOnClose: 10 * time.Second,
178+
}),
179+
},
173180
}
174181

175182
for _, tc := range tcs {

internal/proxy/proxy.go

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ type Config struct {
8585
// connections. A zero-value indicates no limit.
8686
MaxConnections uint64
8787

88+
// WaitOnClose sets the duration to wait for connections to close before
89+
// shutting down. Not setting this field means to close immediately
90+
// regardless of any open connections.
91+
WaitOnClose time.Duration
92+
8893
// Dialer specifies the dialer to use when connecting to AlloyDB
8994
// instances.
9095
Dialer alloydb.Dialer
@@ -172,6 +177,10 @@ type Client struct {
172177

173178
// mnts is a list of all mounted sockets for this client
174179
mnts []*socketMount
180+
181+
// waitOnClose is the maximum duration to wait for open connections to close
182+
// when shutting down.
183+
waitOnClose time.Duration
175184
}
176185

177186
// NewClient completes the initial setup required to get the proxy to a "steady" state.
@@ -210,10 +219,11 @@ func NewClient(ctx context.Context, cmd *cobra.Command, conf *Config) (*Client,
210219
}
211220

212221
c := &Client{
213-
mnts: mnts,
214-
cmd: cmd,
215-
dialer: d,
216-
maxConns: conf.MaxConnections,
222+
mnts: mnts,
223+
cmd: cmd,
224+
dialer: d,
225+
maxConns: conf.MaxConnections,
226+
waitOnClose: conf.WaitOnClose,
217227
}
218228
return c, nil
219229
}
@@ -262,16 +272,40 @@ func (m MultiErr) Error() string {
262272

263273
func (c *Client) Close() error {
264274
var mErr MultiErr
275+
// First, close all open socket listeners to prevent additional connections.
265276
for _, m := range c.mnts {
266277
err := m.Close()
267278
if err != nil {
268279
mErr = append(mErr, err)
269280
}
270281
}
282+
// Next, close the dialer to prevent any additional refreshes.
271283
cErr := c.dialer.Close()
272284
if cErr != nil {
273285
mErr = append(mErr, cErr)
274286
}
287+
if c.waitOnClose == 0 {
288+
if len(mErr) > 0 {
289+
return mErr
290+
}
291+
return nil
292+
}
293+
timeout := time.After(c.waitOnClose)
294+
tick := time.Tick(100 * time.Millisecond)
295+
for {
296+
select {
297+
case <-tick:
298+
if atomic.LoadUint64(&c.connCount) > 0 {
299+
continue
300+
}
301+
case <-timeout:
302+
}
303+
break
304+
}
305+
open := atomic.LoadUint64(&c.connCount)
306+
if open > 0 {
307+
mErr = append(mErr, fmt.Errorf("%d connection(s) still open after waiting %v", open, c.waitOnClose))
308+
}
275309
if len(mErr) > 0 {
276310
return mErr
277311
}

internal/proxy/proxy_test.go

Lines changed: 86 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ type errorDialer struct {
6565
fakeDialer
6666
}
6767

68-
func (errorDialer) Close() error {
68+
func (*errorDialer) Close() error {
6969
return errors.New("errorDialer returns error on Close")
7070
}
7171

@@ -143,15 +143,15 @@ func TestClientInitialization(t *testing.T) {
143143
desc: "with incrementing automatic port selection",
144144
in: &proxy.Config{
145145
Addr: "127.0.0.1",
146-
Port: 5432, // default port
146+
Port: 6000,
147147
Instances: []proxy.InstanceConnConfig{
148148
{Name: inst1},
149149
{Name: inst2},
150150
},
151151
},
152152
wantTCPAddrs: []string{
153-
"127.0.0.1:5432",
154-
"127.0.0.1:5433",
153+
"127.0.0.1:6000",
154+
"127.0.0.1:6001",
155155
},
156156
},
157157
{
@@ -238,25 +238,6 @@ func TestClientInitialization(t *testing.T) {
238238
}
239239
}
240240

241-
func tryTCPDial(t *testing.T, addr string) net.Conn {
242-
attempts := 10
243-
var (
244-
conn net.Conn
245-
err error
246-
)
247-
for i := 0; i < attempts; i++ {
248-
conn, err = net.Dial("tcp", addr)
249-
if err != nil {
250-
time.Sleep(100 * time.Millisecond)
251-
continue
252-
}
253-
return conn
254-
}
255-
256-
t.Fatalf("failed to dial in %v attempts: %v", attempts, err)
257-
return nil
258-
}
259-
260241
func TestClientLimitsMaxConnections(t *testing.T) {
261242
d := &fakeDialer{}
262243
in := &proxy.Config{
@@ -291,17 +272,92 @@ func TestClientLimitsMaxConnections(t *testing.T) {
291272
// wait only a second for the result (since nothing is writing to the
292273
// socket)
293274
conn2.SetReadDeadline(time.Now().Add(time.Second))
294-
_, rErr := conn2.Read(make([]byte, 1))
295-
if rErr != io.EOF {
296-
t.Fatalf("conn.Read should return io.EOF, got = %v", rErr)
275+
276+
wantEOF := func(t *testing.T, c net.Conn) {
277+
var got error
278+
for i := 0; i < 10; i++ {
279+
_, got = c.Read(make([]byte, 1))
280+
if got == io.EOF {
281+
return
282+
}
283+
time.Sleep(100 * time.Millisecond)
284+
}
285+
t.Fatalf("conn.Read should return io.EOF, got = %v", got)
297286
}
298287

288+
wantEOF(t, conn2)
289+
299290
want := 1
300291
if got := d.dialAttempts(); got != want {
301292
t.Fatalf("dial attempts did not match expected, want = %v, got = %v", want, got)
302293
}
303294
}
304295

296+
func tryTCPDial(t *testing.T, addr string) net.Conn {
297+
attempts := 10
298+
var (
299+
conn net.Conn
300+
err error
301+
)
302+
for i := 0; i < attempts; i++ {
303+
conn, err = net.Dial("tcp", addr)
304+
if err != nil {
305+
time.Sleep(100 * time.Millisecond)
306+
continue
307+
}
308+
return conn
309+
}
310+
311+
t.Fatalf("failed to dial in %v attempts: %v", attempts, err)
312+
return nil
313+
}
314+
315+
func TestClientCloseWaitsForActiveConnections(t *testing.T) {
316+
in := &proxy.Config{
317+
Addr: "127.0.0.1",
318+
Port: 5000,
319+
Instances: []proxy.InstanceConnConfig{
320+
{Name: "proj:region:pg"},
321+
},
322+
Dialer: &fakeDialer{},
323+
}
324+
c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in)
325+
if err != nil {
326+
t.Fatalf("proxy.NewClient error: %v", err)
327+
}
328+
go c.Serve(context.Background())
329+
330+
conn := tryTCPDial(t, "127.0.0.1:5000")
331+
_ = conn.Close()
332+
333+
if err := c.Close(); err != nil {
334+
t.Fatalf("c.Close error: %v", err)
335+
}
336+
337+
in.WaitOnClose = time.Second
338+
in.Port = 5001
339+
c, err = proxy.NewClient(context.Background(), &cobra.Command{}, in)
340+
if err != nil {
341+
t.Fatalf("proxy.NewClient error: %v", err)
342+
}
343+
go c.Serve(context.Background())
344+
345+
var open []net.Conn
346+
for i := 0; i < 5; i++ {
347+
conn = tryTCPDial(t, "127.0.0.1:5001")
348+
open = append(open, conn)
349+
}
350+
defer func() {
351+
for _, o := range open {
352+
o.Close()
353+
}
354+
}()
355+
356+
if err := c.Close(); err == nil {
357+
t.Fatal("c.Close should error, got = nil")
358+
}
359+
}
360+
305361
func TestClientClosesCleanly(t *testing.T) {
306362
in := &proxy.Config{
307363
Addr: "127.0.0.1",
@@ -316,12 +372,8 @@ func TestClientClosesCleanly(t *testing.T) {
316372
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
317373
}
318374
go c.Serve(context.Background())
319-
time.Sleep(time.Second) // allow the socket to start listening
320375

321-
conn, dErr := net.Dial("tcp", "127.0.0.1:5000")
322-
if dErr != nil {
323-
t.Fatalf("net.Dial error = %v", dErr)
324-
}
376+
conn := tryTCPDial(t, "127.0.0.1:5000")
325377
_ = conn.Close()
326378

327379
if err := c.Close(); err != nil {
@@ -343,7 +395,9 @@ func TestClosesWithError(t *testing.T) {
343395
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
344396
}
345397
go c.Serve(context.Background())
346-
time.Sleep(time.Second) // allow the socket to start listening
398+
399+
conn := tryTCPDial(t, "127.0.0.1:5000")
400+
defer conn.Close()
347401

348402
if err = c.Close(); err == nil {
349403
t.Fatal("c.Close() should error, got nil")

0 commit comments

Comments
 (0)