From ddadefcb62611720f07d95f0b0004f50e761b29d Mon Sep 17 00:00:00 2001 From: SimFG Date: Tue, 24 Sep 2024 10:19:20 +0800 Subject: [PATCH] enhance: get msg type from the msg header to reduce the Unmarshal usage (#36409) /kind improvement Signed-off-by: SimFG --- pkg/mq/common/message.go | 35 ++++++++++++++++++++++++++++++++ pkg/mq/msgstream/mq_msgstream.go | 17 ++++++---------- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/pkg/mq/common/message.go b/pkg/mq/common/message.go index bd8d231c4874b..e0b215d8d1804 100644 --- a/pkg/mq/common/message.go +++ b/pkg/mq/common/message.go @@ -16,6 +16,14 @@ package common +import ( + "fmt" + + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) + // ProducerOptions contains the options of a producer type ProducerOptions struct { // The topic that this Producer will publish @@ -65,3 +73,30 @@ const ( // SubscriptionPositionUnkown indicates we don't care about the consumer location, since we are doing another seek or only some meta api over that SubscriptionPositionUnknown ) + +const MsgTypeKey = "msg_type" + +func GetMsgType(msg Message) (commonpb.MsgType, error) { + msgType := commonpb.MsgType_Undefined + properties := msg.Properties() + if properties != nil { + if val, ok := properties[MsgTypeKey]; ok { + msgType = commonpb.MsgType(commonpb.MsgType_value[val]) + } + } + if msgType == commonpb.MsgType_Undefined { + header := commonpb.MsgHeader{} + if msg.Payload() == nil { + return msgType, fmt.Errorf("failed to unmarshal message header, payload is empty") + } + err := proto.Unmarshal(msg.Payload(), &header) + if err != nil { + return msgType, fmt.Errorf("failed to unmarshal message header, err %s", err.Error()) + } + if header.Base == nil { + return msgType, fmt.Errorf("failed to unmarshal message, header is uncomplete") + } + msgType = header.Base.MsgType + } + return msgType, nil +} diff --git a/pkg/mq/msgstream/mq_msgstream.go b/pkg/mq/msgstream/mq_msgstream.go index 846f9d447f491..38f32342540c0 100644 --- a/pkg/mq/msgstream/mq_msgstream.go +++ b/pkg/mq/msgstream/mq_msgstream.go @@ -325,7 +325,9 @@ func (ms *mqMsgStream) Produce(msgPack *MsgPack) error { return err } - msg := &common.ProducerMessage{Payload: m, Properties: map[string]string{}} + msg := &common.ProducerMessage{Payload: m, Properties: map[string]string{ + common.MsgTypeKey: v.Msgs[i].Type().String(), + }} InjectCtx(spanCtx, msg.Properties) ms.producerLock.RLock() @@ -396,18 +398,11 @@ func (ms *mqMsgStream) getTsMsgFromConsumerMsg(msg common.Message) (TsMsg, error // GetTsMsgFromConsumerMsg get TsMsg from consumer message func GetTsMsgFromConsumerMsg(unmarshalDispatcher UnmarshalDispatcher, msg common.Message) (TsMsg, error) { - header := commonpb.MsgHeader{} - if msg.Payload() == nil { - return nil, fmt.Errorf("failed to unmarshal message header, payload is empty") - } - err := proto.Unmarshal(msg.Payload(), &header) + msgType, err := common.GetMsgType(msg) if err != nil { - return nil, fmt.Errorf("failed to unmarshal message header, err %s", err.Error()) - } - if header.Base == nil { - return nil, fmt.Errorf("failed to unmarshal message, header is uncomplete") + return nil, err } - tsMsg, err := unmarshalDispatcher.Unmarshal(msg.Payload(), header.Base.MsgType) + tsMsg, err := unmarshalDispatcher.Unmarshal(msg.Payload(), msgType) if err != nil { return nil, fmt.Errorf("failed to unmarshal tsMsg, err %s", err.Error()) }