From 1df584a904f6e9a941dbbdacf78255ae9abdb72f Mon Sep 17 00:00:00 2001 From: damoye Date: Sat, 10 Jun 2017 20:34:16 +0800 Subject: [PATCH] fix shakehand --- socks/socks.go | 50 ++++++++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/socks/socks.go b/socks/socks.go index 46318325..136ba380 100644 --- a/socks/socks.go +++ b/socks/socks.go @@ -43,9 +43,6 @@ const ( // MaxAddrLen is the maximum size of SOCKS address in bytes. const MaxAddrLen = 1 + 1 + 255 + 2 -// MaxReqLen is the maximum size of SOCKS request in bytes. -const MaxReqLen = 1 + 1 + 1 + MaxAddrLen - // Addr represents a SOCKS address as defined in RFC 1928 section 5. type Addr []byte @@ -68,9 +65,10 @@ func (a Addr) String() string { return net.JoinHostPort(host, port) } -// ReadAddr reads just enough bytes from r to get a valid Addr. -func ReadAddr(r io.Reader) (Addr, error) { - b := make([]byte, MaxAddrLen) +func readAddr(r io.Reader, b []byte) (Addr, error) { + if len(b) < MaxAddrLen { + return nil, io.ErrShortBuffer + } _, err := io.ReadFull(r, b[:1]) // read 1st byte for address type if err != nil { return nil, err @@ -95,6 +93,11 @@ func ReadAddr(r io.Reader) (Addr, error) { return nil, ErrAddressNotSupported } +// ReadAddr reads just enough bytes from r to get a valid Addr. +func ReadAddr(r io.Reader) (Addr, error) { + return readAddr(r, make([]byte, MaxAddrLen)) +} + // SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed. func SplitAddr(b []byte) Addr { addrLen := 1 @@ -163,29 +166,32 @@ func ParseAddr(s string) Addr { // Handshake fast-tracks SOCKS initialization to get target address to connect. func Handshake(rw io.ReadWriter) (Addr, error) { - // Read RFC 1928 section 4 for request and reply structure and sizes - buf := make([]byte, MaxReqLen) - - _, err := rw.Read(buf) // SOCKS version and auth methods - if err != nil { + // Read RFC 1928 for request and reply structure and sizes. + buf := make([]byte, MaxAddrLen) + // read VER, NMETHODS, METHODS + if _, err := io.ReadFull(rw, buf[:2]); err != nil { return nil, err } - - _, err = rw.Write([]byte{5, 0}) // SOCKS v5, no auth required - if err != nil { + nmethods := buf[1] + if _, err := io.ReadFull(rw, buf[:nmethods]); err != nil { return nil, err } - - n, err := rw.Read(buf) // SOCKS request: VER, CMD, RSV, Addr - if err != nil { + // write VER METHOD + if _, err := rw.Write([]byte{5, 0}); err != nil { + return nil, err + } + // read VER CMD RSV ATYP DST.ADDR DST.PORT + if _, err := io.ReadFull(rw, buf[:3]); err != nil { return nil, err } - buf = buf[:n] - if buf[1] != CmdConnect { return nil, ErrCommandNotSupported } - - _, err = rw.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) // SOCKS v5, reply succeeded - return buf[3:], err // skip VER, CMD, RSV fields + addr, err := readAddr(rw, buf) + if err != nil { + return nil, err + } + // write VER REP RSV ATYP BND.ADDR BND.PORT + _, err = rw.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) + return addr, err }