Skip to content

Commit

Permalink
Improve the answer using ChatGPT.
Browse files Browse the repository at this point in the history
  • Loading branch information
yjcyxky committed Aug 13, 2024
1 parent 7d7227b commit 3eb9860
Show file tree
Hide file tree
Showing 12 changed files with 533 additions and 38 deletions.
1 change: 0 additions & 1 deletion src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@
pub mod route;
pub mod schema;
pub mod auth;
pub mod publication;
6 changes: 4 additions & 2 deletions src/api/route.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! This module defines the routes of the API.

use crate::api::auth::{CustomSecurityScheme, USERNAME_PLACEHOLDER};
use crate::api::publication::Publication;
use crate::api::schema::{
ApiTags, DeleteResponse, GetConsensusResultResponse, GetEntityAttrResponse,
GetEntityColorMapResponse, GetGraphResponse, GetPromptResponse, GetPublicationsResponse,
Expand All @@ -19,6 +18,7 @@ use crate::model::graph::Graph;
use crate::model::init_db::get_kg_score_table_name;
use crate::model::kge::DEFAULT_MODEL_NAME;
use crate::model::llm::{ChatBot, Context, LlmResponse, PROMPTS};
use crate::model::publication::Publication;
use crate::model::util::match_color;
use crate::query_builder::cypher_builder::{query_nhops, query_shared_nodes};
use crate::query_builder::sql_builder::{get_all_field_pairs, make_order_clause_by_pairs};
Expand Down Expand Up @@ -91,11 +91,13 @@ impl BiomedgpsApi {
&self,
publications: Json<Vec<Publication>>,
question: Query<String>,
pool: Data<&Arc<sqlx::PgPool>>,
_token: CustomSecurityScheme,
) -> GetPublicationsSummaryResponse {
let question = question.0;
let publications = publications.0;
match Publication::fetch_summary_by_chatgpt(&question, &publications).await {
let pool_arc = pool.clone();
match Publication::fetch_summary_by_chatgpt(&question, &publications, Some(&pool_arc)).await {
Ok(result) => GetPublicationsSummaryResponse::ok(result),
Err(e) => {
let err = format!("Failed to fetch publications summary: {}", e);
Expand Down
2 changes: 1 addition & 1 deletion src/api/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize};
use validator::Validate;
use validator::ValidationErrors;

use super::publication::{Publication, PublicationRecords, PublicationsSummary, ConsensusResult};
use crate::model::publication::{Publication, PublicationRecords, PublicationsSummary, ConsensusResult};

#[derive(Tags)]
pub enum ApiTags {
Expand Down
2 changes: 1 addition & 1 deletion src/model/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub struct LlmResponse {
pub created_at: DateTime<Utc>,
}

/// The context is used to store the context for the LLM. The context can be an entity, an expanded relation, or treatments with disease context.
/// The context is used to store the context for the LLM. The context can be an entity, an expanded relation, or treatments with disease context. The context is used to identify the specific context from the request. See the route.rs for more details.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Object)]
pub struct Context {
pub entity: Option<Entity>,
Expand Down
3 changes: 2 additions & 1 deletion src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ pub mod llm;
pub mod kge;
pub mod init_db;
pub mod entity;
pub mod entity_attr;
pub mod entity_attr;
pub mod publication;
88 changes: 77 additions & 11 deletions src/api/publication.rs → src/model/publication.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::model::llm::ChatBot;
use crate::model::llm::{ChatBot, LlmContext, LlmMessage, PROMPTS, PROMPT_TEMPLATE};
use anyhow;
use log::info;
use poem_openapi::Object;
use reqwest;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use urlencoding;

const GUIDESCOPER_PUBLICATIONS_API: &str = "/api/paper_search/";
Expand Down Expand Up @@ -35,6 +36,51 @@ pub struct Publication {
pub provider_url: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Object)]
pub struct PublicationsContext {
pub publications: Vec<Publication>,
pub question: String,
}

impl LlmContext for PublicationsContext {
fn get_context(&self) -> Self {
self.clone()
}

fn render_prompt(
&self,
prompt_template_category: &str,
prompt_template: &str,
) -> Result<String, anyhow::Error> {
let mut prompt = prompt_template.to_string();
let publications = self.publications.iter().map(|p| {
format!("Title: {}\nAuthors: {}\nJournal: {}\nYear: {}\nSummary: {}\nAbstract: {}\nDOI: {}\n", p.title, p.authors.join(", "), p.journal, p.year.unwrap_or(0), p.summary, p.article_abstract.as_ref().unwrap_or(&"".to_string()), p.doi.as_ref().unwrap_or(&"".to_string()))
}).collect::<Vec<String>>();
prompt = prompt.replace("{{publications}}", &publications.join("\n"));
prompt = prompt.replace("{{question}}", &self.question);
Ok(prompt)
}

fn register_prompt_template() {
let mut prompt_templates = PROMPT_TEMPLATE.lock().unwrap();
prompt_templates.insert("answer_question_with_publications", "I have a collection of papers wrappered by the ```:\n```\n{{publications}}\n```\n\nPlease carefully analyze these papers to answer the following question: \n{{question}}\n\nIn your response, please provide a well-integrated analysis that directly answers the question. Include citations from specific papers to support your answer, and ensure that the reasoning behind your answer is clearly explained. Reference relevant details from the papers' summaries or abstracts as needed.");

let mut prompts = PROMPTS.lock().unwrap();

let mut m2 = HashMap::new();
m2.insert("key", "answer_question_with_publications");
m2.insert("label", "Answer question with publications");
m2.insert("type", "question");

// Does it exist?
if prompts.contains(&m2) {
return;
} else {
prompts.push(m2);
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Object)]
pub struct PublicationsSummary {
pub summary: String,
Expand Down Expand Up @@ -220,27 +266,47 @@ impl Publication {
pub async fn fetch_summary_by_chatgpt(
question: &str,
publications: &Vec<Publication>,
pool: Option<&sqlx::PgPool>,
) -> Result<PublicationsSummary, anyhow::Error> {
let openai_api_key = std::env::var("OPENAI_API_KEY").unwrap();
if openai_api_key.is_empty() {
return Err(anyhow::Error::msg("OPENAI_API_KEY not found"));
}

let chatbot = ChatBot::new("GPT4", &openai_api_key);
let publications_context = PublicationsContext {
publications: publications.clone(),
question: question.to_string(),
};

let publications = publications.iter().map(|p| {
format!("Title: {}\nAuthors: {}\nJournal: {}\nYear: {}\nSummary: {}\nAbstract: {}\nDOI: {}\n", p.title, p.authors.join(", "), p.journal, p.year.unwrap_or(0), p.summary, p.article_abstract.as_ref().unwrap_or(&"".to_string()), p.doi.as_ref().unwrap_or(&"".to_string()))
}).collect::<Vec<String>>();
PublicationsContext::register_prompt_template();

let mut llm_msg = match LlmMessage::new(
"answer_question_with_publications",
publications_context,
None,
) {
Ok(msg) => msg,
Err(e) => {
return Err(anyhow::Error::msg(format!(
"Failed to create LLM message: {}",
e
)));
}
};

let prompt = format!(
"I have a collection of papers wrappered by the ```:\n```\n{}\n```\n\nPlease carefully analyze these papers to answer the following question: \n{}\n\nIn your response, please provide a well-integrated analysis that directly answers the question. Include citations from specific papers to support your answer, and ensure that the reasoning behind your answer is clearly explained. Reference relevant details from the papers' summaries or abstracts as needed.",
publications.join("\n"),
question,
);
let response = match llm_msg.answer(&chatbot, pool).await {
Ok(resp) => resp,
Err(e) => {
return Err(anyhow::Error::msg(format!(
"Failed to get response from LLM: {}",
e
)));
}
};

let response = chatbot.answer(prompt).await?;
Ok(PublicationsSummary {
summary: response,
summary: response.message.clone(),
daily_limit_reached: false,
is_disputed: false,
is_incomplete: false,
Expand Down
1 change: 1 addition & 0 deletions studio/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"@mlc-ai/web-llm": "0.2.15",
"@sentry/react": "^7.108.0",
"@textea/json-viewer": "^2.9.0",
"@uiw/react-markdown-preview": "^5.1.2",
"@umijs/max": "^4.1.2",
"@umijs/route-utils": "^2.0.0",
"antd": "5.8.0",
Expand Down
4 changes: 4 additions & 0 deletions studio/src/EdgeInfoPanel/PublicationDesc.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ const Desc: React.FC<{
const highlightWords = (text: string, words: string[]): string => {
let newText = text;
words.forEach(word => {
if (!word) {
return;
}

let escapedWord = escapeRegExp(word);
let regex = new RegExp(`(${escapedWord})(?![^<]*>|[^<>]*<\/)`, 'gi');
newText = newText.replace(regex, '<span class="highlight">$1</span>');
Expand Down
37 changes: 21 additions & 16 deletions studio/src/EdgeInfoPanel/PublicationPanel.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react';
import { MarkdownViewer } from 'biominer-components';
import RehypeRaw from 'rehype-raw';
import MarkdownPreview from '@uiw/react-markdown-preview';
import { Button, List, message, Row, Col, Tag } from 'antd';
import { FileProtectOutlined } from '@ant-design/icons';
import type { Publication, PublicationDetail } from 'biominer-components/dist/typings';
Expand All @@ -24,6 +24,7 @@ const PublicationPanel: React.FC<PublicationPanelProps> = (props) => {
const [abstractMap, setAbstractMap] = useState<Record<string, string>>({});
const [searchId, setSearchId] = useState<string>('');
const [publicationSummary, setPublicationSummary] = useState<string>('Loading...');
const [publicationSummaryByChatGPT, setPublicationSummaryByChatGPT] = useState<string>('');
const [generating, setGenerating] = useState<boolean>(false);

const showAbstract = async (doc_id: string): Promise<PublicationDetail> => {
Expand Down Expand Up @@ -82,7 +83,7 @@ const PublicationPanel: React.FC<PublicationPanelProps> = (props) => {
const docId = docIds[i];
if (!publicationMap[docId].article_abstract) {
const msg = `Load ${i} publication...`;
setPublicationSummary(msg);
setPublicationSummaryByChatGPT(msg);
await showAbstract(docId).then((publication) => {
tempAbstractMap[docId] = publication.article_abstract || '';
}).catch((error) => {
Expand All @@ -94,7 +95,7 @@ const PublicationPanel: React.FC<PublicationPanelProps> = (props) => {
}

setAbstractMap(tempAbstractMap);
setPublicationSummary('Publications loaded, answering question...');
setPublicationSummaryByChatGPT('Publications loaded, answering question...');

answerQuestionWithPublications(
{
Expand All @@ -104,11 +105,11 @@ const PublicationPanel: React.FC<PublicationPanelProps> = (props) => {
Object.values(publicationMap)
).then((response) => {
console.log('Answer: ', response);
setPublicationSummary(response.summary);
setPublicationSummaryByChatGPT(response.summary);
}).catch((error) => {
setGenerating(false);
console.error('Error: ', error);
setPublicationSummary('Failed to answer question, because of the following error: ' + error);
setPublicationSummaryByChatGPT('Failed to answer question, because of the following error: ' + error);
});
}

Expand Down Expand Up @@ -158,19 +159,23 @@ const PublicationPanel: React.FC<PublicationPanelProps> = (props) => {
<Col className='publication-panel-header'>
<span>
<Tag>Question</Tag>
{props.queryStr} <Button type="primary" onClick={() => {
setGenerating(true);
const docIds = Object.keys(publicationMap);
if (docIds.length > 0) {
loadAbstractsAndAnswer(docIds);
}
}} disabled={generating || Object.keys(publicationMap).length == 0} size='small'>
Generate Detailed Answer
</Button>
{props.queryStr}
</span>
<p>
<p style={{ marginTop: '5px' }}>
{/* <Tag>Answer by AI</Tag> */}
<MarkdownViewer markdown={publicationSummary} rehypePlugins={[RehypeRaw]} />
<span>
<Tag>Short Answer</Tag> {publicationSummary} <Button type="primary" onClick={() => {
setGenerating(true);
const docIds = Object.keys(publicationMap);
if (docIds.length > 0) {
loadAbstractsAndAnswer(docIds);
}
}} disabled={generating || Object.keys(publicationMap).length == 0} size='small'>
Generate Detailed Answer
</Button>
</span>
{publicationSummaryByChatGPT &&
<MarkdownPreview className='markdown-viewer' source={publicationSummaryByChatGPT} rehypePlugins={[RehypeRaw]} />}
</p>
</Col>

Expand Down
24 changes: 24 additions & 0 deletions studio/src/EdgeInfoPanel/index.less
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
flex-direction: column;
flex-wrap: nowrap;

.ant-tag {
font-size: 0.85rem;
}

.publication-tag {
font-size: 1rem;
margin-left: 10px;
Expand All @@ -38,6 +42,26 @@
font-size: 1rem;
padding: 20px;
min-height: unset;

.markdown-viewer,
.ant-empty {
border: 1px solid #d9d9d9;
padding: 10px;
border-radius: 4px;
margin: 10px 0 0;
max-height: 500px;
overflow: auto;

h1,
h2,
h3,
h4,
h5,
h6 {
margin-top: 5px;
margin-bottom: 5px;
}
}
}

.publication-panel-content {
Expand Down
2 changes: 1 addition & 1 deletion studio/src/EdgeInfoPanel/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const EdgeInfoPanel: React.FC<EdgeInfoPanelProps> = (props) => {

if (queryStr) {
return <CommonPanel edgeInfo={props.edgeInfo} relationType={relationType}>
<PublicationPanel queryStr={queryStr} />
<PublicationPanel queryStr={queryStr} startNode={startNode.data.name} endNode={endNode.data.name} />
</CommonPanel>;
}
}
Expand Down
Loading

0 comments on commit 3eb9860

Please sign in to comment.