Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: find question message from reply message #1085

Merged
merged 5 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions libs/client-api/src/http_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,28 @@ impl Client {
.into_data()
}

pub async fn get_question_message_from_answer_id(
&self,
workspace_id: &str,
chat_id: &str,
answer_message_id: i64,
) -> Result<Option<ChatMessage>, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/message/find_question",
self.base_url
);

let resp = self
.http_client_with_auth(Method::GET, &url)
.await?
richardshiue marked this conversation as resolved.
Show resolved Hide resolved
.query(&[("answer_message_id", answer_message_id)])
.send()
.await?;
AppResponse::<Option<ChatMessage>>::from_response(resp)
.await?
.into_data()
}

pub async fn calculate_similarity(
&self,
params: CalculateSimilarityParams,
Expand Down
37 changes: 37 additions & 0 deletions libs/database/src/chat/chat_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,40 @@ pub async fn select_chat_message_content<'a, E: Executor<'a, Database = Postgres
.await?;
Ok((row.content, row.meta_data))
}

pub async fn select_chat_message_matching_reply_message_id(
txn: &mut Transaction<'_, Postgres>,
chat_id: &str,
reply_message_id: i64,
) -> Result<Option<ChatMessage>, AppError> {
let chat_id = Uuid::from_str(chat_id)?;
let row = sqlx::query!(
r#"
SELECT message_id, content, created_at, author, meta_data, reply_message_id
FROM af_chat_messages
WHERE chat_id = $1
AND reply_message_id = $2
"#,
&chat_id,
reply_message_id
)
.fetch_one(txn.deref_mut())
.await?;

let message = match serde_json::from_value::<ChatAuthor>(row.author) {
Ok(author) => Some(ChatMessage {
author,
message_id: row.message_id,
content: row.content,
created_at: row.created_at,
meta_data: row.meta_data,
reply_message_id: row.reply_message_id,
}),
Err(err) => {
warn!("Failed to deserialize author: {}", err);
None
},
};

Ok(message)
}
23 changes: 22 additions & 1 deletion src/api/chat.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::biz::chat::ops::{
create_chat, create_chat_message, delete_chat, generate_chat_message_answer, get_chat_messages,
update_chat_message,
get_question_message, update_chat_message,
};
use crate::state::AppState;
use actix_web::web::{Data, Json};
use actix_web::{web, HttpRequest, HttpResponse, Scope};
use serde::Deserialize;

use crate::api::util::ai_model_from_header;
use app_error::AppError;
Expand Down Expand Up @@ -69,6 +70,10 @@ pub fn chat_scope() -> Scope {
web::resource("/{chat_id}/message/answer")
.route(web::post().to(save_answer_handler))
)
.service(
web::resource("/{chat_id}/message/find_question")
.route(web::get().to(get_chat_question_message_handler))
)

// AI response generation
.service(
Expand Down Expand Up @@ -349,6 +354,17 @@ async fn get_chat_message_handler(
Ok(AppResponse::Ok().with_data(messages).into())
}

#[instrument(level = "debug", skip_all, err)]
async fn get_chat_question_message_handler(
path: web::Path<(String, String)>,
query: web::Query<FindQuestionParams>,
state: Data<AppState>,
) -> actix_web::Result<JsonAppResponse<Option<ChatMessage>>> {
let (_workspace_id, chat_id) = path.into_inner();
let message = get_question_message(&state.pg_pool, &chat_id, query.0.answer_message_id).await?;
Ok(AppResponse::Ok().with_data(message).into())
}

#[instrument(level = "debug", skip_all, err)]
async fn get_chat_settings_handler(
path: web::Path<(String, String)>,
Expand Down Expand Up @@ -501,3 +517,8 @@ where
}
}
}

#[derive(Debug, Deserialize)]
struct FindQuestionParams {
answer_message_id: i64,
}
14 changes: 13 additions & 1 deletion src/biz/chat/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use database::chat;
use database::chat::chat_ops::{
delete_answer_message_by_question_message_id, insert_answer_message,
insert_answer_message_with_transaction, insert_chat, insert_question_message,
select_chat_messages,
select_chat_message_matching_reply_message_id, select_chat_messages,
};
use futures::stream::Stream;
use serde_json::json;
Expand Down Expand Up @@ -232,3 +232,15 @@ pub async fn get_chat_messages(
txn.commit().await?;
Ok(messages)
}

pub async fn get_question_message(
pg_pool: &PgPool,
chat_id: &str,
answer_message_id: i64,
) -> Result<Option<ChatMessage>, AppError> {
let mut txn = pg_pool.begin().await?;
let message =
select_chat_message_matching_reply_message_id(&mut txn, chat_id, answer_message_id).await?;
txn.commit().await?;
Ok(message)
}
76 changes: 69 additions & 7 deletions tests/ai_test/chat_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use client_api_test::{ai_test_enabled, TestClient};
use futures_util::StreamExt;
use serde_json::json;
use shared_entity::dto::chat_dto::{
ChatMessageMetadata, ChatRAGData, CreateChatMessageParams, CreateChatParams, MessageCursor,
UpdateChatParams,
ChatMessageMetadata, ChatRAGData, CreateAnswerMessageParams, CreateChatMessageParams,
CreateChatParams, MessageCursor, UpdateChatParams,
};

#[tokio::test]
Expand Down Expand Up @@ -344,6 +344,10 @@ async fn create_chat_context_test() {

// #[tokio::test]
// async fn update_chat_message_test() {
// if !ai_test_enabled() {
// return;
// }

// let test_client = TestClient::new_user_without_ws_conn().await;
// let workspace_id = test_client.workspace_id().await;
// let chat_id = uuid::Uuid::new_v4().to_string();
Expand All @@ -352,13 +356,13 @@ async fn create_chat_context_test() {
// name: "my second chat".to_string(),
// rag_ids: vec![],
// };
//

// test_client
// .api_client
// .create_chat(&workspace_id, params)
// .await
// .unwrap();
//

// let params = CreateChatMessageParams::new_user("where is singapore?");
// let stream = test_client
// .api_client
Expand All @@ -367,7 +371,7 @@ async fn create_chat_context_test() {
// .unwrap();
// let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
// assert_eq!(messages.len(), 2);
//

// let params = UpdateChatMessageContentParams {
// chat_id: chat_id.clone(),
// message_id: messages[0].message_id,
Expand All @@ -378,7 +382,7 @@ async fn create_chat_context_test() {
// .update_chat_message(&workspace_id, &chat_id, params)
// .await
// .unwrap();
//

// let remote_messages = test_client
// .api_client
// .get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 2)
Expand All @@ -387,11 +391,69 @@ async fn create_chat_context_test() {
// .messages;
// assert_eq!(remote_messages[0].content, "where is China?");
// assert_eq!(remote_messages.len(), 2);
//

// // when the question was updated, the answer should be different
// assert_ne!(remote_messages[1].content, messages[1].content);
// }

#[tokio::test]
async fn get_question_message_test() {
if !ai_test_enabled() {
return;
}

let test_client = TestClient::new_user_without_ws_conn().await;
let workspace_id = test_client.workspace_id().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let params = CreateChatParams {
chat_id: chat_id.clone(),
name: "my ai chat".to_string(),
rag_ids: vec![],
};

test_client
.api_client
.create_chat(&workspace_id, params)
.await
.unwrap();

let params = CreateChatMessageParams::new_user("where is singapore?");
let question = test_client
.api_client
.create_question(&workspace_id, &chat_id, params)
.await
.unwrap();

let answer = test_client
.api_client
.get_answer(&workspace_id, &chat_id, question.message_id)
.await
.unwrap();

test_client
.api_client
.save_answer(
&workspace_id,
&chat_id,
CreateAnswerMessageParams {
content: answer.content,
metadata: None,
question_message_id: question.message_id,
},
)
.await
.unwrap();

let find_question = test_client
.api_client
.get_question_message_from_answer_id(&workspace_id, &chat_id, answer.message_id)
.await
.unwrap()
.unwrap();

assert_eq!(find_question.reply_message_id.unwrap(), answer.message_id);
}

async fn collect_answer(mut stream: QuestionStream) -> String {
let mut answer = String::new();
while let Some(value) = stream.next().await {
Expand Down
Loading