Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit maximum number of connections #271

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Server struct {
LMTP bool

Domain string
MaxConnections int
MaxRecipients int
MaxMessageBytes int64
MaxLineLength int
Expand Down Expand Up @@ -127,17 +128,7 @@ func (s *Server) Serve(l net.Listener) error {
}

func (s *Server) handleConn(c *Conn) error {
s.locker.Lock()
s.conns[c] = struct{}{}
s.locker.Unlock()

defer func() {
c.Close()

s.locker.Lock()
delete(s.conns, c)
s.locker.Unlock()
}()
defer c.Close()

if tlsConn, ok := c.conn.(*tls.Conn); ok {
if d := s.ReadTimeout; d != 0 {
Expand All @@ -151,6 +142,29 @@ func (s *Server) handleConn(c *Conn) error {
}
}

// register connection
maxConnsExceeded := false
s.locker.Lock()
if s.MaxConnections > 0 && len(s.conns) >= s.MaxConnections {
maxConnsExceeded = true
} else {
s.conns[c] = struct{}{}
}
s.locker.Unlock()

// limit connections
if maxConnsExceeded {
c.writeResponse(421, EnhancedCode{4, 4, 5}, "Too busy. Try again later.")
return nil
}

// unregister connection
defer func() {
s.locker.Lock()
delete(s.conns, c)
s.locker.Unlock()
}()

c.greet()

for {
Expand Down
46 changes: 46 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1514,3 +1514,49 @@ func TestServerDSNwithSMTPUTF8(t *testing.T) {
t.Fatal("Invalid ORCPT address:", val)
}
}

func TestServer_MaxConnections(t *testing.T) {
cases := []struct {
name string
maxConnections int
expected string
}{
// 0 = unlimited; all connections should be accepted
{name: "MaxConnections set to 0", maxConnections: 0, expected: "220 localhost ESMTP Service Ready"},
// 1 = only one connection is allowed; the second connection should be rejected
{name: "MaxConnections set to 1", maxConnections: 1, expected: "421 4.4.5 Too busy. Try again later."},
// 2 = two connections are allowed; the second connection should be accepted
{name: "MaxConnections set to 2", maxConnections: 2, expected: "220 localhost ESMTP Service Ready"},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// create server with limited allowed connections
_, s, c, scanner1 := testServer(t, func(s *smtp.Server) {
s.MaxConnections = tc.maxConnections
})
defer s.Close()

// there is already be one connection registered
// and we can read the greeting from it (see testServerGreeted())
scanner1.Scan()
if scanner1.Text() != "220 localhost ESMTP Service Ready" {
t.Fatal("Invalid first greeting:", scanner1.Text())
}

// now we create a second connection
c2, err := net.Dial("tcp", c.RemoteAddr().String())
if err != nil {
t.Fatal("Error creating second connection:", err)
}

// we should get an appropriate greeting now
scanner2 := bufio.NewScanner(c2)
scanner2.Scan()
if scanner2.Text() != tc.expected {
t.Fatal("Invalid second greeting:", scanner2.Text())
}
})
}

}