Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/dataset service api #1245

Merged
merged 16 commits into from
Sep 27, 2023
Prev Previous commit
Next Next commit
dataset service api
  • Loading branch information
JohnJyong committed Sep 19, 2023
commit 55d140aef06ec3476dfaf4c94f7b1dfdf778ff58
160 changes: 160 additions & 0 deletions api/controllers/service_api/dataset/segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import datetime
import uuid

from flask import current_app, request
from flask_login import current_user
from flask_restful import reqparse, marshal, fields
from werkzeug.exceptions import NotFound

import services.dataset_service
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
DatasetNotInitedError
from controllers.service_api.wraps import DatasetApiResource
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.helper import TimestampField
from models.dataset import Dataset
from models.model import UploadFile
from services.dataset_service import DocumentService, SegmentService
from services.file_service import FileService



class SegmentApi(DatasetApiResource):
"""Resource for segments."""

def post(self, document_id, dataset):
"""Create single segment."""
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200


class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""

def post(self, dataset):
"""Create document by upload file."""
parser = reqparse.RequestParser()
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
location='json')
parser.add_argument('doc_type', type=str, required=False, nullable=True, location='json')
parser.add_argument('doc_metadata', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()

if not dataset.indexing_technique and not args['indexing_technique']:
raise ValueError('indexing_technique is required.')

# validate args
DocumentService.document_create_args_validate(args)

doc_type = args.get('doc_type')
doc_metadata = args.get('doc_metadata')

if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA:
raise ValueError('Invalid doc_type.')
# save file info
file = request.files['file']
upload_file = FileService.upload_file(file)
data_source = {
'type': 'upload_file',
'info_list': {
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source

try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]

document.doc_metadata = {}

for key, value_type in metadata_schema.items():
value = doc_metadata.get(key)
if value is not None and isinstance(value, value_type):
document.doc_metadata[key] = value

document.doc_type = doc_type
document.updated_at = datetime.datetime.utcnow()
db.session.commit()

return {'id': document.id}


class DocumentApi(DatasetApiResource):
def delete(self, dataset, document_id):
"""Delete document."""
document_id = str(document_id)

document = DocumentService.get_document(dataset.id, document_id)

# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")

# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()

try:
# delete document
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError('Cannot delete document during indexing.')

return {'result': 'success'}, 204


api.add_resource(DocumentAddByTextApi, '/text/documents')
api.add_resource(DocumentAddByFileApi, '/file/documents')
api.add_resource(DocumentApi, '/documents/<uuid:document_id>')
Empty file added api/fields/__init__.py
Empty file.
137 changes: 137 additions & 0 deletions api/fields/app_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from flask_restful import fields

from libs.helper import TimestampField

app_detail_kernel_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
}

related_app_list = {
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
'total': fields.Integer,
}

model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
}

app_detail_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'created_at': TimestampField
}

prompt_config_fields = {
'prompt_template': fields.String,
}

model_config_partial_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}

app_partial_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
'created_at': TimestampField
}

app_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
}

template_fields = {
'name': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'mode': fields.String,
'model_config': fields.Nested(model_config_fields),
}

template_list_fields = {
'data': fields.List(fields.Nested(template_fields)),
}

site_fields = {
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean,
'app_base_url': fields.String,
}

app_detail_fields_with_site = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'site': fields.Nested(site_fields),
'api_base_url': fields.String,
'created_at': TimestampField
}

app_site_fields = {
'app_id': fields.String,
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean
}
Loading