Skip to content

Commit

Permalink
Add minimum score criteria for AI search results (#1417)
Browse files Browse the repository at this point in the history
* Add minimum score criteria for AI search results

* Adjust input to support precise filtering in different search modes.

* Resolve comparison issue

* Update class style

* Fix parsing

* Add test

* Lint

* Format

* Fix tests

---------

Co-authored-by: Pamela Fox <pamela.fox@gmail.com>
  • Loading branch information
sogue and pamelafox authored Mar 20, 2024
1 parent 40e9887 commit ccf2494
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 5 deletions.
14 changes: 13 additions & 1 deletion app/backend/approaches/approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ async def search(
vectors: List[VectorQuery],
use_semantic_ranker: bool,
use_semantic_captions: bool,
minimum_search_score: Optional[float],
minimum_reranker_score: Optional[float],
) -> List[Document]:
# Use semantic ranker if requested and if retrieval mode is text or hybrid (vectors + text)
if use_semantic_ranker and query_text:
Expand Down Expand Up @@ -161,7 +163,17 @@ async def search(
reranker_score=document.get("@search.reranker_score"),
)
)
return documents

qualified_documents = [
doc
for doc in documents
if (
(doc.score or 0) >= (minimum_search_score or 0)
and (doc.reranker_score or 0) >= (minimum_reranker_score or 0)
)
]

return qualified_documents

def get_sources_content(
self, results: List[Document], use_semantic_captions: bool, use_image_citation: bool
Expand Down
14 changes: 13 additions & 1 deletion app/backend/approaches/chatreadretrieveread.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ async def run_until_final_call(
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
top = overrides.get("top", 3)
minimum_search_score = overrides.get("minimum_search_score", 0.0)
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)

filter = self.build_filter(overrides, auth_claims)
use_semantic_ranker = True if overrides.get("semantic_ranker") and has_text else False

Expand Down Expand Up @@ -149,7 +152,16 @@ async def run_until_final_call(
if not has_text:
query_text = None

results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions)
results = await self.search(
top,
query_text,
filter,
vectors,
use_semantic_ranker,
use_semantic_captions,
minimum_search_score,
minimum_reranker_score,
)

sources_content = self.get_sources_content(results, use_semantic_captions, use_image_citation=False)
content = "\n".join(sources_content)
Expand Down
13 changes: 12 additions & 1 deletion app/backend/approaches/chatreadretrievereadvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ async def run_until_final_call(
vector_fields = overrides.get("vector_fields", ["embedding"])
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
top = overrides.get("top", 3)
minimum_search_score = overrides.get("minimum_search_score", 0.0)
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
filter = self.build_filter(overrides, auth_claims)
use_semantic_ranker = True if overrides.get("semantic_ranker") and has_text else False

Expand Down Expand Up @@ -134,7 +136,16 @@ async def run_until_final_call(
if not has_text:
query_text = None

results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions)
results = await self.search(
top,
query_text,
filter,
vectors,
use_semantic_ranker,
use_semantic_captions,
minimum_search_score,
minimum_reranker_score,
)
sources_content = self.get_sources_content(results, use_semantic_captions, use_image_citation=True)
content = "\n".join(sources_content)

Expand Down
13 changes: 12 additions & 1 deletion app/backend/approaches/retrievethenread.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ async def run(

use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
top = overrides.get("top", 3)
minimum_search_score = overrides.get("minimum_search_score", 0.0)
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
filter = self.build_filter(overrides, auth_claims)
# If retrieval mode includes vectors, compute an embedding for the query
vectors: list[VectorQuery] = []
Expand All @@ -95,7 +97,16 @@ async def run(
# Only keep the text query if the retrieval mode uses text, otherwise drop it
query_text = q if has_text else None

results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions)
results = await self.search(
top,
query_text,
filter,
vectors,
use_semantic_ranker,
use_semantic_captions,
minimum_search_score,
minimum_reranker_score,
)

user_content = [q]

Expand Down
13 changes: 12 additions & 1 deletion app/backend/approaches/retrievethenreadvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ async def run(

use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
top = overrides.get("top", 3)
minimum_search_score = overrides.get("minimum_search_score", 0.0)
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
filter = self.build_filter(overrides, auth_claims)
use_semantic_ranker = overrides.get("semantic_ranker") and has_text

Expand All @@ -107,7 +109,16 @@ async def run(
# Only keep the text query if the retrieval mode uses text, otherwise drop it
query_text = q if has_text else None

results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions)
results = await self.search(
top,
query_text,
filter,
vectors,
use_semantic_ranker,
use_semantic_captions,
minimum_search_score,
minimum_reranker_score,
)

image_list: list[ChatCompletionContentPartImageParam] = []
user_content: list[ChatCompletionContentPartParam] = [{"text": q, "type": "text"}]
Expand Down
2 changes: 2 additions & 0 deletions app/frontend/src/api/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ export type ChatAppRequestOverrides = {
exclude_category?: string;
top?: number;
temperature?: number;
minimum_search_score?: number;
minimum_reranker_score?: number;
prompt_template?: string;
prompt_template_prefix?: string;
prompt_template_suffix?: string;
Expand Down
2 changes: 2 additions & 0 deletions app/frontend/src/pages/ask/Ask.module.css
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
}

.askSettingsSeparator {
display: flex;
flex-direction: column;
margin-top: 15px;
}

Expand Down
30 changes: 30 additions & 0 deletions app/frontend/src/pages/ask/Ask.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ export function Component(): JSX.Element {
const [promptTemplatePrefix, setPromptTemplatePrefix] = useState<string>("");
const [promptTemplateSuffix, setPromptTemplateSuffix] = useState<string>("");
const [temperature, setTemperature] = useState<number>(0.3);
const [minimumRerankerScore, setMinimumRerankerScore] = useState<number>(0);
const [minimumSearchScore, setMinimumSearchScore] = useState<number>(0);
const [retrievalMode, setRetrievalMode] = useState<RetrievalMode>(RetrievalMode.Hybrid);
const [retrieveCount, setRetrieveCount] = useState<number>(3);
const [useSemanticRanker, setUseSemanticRanker] = useState<boolean>(true);
Expand Down Expand Up @@ -92,6 +94,8 @@ export function Component(): JSX.Element {
exclude_category: excludeCategory.length === 0 ? undefined : excludeCategory,
top: retrieveCount,
temperature: temperature,
minimum_reranker_score: minimumRerankerScore,
minimum_search_score: minimumSearchScore,
retrieval_mode: retrievalMode,
semantic_ranker: useSemanticRanker,
semantic_captions: useSemanticCaptions,
Expand Down Expand Up @@ -134,6 +138,13 @@ export function Component(): JSX.Element {
setTemperature(newValue);
};

const onMinimumSearchScoreChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
setMinimumSearchScore(parseFloat(newValue || "0"));
};

const onMinimumRerankerScoreChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
setMinimumRerankerScore(parseFloat(newValue || "0"));
};
const onRetrieveCountChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
setRetrieveCount(parseInt(newValue || "3"));
};
Expand Down Expand Up @@ -259,6 +270,25 @@ export function Component(): JSX.Element {
snapToStep
/>

<SpinButton
className={styles.askSettingsSeparator}
label="Minimum search score"
min={0}
step={0.01}
defaultValue={minimumSearchScore.toString()}
onChange={onMinimumSearchScoreChange}
/>

<SpinButton
className={styles.askSettingsSeparator}
label="Minimum reranker score"
min={1}
max={4}
step={0.1}
defaultValue={minimumRerankerScore.toString()}
onChange={onMinimumRerankerScoreChange}
/>

<SpinButton
className={styles.askSettingsSeparator}
label="Retrieve this many search results:"
Expand Down
2 changes: 2 additions & 0 deletions app/frontend/src/pages/chat/Chat.module.css
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@
}

.chatSettingsSeparator {
display: flex;
flex-direction: column;
margin-top: 15px;
}

Expand Down
31 changes: 31 additions & 0 deletions app/frontend/src/pages/chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ const Chat = () => {
const [isConfigPanelOpen, setIsConfigPanelOpen] = useState(false);
const [promptTemplate, setPromptTemplate] = useState<string>("");
const [temperature, setTemperature] = useState<number>(0.3);
const [minimumRerankerScore, setMinimumRerankerScore] = useState<number>(0);
const [minimumSearchScore, setMinimumSearchScore] = useState<number>(0);
const [retrieveCount, setRetrieveCount] = useState<number>(3);
const [retrievalMode, setRetrievalMode] = useState<RetrievalMode>(RetrievalMode.Hybrid);
const [useSemanticRanker, setUseSemanticRanker] = useState<boolean>(true);
Expand Down Expand Up @@ -147,6 +149,8 @@ const Chat = () => {
exclude_category: excludeCategory.length === 0 ? undefined : excludeCategory,
top: retrieveCount,
temperature: temperature,
minimum_reranker_score: minimumRerankerScore,
minimum_search_score: minimumSearchScore,
retrieval_mode: retrievalMode,
semantic_ranker: useSemanticRanker,
semantic_captions: useSemanticCaptions,
Expand Down Expand Up @@ -212,6 +216,14 @@ const Chat = () => {
setTemperature(newValue);
};

const onMinimumSearchScoreChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
setMinimumSearchScore(parseFloat(newValue || "0"));
};

const onMinimumRerankerScoreChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
setMinimumRerankerScore(parseFloat(newValue || "0"));
};

const onRetrieveCountChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
setRetrieveCount(parseInt(newValue || "3"));
};
Expand Down Expand Up @@ -395,6 +407,25 @@ const Chat = () => {
snapToStep
/>

<SpinButton
className={styles.chatSettingsSeparator}
label="Minimum search score"
min={0}
step={0.01}
defaultValue={minimumSearchScore.toString()}
onChange={onMinimumSearchScoreChange}
/>

<SpinButton
className={styles.chatSettingsSeparator}
label="Minimum reranker score"
min={1}
max={4}
step={0.1}
defaultValue={minimumRerankerScore.toString()}
onChange={onMinimumRerankerScoreChange}
/>

<SpinButton
className={styles.chatSettingsSeparator}
label="Retrieve this many search results:"
Expand Down
57 changes: 57 additions & 0 deletions tests/test_chatapproach.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import json

import pytest
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.aio import SearchClient
from openai.types.chat import ChatCompletion

from approaches.chatreadretrieveread import ChatReadRetrieveReadApproach

from .mocks import MockAsyncSearchResultsIterator


async def mock_search(*args, **kwargs):
return MockAsyncSearchResultsIterator(kwargs.get("search_text"), kwargs.get("vector_queries"))


@pytest.fixture
def chat_approach():
Expand Down Expand Up @@ -297,3 +305,52 @@ def test_get_messages_from_history_few_shots(chat_approach):
assert messages[4]["role"] == "assistant"
assert messages[5]["role"] == "user"
assert messages[5]["content"] == user_query_request


@pytest.mark.asyncio
@pytest.mark.parametrize(
"minimum_search_score,minimum_reranker_score,expected_result_count",
[
(0, 0, 1),
(0, 2, 1),
(0.03, 0, 1),
(0.03, 2, 1),
(1, 0, 0),
(0, 4, 0),
(1, 4, 0),
],
)
async def test_search_results_filtering_by_scores(
monkeypatch, minimum_search_score, minimum_reranker_score, expected_result_count
):

chat_approach = ChatReadRetrieveReadApproach(
search_client=SearchClient(endpoint="", index_name="", credential=AzureKeyCredential("")),
auth_helper=None,
openai_client=None,
chatgpt_model="gpt-35-turbo",
chatgpt_deployment="chat",
embedding_deployment="embeddings",
embedding_model="text-",
sourcepage_field="",
content_field="",
query_language="en-us",
query_speller="lexicon",
)

monkeypatch.setattr(SearchClient, "search", mock_search)

filtered_results = await chat_approach.search(
top=10,
query_text="test query",
filter=None,
vectors=[],
use_semantic_ranker=True,
use_semantic_captions=True,
minimum_search_score=minimum_search_score,
minimum_reranker_score=minimum_reranker_score,
)

assert (
len(filtered_results) == expected_result_count
), f"Expected {expected_result_count} results with minimum_search_score={minimum_search_score} and minimum_reranker_score={minimum_reranker_score}"

0 comments on commit ccf2494

Please sign in to comment.