Skip to content

Commit

Permalink
Move onServerMetadataPush to rsocketClient
Browse files Browse the repository at this point in the history
Summary:
`RocketClient` is a higher level abstract - it should not concern itself with lower-level server metadata.
This is a job for the lower-level `rsocketClient`.

Reviewed By: podtserkovskiy

Differential Revision: D67164810

fbshipit-source-id: 50181a59fbc1fa9412551c143a4bb6dfb7614e78
  • Loading branch information
echistyakov authored and facebook-github-bot committed Dec 13, 2024
1 parent 17b64cc commit 1ecfb2f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 38 deletions.
18 changes: 4 additions & 14 deletions thrift/lib/go/thrift/rocket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@ import (
"time"

"github.com/facebook/fbthrift/thrift/lib/go/thrift/types"
"github.com/facebook/fbthrift/thrift/lib/thrift/rpcmetadata"
)

type rocketClient struct {
types.Encoder
types.Decoder

// rsocket client state
client *rsocketClient
client RSocketClient

resultData []byte
resultErr error
Expand All @@ -58,15 +57,15 @@ var _ types.RequestHeaders = (*rocketClient)(nil)
var _ types.ResponseHeaderGetter = (*rocketClient)(nil)

// NewRocketClient creates a new Rocket client given an RSocketClient.
func NewRocketClient(client *rsocketClient, protoID types.ProtocolID, ioTimeout time.Duration, persistentHeaders map[string]string) (types.Protocol, error) {
func NewRocketClient(client RSocketClient, protoID types.ProtocolID, ioTimeout time.Duration, persistentHeaders map[string]string) (types.Protocol, error) {
return newRocketClientFromRsocket(client, protoID, ioTimeout, persistentHeaders)
}

func newRocketClient(conn net.Conn, protoID types.ProtocolID, ioTimeout time.Duration, persistentHeaders map[string]string) (types.Protocol, error) {
return newRocketClientFromRsocket(newRSocketClient(conn), protoID, ioTimeout, persistentHeaders)
}

func newRocketClientFromRsocket(client *rsocketClient, protoID types.ProtocolID, ioTimeout time.Duration, persistentHeaders map[string]string) (types.Protocol, error) {
func newRocketClientFromRsocket(client RSocketClient, protoID types.ProtocolID, ioTimeout time.Duration, persistentHeaders map[string]string) (types.Protocol, error) {
p := &rocketClient{
client: client,
protoID: protoID,
Expand Down Expand Up @@ -111,7 +110,7 @@ func (p *rocketClient) Flush() (err error) {
defer cancel()
}

if err := p.client.SendSetup(ctx, p.onServerMetadataPush); err != nil {
if err := p.client.SendSetup(ctx); err != nil {
return err
}
headers := unionMaps(p.reqHeaders, p.persistentHeaders)
Expand All @@ -137,15 +136,6 @@ func unionMaps(dst, src map[string]string) map[string]string {
return dst
}

func (p *rocketClient) onServerMetadataPush(metadata *rpcmetadata.ServerPushMetadata) {
if metadata.SetupResponse != nil {
setupResponse := metadata.SetupResponse
serverSupportsZstd := (setupResponse.ZstdSupported != nil && *setupResponse.ZstdSupported)
// zstd is only supported if both the client and the server support it.
p.client.useZstd = p.client.useZstd && serverSupportsZstd
}
}

func (p *rocketClient) ReadMessageBegin() (string, types.MessageType, int32, error) {
name := p.messageName
if p.resultErr != nil {
Expand Down
47 changes: 23 additions & 24 deletions thrift/lib/go/thrift/rocket_rsocket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@ import (
"time"

"github.com/facebook/fbthrift/thrift/lib/go/thrift/types"
"github.com/facebook/fbthrift/thrift/lib/thrift/rpcmetadata"
rsocket "github.com/rsocket/rsocket-go"
"github.com/rsocket/rsocket-go/core/transport"
"github.com/rsocket/rsocket-go/payload"
)

// RSocketClient is a client that uses a rsocket library.
type RSocketClient interface {
SendSetup(ctx context.Context, onServerMetadataPush OnServerMetadataPush) error
SendSetup(ctx context.Context) error
FireAndForget(messageName string, protoID types.ProtocolID, typeID types.MessageType, headers map[string]string, dataBytes []byte) error
RequestResponse(ctx context.Context, messageName string, protoID types.ProtocolID, typeID types.MessageType, headers map[string]string, dataBytes []byte) (map[string]string, []byte, error)
Close() error
Expand All @@ -43,14 +42,11 @@ type rsocketClient struct {
useZstd bool
}

// OnServerMetadataPush is called when the server sends a metadata push.
type OnServerMetadataPush func(metadata *rpcmetadata.ServerPushMetadata)

func newRSocketClient(conn net.Conn) *rsocketClient {
func newRSocketClient(conn net.Conn) RSocketClient {
return &rsocketClient{conn: conn}
}

func (r *rsocketClient) SendSetup(_ context.Context, onServerMetadataPush OnServerMetadataPush) error {
func (r *rsocketClient) SendSetup(_ context.Context) error {
if r.client != nil {
// already setup
return nil
Expand All @@ -66,7 +62,17 @@ func (r *rsocketClient) SendSetup(_ context.Context, onServerMetadataPush OnServ
MetadataMimeType(RocketMetadataCompactMimeType).
SetupPayload(setupPayload).
OnClose(func(error) {})
clientStarter := clientBuilder.Acceptor(acceptor(onServerMetadataPush))

clientStarter := clientBuilder.Acceptor(
func(_ context.Context, _ rsocket.RSocket) rsocket.RSocket {
return rsocket.NewAbstractSocket(
rsocket.MetadataPush(
r.onServerMetadataPush,
),
)
},
)

client, err := clientStarter.Transport(transporter(r.conn)).Start(context.Background())
if err != nil {
return err
Expand All @@ -75,23 +81,16 @@ func (r *rsocketClient) SendSetup(_ context.Context, onServerMetadataPush OnServ
return nil
}

func acceptor(onServerMetadataPush OnServerMetadataPush) func(_ context.Context, socket rsocket.RSocket) rsocket.RSocket {
return func(_ context.Context, socket rsocket.RSocket) rsocket.RSocket {
return rsocket.NewAbstractSocket(
rsocket.MetadataPush(
metadataPush(onServerMetadataPush),
),
)
func (r *rsocketClient) onServerMetadataPush(pay payload.Payload) {
metadata, err := decodeServerMetadataPush(pay)
if err != nil {
panic(err)
}
}

func metadataPush(onServerMetadataPush OnServerMetadataPush) func(pay payload.Payload) {
return func(pay payload.Payload) {
metadata, err := decodeServerMetadataPush(pay)
if err != nil {
panic(err)
}
onServerMetadataPush(metadata)
if metadata.SetupResponse != nil {
setupResponse := metadata.SetupResponse
serverSupportsZstd := (setupResponse.ZstdSupported != nil && *setupResponse.ZstdSupported)
// zstd is only supported if both the client and the server support it.
r.useZstd = r.useZstd && serverSupportsZstd
}
}

Expand Down

0 comments on commit 1ecfb2f

Please sign in to comment.