Skip to content

Commit

Permalink
optimize(ttheader): set remote address for client-side after decoding…
Browse files Browse the repository at this point in the history
… TTHeader (cloudwego#465)
  • Loading branch information
YangruiEmma authored May 25, 2022
1 parent e7691de commit 0fe9a04
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 33 deletions.
8 changes: 4 additions & 4 deletions pkg/remote/codec/default_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func TestDefaultCodec_Encode_Decode(t *testing.T) {
ctx := context.Background()
intKVInfo := prepareIntKVInfo()
strKVInfo := prepareStrKVInfo()
sendMsg := initSendMsg(transport.TTHeader)
sendMsg := initClientSendMsg(transport.TTHeader)
sendMsg.TransInfo().PutTransIntInfo(intKVInfo)
sendMsg.TransInfo().PutTransStrInfo(strKVInfo)

Expand All @@ -168,7 +168,7 @@ func TestDefaultCodec_Encode_Decode(t *testing.T) {
test.Assert(t, err == nil, err)

// decode
recvMsg := initRecvMsg()
recvMsg := initServerRecvMsg()
buf, err := out.Bytes()
test.Assert(t, err == nil, err)
in := remote.NewReaderBuffer(buf)
Expand All @@ -190,7 +190,7 @@ func TestDefaultSizedCodec_Encode_Decode(t *testing.T) {
ctx := context.Background()
intKVInfo := prepareIntKVInfo()
strKVInfo := prepareStrKVInfo()
sendMsg := initSendMsg(transport.TTHeader)
sendMsg := initClientSendMsg(transport.TTHeader)
sendMsg.TransInfo().PutTransIntInfo(intKVInfo)
sendMsg.TransInfo().PutTransStrInfo(strKVInfo)

Expand All @@ -203,7 +203,7 @@ func TestDefaultSizedCodec_Encode_Decode(t *testing.T) {
test.Assert(t, err == nil, err)

// decode
recvMsg := initRecvMsg()
recvMsg := initServerRecvMsg()
smallBuf, _ := smallOut.Bytes()
largeBuf, _ := largeOut.Bytes()
err = smallDc.Decode(ctx, recvMsg, remote.NewReaderBuffer(smallBuf))
Expand Down
31 changes: 20 additions & 11 deletions pkg/remote/codec/header_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/remote/transmeta"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/pkg/utils"
)
Expand Down Expand Up @@ -196,9 +197,8 @@ func (t ttHeader) decode(ctx context.Context, message remote.Message, in remote.
if err := readKVInfo(hdIdx, headerInfo, message); err != nil {
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader read kv info failed, %s", err.Error()))
}
if message.RPCRole() == remote.Server {
fillBasicFromInfoOfTTHeader(message)
}
fillBasicInfoOfTTHeader(message)

message.SetPayloadLen(int(totalLen - uint32(headerInfoSize) + Size32 - TTHeaderMetaSize))
return err
}
Expand Down Expand Up @@ -444,15 +444,24 @@ func (m meshHeader) decode(ctx context.Context, message remote.Message, in remot
// Fill basic from_info(from service, from address) which carried by ttheader to rpcinfo.
// It is better to fill rpcinfo in matahandlers in terms of design,
// but metahandlers are executed after payloadDecode, we don't know from_info when error happen in payloadDecode.
// So 'fillBasicFromInfoOfTTHeader' is just for getting more info to output log when decode error happen.
func fillBasicFromInfoOfTTHeader(svrMsg remote.Message) {
fi := rpcinfo.AsMutableEndpointInfo(svrMsg.RPCInfo().From())
if fi != nil {
if v := svrMsg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr]; v != "" {
fi.SetAddress(utils.NewNetAddr("tcp", v))
// So 'fillBasicInfoOfTTHeader' is just for getting more info to output log when decode error happen.
func fillBasicInfoOfTTHeader(msg remote.Message) {
if msg.RPCRole() == remote.Server {
fi := rpcinfo.AsMutableEndpointInfo(msg.RPCInfo().From())
if fi != nil {
if v := msg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr]; v != "" {
fi.SetAddress(utils.NewNetAddr("tcp", v))
}
if v := msg.TransInfo().TransIntInfo()[transmeta.FromService]; v != "" {
fi.SetServiceName(v)
}
}
if v := svrMsg.TransInfo().TransIntInfo()[transmeta.FromService]; v != "" {
fi.SetServiceName(v)
} else {
ti := remoteinfo.AsRemoteInfo(msg.RPCInfo().To())
if ti != nil {
if v := msg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr]; v != "" {
ti.SetRemoteAddr(utils.NewNetAddr("tcp", v))
}
}
}
}
132 changes: 114 additions & 18 deletions pkg/remote/codec/header_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,25 @@ package codec
import (
"context"
"encoding/binary"
"net"
"testing"

"github.com/cloudwego/kitex/internal/mocks"
"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/discovery"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/remote/transmeta"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
"github.com/cloudwego/kitex/transport"
)

var mockPayloadLen = 100

func TestTTHeaderCodec(t *testing.T) {
ctx := context.Background()
sendMsg := initSendMsg(transport.TTHeader)
sendMsg := initClientSendMsg(transport.TTHeader)

// encode
out := remote.NewWriterBuffer(256)
Expand All @@ -43,7 +46,7 @@ func TestTTHeaderCodec(t *testing.T) {
test.Assert(t, err == nil, err)

// decode
recvMsg := initRecvMsg()
recvMsg := initServerRecvMsg()
buf, err := out.Bytes()
test.Assert(t, err == nil, err)
in := remote.NewReaderBuffer(buf)
Expand All @@ -56,7 +59,7 @@ func TestTTHeaderCodecWithTransInfo(t *testing.T) {
ctx := context.Background()
intKVInfo := prepareIntKVInfo()
strKVInfo := prepareStrKVInfo()
sendMsg := initSendMsg(transport.TTHeader)
sendMsg := initClientSendMsg(transport.TTHeader)
sendMsg.TransInfo().PutTransIntInfo(intKVInfo)
sendMsg.TransInfo().PutTransStrInfo(strKVInfo)
sendMsg.Tags()[HeaderFlagsKey] = HeaderFlagSupportOutOfOrder
Expand All @@ -68,7 +71,7 @@ func TestTTHeaderCodecWithTransInfo(t *testing.T) {
test.Assert(t, err == nil, err)

// decode
recvMsg := initRecvMsg()
recvMsg := initServerRecvMsg()
buf, err := out.Bytes()
test.Assert(t, err == nil, err)
in := remote.NewReaderBuffer(buf)
Expand All @@ -85,20 +88,64 @@ func TestTTHeaderCodecWithTransInfo(t *testing.T) {
test.Assert(t, flag == uint16(HeaderFlagSupportOutOfOrder))
}

func TestFillBasicInfoOfTTHeader(t *testing.T) {
ctx := context.Background()
mockAddr := "mock address"
// 1. server side fill from address
// encode
sendMsg := initClientSendMsg(transport.TTHeader)
sendMsg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr] = mockAddr
sendMsg.TransInfo().TransIntInfo()[transmeta.FromService] = mockServiceName
out := remote.NewWriterBuffer(256)
totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, out)
binary.BigEndian.PutUint32(totalLenField, uint32(out.MallocLen()-Size32+mockPayloadLen))
test.Assert(t, err == nil, err)
// decode
recvMsg := initServerRecvMsg()
buf, err := out.Bytes()
test.Assert(t, err == nil, err)
in := remote.NewReaderBuffer(buf)
err = ttHeaderCodec.decode(ctx, recvMsg, in)
test.Assert(t, err == nil, err)
test.Assert(t, recvMsg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr] == mockAddr)
test.Assert(t, recvMsg.RPCInfo().From().Address().String() == mockAddr)
test.Assert(t, recvMsg.RPCInfo().From().ServiceName() == mockServiceName)

// 2. client side fill to address
// encode
sendMsg = initServerSendMsg(transport.TTHeader)
sendMsg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr] = mockAddr
out = remote.NewWriterBuffer(256)
totalLenField, err = ttHeaderCodec.encode(ctx, sendMsg, out)
binary.BigEndian.PutUint32(totalLenField, uint32(out.MallocLen()-Size32+mockPayloadLen))
test.Assert(t, err == nil, err)
// decode
recvMsg = initClientRecvMsg()
toInfo := remoteinfo.AsRemoteInfo(recvMsg.RPCInfo().To())
toInfo.SetInstance(&mockInst{})
buf, err = out.Bytes()
test.Assert(t, err == nil, err)
in = remote.NewReaderBuffer(buf)
err = ttHeaderCodec.decode(ctx, recvMsg, in)
test.Assert(t, err == nil, err)
test.Assert(t, recvMsg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr] == mockAddr)
test.Assert(t, toInfo.Address().String() == mockAddr, toInfo.Address())
}

func BenchmarkTTHeaderCodec(b *testing.B) {
ctx := context.Background()

b.ResetTimer()
for i := 0; i < b.N; i++ {
sendMsg := initSendMsg(transport.TTHeader)
sendMsg := initClientSendMsg(transport.TTHeader)
// encode
out := remote.NewWriterBuffer(256)
totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, out)
binary.BigEndian.PutUint32(totalLenField, uint32(out.MallocLen()-Size32+mockPayloadLen))
test.Assert(b, err == nil, err)

// decode
recvMsg := initRecvMsg()
recvMsg := initServerRecvMsg()
buf, err := out.Bytes()
test.Assert(b, err == nil, err)
in := remote.NewReaderBuffer(buf)
Expand All @@ -116,7 +163,7 @@ func BenchmarkTTHeaderWithTransInfoParallel(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
sendMsg := initSendMsg(transport.TTHeader)
sendMsg := initClientSendMsg(transport.TTHeader)
sendMsg.TransInfo().PutTransIntInfo(intKVInfo)
sendMsg.TransInfo().PutTransStrInfo(strKVInfo)
sendMsg.Tags()[HeaderFlagsKey] = HeaderFlagSupportOutOfOrder
Expand All @@ -128,7 +175,7 @@ func BenchmarkTTHeaderWithTransInfoParallel(b *testing.B) {
test.Assert(b, err == nil, err)

// decode
recvMsg := initRecvMsg()
recvMsg := initServerRecvMsg()
buf, err := out.Bytes()
test.Assert(b, err == nil, err)
in := remote.NewReaderBuffer(buf)
Expand All @@ -153,15 +200,15 @@ func BenchmarkTTHeaderCodecParallel(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
sendMsg := initSendMsg(transport.TTHeader)
sendMsg := initClientSendMsg(transport.TTHeader)
// encode
out := remote.NewWriterBuffer(256)
totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, out)
binary.BigEndian.PutUint32(totalLenField, uint32(out.MallocLen()-Size32+mockPayloadLen))
test.Assert(b, err == nil, err)

// decode
recvMsg := initRecvMsg()
recvMsg := initServerRecvMsg()
buf, err := out.Bytes()
test.Assert(b, err == nil, err)
in := remote.NewReaderBuffer(buf)
Expand All @@ -172,24 +219,73 @@ func BenchmarkTTHeaderCodecParallel(b *testing.B) {
})
}

func initRecvMsg() remote.Message {
var (
mockServiceName = "mock service"
mockMethod = "mock"

mockCliRPCInfo = rpcinfo.NewRPCInfo(
rpcinfo.EmptyEndpointInfo(),
remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{ServiceName: mockServiceName}, mockMethod).ImmutableView(),
rpcinfo.NewInvocation("", mockMethod),
rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())

mockSvrRPCInfo = rpcinfo.NewRPCInfo(rpcinfo.EmptyEndpointInfo(),
rpcinfo.FromBasicInfo(&rpcinfo.EndpointBasicInfo{ServiceName: mockServiceName}),
rpcinfo.NewServerInvocation(),
rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
)

func initServerRecvMsg() remote.Message {
var req interface{}
ink := rpcinfo.NewInvocation("", "mock")
ri := rpcinfo.NewRPCInfo(nil, nil, ink, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
msg := remote.NewMessage(req, mocks.ServiceInfo(), ri, remote.Call, remote.Server)
msg := remote.NewMessage(req, mocks.ServiceInfo(), mockSvrRPCInfo, remote.Call, remote.Server)
return msg
}

func initSendMsg(tp transport.Protocol) remote.Message {
func initClientSendMsg(tp transport.Protocol) remote.Message {
var req interface{}
svcInfo := mocks.ServiceInfo()
ink := rpcinfo.NewInvocation("", "mock")
ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, rpcinfo.NewRPCStats())
msg := remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Client)
msg := remote.NewMessage(req, svcInfo, mockCliRPCInfo, remote.Call, remote.Client)
msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec))
return msg
}

func initServerSendMsg(tp transport.Protocol) remote.Message {
var resp interface{}
msg := remote.NewMessage(resp, mocks.ServiceInfo(), mockSvrRPCInfo, remote.Reply, remote.Server)
msg.SetProtocolInfo(remote.NewProtocolInfo(tp, mocks.ServiceInfo().PayloadCodec))
return msg
}

func initClientRecvMsg() remote.Message {
var resp interface{}
svcInfo := mocks.ServiceInfo()
msg := remote.NewMessage(resp, svcInfo, mockCliRPCInfo, remote.Reply, remote.Client)
return msg
}

var _ discovery.Instance = &mockInst{}

type mockInst struct {
addr net.Addr
}

func (m *mockInst) Address() net.Addr {
return m.addr
}

func (m *mockInst) SetRemoteAddr(addr net.Addr) (ok bool) {
m.addr = addr
return true
}

func (m *mockInst) Weight() int {
return 10
}

func (m *mockInst) Tag(key string) (value string, exist bool) {
return
}

func prepareIntKVInfo() map[uint16]string {
kvInfo := map[uint16]string{
transmeta.FromService: "mockFromService",
Expand Down

0 comments on commit 0fe9a04

Please sign in to comment.