Skip to content

Commit

Permalink
new status code for not enough data response (#2801)
Browse files Browse the repository at this point in the history
  • Loading branch information
drf7 authored Jan 27, 2025
1 parent ecb662f commit fc93c48
Show file tree
Hide file tree
Showing 12 changed files with 83 additions and 35 deletions.
2 changes: 2 additions & 0 deletions nucliadb/src/nucliadb/search/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@ class AnswerStatusCode(str, Enum):
SUCCESS = "0"
ERROR = "-1"
NO_CONTEXT = "-2"
NO_RETRIEVAL_DATA = "-3"

def prettify(self) -> str:
return {
AnswerStatusCode.SUCCESS: "success",
AnswerStatusCode.ERROR: "error",
AnswerStatusCode.NO_CONTEXT: "no_context",
AnswerStatusCode.NO_RETRIEVAL_DATA: "no_retrieval_data",
}[self]


Expand Down
30 changes: 28 additions & 2 deletions nucliadb/src/nucliadb/search/search/chat/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
ChatAuditor,
get_find_results,
get_relations_results,
maybe_audit_chat,
rephrase_query,
sorted_prompt_context_list,
tokens_to_chars,
Expand Down Expand Up @@ -433,14 +434,14 @@ async def ndjson_stream(self) -> AsyncGenerator[str, None]:
"""
yield self._ndjson_encode(RetrievalAskResponseItem(results=self.main_results))
yield self._ndjson_encode(AnswerAskResponseItem(text=NOT_ENOUGH_CONTEXT_ANSWER))
status = AnswerStatusCode.NO_CONTEXT
status = AnswerStatusCode.NO_RETRIEVAL_DATA
yield self._ndjson_encode(StatusAskResponseItem(code=status.value, status=status.prettify()))

async def json(self) -> str:
return SyncAskResponse(
answer=NOT_ENOUGH_CONTEXT_ANSWER,
retrieval_results=self.main_results,
status=AnswerStatusCode.NO_CONTEXT,
status=AnswerStatusCode.NO_RETRIEVAL_DATA.prettify(),
).model_dump_json()


Expand Down Expand Up @@ -487,6 +488,31 @@ async def ask(
resource=resource,
)
except NoRetrievalResultsError as err:
try:
rephrase_time = metrics.elapsed("rephrase")
except KeyError:
# Not all ask requests have a rephrase step
rephrase_time = None

maybe_audit_chat(
kbid=kbid,
user_id=user_id,
client_type=client_type,
origin=origin,
generative_answer_time=0,
generative_answer_first_chunk_time=0,
rephrase_time=rephrase_time,
user_query=user_query,
rephrased_query=rephrased_query,
text_answer=b"",
status_code=AnswerStatusCode.NO_RETRIEVAL_DATA,
chat_history=chat_history,
query_context={},
query_context_order={},
learning_id=None,
model=ask_request.generative_model,
)

# If a retrieval was attempted but no results were found,
# early return the ask endpoint without querying the generative model
return NotEnoughContextAskResult(
Expand Down
6 changes: 3 additions & 3 deletions nucliadb/src/nucliadb/search/search/chat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def maybe_audit_chat(
chat_history: list[ChatContextMessage],
query_context: PromptContext,
query_context_order: PromptContextOrder,
learning_id: str,
model: str,
learning_id: Optional[str],
model: Optional[str],
):
audit = get_audit()
if audit is None:
Expand Down Expand Up @@ -324,7 +324,7 @@ def maybe_audit_chat(


def parse_audit_answer(raw_text_answer: bytes, status_code: AnswerStatusCode) -> Optional[str]:
if status_code == AnswerStatusCode.NO_CONTEXT:
if status_code == AnswerStatusCode.NO_CONTEXT or status_code == AnswerStatusCode.NO_RETRIEVAL_DATA:
# We don't want to audit "Not enough context to answer this." and instead set a None.
return None
return raw_text_answer.decode()
Expand Down
13 changes: 13 additions & 0 deletions nucliadb/tests/nucliadb/integration/test_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,19 @@ async def test_ask_synchronous(nucliadb_reader: AsyncClient, knowledgebox, resou
assert resp_data.status == AnswerStatusCode.SUCCESS.prettify()


async def test_ask_status_code_no_retrieval_data(nucliadb_reader: AsyncClient, knowledgebox):
resp = await nucliadb_reader.post(
f"/kb/{knowledgebox}/ask",
json={"query": "title"},
headers={"X-Synchronous": "True"},
)
assert resp.status_code == 200
resp_data = SyncAskResponse.model_validate_json(resp.content)
assert resp_data.answer == "Not enough data to answer this."
assert len(resp_data.retrieval_results.resources) == 0
assert resp_data.status == AnswerStatusCode.NO_RETRIEVAL_DATA.prettify()


async def test_ask_with_citations(nucliadb_reader: AsyncClient, knowledgebox, resource):
citations = {"foo": [], "bar": []} # type: ignore
citations_gen = CitationsGenerativeResponse(citations=citations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ async def test_get_find_results_vector_search_is_optional(predict, chat_features
[
(b"foobar", AnswerStatusCode.NO_CONTEXT, None),
(b"foobar", AnswerStatusCode.SUCCESS, "foobar"),
(b"foobar", AnswerStatusCode.NO_RETRIEVAL_DATA, None),
],
)
def test_parse_audit_answer(raw_text_answer, status_code, audit_answer):
Expand Down
4 changes: 2 additions & 2 deletions nucliadb_protos/audit.proto
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ message ChatAudit {
// context retrieved on the current ask
repeated ChatContext chat_context = 6;
repeated RetrievedContext retrieved_context = 8;
string learning_id = 5;
optional string learning_id = 5;
int32 status_code = 9;
string model = 10;
optional string model = 10;
}

enum TaskType {
Expand Down
24 changes: 12 additions & 12 deletions nucliadb_protos/python/src/nucliadb_protos/audit_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 8 additions & 4 deletions nucliadb_protos/python/src/nucliadb_protos/audit_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,19 @@ class ChatAudit(google.protobuf.message.Message):
context: collections.abc.Iterable[global___ChatContext] | None = ...,
chat_context: collections.abc.Iterable[global___ChatContext] | None = ...,
retrieved_context: collections.abc.Iterable[global___RetrievedContext] | None = ...,
learning_id: builtins.str = ...,
learning_id: builtins.str | None = ...,
status_code: builtins.int = ...,
model: builtins.str = ...,
model: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_answer", b"_answer", "_rephrased_question", b"_rephrased_question", "answer", b"answer", "rephrased_question", b"rephrased_question"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_answer", b"_answer", "_rephrased_question", b"_rephrased_question", "answer", b"answer", "chat_context", b"chat_context", "context", b"context", "learning_id", b"learning_id", "model", b"model", "question", b"question", "rephrased_question", b"rephrased_question", "retrieved_context", b"retrieved_context", "status_code", b"status_code"]) -> None: ...
def HasField(self, field_name: typing.Literal["_answer", b"_answer", "_learning_id", b"_learning_id", "_model", b"_model", "_rephrased_question", b"_rephrased_question", "answer", b"answer", "learning_id", b"learning_id", "model", b"model", "rephrased_question", b"rephrased_question"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_answer", b"_answer", "_learning_id", b"_learning_id", "_model", b"_model", "_rephrased_question", b"_rephrased_question", "answer", b"answer", "chat_context", b"chat_context", "context", b"context", "learning_id", b"learning_id", "model", b"model", "question", b"question", "rephrased_question", b"rephrased_question", "retrieved_context", b"retrieved_context", "status_code", b"status_code"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_answer", b"_answer"]) -> typing.Literal["answer"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_learning_id", b"_learning_id"]) -> typing.Literal["learning_id"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_model", b"_model"]) -> typing.Literal["model"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_rephrased_question", b"_rephrased_question"]) -> typing.Literal["rephrased_question"] | None: ...

global___ChatAudit = ChatAudit
Expand Down
8 changes: 4 additions & 4 deletions nucliadb_protos/rust/src/audit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ pub struct ChatAudit {
pub chat_context: ::prost::alloc::vec::Vec<ChatContext>,
#[prost(message, repeated, tag = "8")]
pub retrieved_context: ::prost::alloc::vec::Vec<RetrievedContext>,
#[prost(string, tag = "5")]
pub learning_id: ::prost::alloc::string::String,
#[prost(string, optional, tag = "5")]
pub learning_id: ::core::option::Option<::prost::alloc::string::String>,
#[prost(int32, tag = "9")]
pub status_code: i32,
#[prost(string, tag = "10")]
pub model: ::prost::alloc::string::String,
#[prost(string, optional, tag = "10")]
pub model: ::core::option::Option<::prost::alloc::string::String>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
Expand Down
4 changes: 2 additions & 2 deletions nucliadb_utils/src/nucliadb_utils/audit/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def chat(
chat_context: List[ChatContext],
retrieved_context: List[RetrievedContext],
answer: Optional[str],
learning_id: str,
learning_id: Optional[str],
status_code: int,
model: str,
model: Optional[str],
rephrase_time: Optional[float] = None,
generative_answer_time: Optional[float] = None,
generative_answer_first_chunk_time: Optional[float] = None,
Expand Down
4 changes: 2 additions & 2 deletions nucliadb_utils/src/nucliadb_utils/audit/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def chat(
chat_context: List[ChatContext],
retrieved_context: List[RetrievedContext],
answer: Optional[str],
learning_id: str,
learning_id: Optional[str],
status_code: int,
model: str,
model: Optional[str],
rephrase_time: Optional[float] = None,
generative_answer_time: Optional[float] = None,
generative_answer_first_chunk_time: Optional[float] = None,
Expand Down
Loading

0 comments on commit fc93c48

Please sign in to comment.