From f87e7242cd28ca59fb71f88d190a1f29fd2c0311 Mon Sep 17 00:00:00 2001 From: LiuHua <10215101452@stu.ecnu.edu.cn> Date: Thu, 29 Aug 2024 14:31:31 +0800 Subject: [PATCH] complete implementation of dataset SDK (#2147) ### What problem does this PR solve? Complete implementation of dataset SDK. #1102 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn> Co-authored-by: Kevin Hu --- api/apps/sdk/dataset.py | 162 +++++++++++++++++++------- api/utils/api_utils.py | 47 +++++--- sdk/python/ragflow/modules/base.py | 12 +- sdk/python/ragflow/modules/dataset.py | 28 ++++- sdk/python/ragflow/ragflow.py | 54 ++++++--- sdk/python/test/t_dataset.py | 41 ++++++- 6 files changed, 259 insertions(+), 85 deletions(-) diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 7a885ab38f..3d131f6074 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -15,82 +15,156 @@ # from flask import request -from api.db import StatusEnum -from api.db.db_models import APIToken +from api.db import StatusEnum, FileSource +from api.db.db_models import File +from api.db.services.document_service import DocumentService +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import TenantService from api.settings import RetCode from api.utils import get_uuid -from api.utils.api_utils import get_data_error_result -from api.utils.api_utils import get_json_result +from api.utils.api_utils import get_json_result, token_required, get_data_error_result @manager.route('/save', methods=['POST']) -def save(): +@token_required +def save(tenant_id): req = request.json - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) - tenant_id = objs[0].tenant_id e, t = TenantService.get_by_id(tenant_id) - if not e: - return get_data_error_result(retmsg="Tenant not found.") if "id" not in req: + if "tenant_id" in req or "embd_id" in req: + return get_data_error_result( + retmsg="Tenant_id or embedding_model must not be provided") + if "name" not in req: + return get_data_error_result( + retmsg="Name is not empty!") req['id'] = get_uuid() req["name"] = req["name"].strip() if req["name"] == "": return get_data_error_result( - retmsg="Name is not empty") - if KnowledgebaseService.query(name=req["name"]): + retmsg="Name is not empty string!") + if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): return get_data_error_result( - retmsg="Duplicated knowledgebase name") + retmsg="Duplicated knowledgebase name in creating dataset.") req["tenant_id"] = tenant_id req['created_by'] = tenant_id req['embd_id'] = t.embd_id if not KnowledgebaseService.save(**req): - return get_data_error_result(retmsg="Data saving error") - req.pop('created_by') - keys_to_rename = {'embd_id': "embedding_model", 'parser_id': 'parser_method', - 'chunk_num': 'chunk_count', 'doc_num': 'document_count'} - for old_key,new_key in keys_to_rename.items(): - if old_key in req: - req[new_key]=req.pop(old_key) + return get_data_error_result(retmsg="Create dataset error.(Database error)") return get_json_result(data=req) else: - if req["tenant_id"] != tenant_id or req["embd_id"] != t.embd_id: - return get_data_error_result( - retmsg="Can't change tenant_id or embedding_model") + if "tenant_id" in req: + if req["tenant_id"] != tenant_id: + return get_data_error_result( + retmsg="Can't change tenant_id.") - e, kb = KnowledgebaseService.get_by_id(req["id"]) - if not e: - return get_data_error_result( - retmsg="Can't find this knowledgebase!") + if "embd_id" in req: + if req["embd_id"] != t.embd_id: + return get_data_error_result( + retmsg="Can't change embedding_model.") if not KnowledgebaseService.query( created_by=tenant_id, id=req["id"]): return get_json_result( - data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + data=False, retmsg='You do not own the dataset.', retcode=RetCode.OPERATING_ERROR) - if req["chunk_num"] != kb.chunk_num or req['doc_num'] != kb.doc_num: - return get_data_error_result( - retmsg="Can't change document_count or chunk_count ") + e, kb = KnowledgebaseService.get_by_id(req["id"]) - if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id: - return get_data_error_result( - retmsg="if chunk count is not 0, parser method is not changable. ") + if "chunk_num" in req: + if req["chunk_num"] != kb.chunk_num: + return get_data_error_result( + retmsg="Can't change chunk_count.") + if "doc_num" in req: + if req['doc_num'] != kb.doc_num: + return get_data_error_result( + retmsg="Can't change document_count.") - if req["name"].lower() != kb.name.lower() \ - and len(KnowledgebaseService.query(name=req["name"], tenant_id=req['tenant_id'], - status=StatusEnum.VALID.value)) > 0: - return get_data_error_result( - retmsg="Duplicated knowledgebase name.") + if "parser_id" in req: + if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id: + return get_data_error_result( + retmsg="if chunk count is not 0, parse method is not changable.") + if "name" in req: + if req["name"].lower() != kb.name.lower() \ + and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, + status=StatusEnum.VALID.value)) > 0: + return get_data_error_result( + retmsg="Duplicated knowledgebase name in updating dataset.") del req["id"] - req['created_by'] = tenant_id if not KnowledgebaseService.update_by_id(kb.id, req): - return get_data_error_result(retmsg="Data update error ") + return get_data_error_result(retmsg="Update dataset error.(Database error)") return get_json_result(data=True) + + +@manager.route('/delete', methods=['DELETE']) +@token_required +def delete(tenant_id): + req = request.args + kbs = KnowledgebaseService.query( + created_by=tenant_id, id=req["id"]) + if not kbs: + return get_json_result( + data=False, retmsg='You do not own the dataset', + retcode=RetCode.OPERATING_ERROR) + + for doc in DocumentService.query(kb_id=req["id"]): + if not DocumentService.remove_document(doc, kbs[0].tenant_id): + return get_data_error_result( + retmsg="Remove document error.(Database error)") + f2d = File2DocumentService.get_by_document_id(doc.id) + FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) + File2DocumentService.delete_by_document_id(doc.id) + + if not KnowledgebaseService.delete_by_id(req["id"]): + return get_data_error_result( + retmsg="Delete dataset error.(Database error)") + return get_json_result(data=True) + + +@manager.route('/list', methods=['GET']) +@token_required +def list_datasets(tenant_id): + page_number = int(request.args.get("page", 1)) + items_per_page = int(request.args.get("page_size", 1024)) + orderby = request.args.get("orderby", "create_time") + desc = bool(request.args.get("desc", True)) + tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) + kbs = KnowledgebaseService.get_by_tenant_ids( + [m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc) + return get_json_result(data=kbs) + + +@manager.route('/detail', methods=['GET']) +@token_required +def detail(tenant_id): + req = request.args + if "id" in req: + id = req["id"] + kb = KnowledgebaseService.query(created_by=tenant_id, id=req["id"]) + if not kb: + return get_json_result( + data=False, retmsg='You do not own the dataset', + retcode=RetCode.OPERATING_ERROR) + if "name" in req: + name = req["name"] + if kb[0].name != name: + return get_json_result( + data=False, retmsg='You do not own the dataset', + retcode=RetCode.OPERATING_ERROR) + e, k = KnowledgebaseService.get_by_id(id) + return get_json_result(data=k.to_dict()) + else: + if "name" in req: + name = req["name"] + e, k = KnowledgebaseService.get_by_name(kb_name=name, tenant_id=tenant_id) + if not e: + return get_json_result( + data=False, retmsg='You do not own the dataset', + retcode=RetCode.OPERATING_ERROR) + return get_json_result(data=k.to_dict()) + else: + return get_data_error_result( + retmsg="At least one of `id` or `name` must be provided.") diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index b7f51369bc..c5b93d56f0 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -13,30 +13,32 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import functools import json import random import time +from base64 import b64encode from functools import wraps +from hmac import HMAC from io import BytesIO +from urllib.parse import quote, urlencode +from uuid import uuid1 + +import requests from flask import ( Response, jsonify, send_file, make_response, request as flask_request, ) from werkzeug.http import HTTP_STATUS_CODES -from api.utils import json_dumps -from api.settings import RetCode +from api.db.db_models import APIToken from api.settings import ( REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY ) -import requests -import functools +from api.settings import RetCode from api.utils import CustomJSONEncoder -from uuid import uuid1 -from base64 import b64encode -from hmac import HMAC -from urllib.parse import quote, urlencode +from api.utils import json_dumps requests.models.complexjson.dumps = functools.partial( json.dumps, cls=CustomJSONEncoder) @@ -96,7 +98,6 @@ def get_exponential_backoff_interval(retries, full_jitter=False): def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None): - import re result_dict = { "retcode": retcode, "retmsg": retmsg, @@ -145,7 +146,8 @@ def server_error_response(e): return get_json_result( retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1]) if repr(e).find("index_not_found_exception") >= 0: - return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg="No chunk found, please upload file and parse it.") + return get_json_result(retcode=RetCode.EXCEPTION_ERROR, + retmsg="No chunk found, please upload file and parse it.") return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e)) @@ -190,7 +192,9 @@ def decorated_function(*_args, **_kwargs): return get_json_result( retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string) return func(*_args, **_kwargs) + return decorated_function + return wrapper @@ -217,7 +221,7 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None): def construct_response(retcode=RetCode.SUCCESS, - retmsg='success', data=None, auth=None): + retmsg='success', data=None, auth=None): result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data} response_dict = {} for key, value in result_dict.items(): @@ -235,6 +239,7 @@ def construct_response(retcode=RetCode.SUCCESS, response.headers["Access-Control-Expose-Headers"] = "Authorization" return response + def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): import re result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} @@ -263,7 +268,23 @@ def construct_error_response(e): pass if len(e.args) > 1: return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) - if repr(e).find("index_not_found_exception") >=0: - return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") + if repr(e).find("index_not_found_exception") >= 0: + return construct_json_result(code=RetCode.EXCEPTION_ERROR, + message="No chunk found, please upload file and parse it.") return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) + + +def token_required(func): + @wraps(func) + def decorated_function(*args, **kwargs): + token = flask_request.headers.get('Authorization').split()[1] + objs = APIToken.query(token=token) + if not objs: + return get_json_result( + data=False, retmsg='Token is not valid!', retcode=RetCode.AUTHENTICATION_ERROR + ) + kwargs['tenant_id'] = objs[0].tenant_id + return func(*args, **kwargs) + + return decorated_function diff --git a/sdk/python/ragflow/modules/base.py b/sdk/python/ragflow/modules/base.py index fe22e55654..641a5fb5ee 100644 --- a/sdk/python/ragflow/modules/base.py +++ b/sdk/python/ragflow/modules/base.py @@ -18,13 +18,17 @@ def to_json(self): pr[name] = value return pr - def post(self, path, param): - res = self.rag.post(path,param) + res = self.rag.post(path, param) return res - def get(self, path, params=''): - res = self.rag.get(path,params) + def get(self, path, params): + res = self.rag.get(path, params) return res + def rm(self, path, params): + res = self.rag.delete(path, params) + return res + def __str__(self): + return str(self.to_json()) diff --git a/sdk/python/ragflow/modules/dataset.py b/sdk/python/ragflow/modules/dataset.py index 7689cf7fe0..753dbaa8b7 100644 --- a/sdk/python/ragflow/modules/dataset.py +++ b/sdk/python/ragflow/modules/dataset.py @@ -21,18 +21,36 @@ def __init__(self, rag, res_dict): self.permission = "me" self.document_count = 0 self.chunk_count = 0 - self.parser_method = "naive" + self.parse_method = "naive" self.parser_config = None + for k in list(res_dict.keys()): + if k == "embd_id": + res_dict["embedding_model"] = res_dict[k] + if k == "parser_id": + res_dict['parse_method'] = res_dict[k] + if k == "doc_num": + res_dict["document_count"] = res_dict[k] + if k == "chunk_num": + res_dict["chunk_count"] = res_dict[k] + if k not in self.__dict__: + res_dict.pop(k) super().__init__(rag, res_dict) - def save(self): + def save(self) -> bool: res = self.post('/dataset/save', {"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id, "description": self.description, "language": self.language, "embd_id": self.embedding_model, "permission": self.permission, - "doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parser_method, + "doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parse_method, "parser_config": self.parser_config.to_json() }) res = res.json() - if not res.get("retmsg"): return True - raise Exception(res["retmsg"]) \ No newline at end of file + if res.get("retmsg") == "success": return True + raise Exception(res["retmsg"]) + + def delete(self) -> bool: + res = self.rm('/dataset/delete', + {"id": self.id}) + res = res.json() + if res.get("retmsg") == "success": return True + raise Exception(res["retmsg"]) diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index ff3dba7da3..f7a238834a 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + import requests from .modules.dataset import DataSet @@ -25,30 +27,54 @@ def __init__(self, user_key, base_url, version='v1'): """ self.user_key = user_key self.api_url = f"{base_url}/api/{version}" - self.authorization_header = {"Authorization": "{} {}".format("Bearer",self.user_key)} + self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)} def post(self, path, param): res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header) return res - def get(self, path, params=''): - res = requests.get(self.api_url + path, params=params, headers=self.authorization_header) + def get(self, path, params=None): + res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header) + return res + + def delete(self, path, params): + res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header) return res - def create_dataset(self, name:str,avatar:str="",description:str="",language:str="English",permission:str="me", - document_count:int=0,chunk_count:int=0,parser_method:str="naive", - parser_config:DataSet.ParserConfig=None): + def create_dataset(self, name: str, avatar: str = "", description: str = "", language: str = "English", + permission: str = "me", + document_count: int = 0, chunk_count: int = 0, parse_method: str = "naive", + parser_config: DataSet.ParserConfig = None) -> DataSet: if parser_config is None: - parser_config = DataSet.ParserConfig(self, {"chunk_token_count":128,"layout_recognize": True, "delimiter":"\n!?。;!?","task_page_size":12}) - parser_config=parser_config.to_json() - res=self.post("/dataset/save",{"name":name,"avatar":avatar,"description":description,"language":language,"permission":permission, - "doc_num": document_count,"chunk_num":chunk_count,"parser_id":parser_method, - "parser_config":parser_config - } - ) + parser_config = DataSet.ParserConfig(self, {"chunk_token_count": 128, "layout_recognize": True, + "delimiter": "\n!?。;!?", "task_page_size": 12}) + parser_config = parser_config.to_json() + res = self.post("/dataset/save", + {"name": name, "avatar": avatar, "description": description, "language": language, + "permission": permission, + "doc_num": document_count, "chunk_num": chunk_count, "parser_id": parse_method, + "parser_config": parser_config + } + ) res = res.json() - if not res.get("retmsg"): + if res.get("retmsg") == "success": return DataSet(self, res["data"]) raise Exception(res["retmsg"]) + def list_datasets(self, page: int = 1, page_size: int = 150, orderby: str = "create_time", desc: bool = True) -> \ + List[DataSet]: + res = self.get("/dataset/list", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) + res = res.json() + result_list = [] + if res.get("retmsg") == "success": + for data in res['data']: + result_list.append(DataSet(self, data)) + return result_list + raise Exception(res["retmsg"]) + def get_dataset(self, id: str = None, name: str = None) -> DataSet: + res = self.get("/dataset/detail", {"id": id, "name": name}) + res = res.json() + if res.get("retmsg") == "success": + return DataSet(self, res['data']) + raise Exception(res["retmsg"]) diff --git a/sdk/python/test/t_dataset.py b/sdk/python/test/t_dataset.py index 1466233a19..eddae95ac0 100644 --- a/sdk/python/test/t_dataset.py +++ b/sdk/python/test/t_dataset.py @@ -7,7 +7,7 @@ class TestDataset(TestSdk): def test_create_dataset_with_success(self): """ - Test creating dataset with success + Test creating a dataset with success """ rag = RAGFlow(API_KEY, HOST_ADDRESS) ds = rag.create_dataset("God") @@ -18,15 +18,46 @@ def test_create_dataset_with_success(self): def test_update_dataset_with_success(self): """ - Test updating dataset with success. + Test updating a dataset with success. """ rag = RAGFlow(API_KEY, HOST_ADDRESS) ds = rag.create_dataset("ABC") if isinstance(ds, DataSet): - assert ds.name == "ABC", "Name does not match." + assert ds.name == "ABC", "Name does not match." ds.name = 'DEF' res = ds.save() - assert res is True, f"Failed to update dataset, error: {res}" + assert res is True, f"Failed to update dataset, error: {res}" + else: + assert False, f"Failed to create dataset, error: {ds}" + def test_delete_dataset_with_success(self): + """ + Test deleting a dataset with success + """ + rag = RAGFlow(API_KEY, HOST_ADDRESS) + ds = rag.create_dataset("MA") + if isinstance(ds, DataSet): + assert ds.name == "MA", "Name does not match." + res = ds.delete() + assert res is True, f"Failed to delete dataset, error: {res}" else: - assert False, f"Failed to create dataset, error: {ds}" \ No newline at end of file + assert False, f"Failed to create dataset, error: {ds}" + + def test_list_datasets_with_success(self): + """ + Test listing datasets with success + """ + rag = RAGFlow(API_KEY, HOST_ADDRESS) + list_datasets = rag.list_datasets() + assert len(list_datasets) > 0, "Do not exist any dataset" + for ds in list_datasets: + assert isinstance(ds, DataSet), "Existence type is not dataset." + + def test_get_detail_dataset_with_success(self): + """ + Test getting a dataset's detail with success + """ + rag = RAGFlow(API_KEY, HOST_ADDRESS) + ds = rag.get_dataset(name="God") + assert isinstance(ds, DataSet), f"Failed to get dataset, error: {ds}." + assert ds.name == "God", "Name does not match"