Skip to content

Commit bb044e4

Browse files
author
AiQL.com
committed
Support IP Allowlist
1 parent 21ffcec commit bb044e4

File tree

1 file changed

+49
-6
lines changed

1 file changed

+49
-6
lines changed

socks5.go

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ type Server struct {
6161
// authMethods maps authentication methods to their respective
6262
// Authenticator implementations.
6363
authMethods map[uint8]Authenticator
64+
65+
// isIPAllowed is a function that determines whether an IP address is allowed to connect.
66+
isIPAllowed func(net.IP) bool
6467
}
6568

6669
// New creates a new Server instance and potentially returns an error if the configuration is invalid.
@@ -97,24 +100,49 @@ func New(conf *Config) (*Server, error) {
97100
conf.Rules = PermitAll()
98101
}
99102

100-
// Ensure we have a log target
103+
// Ensure a log target is configured. If not, default to logging to standard output.
101104
if conf.Logger == nil {
102105
conf.Logger = log.New(os.Stdout, "", log.LstdFlags)
103106
}
104107

108+
// Initialize the server with the provided configuration.
105109
server := &Server{
106110
config: conf,
107111
}
108112

113+
// Initialize the authentication methods map.
109114
server.authMethods = make(map[uint8]Authenticator)
110115

116+
// Populate the authentication methods map with the configured authenticators.
111117
for _, a := range conf.AuthMethods {
112118
server.authMethods[a.GetCode()] = a
113119
}
114120

121+
// Set a default IP allowlist function that allows all IPs.
122+
server.isIPAllowed = func(ip net.IP) bool {
123+
return true // By default, allow all IPs
124+
}
125+
115126
return server, nil
116127
}
117128

129+
// SetIPAllowlist sets the function to check if a given IP is allowed.
130+
// It takes a list of allowed IPs and updates the server's IP allowlist function accordingly.
131+
func (s *Server) SetIPAllowlist(allowedIPs []net.IP) {
132+
// Update the IP allowlist function to check if the given IP is in the list of allowed IPs.
133+
s.isIPAllowed = func(ip net.IP) bool {
134+
// Iterate through the list of allowed IPs.
135+
for _, allowedIP := range allowedIPs {
136+
if ip.Equal(allowedIP) {
137+
// Return true if the given IP matches any allowed IP.
138+
return true
139+
}
140+
}
141+
// Return false if the given IP is not in the list of allowed IPs.
142+
return false
143+
}
144+
}
145+
118146
// ListenAndServe creates a listener on the specified network address and starts serving connections.
119147
// It is a convenience function that calls net.Listen and then Serve.
120148
//
@@ -149,17 +177,32 @@ func (s *Server) Serve(l net.Listener) error {
149177
// It reads from the connection, processes the SOCKS5 protocol, and handles the request.
150178
//
151179
// ServeConn performs the following steps:
152-
// 1. Reads the version byte from the connection.
153-
// 2. Checks if the version is compatible with SOCKS5.
154-
// 3. Authenticates the connection based on the server's configuration.
155-
// 4. Reads the client's request.
156-
// 5. Processes the client's request and sends the appropriate response.
180+
// - Check the IP allowlist
181+
// - Reads the version byte from the connection.
182+
// - Checks if the version is compatible with SOCKS5.
183+
// - Authenticates the connection based on the server's configuration.
184+
// - Reads the client's request.
185+
// - Processes the client's request and sends the appropriate response.
157186
//
158187
// ServeConn returns an error if any step fails.
159188
func (s *Server) ServeConn(conn net.Conn) error {
160189
defer conn.Close()
161190
bufConn := bufio.NewReader(conn)
162191

192+
// Check client IP against allowlist
193+
clientIP, _, err := net.SplitHostPort(conn.RemoteAddr().String())
194+
if err != nil {
195+
s.config.Logger.Printf("[ERR] socks: Failed to get client IP address: %v", err)
196+
return err
197+
}
198+
ip := net.ParseIP(clientIP)
199+
if s.isIPAllowed(ip) {
200+
s.config.Logger.Printf("[INFO] socks: Connection from allowed IP address: %s", clientIP)
201+
} else {
202+
s.config.Logger.Printf("[WARN] socks: Connection from not allowed IP address: %s", clientIP)
203+
return fmt.Errorf("connection from not allowed IP address")
204+
}
205+
163206
// Read the version byte
164207
version := []byte{0}
165208
if _, err := bufConn.Read(version); err != nil {

0 commit comments

Comments
 (0)