diff --git a/app/backend/approaches/approach.py b/app/backend/approaches/approach.py index 894be43f50..900cf7dfb2 100644 --- a/app/backend/approaches/approach.py +++ b/app/backend/approaches/approach.py @@ -123,6 +123,8 @@ async def search( vectors: List[VectorQuery], use_semantic_ranker: bool, use_semantic_captions: bool, + minimum_search_score: Optional[float], + minimum_reranker_score: Optional[float], ) -> List[Document]: # Use semantic ranker if requested and if retrieval mode is text or hybrid (vectors + text) if use_semantic_ranker and query_text: @@ -161,7 +163,17 @@ async def search( reranker_score=document.get("@search.reranker_score"), ) ) - return documents + + qualified_documents = [ + doc + for doc in documents + if ( + (doc.score or 0) >= (minimum_search_score or 0) + and (doc.reranker_score or 0) >= (minimum_reranker_score or 0) + ) + ] + + return qualified_documents def get_sources_content( self, results: List[Document], use_semantic_captions: bool, use_image_citation: bool diff --git a/app/backend/approaches/chatreadretrieveread.py b/app/backend/approaches/chatreadretrieveread.py index 9cc6ad8abe..a707157fcf 100644 --- a/app/backend/approaches/chatreadretrieveread.py +++ b/app/backend/approaches/chatreadretrieveread.py @@ -89,6 +89,9 @@ async def run_until_final_call( has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False top = overrides.get("top", 3) + minimum_search_score = overrides.get("minimum_search_score", 0.0) + minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0) + filter = self.build_filter(overrides, auth_claims) use_semantic_ranker = True if overrides.get("semantic_ranker") and has_text else False @@ -149,7 +152,16 @@ async def run_until_final_call( if not has_text: query_text = None - results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions) + results = await self.search( + top, + query_text, + filter, + vectors, + use_semantic_ranker, + use_semantic_captions, + minimum_search_score, + minimum_reranker_score, + ) sources_content = self.get_sources_content(results, use_semantic_captions, use_image_citation=False) content = "\n".join(sources_content) diff --git a/app/backend/approaches/chatreadretrievereadvision.py b/app/backend/approaches/chatreadretrievereadvision.py index dfaa80c32f..190fb9b823 100644 --- a/app/backend/approaches/chatreadretrievereadvision.py +++ b/app/backend/approaches/chatreadretrievereadvision.py @@ -87,6 +87,8 @@ async def run_until_final_call( vector_fields = overrides.get("vector_fields", ["embedding"]) use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False top = overrides.get("top", 3) + minimum_search_score = overrides.get("minimum_search_score", 0.0) + minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0) filter = self.build_filter(overrides, auth_claims) use_semantic_ranker = True if overrides.get("semantic_ranker") and has_text else False @@ -134,7 +136,16 @@ async def run_until_final_call( if not has_text: query_text = None - results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions) + results = await self.search( + top, + query_text, + filter, + vectors, + use_semantic_ranker, + use_semantic_captions, + minimum_search_score, + minimum_reranker_score, + ) sources_content = self.get_sources_content(results, use_semantic_captions, use_image_citation=True) content = "\n".join(sources_content) diff --git a/app/backend/approaches/retrievethenread.py b/app/backend/approaches/retrievethenread.py index 80cfe2ee02..3860abb02b 100644 --- a/app/backend/approaches/retrievethenread.py +++ b/app/backend/approaches/retrievethenread.py @@ -86,6 +86,8 @@ async def run( use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False top = overrides.get("top", 3) + minimum_search_score = overrides.get("minimum_search_score", 0.0) + minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0) filter = self.build_filter(overrides, auth_claims) # If retrieval mode includes vectors, compute an embedding for the query vectors: list[VectorQuery] = [] @@ -95,7 +97,16 @@ async def run( # Only keep the text query if the retrieval mode uses text, otherwise drop it query_text = q if has_text else None - results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions) + results = await self.search( + top, + query_text, + filter, + vectors, + use_semantic_ranker, + use_semantic_captions, + minimum_search_score, + minimum_reranker_score, + ) user_content = [q] diff --git a/app/backend/approaches/retrievethenreadvision.py b/app/backend/approaches/retrievethenreadvision.py index 4c3d4c3e73..b6fb1e105a 100644 --- a/app/backend/approaches/retrievethenreadvision.py +++ b/app/backend/approaches/retrievethenreadvision.py @@ -89,6 +89,8 @@ async def run( use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False top = overrides.get("top", 3) + minimum_search_score = overrides.get("minimum_search_score", 0.0) + minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0) filter = self.build_filter(overrides, auth_claims) use_semantic_ranker = overrides.get("semantic_ranker") and has_text @@ -107,7 +109,16 @@ async def run( # Only keep the text query if the retrieval mode uses text, otherwise drop it query_text = q if has_text else None - results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions) + results = await self.search( + top, + query_text, + filter, + vectors, + use_semantic_ranker, + use_semantic_captions, + minimum_search_score, + minimum_reranker_score, + ) image_list: list[ChatCompletionContentPartImageParam] = [] user_content: list[ChatCompletionContentPartParam] = [{"text": q, "type": "text"}] diff --git a/app/frontend/src/api/models.ts b/app/frontend/src/api/models.ts index f0bb70a3fe..56e92a1fcb 100644 --- a/app/frontend/src/api/models.ts +++ b/app/frontend/src/api/models.ts @@ -23,6 +23,8 @@ export type ChatAppRequestOverrides = { exclude_category?: string; top?: number; temperature?: number; + minimum_search_score?: number; + minimum_reranker_score?: number; prompt_template?: string; prompt_template_prefix?: string; prompt_template_suffix?: string; diff --git a/app/frontend/src/pages/ask/Ask.module.css b/app/frontend/src/pages/ask/Ask.module.css index 22880d8a3a..ac280baa17 100644 --- a/app/frontend/src/pages/ask/Ask.module.css +++ b/app/frontend/src/pages/ask/Ask.module.css @@ -56,6 +56,8 @@ } .askSettingsSeparator { + display: flex; + flex-direction: column; margin-top: 15px; } diff --git a/app/frontend/src/pages/ask/Ask.tsx b/app/frontend/src/pages/ask/Ask.tsx index 804710223b..713459d29d 100644 --- a/app/frontend/src/pages/ask/Ask.tsx +++ b/app/frontend/src/pages/ask/Ask.tsx @@ -22,6 +22,8 @@ export function Component(): JSX.Element { const [promptTemplatePrefix, setPromptTemplatePrefix] = useState(""); const [promptTemplateSuffix, setPromptTemplateSuffix] = useState(""); const [temperature, setTemperature] = useState(0.3); + const [minimumRerankerScore, setMinimumRerankerScore] = useState(0); + const [minimumSearchScore, setMinimumSearchScore] = useState(0); const [retrievalMode, setRetrievalMode] = useState(RetrievalMode.Hybrid); const [retrieveCount, setRetrieveCount] = useState(3); const [useSemanticRanker, setUseSemanticRanker] = useState(true); @@ -92,6 +94,8 @@ export function Component(): JSX.Element { exclude_category: excludeCategory.length === 0 ? undefined : excludeCategory, top: retrieveCount, temperature: temperature, + minimum_reranker_score: minimumRerankerScore, + minimum_search_score: minimumSearchScore, retrieval_mode: retrievalMode, semantic_ranker: useSemanticRanker, semantic_captions: useSemanticCaptions, @@ -134,6 +138,13 @@ export function Component(): JSX.Element { setTemperature(newValue); }; + const onMinimumSearchScoreChange = (_ev?: React.SyntheticEvent, newValue?: string) => { + setMinimumSearchScore(parseFloat(newValue || "0")); + }; + + const onMinimumRerankerScoreChange = (_ev?: React.SyntheticEvent, newValue?: string) => { + setMinimumRerankerScore(parseFloat(newValue || "0")); + }; const onRetrieveCountChange = (_ev?: React.SyntheticEvent, newValue?: string) => { setRetrieveCount(parseInt(newValue || "3")); }; @@ -259,6 +270,25 @@ export function Component(): JSX.Element { snapToStep /> + + + + { const [isConfigPanelOpen, setIsConfigPanelOpen] = useState(false); const [promptTemplate, setPromptTemplate] = useState(""); const [temperature, setTemperature] = useState(0.3); + const [minimumRerankerScore, setMinimumRerankerScore] = useState(0); + const [minimumSearchScore, setMinimumSearchScore] = useState(0); const [retrieveCount, setRetrieveCount] = useState(3); const [retrievalMode, setRetrievalMode] = useState(RetrievalMode.Hybrid); const [useSemanticRanker, setUseSemanticRanker] = useState(true); @@ -147,6 +149,8 @@ const Chat = () => { exclude_category: excludeCategory.length === 0 ? undefined : excludeCategory, top: retrieveCount, temperature: temperature, + minimum_reranker_score: minimumRerankerScore, + minimum_search_score: minimumSearchScore, retrieval_mode: retrievalMode, semantic_ranker: useSemanticRanker, semantic_captions: useSemanticCaptions, @@ -212,6 +216,14 @@ const Chat = () => { setTemperature(newValue); }; + const onMinimumSearchScoreChange = (_ev?: React.SyntheticEvent, newValue?: string) => { + setMinimumSearchScore(parseFloat(newValue || "0")); + }; + + const onMinimumRerankerScoreChange = (_ev?: React.SyntheticEvent, newValue?: string) => { + setMinimumRerankerScore(parseFloat(newValue || "0")); + }; + const onRetrieveCountChange = (_ev?: React.SyntheticEvent, newValue?: string) => { setRetrieveCount(parseInt(newValue || "3")); }; @@ -395,6 +407,25 @@ const Chat = () => { snapToStep /> + + + +