diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 849beb5e5cd2..3dc56cc7dd53 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -2,8 +2,12 @@ package websocket import ( "context" + _ "embed" "encoding/base64" + "fmt" "io" + "net/http" + "os" "time" "github.com/gorilla/websocket" @@ -15,6 +19,27 @@ import ( "github.com/xtls/xray-core/transport/internet/tls" ) +//go:embed dialer.html +var webpage []byte +var conns chan *websocket.Conn + +func init() { + if addr := os.Getenv("XRAY_BROWSER_DIALER"); addr != "" { + conns = make(chan *websocket.Conn, 256) + go http.ListenAndServe(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/websocket" { + if conn, err := upgrader.Upgrade(w, r, nil); err == nil { + conns <- conn + } else { + fmt.Println("unexpected error") + } + } else { + w.Write(webpage) + } + })) + } +} + // Dial dials a WebSocket connection to the given destination. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) @@ -66,6 +91,30 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in } uri := protocol + "://" + host + wsSettings.GetNormalizedPath() + if conns != nil { + data := []byte(uri) + if ed != nil { + data = append(data, " "+base64.RawURLEncoding.EncodeToString(ed)...) + } + var conn *websocket.Conn + for { + conn = <-conns + if conn.WriteMessage(websocket.TextMessage, data) != nil { + conn.Close() + } else { + break + } + } + if _, p, err := conn.ReadMessage(); err != nil { + conn.Close() + return nil, err + } else if s := string(p); s != "ok" { + conn.Close() + return nil, newError(s) + } + return newConnection(conn, conn.RemoteAddr(), nil), nil + } + header := wsSettings.GetRequestHeader() if ed != nil { header.Set("Sec-WebSocket-Protocol", base64.StdEncoding.EncodeToString(ed)) diff --git a/transport/internet/websocket/dialer.html b/transport/internet/websocket/dialer.html new file mode 100644 index 000000000000..bbeaec0cff2c --- /dev/null +++ b/transport/internet/websocket/dialer.html @@ -0,0 +1,55 @@ + + + + Browser Dialer + + + + + diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index 0fd84a526211..c1848345364a 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "io" "net/http" + "strings" "sync" "time" @@ -25,6 +26,8 @@ type requestHandler struct { ln *Listener } +var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "") + var upgrader = &websocket.Upgrader{ ReadBufferSize: 4 * 1024, WriteBufferSize: 4 * 1024, @@ -39,7 +42,17 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req writer.WriteHeader(http.StatusNotFound) return } - conn, err := upgrader.Upgrade(writer, request, nil) + + var extraReader io.Reader + var responseHeader = http.Header{} + if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" { + if ed, err := base64.RawURLEncoding.DecodeString(replacer.Replace(str)); err == nil && len(ed) > 0 { + extraReader = bytes.NewReader(ed) + responseHeader.Set("Sec-WebSocket-Protocol", str) + } + } + + conn, err := upgrader.Upgrade(writer, request, responseHeader) if err != nil { newError("failed to convert to WebSocket connection").Base(err).WriteToLog() return @@ -54,12 +67,6 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } } - var extraReader io.Reader - if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" { - if ed, err := base64.StdEncoding.DecodeString(str); err == nil && len(ed) > 0 { - extraReader = bytes.NewReader(ed) - } - } h.ln.addConn(newConnection(conn, remoteAddr, extraReader)) } @@ -128,7 +135,7 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet ln: l, }, ReadHeaderTimeout: time.Second * 4, - MaxHeaderBytes: 2048, + MaxHeaderBytes: 4096, } go func() {