Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support stream mode #253

Merged
merged 2 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion code/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ OPENAI_KEY: sk-xxx,sk-xxx,sk-xxx
OPENAI_MODEL: gpt-3.5-turbo
# openAI 最大token数 默认为2000
OPENAI_MAX_TOKENS: 2000
# 响应超时时间,单位为毫秒,默认为550毫秒
OPENAI_HTTP_CLIENT_TIMEOUT: 550
# 服务器配置
HTTP_PORT: 9000
HTTPS_PORT: 9001
Expand All @@ -22,7 +24,8 @@ KEY_FILE: key.pem
API_URL: https://api.openai.com
# 代理设置, 例如 "http://127.0.0.1:7890", ""代表不使用代理
HTTP_PROXY: ""

# 是否开启流式接口返回
STREAM_MODE: false # set true to use stream mode
# AZURE OPENAI
AZURE_ON: false # set true to use Azure rather than OpenAI
AZURE_API_VERSION: 2023-03-15-preview # 2023-03-15-preview or 2022-12-01 refer https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#completions
Expand Down
1 change: 1 addition & 0 deletions code/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/pandodao/tokenizer-go v0.2.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/opus v0.0.0-20230123082803-1052c3e89e58
github.com/sashabaranov/go-openai v1.13.0
github.com/sirupsen/logrus v1.9.0
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.14.0
Expand Down
2 changes: 2 additions & 0 deletions code/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/sashabaranov/go-openai v1.13.0 h1:EAusFfnhaMaaUspUZ2+MbB/ZcVeD4epJmTOlZ+8AcAE=
github.com/sashabaranov/go-openai v1.13.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/afero v1.9.3 h1:41FoI0fD7OR7mGcKE/aOiLkGreyf8ifIOQmJANWogMk=
Expand Down
63 changes: 0 additions & 63 deletions code/handlers/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,69 +18,6 @@ func msgFilter(msg string) string {

// Parse rich text json to text
func parsePostContent(content string) string {
/*
{
"title":"我是一个标题",
"content":[
[
{
"tag":"text",
"text":"第一行 :",
"style": ["bold", "underline"]
},
{
"tag":"a",
"href":"http://www.feishu.cn",
"text":"超链接",
"style": ["bold", "italic"]
},
{
"tag":"at",
"user_id":"@_user_1",
"user_name":"",
"style": []
}
],
[
{
"tag":"img",
"image_key":"img_47354fbc-a159-40ed-86ab-2ad0f1acb42g"
}
],
[
{
"tag":"text",
"text":"第二行:",
"style": ["bold", "underline"]
},
{
"tag":"text",
"text":"文本测试",
"style": []
}
],
[
{
"tag":"img",
"image_key":"img_47354fbc-a159-40ed-86ab-2ad0f1acb42g"
}
],
[
{
"tag":"media",
"file_key": "file_v2_0dcdd7d9-fib0-4432-a519-41d25aca542j",
"image_key": "img_7ea74629-9191-4176-998c-2e603c9c5e8g"
}
],
[
{
"tag": "emotion",
"emoji_type": "SMILE"
}
]
]
}
*/
var contentMap map[string]interface{}
err := json.Unmarshal([]byte(content), &contentMap)

Expand Down
137 changes: 131 additions & 6 deletions code/handlers/event_msg_action.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
package handlers

import (
"encoding/json"
"fmt"
"log"
"strings"
"time"

"start-feishubot/services/openai"
)

type MessageAction struct { /*消息*/
}

func (*MessageAction) Execute(a *ActionInfo) bool {
msg := a.handler.sessionCache.GetMsg(*a.info.sessionId)
// 如果没有提示词,默认模拟ChatGPT
func setDefaultPrompt(msg []openai.Messages) []openai.Messages {
if !hasSystemRole(msg) {
msg = append(msg, openai.Messages{
Role: "system", Content: "You are ChatGPT, " +
Expand All @@ -22,6 +20,19 @@ func (*MessageAction) Execute(a *ActionInfo) bool {
"Current date" + time.Now().Format("20060102"),
})
}
return msg
}

type MessageAction struct { /*消息*/
}

func (*MessageAction) Execute(a *ActionInfo) bool {
if a.handler.config.StreamMode {
return true
}
msg := a.handler.sessionCache.GetMsg(*a.info.sessionId)
// 如果没有提示词,默认模拟ChatGPT
msg = setDefaultPrompt(msg)
msg = append(msg, openai.Messages{
Role: "user", Content: a.info.qParsed,
})
Expand Down Expand Up @@ -63,3 +74,117 @@ func hasSystemRole(msg []openai.Messages) bool {
}
return false
}

type StreamMessageAction struct { /*消息*/
}

func (m *StreamMessageAction) Execute(a *ActionInfo) bool {
if !a.handler.config.StreamMode {
return true
}
msg := a.handler.sessionCache.GetMsg(*a.info.sessionId)
// 如果没有提示词,默认模拟ChatGPT
msg = setDefaultPrompt(msg)
msg = append(msg, openai.Messages{
Role: "user", Content: a.info.qParsed,
})

cardId, err2 := sendOnProcess(a)
if err2 != nil {
return false
}

answer := ""
chatResponseStream := make(chan string)
done := make(chan struct{}) // 添加 done 信号,保证 goroutine 正确退出
noContentTimeout := time.AfterFunc(10*time.Second, func() {
log.Println("no content timeout")
close(done)
err := updateFinalCard(*a.ctx, "请求超时", cardId)
if err != nil {
return
}
return
})
defer noContentTimeout.Stop()

go func() {
defer func() {
if err := recover(); err != nil {
err := updateFinalCard(*a.ctx, "聊天失败", cardId)
if err != nil {
return
}
}
}()

//log.Printf("UserId: %s , Request: %s", a.info.userId, msg)
aiMode := a.handler.sessionCache.GetAIMode(*a.info.sessionId)
if err := a.handler.gpt.StreamChat(*a.ctx, msg, aiMode,
chatResponseStream); err != nil {
err := updateFinalCard(*a.ctx, "聊天失败", cardId)
if err != nil {
return
}
close(done) // 关闭 done 信号
}

close(done) // 关闭 done 信号
}()
ticker := time.NewTicker(700 * time.Millisecond)
defer ticker.Stop() // 注意在函数结束时停止 ticker
go func() {
for {
select {
case <-done:
return
case <-ticker.C:
err := updateTextCard(*a.ctx, answer, cardId)
if err != nil {
return
}
}
}
}()
for {
select {
case res, ok := <-chatResponseStream:
if !ok {
return false
}
noContentTimeout.Stop()
answer += res
//pp.Println("answer", answer)
case <-done: // 添加 done 信号的处理
err := updateFinalCard(*a.ctx, answer, cardId)
if err != nil {
return false
}
ticker.Stop()
msg := append(msg, openai.Messages{
Role: "assistant", Content: answer,
})
a.handler.sessionCache.SetMsg(*a.info.sessionId, msg)
close(chatResponseStream)
log.Printf("\n\n\n")
jsonByteArray, err := json.Marshal(msg)
if err != nil {
log.Println(err)
}
jsonStr := strings.ReplaceAll(string(jsonByteArray), "\\n", "")
jsonStr = strings.ReplaceAll(jsonStr, "\n", "")
log.Printf("\n\n\n")
return false
}
}
}

func sendOnProcess(a *ActionInfo) (*string, error) {
// send 正在处理中
cardId, err := sendOnProcessCard(*a.ctx, a.info.sessionId, a.info.msgId)
if err != nil {
return nil, err
}
return cardId, nil

}
1 change: 1 addition & 0 deletions code/handlers/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2
&BalanceAction{}, //余额处理
&RolePlayAction{}, //角色扮演处理
&MessageAction{}, //消息处理
&StreamMessageAction{}, //流式消息处理

}
chain(data, actions...)
Expand Down
121 changes: 121 additions & 0 deletions code/handlers/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -782,3 +782,124 @@ func SendAIModeListsCard(ctx context.Context,
withNote("提醒:选择内置模式,让AI更好的理解您的需求。"))
replyCard(ctx, msgId, newCard)
}

func sendOnProcessCard(ctx context.Context,
sessionId *string, msgId *string) (*string, error) {
newCard, _ := newSendCardWithOutHeader(
withNote("正在思考,请稍等..."))
id, err := replyCardWithBackId(ctx, msgId, newCard)
if err != nil {
return nil, err
}
return id, nil
}

func updateTextCard(ctx context.Context, msg string,
msgId *string) error {
newCard, _ := newSendCardWithOutHeader(
withMainText(msg),
withNote("正在生成,请稍等..."))
err := PatchCard(ctx, msgId, newCard)
if err != nil {
return err
}
return nil
}
func updateFinalCard(
ctx context.Context,
msg string,
msgId *string,
) error {
newCard, _ := newSendCardWithOutHeader(
withMainText(msg),
withNote("已完成,您可以继续提问或者选择其他功能。"))
err := PatchCard(ctx, msgId, newCard)
if err != nil {
return err
}
return nil
}

func newSendCardWithOutHeader(
elements ...larkcard.MessageCardElement) (string, error) {
config := larkcard.NewMessageCardConfig().
WideScreenMode(false).
EnableForward(true).
UpdateMulti(true).
Build()
var aElementPool []larkcard.MessageCardElement
for _, element := range elements {
aElementPool = append(aElementPool, element)
}
// 卡片消息体
cardContent, err := larkcard.NewMessageCard().
Config(config).
Elements(
aElementPool,
).
String()
return cardContent, err
}

func PatchCard(ctx context.Context, msgId *string,
cardContent string) error {
//fmt.Println("sendMsg", msg, chatId)
client := initialization.GetLarkClient()
//content := larkim.NewTextMsgBuilder().
// Text(msg).
// Build()

//fmt.Println("content", content)

resp, err := client.Im.Message.Patch(ctx, larkim.NewPatchMessageReqBuilder().
MessageId(*msgId).
Body(larkim.NewPatchMessageReqBodyBuilder().
Content(cardContent).
Build()).
Build())

// 处理错误
if err != nil {
fmt.Println(err)
return err
}

// 服务端错误处理
if !resp.Success() {
fmt.Println(resp.Code, resp.Msg, resp.RequestId())
return errors.New(resp.Msg)
}
return nil
}

func replyCardWithBackId(ctx context.Context,
msgId *string,
cardContent string,
) (*string, error) {
client := initialization.GetLarkClient()
resp, err := client.Im.Message.Reply(ctx, larkim.NewReplyMessageReqBuilder().
MessageId(*msgId).
Body(larkim.NewReplyMessageReqBodyBuilder().
MsgType(larkim.MsgTypeInteractive).
Uuid(uuid.New().String()).
Content(cardContent).
Build()).
Build())

// 处理错误
if err != nil {
fmt.Println(err)
return nil, err
}

// 服务端错误处理
if !resp.Success() {
fmt.Println(resp.Code, resp.Msg, resp.RequestId())
return nil, errors.New(resp.Msg)
}

//ctx = context.WithValue(ctx, "SendMsgId", *resp.Data.MessageId)
//SendMsgId := ctx.Value("SendMsgId")
//pp.Println(SendMsgId)
return resp.Data.MessageId, nil
}
Loading