@@ -11,6 +11,8 @@ use std::sync::{Arc, RwLock};
1111
1212use log:: { debug, info, warn} ;
1313use regex:: Regex ;
14+ use std:: net:: IpAddr ;
15+ use std:: str:: FromStr ;
1416
1517type FnMatch = Box < dyn Send + Sync + Fn ( & str , u16 ) -> bool > ;
1618const SORT_MATCH_RULES_COUNT_THRESHOLD : usize = 10 ;
@@ -152,6 +154,40 @@ impl ProxyRuleManager {
152154 return None ;
153155 }
154156
157+ // IPv6 may start with "::", but we will simply ignore it here
158+ if rule. chars ( ) . nth ( 0 ) . unwrap ( ) . is_numeric ( ) {
159+ // Handle CIDR notation first
160+ if let Some ( ( ip_str, prefix_len) ) = rule. split_once ( '/' ) {
161+ if let ( Ok ( ip) , Ok ( prefix_len) ) =
162+ ( IpAddr :: from_str ( ip_str) , prefix_len. parse :: < u8 > ( ) )
163+ {
164+ let cidr = IpCidr :: new ( ip, prefix_len) ;
165+ return Some ( Box :: new ( move |host, _port| {
166+ if host. is_empty ( ) || !host. chars ( ) . nth ( 0 ) . unwrap ( ) . is_numeric ( ) {
167+ return false ;
168+ }
169+ if let Ok ( host_ip) = IpAddr :: from_str ( host) {
170+ return cidr. contains ( & host_ip) ;
171+ }
172+ false
173+ } ) ) ;
174+ }
175+ }
176+
177+ // Handle direct IP addresses
178+ if let Ok ( ip) = IpAddr :: from_str ( rule) {
179+ return Some ( Box :: new ( move |host, _port| {
180+ if host. is_empty ( ) || !host. chars ( ) . nth ( 0 ) . unwrap ( ) . is_numeric ( ) {
181+ return false ;
182+ }
183+ if let Ok ( host_ip) = IpAddr :: from_str ( host) {
184+ return host_ip == ip;
185+ }
186+ false
187+ } ) ) ;
188+ }
189+ }
190+
155191 let rule_len = rule. len ( ) ;
156192 let bytes = rule. as_bytes ( ) ;
157193
@@ -287,3 +323,28 @@ impl Default for ProxyRuleManager {
287323 Self :: new ( )
288324 }
289325}
326+
327+ struct IpCidr {
328+ ip : IpAddr ,
329+ prefix_len : u8 ,
330+ }
331+
332+ impl IpCidr {
333+ fn new ( ip : IpAddr , prefix_len : u8 ) -> Self {
334+ Self { ip, prefix_len }
335+ }
336+
337+ fn contains ( & self , ip : & IpAddr ) -> bool {
338+ match ( self . ip , ip) {
339+ ( IpAddr :: V4 ( network) , IpAddr :: V4 ( ip) ) => {
340+ let mask = !( ( 1u32 << ( 32 - self . prefix_len ) ) - 1 ) ;
341+ ( u32:: from ( network) & mask) == ( u32:: from ( * ip) & mask)
342+ }
343+ ( IpAddr :: V6 ( network) , IpAddr :: V6 ( ip) ) => {
344+ let mask = !( ( 1u128 << ( 128 - self . prefix_len ) ) - 1 ) ;
345+ ( u128:: from ( network) & mask) == ( u128:: from ( * ip) & mask)
346+ }
347+ _ => false ,
348+ }
349+ }
350+ }
0 commit comments