Skip to content

Commit

Permalink
Support synchronous chat response (#1331)
Browse files Browse the repository at this point in the history
* initial refactor of chat

* fixes

* test fixes

* add test

* moar

* tweaks

* words
  • Loading branch information
vangheem authored Sep 13, 2023
1 parent 36efaac commit 7884805
Show file tree
Hide file tree
Showing 10 changed files with 399 additions and 361 deletions.
148 changes: 86 additions & 62 deletions nucliadb/nucliadb/search/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,41 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
from typing import Union
import base64
from typing import Optional, Union

import pydantic
from fastapi import Body, Header, Request, Response
from fastapi_versioning import version
from starlette.responses import StreamingResponse

from nucliadb.models.responses import HTTPClientError
from nucliadb.search import predict
from nucliadb.search.api.v1.find import find
from nucliadb.search.api.v1.router import KB_PREFIX, api
from nucliadb.search.search.chat.query import chat, rephrase_query_from_chat_history
from nucliadb.search.predict import AnswerStatusCode
from nucliadb.search.search.chat.query import chat, get_relations_results
from nucliadb.search.search.exceptions import IncompleteFindResultsError
from nucliadb_models.resource import NucliaDBRoles
from nucliadb_models.search import (
ChatOptions,
ChatRequest,
FindRequest,
KnowledgeboxFindResults,
NucliaDBClientType,
SearchOptions,
Relations,
)
from nucliadb_utils.authentication import requires
from nucliadb_utils.exceptions import LimitsExceededError

END_OF_STREAM = "_END_"


class SyncChatResponse(pydantic.BaseModel):
answer: str
relations: Optional[Relations]
results: KnowledgeboxFindResults
status: AnswerStatusCode


CHAT_EXAMPLES = {
"search_and_chat": {
"summary": "Ask who won the league final",
Expand All @@ -63,16 +75,20 @@
@version(1)
async def chat_knowledgebox_endpoint(
request: Request,
response: Response,
kbid: str,
item: ChatRequest = Body(examples=CHAT_EXAMPLES),
x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API),
x_nucliadb_user: str = Header(""),
x_forwarded_for: str = Header(""),
) -> Union[StreamingResponse, HTTPClientError]:
x_synchronous: bool = Header(
False,
description="Output response as JSON in a non-streaming way. "
"This is slower and requires waiting for entire answer to be ready.",
),
) -> Union[StreamingResponse, HTTPClientError, Response]:
try:
return await chat_knowledgebox(
response, kbid, item, x_ndb_client, x_nucliadb_user, x_forwarded_for
return await create_chat_response(
kbid, item, x_nucliadb_user, x_ndb_client, x_forwarded_for, x_synchronous
)
except LimitsExceededError as exc:
return HTTPClientError(status_code=exc.status_code, detail=exc.detail)
Expand All @@ -95,60 +111,68 @@ async def chat_knowledgebox_endpoint(
)


async def chat_knowledgebox(
response: Response,
async def create_chat_response(
kbid: str,
item: ChatRequest,
x_ndb_client: NucliaDBClientType,
x_nucliadb_user: str,
x_forwarded_for: str,
):
user_query = item.query
rephrased_query = None
if item.context is not None and len(item.context) > 0:
rephrased_query = await rephrase_query_from_chat_history(
kbid, item.context, item.query, x_nucliadb_user
)

find_request = FindRequest()
find_request.features = [SearchOptions.VECTOR]
if ChatOptions.PARAGRAPHS in item.features:
find_request.features.append(SearchOptions.PARAGRAPH)
find_request.query = rephrased_query or user_query
find_request.fields = item.fields
find_request.filters = item.filters
find_request.field_type_filter = item.field_type_filter
find_request.min_score = item.min_score
find_request.range_creation_start = item.range_creation_start
find_request.range_creation_end = item.range_creation_end
find_request.range_modification_start = item.range_modification_start
find_request.range_modification_end = item.range_modification_end
find_request.show = item.show
find_request.extracted = item.extracted
find_request.shards = item.shards
find_request.autofilter = item.autofilter
find_request.highlight = item.highlight
find_request.with_duplicates = False

find_results, incomplete = await find(
response,
chat_request: ChatRequest,
user_id: str,
client_type: NucliaDBClientType,
origin: str,
x_synchronous: bool,
) -> Response:
chat_result = await chat(
kbid,
find_request,
x_ndb_client,
x_nucliadb_user,
x_forwarded_for,
do_audit=True,
chat_request,
user_id,
client_type,
origin,
)
if incomplete:
raise IncompleteFindResultsError()
if x_synchronous:
text_answer = b""
async for chunk in chat_result.answer_stream:
text_answer += chunk

return await chat(
kbid,
user_query,
rephrased_query,
find_results,
item,
x_nucliadb_user,
x_ndb_client,
x_forwarded_for,
)
relations_results = None
if ChatOptions.RELATIONS in chat_request.features:
relations_results = await get_relations_results(
kbid=kbid, chat_request=chat_request, text_answer=text_answer
)

return Response(
content=SyncChatResponse(
answer=text_answer.decode(),
relations=relations_results,
results=chat_result.find_results,
status=chat_result.status_code.value,
).json(),
headers={
"NUCLIA-LEARNING-ID": chat_result.nuclia_learning_id or "unknown",
"Access-Control-Expose-Headers": "NUCLIA-LEARNING-ID",
},
)
else:

async def _streaming_response():
bytes_results = base64.b64encode(chat_result.find_results.json().encode())
yield len(bytes_results).to_bytes(length=4, byteorder="big", signed=False)
yield bytes_results

text_answer = b""
async for chunk in chat_result.answer_stream:
text_answer += chunk
yield chunk

yield END_OF_STREAM.encode()
if ChatOptions.RELATIONS in chat_request.features:
relations_results = await get_relations_results(
kbid=kbid, chat_request=chat_request, text_answer=text_answer
)
yield base64.b64encode(relations_results.json().encode())

return StreamingResponse(
_streaming_response(),
media_type="application/octet-stream",
headers={
"NUCLIA-LEARNING-ID": chat_result.nuclia_learning_id or "unknown",
"Access-Control-Expose-Headers": "NUCLIA-LEARNING-ID",
},
)
94 changes: 4 additions & 90 deletions nucliadb/nucliadb/search/api/v1/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
#
import json
from datetime import datetime
from time import time
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union

from fastapi import Body, Header, Request, Response
from fastapi_versioning import version
Expand All @@ -30,10 +29,7 @@
from nucliadb.models.responses import HTTPClientError
from nucliadb.search.api.v1.router import KB_PREFIX, api
from nucliadb.search.api.v1.utils import fastapi_query
from nucliadb.search.requesters.utils import Method, node_query
from nucliadb.search.search.find_merge import find_merge_results
from nucliadb.search.search.query import get_default_min_score, global_query_to_pb
from nucliadb.search.search.utils import should_disable_vector_search
from nucliadb.search.search.find import find
from nucliadb_models.common import FieldTypeName
from nucliadb_models.resource import ExtractedDataTypeName, NucliaDBRoles
from nucliadb_models.search import (
Expand All @@ -46,7 +42,6 @@
)
from nucliadb_utils.authentication import requires
from nucliadb_utils.exceptions import LimitsExceededError
from nucliadb_utils.utilities import get_audit

FIND_EXAMPLES = {
"find_hybrid_search": {
Expand Down Expand Up @@ -147,7 +142,7 @@ async def find_knowledgebox(
return HTTPClientError(status_code=422, detail=detail)
try:
results, _ = await find(
response, kbid, item, x_ndb_client, x_nucliadb_user, x_forwarded_for
kbid, item, x_ndb_client, x_nucliadb_user, x_forwarded_for
)
return results
except KnowledgeBoxNotFound:
Expand Down Expand Up @@ -178,90 +173,9 @@ async def find_post_knowledgebox(
) -> Union[KnowledgeboxFindResults, HTTPClientError]:
try:
results, incomplete = await find(
response, kbid, item, x_ndb_client, x_nucliadb_user, x_forwarded_for
kbid, item, x_ndb_client, x_nucliadb_user, x_forwarded_for
)
response.status_code = 206 if incomplete else 200
return results
except LimitsExceededError as exc:
return HTTPClientError(status_code=exc.status_code, detail=exc.detail)


async def find(
response: Response,
kbid: str,
item: FindRequest,
x_ndb_client: NucliaDBClientType,
x_nucliadb_user: str,
x_forwarded_for: str,
do_audit: bool = True,
) -> Tuple[KnowledgeboxFindResults, bool]:
audit = get_audit()
start_time = time()

if SearchOptions.VECTOR in item.features:
if should_disable_vector_search(item):
item.features.remove(SearchOptions.VECTOR)

min_score = item.min_score
if min_score is None:
min_score = await get_default_min_score(kbid)

# We need to query all nodes
pb_query, incomplete_results, autofilters = await global_query_to_pb(
kbid,
features=item.features,
query=item.query,
filters=item.filters,
faceted=item.faceted,
sort=None,
page_number=item.page_number,
page_size=item.page_size,
min_score=min_score,
range_creation_start=item.range_creation_start,
range_creation_end=item.range_creation_end,
range_modification_start=item.range_modification_start,
range_modification_end=item.range_modification_end,
fields=item.fields,
user_vector=item.vector,
vectorset=item.vectorset,
with_duplicates=item.with_duplicates,
with_synonyms=item.with_synonyms,
autofilter=item.autofilter,
key_filters=item.resource_filters,
)
results, query_incomplete_results, queried_nodes, queried_shards = await node_query(
kbid, Method.SEARCH, pb_query, item.shards
)

incomplete_results = incomplete_results or query_incomplete_results

# We need to merge
search_results = await find_merge_results(
results,
count=item.page_size,
page=item.page_number,
kbid=kbid,
show=item.show,
field_type_filter=item.field_type_filter,
extracted=item.extracted,
requested_relations=pb_query.relation_subgraph,
min_score=min_score,
highlight=item.highlight,
)

if audit is not None and do_audit:
await audit.search(
kbid,
x_nucliadb_user,
x_ndb_client.to_proto(),
x_forwarded_for,
pb_query,
time() - start_time,
len(search_results.resources),
)
if item.debug:
search_results.nodes = queried_nodes

search_results.shards = queried_shards
search_results.autofilters = autofilters
return search_results, incomplete_results
Loading

1 comment on commit 7884805

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 7884805 Previous: 7e0355d Ratio
nucliadb/tests/benchmarks/test_search.py::test_search_returns_labels 57.369426309019616 iter/sec (stddev: 0.000685873958150432) 56.953839502297 iter/sec (stddev: 0.0005406884968190274) 0.99
nucliadb/tests/benchmarks/test_search.py::test_search_relations 151.13798854926245 iter/sec (stddev: 0.0004635463620446412) 146.52385643519392 iter/sec (stddev: 0.00046154405885098483) 0.97

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.