@@ -61,6 +61,9 @@ type Server struct {
61
61
// authMethods maps authentication methods to their respective
62
62
// Authenticator implementations.
63
63
authMethods map [uint8 ]Authenticator
64
+
65
+ // isIPAllowed is a function that determines whether an IP address is allowed to connect.
66
+ isIPAllowed func (net.IP ) bool
64
67
}
65
68
66
69
// New creates a new Server instance and potentially returns an error if the configuration is invalid.
@@ -97,24 +100,49 @@ func New(conf *Config) (*Server, error) {
97
100
conf .Rules = PermitAll ()
98
101
}
99
102
100
- // Ensure we have a log target
103
+ // Ensure a log target is configured. If not, default to logging to standard output.
101
104
if conf .Logger == nil {
102
105
conf .Logger = log .New (os .Stdout , "" , log .LstdFlags )
103
106
}
104
107
108
+ // Initialize the server with the provided configuration.
105
109
server := & Server {
106
110
config : conf ,
107
111
}
108
112
113
+ // Initialize the authentication methods map.
109
114
server .authMethods = make (map [uint8 ]Authenticator )
110
115
116
+ // Populate the authentication methods map with the configured authenticators.
111
117
for _ , a := range conf .AuthMethods {
112
118
server .authMethods [a .GetCode ()] = a
113
119
}
114
120
121
+ // Set a default IP allowlist function that allows all IPs.
122
+ server .isIPAllowed = func (ip net.IP ) bool {
123
+ return true // By default, allow all IPs
124
+ }
125
+
115
126
return server , nil
116
127
}
117
128
129
+ // SetIPAllowlist sets the function to check if a given IP is allowed.
130
+ // It takes a list of allowed IPs and updates the server's IP allowlist function accordingly.
131
+ func (s * Server ) SetIPAllowlist (allowedIPs []net.IP ) {
132
+ // Update the IP allowlist function to check if the given IP is in the list of allowed IPs.
133
+ s .isIPAllowed = func (ip net.IP ) bool {
134
+ // Iterate through the list of allowed IPs.
135
+ for _ , allowedIP := range allowedIPs {
136
+ if ip .Equal (allowedIP ) {
137
+ // Return true if the given IP matches any allowed IP.
138
+ return true
139
+ }
140
+ }
141
+ // Return false if the given IP is not in the list of allowed IPs.
142
+ return false
143
+ }
144
+ }
145
+
118
146
// ListenAndServe creates a listener on the specified network address and starts serving connections.
119
147
// It is a convenience function that calls net.Listen and then Serve.
120
148
//
@@ -149,17 +177,32 @@ func (s *Server) Serve(l net.Listener) error {
149
177
// It reads from the connection, processes the SOCKS5 protocol, and handles the request.
150
178
//
151
179
// ServeConn performs the following steps:
152
- // 1. Reads the version byte from the connection.
153
- // 2. Checks if the version is compatible with SOCKS5.
154
- // 3. Authenticates the connection based on the server's configuration.
155
- // 4. Reads the client's request.
156
- // 5. Processes the client's request and sends the appropriate response.
180
+ // - Check the IP allowlist
181
+ // - Reads the version byte from the connection.
182
+ // - Checks if the version is compatible with SOCKS5.
183
+ // - Authenticates the connection based on the server's configuration.
184
+ // - Reads the client's request.
185
+ // - Processes the client's request and sends the appropriate response.
157
186
//
158
187
// ServeConn returns an error if any step fails.
159
188
func (s * Server ) ServeConn (conn net.Conn ) error {
160
189
defer conn .Close ()
161
190
bufConn := bufio .NewReader (conn )
162
191
192
+ // Check client IP against allowlist
193
+ clientIP , _ , err := net .SplitHostPort (conn .RemoteAddr ().String ())
194
+ if err != nil {
195
+ s .config .Logger .Printf ("[ERR] socks: Failed to get client IP address: %v" , err )
196
+ return err
197
+ }
198
+ ip := net .ParseIP (clientIP )
199
+ if s .isIPAllowed (ip ) {
200
+ s .config .Logger .Printf ("[INFO] socks: Connection from allowed IP address: %s" , clientIP )
201
+ } else {
202
+ s .config .Logger .Printf ("[WARN] socks: Connection from not allowed IP address: %s" , clientIP )
203
+ return fmt .Errorf ("connection from not allowed IP address" )
204
+ }
205
+
163
206
// Read the version byte
164
207
version := []byte {0 }
165
208
if _ , err := bufConn .Read (version ); err != nil {
0 commit comments