Skip to content

Commit

Permalink
feat: support for visualizing citation results (via embeddings)
Browse files Browse the repository at this point in the history
Signed-off-by: Kennywu <jdlow@live.cn>
  • Loading branch information
KKenny0 committed Nov 4, 2024
1 parent bd2490b commit cd39cc5
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 3 deletions.
2 changes: 2 additions & 0 deletions libs/kotaemon/kotaemon/indices/qa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .citation import CitationPipeline
from .text_based import CitationQAPipeline
from .visualize_cited import CreateCitationVizPipeline

__all__ = [
"CitationPipeline",
"CitationQAPipeline",
"CreateCitationVizPipeline",
]
145 changes: 145 additions & 0 deletions libs/kotaemon/kotaemon/indices/qa/visualize_cited.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
This module aims to project high-dimensional embeddings
into a lower-dimensional space for visualization.
Refs:
1. [RAGxplorer](https://github.com/gabrielchua/RAGxplorer)
2. [RAGVizExpander](https://github.com/KKenny0/RAGVizExpander)
"""
from typing import List, Tuple

import numpy as np
import pandas as pd
import plotly.graph_objs as go
import umap
from ktem.embeddings.manager import embedding_models_manager as embeddings

from kotaemon.base import BaseComponent, Node
from kotaemon.embeddings import BaseEmbeddings

VISUALIZATION_SETTINGS = {
"Original Query": {"color": "red", "opacity": 1, "symbol": "cross", "size": 15},
"Retrieved": {"color": "green", "opacity": 1, "symbol": "circle", "size": 10},
"Chunks": {"color": "blue", "opacity": 0.4, "symbol": "circle", "size": 10},
"Sub-Questions": {"color": "purple", "opacity": 1, "symbol": "star", "size": 15},
}


class CreateCitationVizPipeline(BaseComponent):
"""Creating PlotData for visualizing query results"""

embedding: BaseEmbeddings = Node(
default_callback=lambda _: embeddings.get_default()
)
projector: umap.UMAP = None

def _set_up_umap(self, embeddings: np.ndarray):
umap_transform = umap.UMAP().fit(embeddings)
return umap_transform

def _project_embeddings(self, embeddings, umap_transform) -> np.ndarray:
umap_embeddings = np.empty((len(embeddings), 2))
for i, embedding in enumerate(embeddings):
umap_embeddings[i] = umap_transform.transform([embedding])
return umap_embeddings

def _get_projections(self, embeddings, umap_transform):
projections = self._project_embeddings(embeddings, umap_transform)
x = projections[:, 0]
y = projections[:, 1]
return x, y

def _prepare_projection_df(
self,
document_projections: Tuple[np.ndarray, np.ndarray],
document_text: List[str],
plot_size: int = 3,
) -> pd.DataFrame:
"""Prepares a DataFrame for visualization from projections and texts.
Args:
document_projections (Tuple[np.ndarray, np.ndarray]):
Tuple of X and Y coordinates of document projections.
document_text (List[str]): List of document texts.
"""
df = pd.DataFrame({"x": document_projections[0], "y": document_projections[1]})
df["document"] = document_text
df["document_cleaned"] = df.document.str.wrap(50).apply(
lambda x: x.replace("\n", "<br>")[:512] + "..."
)
df["size"] = plot_size
df["category"] = "Retrieved"
return df

def _plot_embeddings(self, df: pd.DataFrame) -> go.Figure:
"""
Creates a Plotly figure to visualize the embeddings.
Args:
df (pd.DataFrame): DataFrame containing the data to visualize.
Returns:
go.Figure: A Plotly figure object for visualization.
"""
fig = go.Figure()

for category in df["category"].unique():
category_df = df[df["category"] == category]
settings = VISUALIZATION_SETTINGS.get(
category,
{"color": "grey", "opacity": 1, "symbol": "circle", "size": 10},
)
fig.add_trace(
go.Scatter(
x=category_df["x"],
y=category_df["y"],
mode="markers",
name=category,
marker=dict(
color=settings["color"],
opacity=settings["opacity"],
symbol=settings["symbol"],
size=settings["size"],
line_width=0,
),
hoverinfo="text",
text=category_df["document_cleaned"],
)
)

fig.update_layout(
height=500,
legend=dict(y=100, x=0.5, xanchor="center", yanchor="top", orientation="h"),
)
return fig

def run(self, context: List[str], question: str):
embed_contexts = self.embedding(context)
context_embeddings = np.array([d.embedding for d in embed_contexts])

self.projector = self._set_up_umap(embeddings=context_embeddings)

embed_query = self.embedding(question)
query_projection = self._get_projections(
embeddings=[embed_query[0].embedding], umap_transform=self.projector
)
viz_query_df = pd.DataFrame(
{
"x": [query_projection[0][0]],
"y": [query_projection[1][0]],
"document_cleaned": question,
"category": "Original Query",
"size": 5,
}
)

context_projections = self._get_projections(
embeddings=context_embeddings, umap_transform=self.projector
)
viz_base_df = self._prepare_projection_df(
document_projections=context_projections, document_text=context
)

visualization_df = pd.concat([viz_base_df, viz_query_df], axis=0)
fig = self._plot_embeddings(visualization_df)
return fig
56 changes: 53 additions & 3 deletions libs/ktem/ktem/reasoning/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import tiktoken
from ktem.embeddings.manager import embedding_models_manager as embeddings
from ktem.llms.manager import llms
from ktem.reasoning.prompt_optimization import (
CreateMindmapPipeline,
Expand All @@ -16,6 +17,7 @@
)
from ktem.utils.plantuml import PlantUML
from ktem.utils.render import Render
from plotly.io import to_json
from theflow.settings import settings as flowsettings

from kotaemon.base import (
Expand All @@ -28,6 +30,7 @@
SystemMessage,
)
from kotaemon.indices.qa.citation import CitationPipeline
from kotaemon.indices.qa.visualize_cited import CreateCitationVizPipeline
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import ChatLLM, PromptTemplate

Expand Down Expand Up @@ -240,6 +243,7 @@ class AnswerWithContextPipeline(BaseComponent):

enable_citation: bool = False
enable_mindmap: bool = False
enable_citation_viz: bool = False

system_prompt: str = ""
lang: str = "English" # support English and Japanese
Expand Down Expand Up @@ -409,7 +413,12 @@ def mindmap_call():

answer = Document(
text=output,
metadata={"mindmap": mindmap, "citation": citation, "qa_score": qa_score},
metadata={
"citation_viz": self.enable_citation_viz,
"mindmap": mindmap,
"citation": citation,
"qa_score": qa_score,
},
)

return answer
Expand Down Expand Up @@ -474,6 +483,11 @@ class Config:
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
rewrite_pipeline: RewriteQuestionPipeline | None = None
create_citation_viz_pipeline: CreateCitationVizPipeline = Node(
default_callback=lambda _: CreateCitationVizPipeline(
embedding=embeddings.get_default()
)
)
add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx()

def retrieve(
Expand Down Expand Up @@ -641,10 +655,36 @@ def prepare_mindmap(self, answer) -> Document | None:

return mindmap_content

def show_citations_and_addons(self, answer, docs):
def prepare_citation_viz(self, answer, question, docs) -> Document | None:
doc_texts = [doc.text for doc in docs]
citation_plot = None

def citation_viz_call():
nonlocal citation_plot
citation_plot = self.create_citation_viz_pipeline(doc_texts, question)

if answer.metadata["citation_viz"] and len(docs) > 1:
citation_plot_thread = threading.Thread(target=citation_viz_call)
citation_plot_thread.start()
citation_plot_thread.join()

plot = to_json(citation_plot)
plot_content = Document(channel="plot", content=plot)
else:
print(
"The visualization feat was not enabled or "
"the number of documents cited did not meet "
"the presentation requirements."
)
plot_content = None

return plot_content

def show_citations_and_addons(self, answer, docs, question):
# show the evidence
with_citation, without_citation = self.prepare_citations(answer, docs)
mindmap_output = self.prepare_mindmap(answer)
citation_plot_output = self.prepare_citation_viz(answer, question, docs)

if not with_citation and not without_citation:
yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>")
Expand All @@ -661,6 +701,10 @@ def show_citations_and_addons(self, answer, docs):
if mindmap_output:
yield mindmap_output

# yield citation plot output
if citation_plot_output:
yield citation_plot_output

# yield warning message
if has_llm_score and max_llm_rerank_score < CONTEXT_RELEVANT_WARNING_SCORE:
yield Document(
Expand Down Expand Up @@ -733,7 +777,7 @@ def generate_relevant_scores():
if scoring_thread:
scoring_thread.join()

yield from self.show_citations_and_addons(answer, docs)
yield from self.show_citations_and_addons(answer, docs, message)

return answer

Expand Down Expand Up @@ -767,6 +811,7 @@ def get_pipeline(cls, settings, states, retrievers):
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
answer_pipeline.enable_mindmap = settings[f"{prefix}.create_mindmap"]
answer_pipeline.enable_citation_viz = settings[f"{prefix}.create_citation_viz"]
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
Expand Down Expand Up @@ -820,6 +865,11 @@ def get_user_settings(cls) -> dict:
"value": False,
"component": "checkbox",
},
"create_citation_viz": {
"name": "Create Visualization of the retrieved docs",
"value": False,
"component": "checkbox",
},
"system_prompt": {
"name": "System Prompt",
"value": "This is a question answering system",
Expand Down

0 comments on commit cd39cc5

Please sign in to comment.