-
Notifications
You must be signed in to change notification settings - Fork 3.6k
/
Copy pathserver.go
147 lines (126 loc) · 3.77 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
package http
import (
"context"
"net"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/influxdata/influxdb/logger"
"go.uber.org/zap"
)
// DefaultShutdownTimeout is the default timeout for shutting down the http server.
const DefaultShutdownTimeout = 20 * time.Second
// Server is an abstraction around the http.Server that handles a server process.
// It manages the full lifecycle of a server by serving a handler on a socket.
// If signals have been registered, it will attempt to terminate the server using
// Shutdown if a signal is received and will force a shutdown if a second signal
// is received.
type Server struct {
ShutdownTimeout time.Duration
srv *http.Server
signals map[os.Signal]struct{}
logger *zap.Logger
wg sync.WaitGroup
}
// NewServer returns a new server struct that can be used.
func NewServer(handler http.Handler, logger *zap.Logger) *Server {
if logger == nil {
logger = zap.NewNop()
}
return &Server{
ShutdownTimeout: DefaultShutdownTimeout,
srv: &http.Server{
Handler: handler,
ErrorLog: zap.NewStdLog(logger),
},
logger: logger,
}
}
// Serve will run the server using the listener to accept connections.
func (s *Server) Serve(listener net.Listener) error {
// When we return, wait for all pending goroutines to finish.
defer s.wg.Wait()
signalCh, cancel := s.notifyOnSignals()
defer cancel()
errCh := s.serve(listener)
select {
case err := <-errCh:
// The server has failed and reported an error.
return err
case <-signalCh:
// We have received an interrupt. Signal the shutdown process.
return s.shutdown(signalCh)
}
}
func (s *Server) serve(listener net.Listener) <-chan error {
s.wg.Add(1)
errCh := make(chan error, 1)
go func() {
defer s.wg.Done()
if err := s.srv.Serve(listener); err != nil {
errCh <- err
}
close(errCh)
}()
return errCh
}
func (s *Server) shutdown(signalCh <-chan os.Signal) error {
s.logger.Info("Shutting down server", logger.DurationLiteral("timeout", s.ShutdownTimeout))
// The shutdown needs to succeed in 20 seconds or less.
ctx, cancel := context.WithTimeout(context.Background(), s.ShutdownTimeout)
defer cancel()
// Wait for another signal to cancel the shutdown.
done := make(chan struct{})
defer close(done)
s.wg.Add(1)
go func() {
defer s.wg.Done()
select {
case <-signalCh:
s.logger.Info("Initializing hard shutdown")
cancel()
case <-done:
}
}()
return s.srv.Shutdown(ctx)
}
// ListenForSignals registers the the server to listen for the given signals
// to shutdown the server. The signals are not captured until Serve is called.
func (s *Server) ListenForSignals(signals ...os.Signal) {
if s.signals == nil {
s.signals = make(map[os.Signal]struct{})
}
for _, sig := range signals {
s.signals[sig] = struct{}{}
}
}
func (s *Server) notifyOnSignals() (_ <-chan os.Signal, cancel func()) {
if len(s.signals) == 0 {
return nil, func() {}
}
// Retrieve which signals we want to be notified on.
signals := make([]os.Signal, 0, len(s.signals))
for sig := range s.signals {
signals = append(signals, sig)
}
// Create the signal channel and mark ourselves to be notified
// of signals. Allow up to two signals for each signal type we catch.
signalCh := make(chan os.Signal, len(signals)*2)
signal.Notify(signalCh, signals...)
return signalCh, func() { signal.Stop(signalCh) }
}
// ListenAndServe is a convenience method for opening a listener using the address
// and then serving the handler on that address. This method sets up the typical
// signal handlers.
func ListenAndServe(addr string, handler http.Handler, logger *zap.Logger) error {
l, err := net.Listen("tcp", addr)
if err != nil {
return err
}
server := NewServer(handler, logger)
server.ListenForSignals(os.Interrupt, syscall.SIGTERM)
return server.Serve(l)
}