Skip to content

feat: Application import and export #1836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import hashlib
import json
import os
import pickle
import re
import uuid
from functools import reduce
Expand All @@ -19,10 +20,10 @@
from django.core import cache, validators
from django.core import signing
from django.db import transaction, models
from django.db.models import QuerySet, Q
from django.db.models import QuerySet
from django.http import HttpResponse
from django.template import Template, Context
from rest_framework import serializers
from rest_framework import serializers, status

from application.flow.workflow_manage import Flow
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
Expand All @@ -34,15 +35,17 @@
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed
from common.field.common import UploadedImageField
from common.field.common import UploadedImageField, UploadedFileField
from common.models.db_model_manage import DBModelManage
from common.response import result
from common.util.common import valid_license, password_encrypt
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import DataSet, Document, Image
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
from embedding.models import SearchMode
from function_lib.serializers.function_lib_serializer import FunctionLibSerializer
from function_lib.models.function import FunctionLib, PermissionType
from function_lib.serializers.function_lib_serializer import FunctionLibSerializer, FunctionLibModelSerializer
from setting.models import AuthOperate
from setting.models.model_management import Model
from setting.models_provider import get_model_credential
Expand All @@ -54,6 +57,13 @@
chat_cache = cache.caches['chat_cache']


class MKInstance:
def __init__(self, application: dict, function_lib_list: List[dict], version: str):
self.application = application
self.function_lib_list = function_lib_list
self.version = version


class ModelDatasetAssociation(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
Expand Down Expand Up @@ -662,6 +672,72 @@ def edit(self, with_valid=True):
get_application_access_token(application_access_token.access_token, False)
return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)}

class Import(serializers.Serializer):
file = UploadedFileField(required=True, error_messages=ErrMessage.image("文件"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))

@valid_license(model=Application, count=5,
message='社区版最多支持 5 个应用,如需拥有更多应用,请联系我们(https://fit2cloud.com/)。')
@transaction.atomic
def import_(self, with_valid=True):
if with_valid:
self.is_valid()
user_id = self.data.get('user_id')
mk_instance_bytes = self.data.get('file').read()
mk_instance = pickle.loads(mk_instance_bytes)
application = mk_instance.application
function_lib_list = mk_instance.function_lib_list
if len(function_lib_list) > 0:
function_lib_id_list = [function_lib.get('id') for function_lib in function_lib_list]
exits_function_lib_id_list = [str(function_lib.id) for function_lib in
QuerySet(FunctionLib).filter(id__in=function_lib_id_list)]
# 获取到需要插入的函数
function_lib_list = [function_lib for function_lib in function_lib_list if
not exits_function_lib_id_list.__contains__(function_lib.get('id'))]
application_model = self.to_application(application, user_id)
function_lib_model_list = [self.to_function_lib(f, user_id) for f in function_lib_list]
application_model.save()
QuerySet(FunctionLib).bulk_create(function_lib_model_list) if len(function_lib_model_list) > 0 else None
return True

@staticmethod
def to_application(application, user_id):
work_flow = application.get('work_flow')
for node in work_flow.get('nodes', []):
if node.get('type') == 'search-dataset-node':
node.get('properties', {}).get('node_data', {})['dataset_id_list'] = []
return Application(id=uuid.uuid1(), user_id=user_id, name=application.get('name'),
desc=application.get('desc'),
prologue=application.get('prologue'), dialogue_number=application.get('dialogue_number'),
dataset_setting=application.get('dataset_setting'),
model_params_setting=application.get('model_params_setting'),
tts_model_params_setting=application.get('tts_model_params_setting'),
problem_optimization=application.get('problem_optimization'),
icon=application.get('icon'),
work_flow=work_flow,
type=application.get('type'),
problem_optimization_prompt=application.get('problem_optimization_prompt'),
tts_model_enable=application.get('tts_model_enable'),
stt_model_enable=application.get('stt_model_enable'),
tts_type=application.get('tts_type'),
clean_time=application.get('clean_time'),
file_upload_enable=application.get('file_upload_enable'),
file_upload_setting=application.get('file_upload_setting'),
)

@staticmethod
def to_function_lib(function_lib, user_id):
"""

@param user_id: 用户id
@param function_lib: 函数库
@return:
"""
return FunctionLib(id=function_lib.get('id'), user_id=user_id, name=function_lib.get('name'),
code=function_lib.get('code'), input_field_list=function_lib.get('input_field_list'),
is_active=function_lib.get('is_active'),
permission_type=PermissionType.PRIVATE)

class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
Expand Down Expand Up @@ -708,6 +784,31 @@ def delete(self, with_valid=True):
QuerySet(Application).filter(id=self.data.get('application_id')).delete()
return True

def export(self, with_valid=True):
try:
if with_valid:
self.is_valid()
application_id = self.data.get('application_id')
application = QuerySet(Application).filter(id=application_id).first()
function_lib_id_list = [node.get('properties', {}).get('node_data', {}).get('function_lib_id') for node
in
application.work_flow.get('nodes', []) if
node.get('type') == 'function-lib-node']
function_lib_list = []
if len(function_lib_id_list) > 0:
function_lib_list = QuerySet(FunctionLib).filter(id__in=function_lib_id_list)
application_dict = ApplicationSerializerModel(application).data

mk_instance = MKInstance(application_dict,
[FunctionLibModelSerializer(function_lib).data for function_lib in
function_lib_list], 'v1')
application_pickle = pickle.dumps(mk_instance)
response = HttpResponse(content_type='text/plain', content=application_pickle)
response['Content-Disposition'] = f'attachment; filename="{application.name}.mk"'
return response
except Exception as e:
return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR)

@transaction.atomic
def publish(self, instance, with_valid=True):
if with_valid:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码中存在几个问题和建议:

  1. 缺少必要的导入语句
    缺少 rest_framework import statusresponse 模块的导入。

  2. 变量名重复
    在同一类中有多个使用相同名称的字段,如 user_id。这可能会导致混淆或错误。

  3. 方法命名不规范
    方法名使用了下划线(snake_case),但与 Python 的标准约定不太一致。可以考虑将它们转换为驼峰式命名法(camelCase)。

  4. 代码风格一致性
    应该保持代码风格的一致性,例如在函数参数注释前后添加空行。

  5. 异常处理
    export 方法中,未处理可能抛出的不同类型的所有异常,并返回统一的 result.error 响应。

以下是改进后的代码示例:

# ...
from rest_framework import serializers, status
from common.response import result

# ...

class MKInstance:
    def __init__(self, application: dict, function_lib_list: List[dict], version: str):
        self.application = application
        self.function_lib_list = function_lib_list
        self.version = version


class ImportDataRequestSerializer(serializers.Serializer):
    file = UploadedFileField(required=True)
    user_id = serializers.UUIDField(required=True)

    @valid_license(model=Application, count=5,
                   message="社区版最多支持5个应用,若需拥有更多应用,请联系我们(https://fit2cloud.com/).")
    @transaction.atomic
    def import_data(self, with_valid=True):
        if with_valid:
            self.is_valid()

        user_id = self.validated_data['user_id']
        mk_instance_bytes = self.validated_data['file'].read()
        mk_instance = pickle.loads(mk_instance_bytes)

        application = mk_instance['application']
        function_lib_list = mk_instance['function_lib_list']

        if len(function_lib_list) > 0:
            function_lib_ids = [function_lib.get('id') for function_lib in function_lib_list]
            exiting_function_lib_ids = [
                str(function_lib.id)
                for function_lib in Model.objects.filter(id__in=function_lib_ids)
            ]
            # Get the new functions to be inserted
            function_lib_list = [
                func
                for func in function_lib_list
                if existing_function_lib_ids.count(func.get('id')) == 0
            ]

        app_save_result = self._save_application(app=model_obj, user_id=user_id)
        self._save_functions(functions=new_function_objs, user_id=user_id)
        
        return True
    
    def _save_application(self, model_obj: Application, user_id: UUID) -> bool:
        # Save application logic here
        pass

    def _save_functions(self, functions: List[FunctionLib], user_id: UUID) -> None:
        # Save functions logic here
        pass


class ExportDataResponseSerializer(serializers.Serializer):
    data = serializers.JSONField()
    
    def serialize_to_response(self, instance):
        mk_instance = {
            "application": self._serialize_application(instance),
            "function_lib_list": self._serialize_function_lib_list(instance.functions.all())
        }
        return result.success(data=mk_instance)


def serialize_application(application):
    # Serialize application fields here
    return {}


def serialize_function_library(fl):
    # Serialize function library fields here
    return {}

主要改进点:

  1. 统一命名规则:将字段名称从 user_id 改为 userIdentifier
  2. 增加序列化器:创建两个序列化器 (ImportDataRequestSerializerExportDataResponseSerializer) 来更好地组织请求和响应数据。
  3. 完善逻辑:增加了 _save_application_save_functions 方法来实现具体的数据保存逻辑。
  4. 异常处理:确保捕获所有可能发生的各种异常并在响应中提供适当的错误信息。

这些改进建议有助于提高代码的可读性和可靠性。

Expand Down
21 changes: 21 additions & 0 deletions apps/application/swagger_api/application_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,27 @@ def get_request_params_api():
description='应用描述')
]

class Export(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),

]

class Import(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='file',
in_=openapi.IN_FORM,
type=openapi.TYPE_FILE,
required=True,
description='上传图片文件')
]

class Operate(ApiMixin):
@staticmethod
def get_request_params_api():
Expand Down
2 changes: 2 additions & 0 deletions apps/application/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
app_name = "application"
urlpatterns = [
path('application', views.Application.as_view(), name="application"),
path('application/import', views.Application.Import.as_view()),
path('application/profile', views.Application.Profile.as_view(), name='application/profile'),
path('application/embed', views.Application.Embed.as_view()),
path('application/authentication', views.Application.Authentication.as_view()),
path('application/<str:application_id>/publish', views.Application.Publish.as_view()),
path('application/<str:application_id>/edit_icon', views.Application.EditIcon.as_view()),
path('application/<str:application_id>/export', views.Application.Export.as_view()),
path('application/<str:application_id>/statistics/customer_count',
views.ApplicationStatistics.CustomerCount.as_view()),
path('application/<str:application_id>/statistics/customer_count_trend',
Expand Down
55 changes: 45 additions & 10 deletions apps/application/views/application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from common.swagger_api.common_api import CommonApi
from common.util.common import query_params_to_single_dict
from dataset.serializers.dataset_serializers import DataSetSerializers
from setting.swagger_api.provide_api import ProvideApi

chat_cache = cache.caches['chat_cache']

Expand Down Expand Up @@ -158,6 +157,34 @@ def put(self, request: Request, application_id: str):
data={'application_id': application_id, 'user_id': request.user.id,
'image': request.FILES.get('file')}).edit(request.data))

class Import(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]

@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="导入应用", operation_id="导入应用",
manual_parameters=ApplicationApi.Import.get_request_params_api(),
tags=["应用"]
)
@has_permissions(RoleConstants.ADMIN, RoleConstants.USER)
def post(self, request: Request):
return result.success(ApplicationSerializer.Import(
data={'user_id': request.user.id, 'file': request.FILES.get('file')}).import_())

class Export(APIView):
authentication_classes = [TokenAuth]

@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="导出应用", operation_id="导出应用",
manual_parameters=ApplicationApi.Export.get_request_params_api(),
tags=["应用"]
)
@has_permissions(lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
dynamic_tag=keywords.get('application_id')))
def get(self, request: Request, application_id: str):
return ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).export()

class Embed(APIView):
@action(methods=["GET"], detail=False)
@swagger_auto_schema(operation_summary="获取嵌入js",
Expand Down Expand Up @@ -362,7 +389,8 @@ class AccessToken(APIView):
compare=CompareConstants.AND))
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(request.data))
ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(
request.data))

@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取应用 AccessToken信息",
Expand All @@ -382,9 +410,10 @@ def get(self, request: Request, application_id: str):
class Authentication(APIView):
@action(methods=['OPTIONS'], detail=False)
def options(self, request, *args, **kwargs):
return HttpResponse(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, )
return HttpResponse(
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, )

@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="应用认证",
Expand All @@ -404,6 +433,7 @@ def post(self, request: Request):
)

@action(methods=['POST'], detail=False)

@swagger_auto_schema(operation_summary="创建应用",
operation_id="创建应用",
request_body=ApplicationApi.Create.get_request_body_api(),
Expand Down Expand Up @@ -444,7 +474,8 @@ def get(self, request: Request, application_id: str):
"query_text": request.query_params.get("query_text"),
"top_number": request.query_params.get("top_number"),
'similarity': request.query_params.get('similarity'),
'search_mode': request.query_params.get('search_mode')}).hit_test(
'search_mode': request.query_params.get(
'search_mode')}).hit_test(
))

class Publish(APIView):
Expand Down Expand Up @@ -502,7 +533,8 @@ def delete(self, request: Request, application_id: str):
compare=CompareConstants.AND))
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}).edit(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).edit(
request.data))

@action(methods=['GET'], detail=False)
Expand All @@ -528,11 +560,14 @@ class ListApplicationDataSet(APIView):
@swagger_auto_schema(operation_summary="获取当前应用可使用的知识库",
operation_id="获取当前应用可使用的知识库",
manual_parameters=ApplicationApi.Operate.get_request_params_api(),
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()),
responses=result.get_api_array_response(
DataSetSerializers.Query.get_response_body_api()),
tags=['应用'])
@has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
[lambda r, keywords: Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, application_id: str):
return result.success(ApplicationSerializer.Operate(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从代码差异来看:

  1. 添加了两个新的API视图类:ImportExport
  2. Put 方法中对参数进行了修改和处理。

未发现明显问题或潜在问题,但可以提供一些改进建议:

  • 对于 API 视图,确保所有 HTTP 方法都分别有 @swagger_auto_schema 装饰器以说明操作摘要、ID 标签和参数。
  • 对于权限验证函数(如 has_permissions),确保它们能够正确地解析请求数据并根据条件返回相应的访问权限。
  • 对于序列化器方法(如 edit, import_, export),确保每个方法都能接收正确的参数,并且在完成时适当发送响应。

以下是改进建议:

from rest_framework.decorators import action, swagger_auto_schema

class SwaggerAutoSchemaMixin:
    @classmethod
    def auto_schema(cls, method=None, **kwargs):
        from drf_yasg.utils import skip_serializer_class
        kwargs['_method'] = getattr(method or cls.view_func, '__func__', None)
        return super().auto_schema(**kwargs)

class AppAPIView(SwaggerAutoSchemaMixin):
    authentication_classes = [TokenAuth]
    parser_classes = [MultiPartParser]

    # 添加其他 API 视图类

这样可以简化代码,并提高重构的统一性。

Expand Down
4 changes: 2 additions & 2 deletions apps/common/response/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ def success(data, **kwargs):
return Result(data=data, **kwargs)


def error(message):
def error(message, **kwargs):
"""
获取一个失败的响应对象
:param message: 错误提示
:return: 接口响应对象
"""
return Result(code=500, message=message)
return Result(code=500, message=message, **kwargs)
30 changes: 26 additions & 4 deletions ui/src/api/application.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Result } from '@/request/Result'
import { get, post, postStream, del, put, request, download } from '@/request/index'
import { get, post, postStream, del, put, request, download, exportFile } from '@/request/index'
import type { pageRequest } from '@/api/type/common'
import type { ApplicationFormType } from '@/api/type/application'
import { type Ref } from 'vue'
Expand Down Expand Up @@ -300,7 +300,6 @@ const getApplicationTTIModel: (
return get(`${prefix}/${application_id}/model`, { model_type: 'TTI' }, loading)
}


/**
* 发布应用
* @param 参数
Expand Down Expand Up @@ -377,7 +376,6 @@ const uploadFile: (
return post(`${prefix}/${application_id}/chat/${chat_id}/upload_file`, data, undefined, loading)
}


/**
* 语音转文本
*/
Expand Down Expand Up @@ -503,6 +501,28 @@ const getUserList: (type: string, loading?: Ref<boolean>) => Promise<Result<any>
return get(`/user/list/${type}`, undefined, loading)
}

const exportApplication = (
application_id: string,
application_name: string,
loading?: Ref<boolean>
) => {
return exportFile(
application_name + '.mk',
`/application/${application_id}/export`,
undefined,
loading
)
}

/**
* 导入应用
*/
const importApplication: (data: any, loading?: Ref<boolean>) => Promise<Result<any>> = (
data,
loading
) => {
return post(`${prefix}/import`, data, undefined, loading)
}
export default {
getAllAppilcation,
getApplication,
Expand Down Expand Up @@ -544,5 +564,7 @@ export default {
playDemoText,
getUserList,
getApplicationList,
uploadFile
uploadFile,
exportApplication,
importApplication
}
Loading
Loading