Skip to content

Commit

Permalink
fix(http_listener_v2): fix panic on close (#10132)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmalek authored Dec 10, 2021
1 parent 039c968 commit 1b95720
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 29 deletions.
78 changes: 51 additions & 27 deletions plugins/inputs/http_listener_v2/http_listener_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"compress/gzip"
"crypto/subtle"
"crypto/tls"
"errors"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -50,12 +51,15 @@ type HTTPListenerV2 struct {
BasicUsername string `toml:"basic_username"`
BasicPassword string `toml:"basic_password"`
HTTPHeaderTags map[string]string `toml:"http_header_tags"`

tlsint.ServerConfig
tlsConf *tls.Config

TimeFunc
Log telegraf.Logger

wg sync.WaitGroup
wg sync.WaitGroup
close chan struct{}

listener net.Listener

Expand Down Expand Up @@ -154,44 +158,34 @@ func (h *HTTPListenerV2) Start(acc telegraf.Accumulator) error {

h.acc = acc

tlsConf, err := h.ServerConfig.TLSConfig()
if err != nil {
return err
}

server := &http.Server{
Addr: h.ServiceAddress,
Handler: h,
ReadTimeout: time.Duration(h.ReadTimeout),
WriteTimeout: time.Duration(h.WriteTimeout),
TLSConfig: tlsConf,
}

var listener net.Listener
if tlsConf != nil {
listener, err = tls.Listen("tcp", h.ServiceAddress, tlsConf)
} else {
listener, err = net.Listen("tcp", h.ServiceAddress)
}
if err != nil {
return err
}
h.listener = listener
h.Port = listener.Addr().(*net.TCPAddr).Port
server := h.createHTTPServer()

h.wg.Add(1)
go func() {
defer h.wg.Done()
if err := server.Serve(h.listener); err != nil {
h.Log.Errorf("Serve failed: %v", err)
if !errors.Is(err, net.ErrClosed) {
h.Log.Errorf("Serve failed: %v", err)
}
close(h.close)
}
}()

h.Log.Infof("Listening on %s", listener.Addr().String())
h.Log.Infof("Listening on %s", h.listener.Addr().String())

return nil
}

func (h *HTTPListenerV2) createHTTPServer() *http.Server {
return &http.Server{
Addr: h.ServiceAddress,
Handler: h,
ReadTimeout: time.Duration(h.ReadTimeout),
WriteTimeout: time.Duration(h.WriteTimeout),
TLSConfig: h.tlsConf,
}
}

// Stop cleans up all resources
func (h *HTTPListenerV2) Stop() {
if h.listener != nil {
Expand All @@ -202,6 +196,28 @@ func (h *HTTPListenerV2) Stop() {
h.wg.Wait()
}

func (h *HTTPListenerV2) Init() error {
tlsConf, err := h.ServerConfig.TLSConfig()
if err != nil {
return err
}

var listener net.Listener
if tlsConf != nil {
listener, err = tls.Listen("tcp", h.ServiceAddress, tlsConf)
} else {
listener, err = net.Listen("tcp", h.ServiceAddress)
}
if err != nil {
return err
}
h.tlsConf = tlsConf
h.listener = listener
h.Port = listener.Addr().(*net.TCPAddr).Port

return nil
}

func (h *HTTPListenerV2) ServeHTTP(res http.ResponseWriter, req *http.Request) {
handler := h.serveWrite

Expand All @@ -213,6 +229,13 @@ func (h *HTTPListenerV2) ServeHTTP(res http.ResponseWriter, req *http.Request) {
}

func (h *HTTPListenerV2) serveWrite(res http.ResponseWriter, req *http.Request) {
select {
case <-h.close:
res.WriteHeader(http.StatusGone)
return
default:
}

// Check that the content length is not too large for us to handle.
if req.ContentLength > int64(h.MaxBodySize) {
if err := tooLarge(res); err != nil {
Expand Down Expand Up @@ -393,6 +416,7 @@ func init() {
Paths: []string{"/telegraf"},
Methods: []string{"POST", "PUT"},
DataSource: body,
close: make(chan struct{}),
}
})
}
28 changes: 26 additions & 2 deletions plugins/inputs/http_listener_v2/http_listener_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func newTestHTTPListenerV2() *HTTPListenerV2 {
TimeFunc: time.Now,
MaxBodySize: config.Size(70000),
DataSource: "body",
close: make(chan struct{}),
}
return listener
}
Expand All @@ -78,6 +79,7 @@ func newTestHTTPSListenerV2() *HTTPListenerV2 {
Parser: parser,
ServerConfig: *pki.TLSServerConfig(),
TimeFunc: time.Now,
close: make(chan struct{}),
}

return listener
Expand Down Expand Up @@ -117,10 +119,10 @@ func TestInvalidListenerConfig(t *testing.T) {
TimeFunc: time.Now,
MaxBodySize: config.Size(70000),
DataSource: "body",
close: make(chan struct{}),
}

acc := &testutil.Accumulator{}
require.Error(t, listener.Start(acc))
require.Error(t, listener.Init())

// Stop is called when any ServiceInput fails to start; it must succeed regardless of state
listener.Stop()
Expand All @@ -131,6 +133,7 @@ func TestWriteHTTPSNoClientAuth(t *testing.T) {
listener.TLSAllowedCACerts = nil

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand All @@ -155,6 +158,7 @@ func TestWriteHTTPSWithClientAuth(t *testing.T) {
listener := newTestHTTPSListenerV2()

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand All @@ -169,6 +173,7 @@ func TestWriteHTTPBasicAuth(t *testing.T) {
listener := newTestHTTPAuthListener()

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand All @@ -187,6 +192,7 @@ func TestWriteHTTP(t *testing.T) {
listener := newTestHTTPListenerV2()

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand Down Expand Up @@ -237,6 +243,7 @@ func TestWriteHTTPWithPathTag(t *testing.T) {
listener.PathTag = true

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand All @@ -260,6 +267,7 @@ func TestWriteHTTPWithMultiplePaths(t *testing.T) {
listener.PathTag = true

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand Down Expand Up @@ -292,6 +300,7 @@ func TestWriteHTTPNoNewline(t *testing.T) {
listener := newTestHTTPListenerV2()

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand Down Expand Up @@ -319,9 +328,11 @@ func TestWriteHTTPExactMaxBodySize(t *testing.T) {
Parser: parser,
MaxBodySize: config.Size(len(hugeMetric)),
TimeFunc: time.Now,
close: make(chan struct{}),
}

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand All @@ -342,9 +353,11 @@ func TestWriteHTTPVerySmallMaxBody(t *testing.T) {
Parser: parser,
MaxBodySize: config.Size(4096),
TimeFunc: time.Now,
close: make(chan struct{}),
}

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand All @@ -359,6 +372,8 @@ func TestWriteHTTPGzippedData(t *testing.T) {
listener := newTestHTTPListenerV2()

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand Down Expand Up @@ -391,6 +406,7 @@ func TestWriteHTTPSnappyData(t *testing.T) {
listener := newTestHTTPListenerV2()

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand Down Expand Up @@ -429,6 +445,7 @@ func TestWriteHTTPHighTraffic(t *testing.T) {
listener := newTestHTTPListenerV2()

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand Down Expand Up @@ -464,6 +481,7 @@ func TestReceive404ForInvalidEndpoint(t *testing.T) {
listener := newTestHTTPListenerV2()

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand All @@ -478,6 +496,7 @@ func TestWriteHTTPInvalid(t *testing.T) {
listener := newTestHTTPListenerV2()

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand All @@ -492,6 +511,7 @@ func TestWriteHTTPEmpty(t *testing.T) {
listener := newTestHTTPListenerV2()

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand All @@ -507,6 +527,7 @@ func TestWriteHTTPTransformHeaderValuesToTagsSingleWrite(t *testing.T) {
listener.HTTPHeaderTags = map[string]string{"Present_http_header_1": "presentMeasurementKey1", "present_http_header_2": "presentMeasurementKey2", "NOT_PRESENT_HEADER": "notPresentMeasurementKey"}

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand Down Expand Up @@ -545,6 +566,7 @@ func TestWriteHTTPTransformHeaderValuesToTagsBulkWrite(t *testing.T) {
listener.HTTPHeaderTags = map[string]string{"Present_http_header_1": "presentMeasurementKey1", "Present_http_header_2": "presentMeasurementKey2", "NOT_PRESENT_HEADER": "notPresentMeasurementKey"}

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand Down Expand Up @@ -576,6 +598,7 @@ func TestWriteHTTPQueryParams(t *testing.T) {
listener.Parser = parser

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand All @@ -597,6 +620,7 @@ func TestWriteHTTPFormData(t *testing.T) {
listener.Parser = parser

acc := &testutil.Accumulator{}
require.NoError(t, listener.Init())
require.NoError(t, listener.Start(acc))
defer listener.Stop()

Expand Down

0 comments on commit 1b95720

Please sign in to comment.