1
1
#!/usr/bin/env python3
2
2
3
- import os , sys , socket , logging , inspect , random , signal , asyncio , contextlib as cl
3
+ import os , sys , socket , logging , inspect , random , signal , asyncio
4
4
5
5
6
6
err_fmt = lambda err : f'[{ err .__class__ .__name__ } ] { err } '
@@ -9,8 +9,8 @@ class HTTPSFragProxy:
9
9
10
10
buff_sz = 1500
11
11
12
- def __init__ (self , sni_list = None , idle_timeout = None , ** bind ):
13
- self .idle_timeout , self .bind = idle_timeout , bind
12
+ def __init__ (self , sni_list = None , idle_timeout = None , block_http = False , ** bind ):
13
+ self .idle_timeout , self .block_http , self . bind = idle_timeout , block_http , bind
14
14
if sni_list is not None :
15
15
if not sni_list : sni_list = [b'\0 ' ] # match nothing
16
16
else : sni_list = list (s .strip ().encode () for s in sni_list )
@@ -33,47 +33,48 @@ class HTTPSFragProxy:
33
33
return os .kill (os .getpid (), signal .SIGTERM )
34
34
self .activity .clear ()
35
35
36
- async def conn_wrap (self , func , reader , writer , close_writer = False ):
36
+ async def conn_wrap (self , func , cid , reader , writer , close_writer = True ):
37
37
try : await func (reader , writer )
38
- except Exception as err :
38
+ except Exception as err : # handles client closing connection too
39
39
writer .close ()
40
- self .log .exception ( 'Failed handing connection : %s' , err_fmt (err ))
40
+ self .log .info ( 'Connection error [%s] : %s', ':' . join ( cid ) , err_fmt (err ))
41
41
finally :
42
42
if close_writer : writer .close ()
43
43
44
44
def conn_handle (self , reader , writer ):
45
- return self .conn_wrap (self ._conn_handle , reader , writer )
46
- async def _conn_handle (self , reader , writer ):
45
+ cid = list (map (str , writer .get_extra_info ('peername' )))
46
+ return self .conn_wrap (lambda * a : self ._conn_handle (cid , * a ), cid , reader , writer )
47
+
48
+ async def _conn_handle (self , cid , reader , writer ):
47
49
self .activity .set ()
48
- cid = ':' .join (map (str , writer .get_extra_info ('peername' )))
49
50
http_data = await reader .read (self .buff_sz )
50
51
if not http_data : writer .close (); return
51
- method , url = ( headers : = http_data .split (b'\r \n ' ) )[0 ].decode ().split ()[:2 ]
52
+ method , url = dst = http_data .split (b'\r \n ' )[0 ].decode ().split ()[:2 ]
52
53
self .log .debug ( 'Connection [%s] ::'
53
- ' %s %s [buff=%d]' , cid , method , url , len (http_data ) )
54
+ ' %s %s [buff=%d]' , ':' .join (cid ), method , url , len (http_data ) )
55
+ cid .extend (dst )
54
56
55
- if https := (method == 'CONNECT' ): # https
56
- host , _ , port = url .partition (':' )
57
- port = int (port ) if port else 443
58
- else : # paintext http
59
- host_header = next ((h for h in headers if h .startswith (b'Host: ' )), None )
60
- if not host_header : raise ValueError ('Missing Host header' )
61
- host , _ , port = host_header [6 :].partition (b':' )
62
- host , port = host .decode (), int (port ) if port else 80
57
+ if https := (method == 'CONNECT' ): port_default = 443
58
+ elif not url .startswith ('http://' ) or self .block_http : raise ValueError (url )
59
+ else : url , port_default = url [7 :].split ('/' , 1 )[0 ], 80
60
+ host , _ , port = url .rpartition (':' )
61
+ if host and port and port .isdigit (): port = int (port )
62
+ else : host , port = url , port_default
63
+ if host [0 ] == '[' and host [- 1 ] == ']' : host = host [1 :- 1 ] # raw ipv6
63
64
64
65
try : xreader , xwriter = await asyncio .open_connection (host , port )
65
- except OSError as err : return self .log .info ( 'Connection [%s]'
66
- ' to %s:%s failed (tls=%s): %s' , cid , host , port , int (https ), err_fmt (err ) )
66
+ except OSError as err : return self .log .info ( 'Connection [%s] to %s:%s '
67
+ ' failed (tls=%s): %s' , ':' . join ( cid ) , host , port , int (https ), err_fmt (err ) )
67
68
self .activity .set ()
68
69
if not https : xwriter .write (http_data ); await xwriter .drain ()
69
70
else :
70
71
writer .write (b'HTTP/1.1 200 Connection Established\r \n \r \n ' )
71
72
await writer .drain ()
72
- await self .conn_wrap (self .fragment_data , reader , xwriter )
73
+ await self .conn_wrap (self .fragment_data , cid , reader , xwriter , False )
73
74
74
75
for task in asyncio .as_completed ([
75
- self .conn_wrap (self .pipe , reader , xwriter , True ),
76
- self .conn_wrap (self .pipe , xreader , writer , True ) ]):
76
+ self .conn_wrap (self .pipe , cid , reader , xwriter ),
77
+ self .conn_wrap (self .pipe , cid , xreader , writer ) ]):
77
78
await task
78
79
79
80
async def fragment_data (self , reader , writer ):
@@ -121,22 +122,20 @@ def main(args=None):
121
122
Default is to apply such fragmentation to all processed connections.''' ))
122
123
parser .add_argument ('-t' , '--idle-timeout' , type = float , metavar = 'seconds' , help = dd ('''
123
124
Stop after number of seconds being idle. Useful if started from systemd socket.''' ))
125
+ parser .add_argument ('--block-http' , action = 'store_true' , help = dd ('''
126
+ Reject/close insecure plaintext http connections made through the proxy.''' ))
124
127
parser .add_argument ('--debug' , action = 'store_true' , help = 'Verbose logging to stderr.' )
125
128
opts = parser .parse_args (sys .argv [1 :] if args is None else args )
126
129
127
- @cl .contextmanager
128
- def in_file (path ):
129
- if not path or path == '-' : return (yield sys .stdin )
130
- if path [0 ] == '%' : path = int (path [1 :])
131
- with open (path ) as src : yield src
132
-
133
130
logging .basicConfig ( format = '%(levelname)s :: %(message)s' ,
134
131
level = logging .WARNING if not opts .debug else logging .DEBUG )
135
132
log = logging .getLogger ('nhp.main' )
136
133
137
134
sni_list = None
138
- if opts .frag_domains :
139
- with in_file (opts .frag_domains ) as src : sni_list = src .read ().split ()
135
+ if p := opts .frag_domains :
136
+ if not p or p == '-' : src = sys .stdin
137
+ else : src = open (p if p [0 ] != '%' else int (p [1 :]))
138
+ sni_list = src .read ().split ()
140
139
141
140
sd_pid , sd_fds = (int (os .environ .get (f'LISTEN_{ k } ' , 0 )) for k in ['PID' , 'FDS' ])
142
141
if sd_pid == os .getpid () and sd_fds :
@@ -147,7 +146,8 @@ def main(args=None):
147
146
else :
148
147
host , _ , port = opts .bind .partition (':' ); port = int (port or 8101 )
149
148
proxy = dict (host = host , port = port )
150
- proxy = HTTPSFragProxy (sni_list = sni_list , idle_timeout = opts .idle_timeout , ** proxy )
149
+ proxy = HTTPSFragProxy ( sni_list = sni_list ,
150
+ idle_timeout = opts .idle_timeout , block_http = opts .block_http , ** proxy )
151
151
152
152
log .debug ( 'Starting proxy (%s)...' ,
153
153
f'{ len (sni_list ):,d} SNI domains' if sni_list is not None else 'fragmenting any SNI' )
0 commit comments