Skip to content

Adds search functionality #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions business_objects/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import datetime
from .. import daemon
from threading import Lock
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import Select


__THREAD_LOCK = Lock()
Expand Down Expand Up @@ -278,3 +280,18 @@ def simple_selection_builder(
{where}
{order_by_s}
"""


def print_orm_query(
query, bound_params: bool = True, return_as_str: bool = False
) -> Optional[str]:
# helper method for dev to print the query that is generated by the ORM
if not isinstance(query, Select):
query = query.statement
return_str = query.compile(
dialect=postgresql.dialect(), compile_kwargs={"literal_binds": bound_params}
)
if return_as_str:
return return_str

print(return_str, flush=True)
95 changes: 92 additions & 3 deletions cognition_objects/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
from ..cognition_objects import message
from ..business_objects import general
from ..session import session
from ..models import CognitionConversation
from ..models import CognitionConversation, CognitionMessage
from ..util import prevent_sql_injection
from sqlalchemy.sql.expression import Subquery
from sqlalchemy import or_
from sqlalchemy.sql.expression import cast
from sqlalchemy import String as sqlalchemy_string
from sqlalchemy import select


def get(project_id: str, conversation_id: str) -> CognitionConversation:
Expand Down Expand Up @@ -103,17 +108,29 @@ def get_overview_list(


def get_all_paginated_by_project_id(
project_id: str, page: int, limit: int, order_asc: bool = True, user_id: str = None
project_id: str,
page: int,
limit: int,
order_asc: bool = True,
user_id: Optional[str] = None,
filter_dict: Optional[Dict[str, Any]] = None,
) -> Tuple[int, int, List[CognitionConversation]]:
total_count_query = session.query(CognitionConversation.id).filter(
CognitionConversation.project_id == project_id
)
subquery = None
if filter_dict is not None:
subquery = __get_conversation_ids_by_filter(project_id, **filter_dict)
total_count_query = total_count_query.filter(
CognitionConversation.id.in_(subquery)
)

if user_id is not None:
total_count_query = total_count_query.filter(
CognitionConversation.created_by == user_id
)
total_count = total_count_query.count()

total_count = total_count_query.count()
if total_count == 0:
num_pages = 0
else:
Expand All @@ -130,6 +147,8 @@ def get_all_paginated_by_project_id(
)
if user_id is not None:
query = query.filter(CognitionConversation.created_by == user_id)
if subquery is not None:
query = query.filter(CognitionConversation.id.in_(subquery))
if order_asc:
query = query.order_by(CognitionConversation.created_at.asc())
else:
Expand All @@ -140,6 +159,76 @@ def get_all_paginated_by_project_id(
return total_count, num_pages, paginated_result


def __get_conversation_ids_by_filter(
project_id: str,
user_id: Optional[str] = None,
has_error: Optional[bool] = None,
has_tmp_files: Optional[bool] = None,
tmp_file_name: Optional[str] = None,
question_or_answer: Optional[str] = None,
fact_contains: Optional[str] = None,
feedback_value: Optional[str] = None,
feedback_message_contains: Optional[str] = None,
created_at_from: Optional[str] = None,
created_at_to: Optional[str] = None,
) -> Subquery:
query = select(CognitionConversation.id).filter(
CognitionConversation.project_id == project_id
)
if created_at_from is not None:
query = query.filter(CognitionConversation.created_at >= created_at_from)
if created_at_to is not None:
query = query.filter(CognitionConversation.created_at <= created_at_to)
if user_id is not None:
query = query.filter(CognitionConversation.created_by == user_id)
if has_error is not None:
if has_error:
query = query.filter(CognitionConversation.error.isnot(None))
else:
query = query.filter(CognitionConversation.error.is_(None))
if has_tmp_files is not None:
query = query.filter(CognitionConversation.has_tmp_files == has_tmp_files)
if tmp_file_name is not None:
tmp_file_name = "%" + tmp_file_name + "%"
query = query.filter(
CognitionConversation.scope_dict.op("->>")("parsed_documents").ilike(
tmp_file_name
)
)
if (
question_or_answer is not None
or fact_contains is not None
or feedback_value is not None
or feedback_message_contains is not None
):
query = query.join(
CognitionMessage,
(CognitionMessage.project_id == CognitionConversation.project_id)
& (CognitionMessage.conversation_id == CognitionConversation.id),
)
if question_or_answer is not None:
question_or_answer = "%" + question_or_answer + "%"
query = query.filter(
or_(
CognitionMessage.question.ilike(question_or_answer),
CognitionMessage.answer.ilike(question_or_answer),
)
)
if fact_contains is not None:
fact_contains = "%" + fact_contains + "%"
query = query.filter(
cast(CognitionMessage.facts, sqlalchemy_string).ilike(fact_contains)
)
if feedback_value is not None:
query = query.filter(CognitionMessage.feedback_value == feedback_value)
if feedback_message_contains is not None:
feedback_message_contains = "%" + feedback_message_contains + "%"
query = query.filter(
CognitionMessage.feedback_message.ilike(feedback_message_contains)
)
return query


def has_error(project_id: str, conversation_id: str) -> bool:
conversation_item = get(project_id, conversation_id)
if conversation_item is None:
Expand Down