diff --git a/api/commands.py b/api/commands.py index 35b5c5d5f8a257..deaa2e1675b7bc 100644 --- a/api/commands.py +++ b/api/commands.py @@ -28,7 +28,7 @@ from libs.rsa import generate_key_pair from models.account import InvitationCode, Tenant, TenantAccountJoin from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding -from models.model import Account, AppModelConfig, App +from models.model import Account, AppModelConfig, App, MessageAnnotation, Message import secrets import base64 @@ -752,6 +752,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size): pbar.update(len(data_batch)) +@click.command('add-annotation-question-field-value', help='add annotation question value') +def add_annotation_question_field_value(): + click.echo(click.style('Start add annotation question value.', fg='green')) + message_annotations = db.session.query(MessageAnnotation).all() + message_annotation_deal_count = 0 + if message_annotations: + for message_annotation in message_annotations: + try: + if message_annotation.message_id and not message_annotation.question: + message = db.session.query(Message).filter( + Message.id == message_annotation.message_id + ).first() + message_annotation.question = message.query + db.session.add(message_annotation) + db.session.commit() + message_annotation_deal_count += 1 + except Exception as e: + click.echo( + click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)), + fg='red')) + click.echo( + click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green')) + + def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) @@ -766,3 +790,4 @@ def register_commands(app): app.cli.add_command(normalization_collections) app.cli.add_command(migrate_default_input_to_dataset_query_variable) app.cli.add_command(add_qdrant_full_text_index) + app.cli.add_command(add_annotation_question_field_value) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 46fa0e79dc0abf..21dcbd62be8b8f 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -9,7 +9,7 @@ from . import extension, setup, version, apikey, admin # Import app controllers -from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio +from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation # Import auth controllers from .auth import login, oauth, data_source_oauth, activate diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py new file mode 100644 index 00000000000000..a21a2eae6446ef --- /dev/null +++ b/api/controllers/console/app/annotation.py @@ -0,0 +1,291 @@ +from flask_login import current_user +from flask_restful import Resource, reqparse, marshal_with, marshal +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.app.error import NoFileUploadedError +from controllers.console.datasets.error import TooManyFilesError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from extensions.ext_redis import redis_client +from fields.annotation_fields import annotation_list_fields, annotation_hit_history_list_fields, annotation_fields, \ + annotation_hit_history_fields +from libs.login import login_required +from services.annotation_service import AppAnnotationService +from flask import request + + +class AnnotationReplyActionApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check('annotation') + def post(self, app_id, action): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + app_id = str(app_id) + parser = reqparse.RequestParser() + parser.add_argument('score_threshold', required=True, type=float, location='json') + parser.add_argument('embedding_provider_name', required=True, type=str, location='json') + parser.add_argument('embedding_model_name', required=True, type=str, location='json') + args = parser.parse_args() + if action == 'enable': + result = AppAnnotationService.enable_app_annotation(args, app_id) + elif action == 'disable': + result = AppAnnotationService.disable_app_annotation(app_id) + else: + raise ValueError('Unsupported annotation reply action') + return result, 200 + + +class AppAnnotationSettingDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, app_id): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + app_id = str(app_id) + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id) + return result, 200 + + +class AppAnnotationSettingUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, app_id, annotation_setting_id): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + app_id = str(app_id) + annotation_setting_id = str(annotation_setting_id) + + parser = reqparse.RequestParser() + parser.add_argument('score_threshold', required=True, type=float, location='json') + args = parser.parse_args() + + result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) + return result, 200 + + +class AnnotationReplyActionStatusApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check('annotation') + def get(self, app_id, job_id, action): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + job_id = str(job_id) + app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id)) + cache_result = redis_client.get(app_annotation_job_key) + if cache_result is None: + raise ValueError("The job is not exist.") + + job_status = cache_result.decode() + error_msg = '' + if job_status == 'error': + app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id)) + error_msg = redis_client.get(app_annotation_error_key).decode() + + return { + 'job_id': job_id, + 'job_status': job_status, + 'error_msg': error_msg + }, 200 + + +class AnnotationListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, app_id): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + page = request.args.get('page', default=1, type=int) + limit = request.args.get('limit', default=20, type=int) + keyword = request.args.get('keyword', default=None, type=str) + + app_id = str(app_id) + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) + response = { + 'data': marshal(annotation_list, annotation_fields), + 'has_more': len(annotation_list) == limit, + 'limit': limit, + 'total': total, + 'page': page + } + return response, 200 + + +class AnnotationExportApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, app_id): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + app_id = str(app_id) + annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) + response = { + 'data': marshal(annotation_list, annotation_fields) + } + return response, 200 + + +class AnnotationCreateApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check('annotation') + @marshal_with(annotation_fields) + def post(self, app_id): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + app_id = str(app_id) + parser = reqparse.RequestParser() + parser.add_argument('question', required=True, type=str, location='json') + parser.add_argument('answer', required=True, type=str, location='json') + args = parser.parse_args() + annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) + return annotation + + +class AnnotationUpdateDeleteApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check('annotation') + @marshal_with(annotation_fields) + def post(self, app_id, annotation_id): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + app_id = str(app_id) + annotation_id = str(annotation_id) + parser = reqparse.RequestParser() + parser.add_argument('question', required=True, type=str, location='json') + parser.add_argument('answer', required=True, type=str, location='json') + args = parser.parse_args() + annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) + return annotation + + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check('annotation') + def delete(self, app_id, annotation_id): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + app_id = str(app_id) + annotation_id = str(annotation_id) + AppAnnotationService.delete_app_annotation(app_id, annotation_id) + return {'result': 'success'}, 200 + + +class AnnotationBatchImportApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check('annotation') + def post(self, app_id): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + app_id = str(app_id) + # get file from request + file = request.files['file'] + # check file + if 'file' not in request.files: + raise NoFileUploadedError() + + if len(request.files) > 1: + raise TooManyFilesError() + # check file type + if not file.filename.endswith('.csv'): + raise ValueError("Invalid file type. Only CSV files are allowed") + return AppAnnotationService.batch_import_app_annotations(app_id, file) + + +class AnnotationBatchImportStatusApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check('annotation') + def get(self, app_id, job_id): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + job_id = str(job_id) + indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is None: + raise ValueError("The job is not exist.") + job_status = cache_result.decode() + error_msg = '' + if job_status == 'error': + indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) + error_msg = redis_client.get(indexing_error_msg_key).decode() + + return { + 'job_id': job_id, + 'job_status': job_status, + 'error_msg': error_msg + }, 200 + + +class AnnotationHitHistoryListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, app_id, annotation_id): + # The role of the current user in the table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + page = request.args.get('page', default=1, type=int) + limit = request.args.get('limit', default=20, type=int) + app_id = str(app_id) + annotation_id = str(annotation_id) + annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id, + page, limit) + response = { + 'data': marshal(annotation_hit_history_list, annotation_hit_history_fields), + 'has_more': len(annotation_hit_history_list) == limit, + 'limit': limit, + 'total': total, + 'page': page + } + return response + + +api.add_resource(AnnotationReplyActionApi, '/apps//annotation-reply/') +api.add_resource(AnnotationReplyActionStatusApi, + '/apps//annotation-reply//status/') +api.add_resource(AnnotationListApi, '/apps//annotations') +api.add_resource(AnnotationExportApi, '/apps//annotations/export') +api.add_resource(AnnotationUpdateDeleteApi, '/apps//annotations/') +api.add_resource(AnnotationBatchImportApi, '/apps//annotations/batch-import') +api.add_resource(AnnotationBatchImportStatusApi, '/apps//annotations/batch-import-status/') +api.add_resource(AnnotationHitHistoryListApi, '/apps//annotations//hit-histories') +api.add_resource(AppAnnotationSettingDetailApi, '/apps//annotation-setting') +api.add_resource(AppAnnotationSettingUpdateApi, '/apps//annotation-settings/') diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index b6086b68b833c9..d7b31906c8de21 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException): class ProviderNotSupportSpeechToTextError(BaseHTTPException): error_code = 'provider_not_support_speech_to_text' description = "Provider not support speech to text." - code = 400 \ No newline at end of file + code = 400 + + +class NoFileUploadedError(BaseHTTPException): + error_code = 'no_file_uploaded' + description = "Please upload your file." + code = 400 + + +class TooManyFilesError(BaseHTTPException): + error_code = 'too_many_files' + description = "Only one file is allowed." + code = 400 diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index e73bed7e320aa1..b26287fd4bd27d 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -6,22 +6,23 @@ from flask_login import current_user from flask_restful import Resource, reqparse, marshal_with, fields from flask_restful.inputs import int_range -from werkzeug.exceptions import InternalServerError, NotFound +from werkzeug.exceptions import InternalServerError, NotFound, Forbidden from controllers.console import api from controllers.console.app import _get_app from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \ AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.login import login_required -from fields.conversation_fields import message_detail_fields +from fields.conversation_fields import message_detail_fields, annotation_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from extensions.ext_database import db from models.model import MessageAnnotation, Conversation, Message, MessageFeedback +from services.annotation_service import AppAnnotationService from services.completion_service import CompletionService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError @@ -151,44 +152,24 @@ class MessageAnnotationApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_resource_check('annotation') + @marshal_with(annotation_fields) def post(self, app_id): - app_id = str(app_id) + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() - # get app info - app = _get_app(app_id) + app_id = str(app_id) parser = reqparse.RequestParser() - parser.add_argument('message_id', required=True, type=uuid_value, location='json') - parser.add_argument('content', type=str, location='json') + parser.add_argument('message_id', required=False, type=uuid_value, location='json') + parser.add_argument('question', required=True, type=str, location='json') + parser.add_argument('answer', required=True, type=str, location='json') + parser.add_argument('annotation_reply', required=False, type=dict, location='json') args = parser.parse_args() + annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id) - message_id = str(args['message_id']) - - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app.id - ).first() - - if not message: - raise NotFound("Message Not Exists.") - - annotation = message.annotation - - if annotation: - annotation.content = args['content'] - else: - annotation = MessageAnnotation( - app_id=app.id, - conversation_id=message.conversation_id, - message_id=message.id, - content=args['content'], - account_id=current_user.id - ) - db.session.add(annotation) - - db.session.commit() - - return {'result': 'success'} + return annotation class MessageAnnotationCountApi(Resource): diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 75ee2aaafe6682..5103fddca3fe44 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -24,29 +24,29 @@ def post(self, app_id): """Modify app model config""" app_id = str(app_id) - app_model = _get_app(app_id) + app = _get_app(app_id) # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, account=current_user, config=request.json, - mode=app_model.mode + mode=app.mode ) new_app_model_config = AppModelConfig( - app_id=app_model.id, + app_id=app.id, ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) db.session.add(new_app_model_config) db.session.flush() - app_model.app_model_config_id = new_app_model_config.id + app.app_model_config_id = new_app_model_config.id db.session.commit() app_model_config_was_updated.send( - app_model, + app, app_model_config=new_app_model_config ) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 63066f9f56aa49..435ebaee7c2212 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource): 'suggested_questions_after_answer': fields.Raw, 'speech_to_text': fields.Raw, 'retriever_resource': fields.Raw, + 'annotation_reply': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, 'sensitive_word_avoidance': fields.Raw, @@ -49,6 +50,7 @@ def get(self, installed_app: InstalledApp): 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'speech_to_text': app_model_config.speech_to_text_dict, 'retriever_resource': app_model_config.retriever_resource_dict, + 'annotation_reply': app_model_config.annotation_reply_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list, 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, diff --git a/api/controllers/console/universal_chat/parameter.py b/api/controllers/console/universal_chat/parameter.py index fb00ca12cf6334..31bb9f27901edf 100644 --- a/api/controllers/console/universal_chat/parameter.py +++ b/api/controllers/console/universal_chat/parameter.py @@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource): 'suggested_questions_after_answer': fields.Raw, 'speech_to_text': fields.Raw, 'retriever_resource': fields.Raw, + 'annotation_reply': fields.Raw } @marshal_with(parameters_fields) @@ -32,6 +33,7 @@ def get(self, universal_app: App): 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'speech_to_text': app_model_config.speech_to_text_dict, 'retriever_resource': app_model_config.retriever_resource_dict, + 'annotation_reply': app_model_config.annotation_reply_dict, } diff --git a/api/controllers/console/universal_chat/wraps.py b/api/controllers/console/universal_chat/wraps.py index 1fd1747848e896..b940d4f1ad07b8 100644 --- a/api/controllers/console/universal_chat/wraps.py +++ b/api/controllers/console/universal_chat/wraps.py @@ -47,6 +47,7 @@ def decorated(*args, **kwargs): suggested_questions=json.dumps([]), suggested_questions_after_answer=json.dumps({'enabled': True}), speech_to_text=json.dumps({'enabled': True}), + annotation_reply=json.dumps({'enabled': False}), retriever_resource=json.dumps({'enabled': True}), more_like_this=None, sensitive_word_avoidance=None, diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index efe86ea8c360dc..f12802f1e470f1 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -55,6 +55,7 @@ def decorated(*args, **kwargs): members = billing_info['members'] apps = billing_info['apps'] vector_space = billing_info['vector_space'] + annotation_quota_limit = billing_info['annotation_quota_limit'] if resource == 'members' and 0 < members['limit'] <= members['size']: abort(403, error_msg) @@ -62,6 +63,8 @@ def decorated(*args, **kwargs): abort(403, error_msg) elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']: abort(403, error_msg) + elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] <= annotation_quota_limit['size']: + abort(403, error_msg) else: return view(*args, **kwargs) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index f38f60cf5bda43..409709b812a790 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource): 'suggested_questions_after_answer': fields.Raw, 'speech_to_text': fields.Raw, 'retriever_resource': fields.Raw, + 'annotation_reply': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, 'sensitive_word_avoidance': fields.Raw, @@ -49,6 +50,7 @@ def get(self, app_model: App, end_user): 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'speech_to_text': app_model_config.speech_to_text_dict, 'retriever_resource': app_model_config.retriever_resource_dict, + 'annotation_reply': app_model_config.annotation_reply_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list, 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 45213c4c754010..07f200ddc2517d 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource): 'suggested_questions_after_answer': fields.Raw, 'speech_to_text': fields.Raw, 'retriever_resource': fields.Raw, + 'annotation_reply': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, 'sensitive_word_avoidance': fields.Raw, @@ -48,6 +49,7 @@ def get(self, app_model: App, end_user): 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'speech_to_text': app_model_config.speech_to_text_dict, 'retriever_resource': app_model_config.retriever_resource_dict, + 'annotation_reply': app_model_config.annotation_reply_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list, 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, diff --git a/api/core/completion.py b/api/core/completion.py index 30ad23e6295fd5..0f8c78a9fe3230 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -12,8 +12,10 @@ from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \ ConversationTaskInterruptException +from core.embedding.cached_embedding import CacheEmbedding from core.external_data_tool.factory import ExternalDataToolFactory from core.file.file_obj import FileObj +from core.index.vector_index.vector_index import VectorIndex from core.model_providers.error import LLMBadRequestError from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ ReadOnlyConversationTokenDBBufferSharedMemory @@ -23,9 +25,12 @@ from core.orchestrator_rule_parser import OrchestratorRuleParser from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform +from models.dataset import Dataset from models.model import App, AppModelConfig, Account, Conversation, EndUser from core.moderation.base import ModerationException, ModerationAction from core.moderation.factory import ModerationFactory +from services.annotation_service import AppAnnotationService +from services.dataset_service import DatasetCollectionBindingService class Completion: @@ -33,7 +38,7 @@ class Completion: def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict, files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False, retriever_from: str = 'dev', - auto_generate_name: bool = True): + auto_generate_name: bool = True, from_source: str = 'console'): """ errors: ProviderTokenNotInitError """ @@ -109,7 +114,10 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer fake_response=str(e) ) return - + # check annotation reply + annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source) + if annotation_reply: + return # fill in variable inputs from external data tools if exists external_data_tools = app_model_config.external_data_tools_list if external_data_tools: @@ -166,17 +174,18 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer except ChunkedEncodingError as e: # Interrupt by LLM (like OpenAI), handle it. logging.warning(f'ChunkedEncodingError: {e}') - conversation_message_task.end() return @classmethod - def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str): + def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, + query: str): if not app_model_config.sensitive_word_avoidance_dict['enabled']: return inputs, query type = app_model_config.sensitive_word_avoidance_dict['type'] - moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config']) + moderation = ModerationFactory(type, app_id, tenant_id, + app_model_config.sensitive_word_avoidance_dict['config']) moderation_result = moderation.moderation_for_inputs(inputs, query) if not moderation_result.flagged: @@ -324,6 +333,76 @@ def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBuf external_context = memory.load_memory_variables({}) return external_context[memory_key] + @classmethod + def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask, + from_source: str) -> bool: + """Get memory messages.""" + app_model_config = conversation_message_task.app_model_config + app = conversation_message_task.app + annotation_reply = app_model_config.annotation_reply_dict + if annotation_reply['enabled']: + score_threshold = annotation_reply.get('score_threshold', 1) + embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name'] + embedding_model_name = annotation_reply['embedding_model']['embedding_model_name'] + # get embedding model + embedding_model = ModelFactory.get_embedding_model( + tenant_id=app.tenant_id, + model_provider_name=embedding_provider_name, + model_name=embedding_model_name + ) + embeddings = CacheEmbedding(embedding_model) + + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_provider_name, + embedding_model_name, + 'annotation' + ) + + dataset = Dataset( + id=app.id, + tenant_id=app.tenant_id, + indexing_technique='high_quality', + embedding_model_provider=embedding_provider_name, + embedding_model=embedding_model_name, + collection_binding_id=dataset_collection_binding.id + ) + + vector_index = VectorIndex( + dataset=dataset, + config=current_app.config, + embeddings=embeddings + ) + + documents = vector_index.search( + conversation_message_task.query, + search_type='similarity_score_threshold', + search_kwargs={ + 'k': 1, + 'score_threshold': score_threshold, + 'filter': { + 'group_id': [dataset.id] + } + } + ) + if documents: + annotation_id = documents[0].metadata['annotation_id'] + score = documents[0].metadata['score'] + annotation = AppAnnotationService.get_annotation_by_id(annotation_id) + if annotation: + conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name) + # insert annotation history + AppAnnotationService.add_annotation_history(annotation.id, + app.id, + annotation.question, + annotation.content, + conversation_message_task.query, + conversation_message_task.user.id, + conversation_message_task.message.id, + from_source, + score) + return True + return False + @classmethod def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig, conversation: Conversation, diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 614fed273d0815..bf700b3bec5a7f 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -319,6 +319,10 @@ def end(self): self._pub_handler.pub_message_end(self.retriever_resource) self._pub_handler.pub_end() + def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str): + self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at) + self._pub_handler.pub_end() + class PubHandler: def __init__(self, user: Union[Account, EndUser], task_id: str, @@ -435,7 +439,7 @@ def pub_message_end(self, retriever_resource: List): 'task_id': self._task_id, 'message_id': self._message.id, 'mode': self._conversation.mode, - 'conversation_id': self._conversation.id + 'conversation_id': self._conversation.id, } } if retriever_resource: @@ -446,6 +450,30 @@ def pub_message_end(self, retriever_resource: List): self.pub_end() raise ConversationTaskStoppedException() + def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float): + content = { + 'event': 'annotation', + 'data': { + 'task_id': self._task_id, + 'message_id': self._message.id, + 'mode': self._conversation.mode, + 'conversation_id': self._conversation.id, + 'text': text, + 'annotation_id': annotation_id, + 'annotation_author_name': annotation_author_name + } + } + self._message.answer = text + self._message.provider_response_latency = time.perf_counter() - start_at + + db.session.commit() + + redis_client.publish(self._channel, json.dumps(content)) + + if self._is_stopped(): + self.pub_end() + raise ConversationTaskStoppedException() + def pub_end(self): content = { 'event': 'end', diff --git a/api/core/index/base.py b/api/core/index/base.py index 025c94bbe56baf..166b2d65c0dc71 100644 --- a/api/core/index/base.py +++ b/api/core/index/base.py @@ -32,6 +32,10 @@ def text_exists(self, id: str) -> bool: def delete_by_ids(self, ids: list[str]) -> None: raise NotImplementedError + @abstractmethod + def delete_by_metadata_field(self, key: str, value: str) -> None: + raise NotImplementedError + @abstractmethod def delete_by_group_id(self, group_id: str) -> None: raise NotImplementedError diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/index/keyword_table_index/keyword_table_index.py index b315de6191ec58..81e8145bcb2651 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/index/keyword_table_index/keyword_table_index.py @@ -107,6 +107,9 @@ def delete_by_document_id(self, document_id: str): self._save_dataset_keyword_table(keyword_table) + def delete_by_metadata_field(self, key: str, value: str): + pass + def get_retriever(self, **kwargs: Any) -> BaseRetriever: return KeywordTableRetriever(index=self, **kwargs) diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py index a8bba763d4f905..69fbc6beef3969 100644 --- a/api/core/index/vector_index/milvus_vector_index.py +++ b/api/core/index/vector_index/milvus_vector_index.py @@ -121,6 +121,16 @@ def delete_by_document_id(self, document_id: str): 'filter': f'id in {ids}' }) + def delete_by_metadata_field(self, key: str, value: str): + + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + ids = vector_store.get_ids_by_metadata_field(key, value) + if ids: + vector_store.del_texts({ + 'filter': f'id in {ids}' + }) + def delete_by_ids(self, doc_ids: list[str]) -> None: vector_store = self._get_vector_store() diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py index dbadab118eaaa2..fdb0b49bb1c1a5 100644 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ b/api/core/index/vector_index/qdrant_vector_index.py @@ -138,6 +138,22 @@ def delete_by_document_id(self, document_id: str): ], )) + def delete_by_metadata_field(self, key: str, value: str): + + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + from qdrant_client.http import models + + vector_store.del_texts(models.Filter( + must=[ + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue(value=value), + ), + ], + )) + def delete_by_ids(self, ids: list[str]) -> None: vector_store = self._get_vector_store() diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py index 23fa4ec8f537cb..0d51ab8d90a710 100644 --- a/api/core/index/vector_index/weaviate_vector_index.py +++ b/api/core/index/vector_index/weaviate_vector_index.py @@ -141,6 +141,17 @@ def delete_by_document_id(self, document_id: str): "valueText": document_id }) + def delete_by_metadata_field(self, key: str, value: str): + + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + vector_store.del_texts({ + "operator": "Equal", + "path": [key], + "valueText": value + }) + def delete_by_group_id(self, group_id: str): if self._is_origin(): self.recreate_dataset(self.dataset) diff --git a/api/core/vector_store/milvus_vector_store.py b/api/core/vector_store/milvus_vector_store.py index 0055d76c94177d..67b958ded02137 100644 --- a/api/core/vector_store/milvus_vector_store.py +++ b/api/core/vector_store/milvus_vector_store.py @@ -30,6 +30,16 @@ def get_ids_by_document_id(self, document_id: str): else: return None + def get_ids_by_metadata_field(self, key: str, value: str): + result = self.col.query( + expr=f'metadata["{key}"] == "{value}"', + output_fields=["id"] + ) + if result: + return [item["id"] for item in result] + else: + return None + def get_ids_by_doc_ids(self, doc_ids: list): result = self.col.query( expr=f'metadata["doc_id"] in {doc_ids}', diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index d165b014d6191e..4f784c6648aeba 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -6,13 +6,13 @@ @app_model_config_was_updated.connect def handle(sender, **kwargs): - app_model = sender + app = sender app_model_config = kwargs.get('app_model_config') dataset_ids = get_dataset_ids_from_model_config(app_model_config) app_dataset_joins = db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app_model.id + AppDatasetJoin.app_id == app.id ).all() removed_dataset_ids = [] @@ -29,14 +29,14 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app_model.id, + AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() if added_dataset_ids: for dataset_id in added_dataset_ids: app_dataset_join = AppDatasetJoin( - app_id=app_model.id, + app_id=app.id, dataset_id=dataset_id ) db.session.add(app_dataset_join) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py new file mode 100644 index 00000000000000..67cb8e4ea9fca4 --- /dev/null +++ b/api/fields/annotation_fields.py @@ -0,0 +1,36 @@ +from flask_restful import fields +from libs.helper import TimestampField + +account_fields = { + 'id': fields.String, + 'name': fields.String, + 'email': fields.String +} + + +annotation_fields = { + "id": fields.String, + "question": fields.String, + "answer": fields.Raw(attribute='content'), + "hit_count": fields.Integer, + "created_at": TimestampField, + # 'account': fields.Nested(account_fields, allow_null=True) +} + +annotation_list_fields = { + "data": fields.List(fields.Nested(annotation_fields)), +} + +annotation_hit_history_fields = { + "id": fields.String, + "source": fields.String, + "score": fields.Float, + "question": fields.String, + "created_at": TimestampField, + "match": fields.String(attribute='annotation_question'), + "response": fields.String(attribute='annotation_content') +} + +annotation_hit_history_list_fields = { + "data": fields.List(fields.Nested(annotation_hit_history_fields)), +} diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 2c8b5eb10907f6..e9db1aca504a15 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -21,6 +21,7 @@ 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), + 'annotation_reply': fields.Raw(attribute='annotation_reply_dict'), 'more_like_this': fields.Raw(attribute='more_like_this_dict'), 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), 'external_data_tools': fields.Raw(attribute='external_data_tools_list'), diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 49a96c2751dd8e..5e8412ef84de36 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -23,11 +23,18 @@ def format(self, value): } annotation_fields = { + 'id': fields.String, + 'question': fields.String, 'content': fields.String, 'account': fields.Nested(account_fields, allow_null=True), 'created_at': TimestampField } +annotation_hit_history_fields = { + 'annotation_id': fields.String, + 'annotation_create_account': fields.Nested(account_fields, allow_null=True) +} + message_file_fields = { 'id': fields.String, 'type': fields.String, @@ -49,6 +56,7 @@ def format(self, value): 'from_account_id': fields.String, 'feedbacks': fields.List(fields.Nested(feedback_fields)), 'annotation': fields.Nested(annotation_fields, allow_null=True), + 'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True), 'created_at': TimestampField, 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), } diff --git a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py new file mode 100644 index 00000000000000..2a8a9abcb4d21f --- /dev/null +++ b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py @@ -0,0 +1,50 @@ +"""add_app_anntation_setting + +Revision ID: 246ba09cbbdb +Revises: 714aafe25d39 +Create Date: 2023-12-14 11:26:12.287264 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '246ba09cbbdb' +down_revision = '714aafe25d39' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('app_annotation_settings', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('score_threshold', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('collection_binding_id', postgresql.UUID(), nullable=False), + sa.Column('created_user_id', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_user_id', postgresql.UUID(), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey') + ) + with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op: + batch_op.create_index('app_annotation_settings_app_idx', ['app_id'], unique=False) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('annotation_reply') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True)) + + with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op: + batch_op.drop_index('app_annotation_settings_app_idx') + + op.drop_table('app_annotation_settings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py b/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py new file mode 100644 index 00000000000000..286e2d3e092575 --- /dev/null +++ b/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py @@ -0,0 +1,32 @@ +"""add-annotation-histoiry-score + +Revision ID: 46976cc39132 +Revises: e1901f623fd0 +Create Date: 2023-12-13 04:39:59.302971 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '46976cc39132' +down_revision = 'e1901f623fd0' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.add_column(sa.Column('score', sa.Float(), server_default=sa.text('0'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.drop_column('score') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py new file mode 100644 index 00000000000000..5e0eba623b7d8a --- /dev/null +++ b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py @@ -0,0 +1,34 @@ +"""add_anntation_history_match_response + +Revision ID: 714aafe25d39 +Revises: f2a6fc85e260 +Create Date: 2023-12-14 06:38:02.972527 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '714aafe25d39' +down_revision = 'f2a6fc85e260' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False)) + batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.drop_column('annotation_content') + batch_op.drop_column('annotation_question') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py new file mode 100644 index 00000000000000..94200de9d4ccd4 --- /dev/null +++ b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py @@ -0,0 +1,79 @@ +"""add-annotation-reply + +Revision ID: e1901f623fd0 +Revises: fca025d3b60f +Create Date: 2023-12-12 06:58:41.054544 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'e1901f623fd0' +down_revision = 'fca025d3b60f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('app_annotation_hit_histories', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('annotation_id', postgresql.UUID(), nullable=False), + sa.Column('source', sa.Text(), nullable=False), + sa.Column('question', sa.Text(), nullable=False), + sa.Column('account_id', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey') + ) + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.create_index('app_annotation_hit_histories_account_idx', ['account_id'], unique=False) + batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False) + batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True)) + + with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: + batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'::character varying"), nullable=False)) + + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.add_column(sa.Column('question', sa.Text(), nullable=True)) + batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False)) + batch_op.alter_column('conversation_id', + existing_type=postgresql.UUID(), + nullable=True) + batch_op.alter_column('message_id', + existing_type=postgresql.UUID(), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.alter_column('message_id', + existing_type=postgresql.UUID(), + nullable=False) + batch_op.alter_column('conversation_id', + existing_type=postgresql.UUID(), + nullable=False) + batch_op.drop_column('hit_count') + batch_op.drop_column('question') + + with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: + batch_op.drop_column('type') + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('annotation_reply') + + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.drop_index('app_annotation_hit_histories_app_idx') + batch_op.drop_index('app_annotation_hit_histories_annotation_idx') + batch_op.drop_index('app_annotation_hit_histories_account_idx') + + op.drop_table('app_annotation_hit_histories') + # ### end Alembic commands ### diff --git a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py new file mode 100644 index 00000000000000..b85a8fd023bdea --- /dev/null +++ b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py @@ -0,0 +1,34 @@ +"""add_anntation_history_message_id + +Revision ID: f2a6fc85e260 +Revises: 46976cc39132 +Create Date: 2023-12-13 11:09:29.329584 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'f2a6fc85e260' +down_revision = '46976cc39132' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False)) + batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.drop_index('app_annotation_hit_histories_message_idx') + batch_op.drop_column('message_id') + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index a03a5cd38d0d09..a40af353453ab6 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -475,5 +475,6 @@ class DatasetCollectionBinding(db.Model): id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) provider_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(40), nullable=False) + type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) collection_name = db.Column(db.String(64), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/models/model.py b/api/models/model.py index b3570f7f4272d8..1805f0bb339a92 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -2,6 +2,7 @@ from flask import current_app, request from flask_login import UserMixin +from sqlalchemy import Float from sqlalchemy.dialects.postgresql import UUID from core.file.upload_file_parser import UploadFileParser @@ -128,6 +129,25 @@ def retriever_resource_dict(self) -> dict: return json.loads(self.retriever_resource) if self.retriever_resource \ else {"enabled": False} + @property + def annotation_reply_dict(self) -> dict: + annotation_setting = db.session.query(AppAnnotationSetting).filter( + AppAnnotationSetting.app_id == self.app_id).first() + if annotation_setting: + collection_binding_detail = annotation_setting.collection_binding_detail + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name + } + } + + else: + return {"enabled": False} + @property def more_like_this_dict(self) -> dict: return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} @@ -170,7 +190,9 @@ def dataset_configs_dict(self) -> dict: @property def file_upload_dict(self) -> dict: - return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}} + return json.loads(self.file_upload) if self.file_upload else { + "image": {"enabled": False, "number_limits": 3, "detail": "high", + "transfer_methods": ["remote_url", "local_file"]}} def to_dict(self) -> dict: return { @@ -182,6 +204,7 @@ def to_dict(self) -> dict: "suggested_questions_after_answer": self.suggested_questions_after_answer_dict, "speech_to_text": self.speech_to_text_dict, "retriever_resource": self.retriever_resource_dict, + "annotation_reply": self.annotation_reply_dict, "more_like_this": self.more_like_this_dict, "sensitive_word_avoidance": self.sensitive_word_avoidance_dict, "external_data_tools": self.external_data_tools_list, @@ -504,6 +527,12 @@ def annotation(self): annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first() return annotation + @property + def annotation_hit_history(self): + annotation_history = (db.session.query(AppAnnotationHitHistory) + .filter(AppAnnotationHitHistory.message_id == self.id).first()) + return annotation_history + @property def app_model_config(self): conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() @@ -616,9 +645,11 @@ class MessageAnnotation(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) - conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) - message_id = db.Column(UUID, nullable=False) + conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True) + message_id = db.Column(UUID, nullable=True) + question = db.Column(db.Text, nullable=True) content = db.Column(db.Text, nullable=False) + hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0')) account_id = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -629,6 +660,79 @@ def account(self): return account +class AppAnnotationHitHistory(db.Model): + __tablename__ = 'app_annotation_hit_histories' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey'), + db.Index('app_annotation_hit_histories_app_idx', 'app_id'), + db.Index('app_annotation_hit_histories_account_idx', 'account_id'), + db.Index('app_annotation_hit_histories_annotation_idx', 'annotation_id'), + db.Index('app_annotation_hit_histories_message_idx', 'message_id'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + annotation_id = db.Column(UUID, nullable=False) + source = db.Column(db.Text, nullable=False) + question = db.Column(db.Text, nullable=False) + account_id = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + score = db.Column(Float, nullable=False, server_default=db.text('0')) + message_id = db.Column(UUID, nullable=False) + annotation_question = db.Column(db.Text, nullable=False) + annotation_content = db.Column(db.Text, nullable=False) + + @property + def account(self): + account = (db.session.query(Account) + .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) + .filter(MessageAnnotation.id == self.annotation_id).first()) + return account + + @property + def annotation_create_account(self): + account = db.session.query(Account).filter(Account.id == self.account_id).first() + return account + + +class AppAnnotationSetting(db.Model): + __tablename__ = 'app_annotation_settings' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey'), + db.Index('app_annotation_settings_app_idx', 'app_id') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + score_threshold = db.Column(Float, nullable=False, server_default=db.text('0')) + collection_binding_id = db.Column(UUID, nullable=False) + created_user_id = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_user_id = db.Column(UUID, nullable=False) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def created_account(self): + account = (db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id).first()) + return account + + @property + def updated_account(self): + account = (db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id).first()) + return account + + @property + def collection_binding_detail(self): + from .dataset import DatasetCollectionBinding + collection_binding_detail = (db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == self.collection_binding_id).first()) + return collection_binding_detail + + class OperationLog(db.Model): __tablename__ = 'operation_logs' __table_args__ = ( diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py new file mode 100644 index 00000000000000..b84e87bf44783c --- /dev/null +++ b/api/services/annotation_service.py @@ -0,0 +1,426 @@ +import datetime +import json +import uuid + +import pandas as pd +from flask_login import current_user +from sqlalchemy import or_ +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import NotFound + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.model import MessageAnnotation, Message, App, AppAnnotationHitHistory, AppAnnotationSetting +from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task +from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task +from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task +from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task +from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task +from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task + + +class AppAnnotationService: + @classmethod + def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + if 'message_id' in args and args['message_id']: + message_id = str(args['message_id']) + # get message info + message = db.session.query(Message).filter( + Message.id == message_id, + Message.app_id == app.id + ).first() + + if not message: + raise NotFound("Message Not Exists.") + + annotation = message.annotation + # save the message annotation + if annotation: + annotation.content = args['answer'] + annotation.question = args['question'] + else: + annotation = MessageAnnotation( + app_id=app.id, + conversation_id=message.conversation_id, + message_id=message.id, + content=args['answer'], + question=args['question'], + account_id=current_user.id + ) + else: + annotation = MessageAnnotation( + app_id=app.id, + content=args['answer'], + question=args['question'], + account_id=current_user.id + ) + db.session.add(annotation) + db.session.commit() + # if annotation reply is enabled , add annotation to index + annotation_setting = db.session.query(AppAnnotationSetting).filter( + AppAnnotationSetting.app_id == app_id).first() + if annotation_setting: + add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, + app_id, annotation_setting.collection_binding_id) + return annotation + + @classmethod + def enable_app_annotation(cls, args: dict, app_id: str) -> dict: + enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) + cache_result = redis_client.get(enable_app_annotation_key) + if cache_result is not None: + return { + 'job_id': cache_result, + 'job_status': 'processing' + } + + # async job + job_id = str(uuid.uuid4()) + enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) + # send batch add segments task + redis_client.setnx(enable_app_annotation_job_key, 'waiting') + enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id, + args['score_threshold'], + args['embedding_provider_name'], args['embedding_model_name']) + return { + 'job_id': job_id, + 'job_status': 'waiting' + } + + @classmethod + def disable_app_annotation(cls, app_id: str) -> dict: + disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) + cache_result = redis_client.get(disable_app_annotation_key) + if cache_result is not None: + return { + 'job_id': cache_result, + 'job_status': 'processing' + } + + # async job + job_id = str(uuid.uuid4()) + disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) + # send batch add segments task + redis_client.setnx(disable_app_annotation_job_key, 'waiting') + disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) + return { + 'job_id': job_id, + 'job_status': 'waiting' + } + + @classmethod + def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + if keyword: + annotations = (db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .filter( + or_( + MessageAnnotation.question.ilike('%{}%'.format(keyword)), + MessageAnnotation.content.ilike('%{}%'.format(keyword)) + ) + ) + .order_by(MessageAnnotation.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + else: + annotations = (db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + return annotations.items, annotations.total + + @classmethod + def export_annotation_list_by_app_id(cls, app_id: str): + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + annotations = (db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()).all()) + return annotations + + @classmethod + def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + + annotation = MessageAnnotation( + app_id=app.id, + content=args['answer'], + question=args['question'], + account_id=current_user.id + ) + db.session.add(annotation) + db.session.commit() + # if annotation reply is enabled , add annotation to index + annotation_setting = db.session.query(AppAnnotationSetting).filter( + AppAnnotationSetting.app_id == app_id).first() + if annotation_setting: + add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, + app_id, annotation_setting.collection_binding_id) + return annotation + + @classmethod + def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + + annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + + if not annotation: + raise NotFound("Annotation not found") + + annotation.content = args['answer'] + annotation.question = args['question'] + + db.session.commit() + # if annotation reply is enabled , add annotation to index + app_annotation_setting = db.session.query(AppAnnotationSetting).filter( + AppAnnotationSetting.app_id == app_id + ).first() + + if app_annotation_setting: + update_annotation_to_index_task.delay(annotation.id, annotation.question, + current_user.current_tenant_id, + app_id, app_annotation_setting.collection_binding_id) + + return annotation + + @classmethod + def delete_app_annotation(cls, app_id: str, annotation_id: str): + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + + annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + + if not annotation: + raise NotFound("Annotation not found") + + db.session.delete(annotation) + + annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) + .filter(AppAnnotationHitHistory.annotation_id == annotation_id) + .all() + ) + if annotation_hit_histories: + for annotation_hit_history in annotation_hit_histories: + db.session.delete(annotation_hit_history) + + db.session.commit() + # if annotation reply is enabled , delete annotation index + app_annotation_setting = db.session.query(AppAnnotationSetting).filter( + AppAnnotationSetting.app_id == app_id + ).first() + + if app_annotation_setting: + delete_annotation_index_task.delay(annotation.id, app_id, + current_user.current_tenant_id, + app_annotation_setting.collection_binding_id) + + @classmethod + def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + + try: + # Skip the first row + df = pd.read_csv(file) + result = [] + for index, row in df.iterrows(): + content = { + 'question': row[0], + 'answer': row[1] + } + result.append(content) + if len(result) == 0: + raise ValueError("The CSV file is empty.") + # async job + job_id = str(uuid.uuid4()) + indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + # send batch add segments task + redis_client.setnx(indexing_cache_key, 'waiting') + batch_import_annotations_task.delay(str(job_id), result, app_id, + current_user.current_tenant_id, current_user.id) + except Exception as e: + return { + 'error_msg': str(e) + } + return { + 'job_id': job_id, + 'job_status': 'waiting' + } + + @classmethod + def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + + annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + + if not annotation: + raise NotFound("Annotation not found") + + annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) + .filter(AppAnnotationHitHistory.app_id == app_id, + AppAnnotationHitHistory.annotation_id == annotation_id, + ) + .order_by(AppAnnotationHitHistory.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + return annotation_hit_histories.items, annotation_hit_histories.total + + @classmethod + def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: + annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + + if not annotation: + return None + return annotation + + @classmethod + def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str, + annotation_content: str, query: str, user_id: str, + message_id: str, from_source: str, score: float): + # add hit count to annotation + db.session.query(MessageAnnotation).filter( + MessageAnnotation.id == annotation_id + ).update( + {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, + synchronize_session=False + ) + + annotation_hit_history = AppAnnotationHitHistory( + annotation_id=annotation_id, + app_id=app_id, + account_id=user_id, + question=query, + source=from_source, + score=score, + message_id=message_id, + annotation_question=annotation_question, + annotation_content=annotation_content + ) + db.session.add(annotation_hit_history) + db.session.commit() + + @classmethod + def get_app_annotation_setting_by_app_id(cls, app_id: str): + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + + annotation_setting = db.session.query(AppAnnotationSetting).filter( + AppAnnotationSetting.app_id == app_id).first() + if annotation_setting: + collection_binding_detail = annotation_setting.collection_binding_detail + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name + } + } + return { + "enabled": False + } + + @classmethod + def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + + annotation_setting = db.session.query(AppAnnotationSetting).filter( + AppAnnotationSetting.app_id == app_id, + AppAnnotationSetting.id == annotation_setting_id, + ).first() + if not annotation_setting: + raise NotFound("App annotation not found") + annotation_setting.score_threshold = args['score_threshold'] + annotation_setting.updated_user_id = current_user.id + annotation_setting.updated_at = datetime.datetime.utcnow() + db.session.add(annotation_setting) + db.session.commit() + + collection_binding_detail = annotation_setting.collection_binding_detail + + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name + } + } diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 3ffd8b0431e9ae..97a96aab993692 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -138,7 +138,22 @@ def validate_configuration(cls, tenant_id: str, account: Account, config: dict, config["retriever_resource"]["enabled"] = False if not isinstance(config["retriever_resource"]["enabled"], bool): - raise ValueError("enabled in speech_to_text must be of boolean type") + raise ValueError("enabled in retriever_resource must be of boolean type") + + # annotation reply + if 'annotation_reply' not in config or not config["annotation_reply"]: + config["annotation_reply"] = { + "enabled": False + } + + if not isinstance(config["annotation_reply"], dict): + raise ValueError("annotation_reply must be of dict type") + + if "enabled" not in config["annotation_reply"] or not config["annotation_reply"]["enabled"]: + config["annotation_reply"]["enabled"] = False + + if not isinstance(config["annotation_reply"]["enabled"], bool): + raise ValueError("enabled in annotation_reply must be of boolean type") # more_like_this if 'more_like_this' not in config or not config["more_like_this"]: @@ -325,6 +340,7 @@ def validate_configuration(cls, tenant_id: str, account: Account, config: dict, "suggested_questions_after_answer": config["suggested_questions_after_answer"], "speech_to_text": config["speech_to_text"], "retriever_resource": config["retriever_resource"], + "annotation_reply": config["annotation_reply"], "more_like_this": config["more_like_this"], "sensitive_word_avoidance": config["sensitive_word_avoidance"], "external_data_tools": config["external_data_tools"], diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 249766236f61c2..977acd9c9a6d3c 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -165,7 +165,8 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, 'streaming': streaming, 'is_model_config_override': is_model_config_override, 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev', - 'auto_generate_name': auto_generate_name + 'auto_generate_name': auto_generate_name, + 'from_source': from_source }) generate_worker_thread.start() @@ -193,7 +194,7 @@ def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_m query: str, inputs: dict, files: List[PromptMessageFile], detached_user: Union[Account, EndUser], detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool, - retriever_from: str = 'dev', auto_generate_name: bool = True): + retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'): with flask_app.app_context(): # fixed the state of the model object when it detached from the original session user = db.session.merge(detached_user) @@ -218,7 +219,8 @@ def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_m streaming=streaming, is_override=is_model_config_override, retriever_from=retriever_from, - auto_generate_name=auto_generate_name + auto_generate_name=auto_generate_name, + from_source=from_source ) except (ConversationTaskInterruptException, ConversationTaskStoppedException): pass @@ -385,6 +387,9 @@ def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict result = json.loads(result) if result.get('error'): cls.handle_error(result) + if result['event'] == 'annotation' and 'data' in result: + message_result['annotation'] = result.get('data') + return cls.get_blocking_annotation_message_response_data(message_result) if result['event'] == 'message' and 'data' in result: message_result['message'] = result.get('data') if result['event'] == 'message_end' and 'data' in result: @@ -427,6 +432,9 @@ def generate() -> Generator: elif event == 'agent_thought': yield "data: " + json.dumps( cls.get_agent_thought_response_data(result.get('data'))) + "\n\n" + elif event == 'annotation': + yield "data: " + json.dumps( + cls.get_annotation_response_data(result.get('data'))) + "\n\n" elif event == 'message_end': yield "data: " + json.dumps( cls.get_message_end_data(result.get('data'))) + "\n\n" @@ -499,6 +507,25 @@ def get_blocking_message_response_data(cls, data: dict): return response_data + @classmethod + def get_blocking_annotation_message_response_data(cls, data: dict): + message = data.get('annotation') + response_data = { + 'event': 'annotation', + 'task_id': message.get('task_id'), + 'id': message.get('message_id'), + 'answer': message.get('text'), + 'metadata': {}, + 'created_at': int(time.time()), + 'annotation_id': message.get('annotation_id'), + 'annotation_author_name': message.get('annotation_author_name') + } + + if message.get('mode') == 'chat': + response_data['conversation_id'] = message.get('conversation_id') + + return response_data + @classmethod def get_message_end_data(cls, data: dict): response_data = { @@ -551,6 +578,23 @@ def get_agent_thought_response_data(cls, data: dict): return response_data + @classmethod + def get_annotation_response_data(cls, data: dict): + response_data = { + 'event': 'annotation', + 'task_id': data.get('task_id'), + 'id': data.get('message_id'), + 'answer': data.get('text'), + 'created_at': int(time.time()), + 'annotation_id': data.get('annotation_id'), + 'annotation_author_name': data.get('annotation_author_name'), + } + + if data.get('mode') == 'chat': + response_data['conversation_id'] = data.get('conversation_id') + + return response_data + @classmethod def handle_error(cls, result: dict): logging.debug("error: %s", result) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f6b57321b17481..92740609a6a477 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -33,10 +33,7 @@ from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task -from tasks.create_segment_to_index_task import create_segment_to_index_task -from tasks.update_segment_index_task import update_segment_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task -from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -1175,10 +1172,12 @@ def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: D class DatasetCollectionBindingService: @classmethod - def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding: + def get_dataset_collection_binding(cls, provider_name: str, model_name: str, + collection_type: str = 'dataset') -> DatasetCollectionBinding: dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ filter(DatasetCollectionBinding.provider_name == provider_name, - DatasetCollectionBinding.model_name == model_name). \ + DatasetCollectionBinding.model_name == model_name, + DatasetCollectionBinding.type == collection_type). \ order_by(DatasetCollectionBinding.created_at). \ first() @@ -1186,8 +1185,20 @@ def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> dataset_collection_binding = DatasetCollectionBinding( provider_name=provider_name, model_name=model_name, - collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' + collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node', + type=collection_type ) db.session.add(dataset_collection_binding) - db.session.flush() + db.session.commit() + return dataset_collection_binding + + @classmethod + def get_dataset_collection_binding_by_id_and_type(cls, collection_binding_id: str, + collection_type: str = 'dataset') -> DatasetCollectionBinding: + dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ + filter(DatasetCollectionBinding.id == collection_binding_id, + DatasetCollectionBinding.type == collection_type). \ + order_by(DatasetCollectionBinding.created_at). \ + first() + return dataset_collection_binding diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py new file mode 100644 index 00000000000000..84d94f39cacbd8 --- /dev/null +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -0,0 +1,59 @@ +import logging +import time + +import click +from celery import shared_task +from langchain.schema import Document + +from core.index.index import IndexBuilder + +from models.dataset import Dataset +from services.dataset_service import DatasetCollectionBindingService + + +@shared_task(queue='dataset') +def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, + collection_binding_id: str): + """ + Add annotation to index. + :param annotation_id: annotation id + :param question: question + :param tenant_id: tenant id + :param app_id: app id + :param collection_binding_id: embedding binding id + + Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) + """ + logging.info(click.style('Start build index for annotation: {}'.format(annotation_id), fg='green')) + start_at = time.perf_counter() + + try: + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + collection_binding_id, + 'annotation' + ) + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique='high_quality', + collection_binding_id=dataset_collection_binding.id + ) + + document = Document( + page_content=question, + metadata={ + "annotation_id": annotation_id, + "app_id": app_id, + "doc_id": annotation_id + } + ) + index = IndexBuilder.get_index(dataset, 'high_quality') + if index: + index.add_texts([document]) + end_at = time.perf_counter() + logging.info( + click.style( + 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), + fg='green')) + except Exception: + logging.exception("Build index for annotation failed") diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py new file mode 100644 index 00000000000000..8c908bf97e6b16 --- /dev/null +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -0,0 +1,99 @@ +import json +import logging +import time + +import click +from celery import shared_task +from langchain.schema import Document +from werkzeug.exceptions import NotFound + +from core.index.index import IndexBuilder +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset +from models.model import MessageAnnotation, App, AppAnnotationSetting +from services.dataset_service import DatasetCollectionBindingService + + +@shared_task(queue='dataset') +def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, + user_id: str): + """ + Add annotation to index. + :param job_id: job_id + :param content_list: content list + :param tenant_id: tenant id + :param app_id: app id + :param user_id: user_id + + """ + logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green')) + start_at = time.perf_counter() + indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == tenant_id, + App.status == 'normal' + ).first() + + if app: + try: + documents = [] + for content in content_list: + annotation = MessageAnnotation( + app_id=app.id, + content=content['answer'], + question=content['question'], + account_id=user_id + ) + db.session.add(annotation) + db.session.flush() + + document = Document( + page_content=content['question'], + metadata={ + "annotation_id": annotation.id, + "app_id": app_id, + "doc_id": annotation.id + } + ) + documents.append(document) + # if annotation reply is enabled , batch add annotations' index + app_annotation_setting = db.session.query(AppAnnotationSetting).filter( + AppAnnotationSetting.app_id == app_id + ).first() + + if app_annotation_setting: + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + app_annotation_setting.collection_binding_id, + 'annotation' + ) + if not dataset_collection_binding: + raise NotFound("App annotation setting not found") + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique='high_quality', + embedding_model_provider=dataset_collection_binding.provider_name, + embedding_model=dataset_collection_binding.model_name, + collection_binding_id=dataset_collection_binding.id + ) + + index = IndexBuilder.get_index(dataset, 'high_quality') + if index: + index.add_texts(documents) + + db.session.commit() + redis_client.setex(indexing_cache_key, 600, 'completed') + end_at = time.perf_counter() + logging.info( + click.style( + 'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at), + fg='green')) + except Exception as e: + db.session.rollback() + redis_client.setex(indexing_cache_key, 600, 'error') + indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) + redis_client.setex(indexing_error_msg_key, 600, str(e)) + logging.exception("Build index for batch import annotations failed") diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py new file mode 100644 index 00000000000000..f4afb2383f6cbc --- /dev/null +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -0,0 +1,45 @@ +import datetime +import logging +import time + +import click +from celery import shared_task +from core.index.index import IndexBuilder +from models.dataset import Dataset +from services.dataset_service import DatasetCollectionBindingService + + +@shared_task(queue='dataset') +def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, + collection_binding_id: str): + """ + Async delete annotation index task + """ + logging.info(click.style('Start delete app annotation index: {}'.format(app_id), fg='green')) + start_at = time.perf_counter() + try: + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + collection_binding_id, + 'annotation' + ) + + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique='high_quality', + collection_binding_id=dataset_collection_binding.id + ) + + vector_index = IndexBuilder.get_default_high_quality_index(dataset) + if vector_index: + try: + vector_index.delete_by_metadata_field('annotation_id', annotation_id) + except Exception: + logging.exception("Delete annotation index failed when annotation deleted.") + end_at = time.perf_counter() + logging.info( + click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), + fg='green')) + except Exception as e: + logging.exception("Annotation deleted index failed:{}".format(str(e))) + diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py new file mode 100644 index 00000000000000..ee665fbb92adec --- /dev/null +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -0,0 +1,74 @@ +import datetime +import logging +import time + +import click +from celery import shared_task +from werkzeug.exceptions import NotFound + +from core.index.index import IndexBuilder +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset +from models.model import MessageAnnotation, App, AppAnnotationSetting + + +@shared_task(queue='dataset') +def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): + """ + Async enable annotation reply task + """ + logging.info(click.style('Start delete app annotations index: {}'.format(app_id), fg='green')) + start_at = time.perf_counter() + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + + app_annotation_setting = db.session.query(AppAnnotationSetting).filter( + AppAnnotationSetting.app_id == app_id + ).first() + + if not app_annotation_setting: + raise NotFound("App annotation setting not found") + + disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) + disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) + + try: + + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique='high_quality', + collection_binding_id=app_annotation_setting.collection_binding_id + ) + + vector_index = IndexBuilder.get_default_high_quality_index(dataset) + if vector_index: + try: + vector_index.delete_by_metadata_field('app_id', app_id) + except Exception: + logging.exception("Delete doc index failed when dataset deleted.") + redis_client.setex(disable_app_annotation_job_key, 600, 'completed') + + # delete annotation setting + db.session.delete(app_annotation_setting) + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), + fg='green')) + except Exception as e: + logging.exception("Annotation batch deleted index failed:{}".format(str(e))) + redis_client.setex(disable_app_annotation_job_key, 600, 'error') + disable_app_annotation_error_key = 'disable_app_annotation_error_{}'.format(str(job_id)) + redis_client.setex(disable_app_annotation_error_key, 600, str(e)) + finally: + redis_client.delete(disable_app_annotation_key) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py new file mode 100644 index 00000000000000..5a68b9285c7227 --- /dev/null +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -0,0 +1,106 @@ +import datetime +import logging +import time + +import click +from celery import shared_task +from langchain.schema import Document +from werkzeug.exceptions import NotFound + +from core.index.index import IndexBuilder +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset +from models.model import MessageAnnotation, App, AppAnnotationSetting +from services.dataset_service import DatasetCollectionBindingService + + +@shared_task(queue='dataset') +def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float, + embedding_provider_name: str, embedding_model_name: str): + """ + Async enable annotation reply task + """ + logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green')) + start_at = time.perf_counter() + # get app info + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + + annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all() + enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) + enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) + + try: + documents = [] + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_provider_name, + embedding_model_name, + 'annotation' + ) + annotation_setting = db.session.query(AppAnnotationSetting).filter( + AppAnnotationSetting.app_id == app_id).first() + if annotation_setting: + annotation_setting.score_threshold = score_threshold + annotation_setting.collection_binding_id = dataset_collection_binding.id + annotation_setting.updated_user_id = user_id + annotation_setting.updated_at = datetime.datetime.utcnow() + db.session.add(annotation_setting) + else: + new_app_annotation_setting = AppAnnotationSetting( + app_id=app_id, + score_threshold=score_threshold, + collection_binding_id=dataset_collection_binding.id, + created_user_id=user_id, + updated_user_id=user_id + ) + db.session.add(new_app_annotation_setting) + + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique='high_quality', + embedding_model_provider=embedding_provider_name, + embedding_model=embedding_model_name, + collection_binding_id=dataset_collection_binding.id + ) + if annotations: + for annotation in annotations: + document = Document( + page_content=annotation.question, + metadata={ + "annotation_id": annotation.id, + "app_id": app_id, + "doc_id": annotation.id + } + ) + documents.append(document) + index = IndexBuilder.get_index(dataset, 'high_quality') + if index: + try: + index.delete_by_metadata_field('app_id', app_id) + except Exception as e: + logging.info( + click.style('Delete annotation index error: {}'.format(str(e)), + fg='red')) + index.add_texts(documents) + db.session.commit() + redis_client.setex(enable_app_annotation_job_key, 600, 'completed') + end_at = time.perf_counter() + logging.info( + click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at), + fg='green')) + except Exception as e: + logging.exception("Annotation batch created index failed:{}".format(str(e))) + redis_client.setex(enable_app_annotation_job_key, 600, 'error') + enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id)) + redis_client.setex(enable_app_annotation_error_key, 600, str(e)) + db.session.rollback() + finally: + redis_client.delete(enable_app_annotation_key) diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py new file mode 100644 index 00000000000000..e477b8c2c8f4c1 --- /dev/null +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -0,0 +1,63 @@ +import logging +import time + +import click +from celery import shared_task +from langchain.schema import Document + +from core.index.index import IndexBuilder + +from models.dataset import Dataset +from services.dataset_service import DatasetCollectionBindingService + + +@shared_task(queue='dataset') +def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, + collection_binding_id: str): + """ + Update annotation to index. + :param annotation_id: annotation id + :param question: question + :param tenant_id: tenant id + :param app_id: app id + :param collection_binding_id: embedding binding id + + Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) + """ + logging.info(click.style('Start update index for annotation: {}'.format(annotation_id), fg='green')) + start_at = time.perf_counter() + + try: + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + collection_binding_id, + 'annotation' + ) + + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique='high_quality', + embedding_model_provider=dataset_collection_binding.provider_name, + embedding_model=dataset_collection_binding.model_name, + collection_binding_id=dataset_collection_binding.id + ) + + document = Document( + page_content=question, + metadata={ + "annotation_id": annotation_id, + "app_id": app_id, + "doc_id": annotation_id + } + ) + index = IndexBuilder.get_index(dataset, 'high_quality') + if index: + index.delete_by_metadata_field('annotation_id', annotation_id) + index.add_texts([document]) + end_at = time.perf_counter() + logging.info( + click.style( + 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), + fg='green')) + except Exception: + logging.exception("Build index for annotation failed")