Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Also check sockets bind to tcp6 and fail on all closed sockets #824

Merged
merged 4 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/sockstate/netstat_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
// elements that satisfy the accept function
func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) {
// (juanjux) TODO: not implemented
logrus.Info("Connection checking not implemented for Darwin")
return []sockTabEntry{}, nil
logrus.Warn("Connection checking not implemented for Darwin")
return nil, ErrSocketCheckNotImplemented.New()
}

func GetConnInode(c *net.TCPConn) (n uint64, err error) {
Expand Down
54 changes: 40 additions & 14 deletions internal/sockstate/netstat_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
)

const (
pathTCPTab = "/proc/net/tcp"
pathTCP4Tab = "/proc/net/tcp"
pathTCP6Tab = "/proc/net/tcp6"
ipv4StrLen = 8
ipv6StrLen = 32
)

type procFd struct {
Expand Down Expand Up @@ -120,6 +122,23 @@ func parseIPv4(s string) (net.IP, error) {
return ip, nil
}

func parseIPv6(s string) (net.IP, error) {
ip := make(net.IP, net.IPv6len)
const grpLen = 4
i, j := 0, 4
for len(s) != 0 {
grp := s[0:8]
u, err := strconv.ParseUint(grp, 16, 32)
binary.LittleEndian.PutUint32(ip[i:j], uint32(u))
if err != nil {
return nil, err
}
i, j = i+grpLen, j+grpLen
s = s[8:]
}
return ip, nil
}

func parseAddr(s string) (*sockAddr, error) {
fields := strings.Split(s, ":")
if len(fields) < 2 {
Expand All @@ -130,6 +149,8 @@ func parseAddr(s string) (*sockAddr, error) {
switch len(fields[0]) {
case ipv4StrLen:
ip, err = parseIPv4(fields[0])
case ipv6StrLen:
ip, err = parseIPv6(fields[0])
default:
log.Fatal("Badly formatted connection address:", s)
}
Expand Down Expand Up @@ -192,21 +213,26 @@ func parseSocktab(r io.Reader, accept AcceptFn) ([]sockTabEntry, error) {
// tcpSocks returns a slice of active TCP sockets containing only those
// elements that satisfy the accept function
func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) {
f, err := os.Open(pathTCPTab)
defer func() {
_ = f.Close()
}()
if err != nil {
return nil, err
}
paths := [2]string{pathTCP4Tab, pathTCP6Tab}
var allTabs []sockTabEntry
for _, p := range paths {
f, err := os.Open(p)
defer func() {
_ = f.Close()
}()
if err != nil {
return nil, err
}

tabs, err := parseSocktab(f, accept)
if err != nil {
return nil, err
}
t, err := parseSocktab(f, accept)
if err != nil {
return nil, err
}
allTabs = append(allTabs, t...)

extractProcInfo(tabs)
return tabs, nil
}
extractProcInfo(allTabs)
return allTabs, nil
}

// GetConnInode returns the Linux inode number of a TCP connection
Expand Down
4 changes: 2 additions & 2 deletions internal/sockstate/netstat_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
// elements that satisfy the accept function
func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) {
// (juanjux) TODO: not implemented
logrus.Info("Connection checking not implemented for Windows")
return []sockTabEntry{}, nil
logrus.Warn("Connection checking not implemented for Windows")
return nil, ErrSocketCheckNotImplemented.New()
}

func GetConnInode(c *net.TCPConn) (n uint64, err error) {
Expand Down
21 changes: 16 additions & 5 deletions internal/sockstate/sockstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import (
type SockState uint8

const (
Finished = iota
Broken
Broken = iota
Other
Error
)
Expand Down Expand Up @@ -37,12 +36,24 @@ func GetInodeSockState(port int, inode uint64) (SockState, error) {

switch len(socks) {
case 0:
return Finished, nil
return Broken, nil
case 1:
if socks[0].State == CloseWait {
switch socks[0].State {
case CloseWait:
fallthrough
case TimeWait:
fallthrough
case FinWait1:
fallthrough
case FinWait2:
fallthrough
case Close:
fallthrough
case Closing:
return Broken, nil
default:
return Other, nil
}
return Other, nil
default: // more than one sock for inode, impossible?
return Error, ErrMultipleSocketsForInode.New()
}
Expand Down
8 changes: 2 additions & 6 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,12 @@ func (h *Handler) ComQuery(
for {
select {
case <-quit:
// timeout or other errors detected by the calling routine
return
default:
}

st, err := sockstate.GetInodeSockState(t.Port, inode)
switch st {
case sockstate.Finished:
// Not Linux OSs will also exit here
return
case sockstate.Broken:
errChan <- ErrConnectionWasClosed.New()
return
Expand All @@ -243,6 +239,7 @@ rowLoop:

if r.RowsAffected == rowsBatch {
if err := callback(r); err != nil {
close(quit)
return err
}

Expand Down Expand Up @@ -276,13 +273,12 @@ rowLoop:
}
timer.Reset(waitTime)
}
close(quit)

if err := rows.Close(); err != nil {
return err
}

close(quit)

// Even if r.RowsAffected = 0, the callback must be
// called to update the state in the go-vitess' listener
// and avoid returning errors when the query doesn't
Expand Down
8 changes: 4 additions & 4 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ func TestHandlerKill(t *testing.T) {
require.Len(handler.c, 2)
require.Equal(conntainer1, handler.c[1])
require.Equal(conntainer2, handler.c[2])

assertNoConnProcesses(t, e, conn2.ConnectionID)

ctx1 := handler.sm.NewContextWithQuery(conn1, "SELECT 1")
Expand Down Expand Up @@ -256,6 +255,7 @@ func TestHandlerTimeout(t *testing.T) {
})
require.NoError(err)
}

func TestOkClosedConnection(t *testing.T) {
require := require.New(t)
e := setupMemDB(require)
Expand All @@ -282,11 +282,11 @@ func TestOkClosedConnection(t *testing.T) {
0,
)
h.AddNetConnection(&conn)
c2 := newConn(2)
h.NewConnection(c2)
c := newConn(1)
h.NewConnection(c)

q := fmt.Sprintf("SELECT SLEEP(%d)", tcpCheckerSleepTime*4)
err = h.ComQuery(c2, q, func(res *sqltypes.Result) error {
err = h.ComQuery(c, q, func(res *sqltypes.Result) error {
return nil
})
require.NoError(err)
Expand Down