Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions internal/api/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ func TestAuthServer(t *testing.T) {
}
assert.True(t, token.ExpiresAt.Before(timeNow))

// 现在 Token 有效期应该是 24 小时前
// Now Token expiration should be 24 hours ago
token, err = tokenService.GetTokenByToken(context.Background(), token.Token)
if err != nil {
t.Fatalf("Failed to get refresh token: %v", err)
}
assert.True(t, token.ExpiresAt.Before(timeNow))

// 这时候 RefreshToken 应该失效
// At this point RefreshToken should be invalid
resp, err := authServer.RefreshToken(context.Background(),
&authv1.RefreshTokenRequest{
RefreshToken: token.Token,
Expand All @@ -71,15 +71,15 @@ func TestAuthServer(t *testing.T) {
assert.Error(t, err)
assert.Nil(t, resp)

// 更新 Token 有效期 到 24 小时候
// Update Token expiration to 24 hours later
token.ExpiresAt = timeNow.Add(time.Hour * 24)
token, err = tokenService.UpdateToken(context.Background(), token)
if err != nil {
t.Fatalf("Failed to update refresh token: %v", err)
}
assert.True(t, token.ExpiresAt.After(timeNow))

// 这时候 RefreshToken 应该有效
// At this point RefreshToken should be valid
resp, err = authServer.RefreshToken(context.Background(),
&authv1.RefreshTokenRequest{
RefreshToken: token.Token,
Expand All @@ -88,7 +88,7 @@ func TestAuthServer(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, resp)

// 刚刚 RefreshToken 之后,有效期应该刷新到一个月后
// After RefreshToken, expiration should be refreshed to one month later
token, err = tokenService.GetTokenByToken(context.Background(), resp.RefreshToken)
if err != nil {
t.Fatalf("Failed to get refresh token: %v", err)
Expand Down
251 changes: 0 additions & 251 deletions internal/api/chat/create_conversation_message.go
Original file line number Diff line number Diff line change
@@ -1,252 +1 @@
package chat

import (
"context"

"paperdebugger/internal/libs/contextutil"
"paperdebugger/internal/libs/shared"
"paperdebugger/internal/models"
chatv1 "paperdebugger/pkg/gen/api/chat/v1"

"github.com/google/uuid"
"github.com/openai/openai-go/v2/responses"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"google.golang.org/protobuf/encoding/protojson"
)

// 设计理念:
// 发送给 GPT 之前,消息列表已经构造进 Conversation 对象中(也保存在数据库里)
// 我们发送给 GPT 的就是从数据库里拿到的 Conversation 对象里面的内容(InputItemList)

// buildUserMessage constructs both the user-facing message and the OpenAI input message
func (s *ChatServer) buildUserMessage(ctx context.Context, userMessage, userSelectedText string, conversationType chatv1.ConversationType) (*chatv1.Message, *responses.ResponseInputItemUnionParam, error) {
userPrompt, err := s.chatService.GetPrompt(ctx, userMessage, userSelectedText, conversationType)
if err != nil {
return nil, nil, err
}

var inappMessage *chatv1.Message
switch conversationType {
case chatv1.ConversationType_CONVERSATION_TYPE_DEBUG:
inappMessage = &chatv1.Message{
MessageId: "pd_msg_user_" + uuid.New().String(),
Payload: &chatv1.MessagePayload{
MessageType: &chatv1.MessagePayload_User{
User: &chatv1.MessageTypeUser{
Content: userPrompt,
},
},
},
}
default:
inappMessage = &chatv1.Message{
MessageId: "pd_msg_user_" + uuid.New().String(),
Payload: &chatv1.MessagePayload{
MessageType: &chatv1.MessagePayload_User{
User: &chatv1.MessageTypeUser{
Content: userMessage,
SelectedText: &userSelectedText,
},
},
},
}
}

openaiMessage := &responses.ResponseInputItemUnionParam{
OfInputMessage: &responses.ResponseInputItemMessageParam{
Role: "user",
Content: responses.ResponseInputMessageContentListParam{
responses.ResponseInputContentParamOfInputText(userPrompt),
},
},
}

return inappMessage, openaiMessage, nil
}

// buildSystemMessage constructs both the user-facing system message and the OpenAI input message
func (s *ChatServer) buildSystemMessage(systemPrompt string) (*chatv1.Message, *responses.ResponseInputItemUnionParam) {
inappMessage := &chatv1.Message{
MessageId: "pd_msg_system_" + uuid.New().String(),
Payload: &chatv1.MessagePayload{
MessageType: &chatv1.MessagePayload_System{
System: &chatv1.MessageTypeSystem{
Content: systemPrompt,
},
},
},
}

openaiMessage := &responses.ResponseInputItemUnionParam{
OfInputMessage: &responses.ResponseInputItemMessageParam{
Role: "system",
Content: responses.ResponseInputMessageContentListParam{
responses.ResponseInputContentParamOfInputText(systemPrompt),
},
},
}

return inappMessage, openaiMessage
}

// convertToBSON converts a protobuf message to BSON
func convertToBSON(msg *chatv1.Message) (bson.M, error) {
jsonBytes, err := protojson.Marshal(msg)
if err != nil {
return nil, err
}
var bsonMsg bson.M
if err := bson.UnmarshalExtJSON(jsonBytes, true, &bsonMsg); err != nil {
return nil, err
}
return bsonMsg, nil
}

// 创建对话并写入数据库
// 返回 Conversation 对象
func (s *ChatServer) createConversation(
ctx context.Context,
userId bson.ObjectID,
projectId string,
latexFullSource string,
projectInstructions string,
userInstructions string,
userMessage string,
userSelectedText string,
modelSlug string,
conversationType chatv1.ConversationType,
) (*models.Conversation, error) {
systemPrompt, err := s.chatService.GetSystemPrompt(ctx, latexFullSource, projectInstructions, userInstructions, conversationType)
if err != nil {
return nil, err
}

_, openaiSystemMsg := s.buildSystemMessage(systemPrompt)
inappUserMsg, openaiUserMsg, err := s.buildUserMessage(ctx, userMessage, userSelectedText, conversationType)
if err != nil {
return nil, err
}

messages := []*chatv1.Message{inappUserMsg}
oaiHistory := responses.ResponseNewParamsInputUnion{
OfInputItemList: responses.ResponseInputParam{*openaiSystemMsg, *openaiUserMsg},
}

return s.chatService.InsertConversationToDB(
ctx, userId, projectId, modelSlug, messages, oaiHistory.OfInputItemList,
)
}

// 追加消息到对话并写入数据库
// 返回 Conversation 对象
func (s *ChatServer) appendConversationMessage(
ctx context.Context,
userId bson.ObjectID,
conversationId string,
userMessage string,
userSelectedText string,
conversationType chatv1.ConversationType,
) (*models.Conversation, error) {
objectID, err := bson.ObjectIDFromHex(conversationId)
if err != nil {
return nil, err
}

conversation, err := s.chatService.GetConversation(ctx, userId, objectID)
if err != nil {
return nil, err
}

userMsg, userOaiMsg, err := s.buildUserMessage(ctx, userMessage, userSelectedText, conversationType)
if err != nil {
return nil, err
}

bsonMsg, err := convertToBSON(userMsg)
if err != nil {
return nil, err
}
conversation.InappChatHistory = append(conversation.InappChatHistory, bsonMsg)
conversation.OpenaiChatHistory = append(conversation.OpenaiChatHistory, *userOaiMsg)

if err := s.chatService.UpdateConversation(conversation); err != nil {
return nil, err
}

return conversation, nil
}

// 如果 conversationId 是 "", 就创建新对话,否则就追加消息到对话
// conversationType 可以在一次 conversation 中多次切换
func (s *ChatServer) prepare(ctx context.Context, projectId string, conversationId string, userMessage string, userSelectedText string, modelSlug string, conversationType chatv1.ConversationType) (context.Context, *models.Conversation, *models.Settings, error) {
actor, err := contextutil.GetActor(ctx)
if err != nil {
return ctx, nil, nil, err
}

project, err := s.projectService.GetProject(ctx, actor.ID, projectId)
if err != nil && err != mongo.ErrNoDocuments {
return ctx, nil, nil, err
}

userInstructions, err := s.userService.GetUserInstructions(ctx, actor.ID)
if err != nil {
return ctx, nil, nil, err
}

var latexFullSource string
switch conversationType {
case chatv1.ConversationType_CONVERSATION_TYPE_DEBUG:
latexFullSource = "latex_full_source is not available in debug mode"
default:
if project == nil || project.IsOutOfDate() {
return ctx, nil, nil, shared.ErrProjectOutOfDate("project is out of date")
}

latexFullSource, err = project.GetFullContent()
if err != nil {
return ctx, nil, nil, err
}
}

var conversation *models.Conversation

if conversationId == "" {
conversation, err = s.createConversation(
ctx,
actor.ID,
projectId,
latexFullSource,
project.Instructions,
userInstructions,
userMessage,
userSelectedText,
modelSlug,
conversationType,
)
} else {
conversation, err = s.appendConversationMessage(
ctx,
actor.ID,
conversationId,
userMessage,
userSelectedText,
conversationType,
)
}

if err != nil {
return ctx, nil, nil, err
}

ctx = contextutil.SetProjectID(ctx, conversation.ProjectID)
ctx = contextutil.SetConversationID(ctx, conversation.ID.Hex())

settings, err := s.userService.GetUserSettings(ctx, actor.ID)
if err != nil {
return ctx, conversation, nil, err
}

return ctx, conversation, settings, nil
}
Loading
Loading