Skip to content

Commit

Permalink
feat: support client
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaost committed Oct 21, 2024
1 parent 0c152ad commit ade607f
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
.idea
# vscode
.vscode

# Go workspace file
go.work
go.work.sum
80 changes: 80 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package websocket

import (
"bytes"
"errors"
"fmt"
"time"

"github.com/cloudwego/hertz/pkg/protocol"
)

// ErrBadHandshake is returned when the server response to opening handshake is
// invalid.
var ErrBadHandshake = errors.New("websocket: bad handshake")

type ClientUpgrader struct {
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then buffers allocated by the HTTP server are used. The
// I/O buffer sizes do not limit the size of the messages that can be sent
// or received.
ReadBufferSize, WriteBufferSize int

// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool

// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
}

func (p *ClientUpgrader) PrepareRequest(req *protocol.Request) {
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", generateChallengeKey())
if p.EnableCompression {
req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}
}

func (p *ClientUpgrader) UpgradeResponse(req *protocol.Request, resp *protocol.Response) (*Conn, error) {
if resp.StatusCode() != 101 ||
!tokenContainsValue(resp.Header.Get("Upgrade"), "websocket") ||
!tokenContainsValue(resp.Header.Get("Connection"), "Upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKeyBytes(req.Header.Peek("Sec-Websocket-Key")) {
return nil, ErrBadHandshake
}

c, err := resp.Hijack()

Check failure on line 59 in client.go

View workflow job for this annotation

GitHub Actions / lint-and-ut

resp.Hijack undefined (type *protocol.Response has no field or method Hijack)
if err != nil {
return nil, fmt.Errorf("Hijack response connection err: %w", err)
}

c.SetDeadline(time.Time{})
conn := newConn(c, false, p.ReadBufferSize, p.WriteBufferSize, p.WriteBufferPool, nil, nil)

// can not use p.EnableCompression, always follow ext returned from server
compress := false
extensions := parseDataHeader(resp.Header.Peek("Sec-WebSocket-Extensions"))
for _, ext := range extensions {
if bytes.HasPrefix(ext, strPermessageDeflate) {
compress = true
}
}
if compress {
conn.newCompressionWriter = compressNoContextTakeover
conn.newDecompressionReader = decompressNoContextTakeover
}
return conn, nil
}
98 changes: 98 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package websocket

import (
"context"
"fmt"
"log"
"net"
"runtime"
"time"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/client"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/network/standard"
"github.com/cloudwego/hertz/pkg/protocol"
)

const (
testaddr = "localhost:10012"
testpath = "/echo"
)

func runServer(addr string) {
upgrader := HertzUpgrader{} // use default options
h := server.Default(server.WithHostPorts(addr))
// https://github.com/cloudwego/hertz/issues/121
h.NoHijackConnPool = true
h.GET(testpath, func(_ context.Context, c *app.RequestContext) {
err := upgrader.Upgrade(c, func(conn *Conn) {
for {
mt, message, err := conn.ReadMessage()
if err != nil {
log.Println("read:", err)
break
}
log.Printf("[server] recv: %v %s", mt, message)
err = conn.WriteMessage(mt, message)
if err != nil {
log.Println("write:", err)
break
}
}
})
if err != nil {
log.Print("upgrade:", err)
return
}
})
go h.Run()
}

func waitListener(addr string) {
time.Sleep(5 * time.Millisecond) // likely it's up
_, file, no, _ := runtime.Caller(1)
for i := 0; i < 50; i++ { // 5s
if ln, err := net.Dial("tcp", addr); err == nil {
ln.Close()
log.Printf("[server] %s is up @ %s:%d", addr, file, no)
return
}
log.Printf("waiting server %s @ %s:%d", addr, file, no)
time.Sleep(100 * time.Millisecond)
}
panic("server " + addr + " not ready")
}

func ExampleClient() {
runServer(testaddr)
waitListener(testaddr)

c, err := client.NewClient(client.WithDialer(standard.NewDialer()))
if err != nil {
panic(err)
}

req, resp := protocol.AcquireRequest(), protocol.AcquireResponse()
req.SetRequestURI("http://" + testaddr + testpath)
req.SetMethod("GET")

u := &ClientUpgrader{}
u.PrepareRequest(req)
err = c.Do(context.Background(), req, resp)
if err != nil {
panic(err)
}
conn, err := u.UpgradeResponse(req, resp)
if err != nil {
panic(err)
}

conn.WriteMessage(TextMessage, []byte("hello"))
m, b, err := conn.ReadMessage()
if err != nil {
panic(err)
}
fmt.Println(m, string(b))
// Output: 1 hello
}
9 changes: 9 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,21 @@ import (
"bytes"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"math/rand"
"unicode/utf8"
"unsafe"
)

var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

func generateChallengeKey() string {
b := make([]byte, 16)
binary.BigEndian.PutUint64(b, rand.Uint64())
binary.BigEndian.PutUint64(b[8:], rand.Uint64())
return base64.StdEncoding.EncodeToString(b)
}

// Token octets per RFC 2616.
var isTokenOctet = [256]bool{
'!': true,
Expand Down

0 comments on commit ade607f

Please sign in to comment.