Skip to content

feat: Knowledge base generation problem #2760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions apps/dataset/serializers/common_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,26 @@ def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List):
if len(dataset_list) == 0:
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
return str(dataset_list[0].embedding_mode_id)


class GenerateRelatedSerializer(ApiMixin, serializers.Serializer):
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Model id')))
prompt = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_('Prompt word')))
state_list = serializers.ListField(required=False, child=serializers.CharField(required=True),
error_messages=ErrMessage.list("state list"))

@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'model_id': openapi.Schema(type=openapi.TYPE_STRING,
title=_('Model id'),
description=_('Model id')),
'prompt': openapi.Schema(type=openapi.TYPE_STRING, title=_('Prompt word'),
description=_("Prompt word")),
'state_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title=_('state list'))
}
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code seems to be well-written and follows Python best practices. However, there are a few minor improvements that can be made:

  1. Type Hinting: The use of List from the typing module is recommended for better readability and clarity.

  2. Error Messages: Ensure that the error messages are translated properly using _ (gettext), as they might need to be checked against your gettext translation files.

  3. API Documentation: The generate_related_serializer.get_request_body_api() method returns an OpenAPI schema, which is useful for API documentation but does not directly affect the logic. It should be documented within comments if necessary.

    # Returns an OpenAPI schema representation of the request body for generate-related requests.

Here's the updated version with some additional comments:

@@ -222,3 +222,27 @@ def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List[str]:
     """Retrieve the embedding mode ID based on dataset IDs."""
     if len(dataset_list) == 0:
         raise Exception(_(u"Knowledge base setting error, please reset the knowledge base"))
     return str(dataset_list[0].embedding_mode_id)

+
+import typing
+from rest_framework import serializers
+from drf_yasg.utils import swagger_auto_schema
+from core.constants import ErrMessage
+
+class GenerateRelatedSerializer(ApiMixin, serializers.Serializer):
+    """
+    Serializer class for the related generation API endpoint.
+    
+    Fields:
+       model_id (UUIDField): Required UUID representing the AI model ID.
+       prompt (CharField): Required string prompting the model to generate text.
+       state_list (ListField): Optional list of strings indicating states for further processing.
+       
+    Methods:
+       get_request_body_api() -> openapi.Schema: Returns an OpenAPI Schema instance describing the request format.
+    """
+
+    model_id = serializers.UUIDField(required=True,
+                                      error_messages=ErrMessage.uuid(_("Model id")))
         
+    prompt = serializers.CharField(required=True,
+                                   error_messages=ErrMessage.uuid(_("Prompt word")))
     
     

Summary of Recommendations:

  • Use List[str] instead of just List.
  • Ensure proper translation of error messages.
  • Add docstrings to explain the purpose and usage of the methods and fields.

31 changes: 29 additions & 2 deletions apps/dataset/serializers/dataset_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from django.core import validators
from django.db import transaction, models
from django.db.models import QuerySet
from django.db.models.functions import Reverse, Substr
from django.http import HttpResponse
from drf_yasg import openapi
from rest_framework import serializers
Expand All @@ -42,9 +43,10 @@
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, TaskType, \
State, File, Image
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir, \
GenerateRelatedSerializer
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from dataset.task import sync_web_dataset, sync_replace_web_dataset
from dataset.task import sync_web_dataset, sync_replace_web_dataset, generate_related_by_dataset_id
from embedding.models import SearchMode
from embedding.task import embedding_by_dataset, delete_embedding_by_dataset
from setting.models import AuthOperate, Model
Expand Down Expand Up @@ -814,6 +816,31 @@ def re_embedding(self, with_valid=True):
except AlreadyQueued as e:
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))

def generate_related(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
GenerateRelatedSerializer(data=instance).is_valid(raise_exception=True)
dataset_id = self.data.get('id')
model_id = instance.get("model_id")
prompt = instance.get("prompt")
state_list = instance.get('state_list')
ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=dataset_id),
TaskType.GENERATE_PROBLEM,
State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
1),
).filter(task_type_status__in=state_list, dataset_id=dataset_id)
.values('id'),
TaskType.GENERATE_PROBLEM,
State.PENDING)
ListenerManagement.get_aggregation_document_status_by_dataset_id(dataset_id)()
try:
generate_related_by_dataset_id.delay(dataset_id, model_id, prompt, state_list)
except AlreadyQueued as e:
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))

def list_application(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code you provided contains some general improvements and optimizations that can be made:

  1. Imports: The Import statement for QuerySet, Reverse, Substr should be at the top of each function where it is used.

  2. Docstring Style: It's good practice to have docstrings for classes and methods, especially when they're complex or perform multiple actions. Ensure all functions have clear explanations.

  3. Function Separation: You might consider separating the generate_related and other related logic into separate modules or within a more logical class hierarchy.

  4. Error Handling: The exception handling in re_embedding already uses AlreadyQueued. If there are any specific exceptions or error messages needed, consider adding them.

  5. Variable Naming: Ensure variable names are descriptive and adhere to PEP 8 standards. For instance, use state_list instead of 'state_list'.

Here’s an improved version of the relevant section of code with these suggestions:

from django.core import validators
from django.db import transaction, models
from django.db.models import QuerySet
from django.db.models.functions import Reverse, Substr

from drf_yasg import openapi
from rest_framework import serializers

from dataset.models import Dataset, Document, Paragraph, Problem, Type, ProblemParagraphMapping, TaskType, \
    State, File, Image
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
    get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from.dataset.task import sync_web_dataset, sync_replace_web_dataset, generate_related_by_dataset_id
from embedding.models import SearchMode
from embedding.task import embedding_by_dataset, delete_embedding_by_dataset
from setting.models import AuthOperate, Model


class YourClassName(serializers.Serializer):
    data = serializers.DictField(required=True)

    class Meta:
        fields = ('data')

    def re_embedding(self, with_valid=True):
        if with_valid:
            self.is_valid(raise_exception=True)
        dataset_id = self.data['id']
        
        try:
            # Your existing implementation here
        except AlreadyQueued as e:
            raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))

    def generate_related(self, with_valid=True):
        if with_valid:
            self.is_valid(raise_exception=True)
            GenerateRelatedSerializer.data = self.validated_data
        
        dataset_id = self.validated_data.get('id')
        model_id = self.validated_data.get("model_id")
        prompt = self.validated_data.get("prompt")
        state_list = [s.lower() for s in self.validated_data.get('state_list')]
        
        ListenerManagement.update_status(
            queryset=Document.objects.filter(dataset_id=dataset_id),
            task_type=TaskType.GENERATE_PROBLEM,
            new_state=State.PENDING
        )
        
        ListenerManagement.update_status(
            queryset=(
                Paragraph.objects.annotate(
                    status_reverse=Reverse('status'),
                    task_type_status=Substr('status_reverse', TaskType.GENERATE_PROBLEM.value, 1)
                )
                .filter(task_type_status='pending', dataset_id=dataset_id)
                .values('id')
            ),
            task_type=TaskType.GENERATE_PROBLEM,
            new_state=State.PENDING
        )

        ListenerManagement.get_aggregation_document_status_by_dataset_id(dataset_id)()
        
        try:
            generate_related_by_dataset_id.delay(dataset_id, model_id, prompt, tuple(state_list))
        except AlreadyQueued as e:
            raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))


    def list_application(self, with_valid=True):
        if with_valid:
            self.is_valid(raise_exception=True)

Make sure to adapt the class name (YourClassName) and adjust imports and method names according to the context of your application.

Expand Down
11 changes: 11 additions & 0 deletions apps/dataset/task/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ def is_the_task_interrupted():
return is_the_task_interrupted


@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']},
name='celery:generate_related_by_dataset')
def generate_related_by_dataset_id(dataset_id, model_id, prompt, state_list=None):
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
for document in document_list:
try:
generate_related_by_document_id.delay(document.id, model_id, prompt, state_list)
except Exception as e:
pass


@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
name='celery:generate_related_by_document')
def generate_related_by_document_id(document_id, model_id, prompt, state_list=None):
Expand Down
2 changes: 2 additions & 0 deletions apps/dataset/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
path('dataset/<str:dataset_id>/export', views.Dataset.Export.as_view(), name="export"),
path('dataset/<str:dataset_id>/export_zip', views.Dataset.ExportZip.as_view(), name="export_zip"),
path('dataset/<str:dataset_id>/re_embedding', views.Dataset.Embedding.as_view(), name="dataset_key"),
path('dataset/<str:dataset_id>/generate_related', views.Dataset.GenerateRelated.as_view(),
name="dataset_generate_related"),
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
path('dataset/<str:dataset_id>/sync_web', views.Dataset.SyncWeb.as_view()),
Expand Down
18 changes: 18 additions & 0 deletions apps/dataset/views/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from common.response import result
from common.response.result import get_page_request_params, get_page_api_response, get_api_response
from common.swagger_api.common_api import CommonApi
from dataset.serializers.common_serializers import GenerateRelatedSerializer
from dataset.serializers.dataset_serializers import DataSetSerializers
from dataset.views.common import get_dataset_operation_object
from setting.serializers.provider_serializers import ModelSerializer
Expand Down Expand Up @@ -173,6 +174,23 @@ def put(self, request: Request, dataset_id: str):
return result.success(
DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).re_embedding())

class GenerateRelated(APIView):
authentication_classes = [TokenAuth]

@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary=_('Generate related'), operation_id=_('Generate related'),
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
request_body=GenerateRelatedSerializer.get_request_body_api(),
tags=[_('Knowledge Base')]
)
@log(menu='document', operate="Generate related documents",
get_operation_object=lambda r, keywords: get_dataset_operation_object(keywords.get('dataset_id'))
)
def put(self, request: Request, dataset_id: str):
return result.success(
DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).generate_related(
request.data))

class Export(APIView):
authentication_classes = [TokenAuth]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code is already well-structured and does not seem to have any obvious issues. However, there are a few minor suggestions for optimization:

  1. Method Naming: Use more descriptive method names. For example, operate can be renamed to something like execute_operate.

  2. Docstrings: Ensure that all methods have clear docstrings explaining their purpose. This will help other developers understand the code better.

  3. Error Handling: Consider adding try-except blocks around API calls to handle exceptions gracefully.

  4. Logging: While current logging seems fine, consider whether you need more detailed logs or additional log levels if needed.

Here's an updated version of the code with some of these improvements:

from django.http import HttpResponse
from rest_framework.response import Response
from django.utils.decorators import method_decorator

# Import necessary modules here for improved readability

class DataSetView(APIView):
    # ... (rest of the codes remains unchanged)

class GenerateRelated(APIView):
    authentication_classes = [TokenAuth]

    @action(methods=['PUT'], detail=False)
    @swagger_auto_schema(
        operation_summary=_('Generate Related Documents'),
        operation_id=_('Generate Related Documents'),
        manual_parameters=[DataSetSerializers.Operate.get_request_params_api()],
        request_body=GenerateRelatedSerializer.get_request_body_api(),
        tags=[_('Knowledge Base')]
    )
    @log(menu='document', operate="Generate related documents", 
         get_operation_object=lambda r, keywords: get_dataset_operation_object(keywords.get('dataset_id')))
    def put(self, request: Request, dataset_id: str) -> HttpResponse:
        data = {
            'id': dataset_id,
            'user_id': request.user.id
        }
        
        try:
            return response.success(DataSetSerializers.Execute(data=data).generate_related(request.data))
        except Exception as e:
            logger.error(f"Failed to generate related documents for dataset {dataset_id}: {e}")
            return HttpResponse("An error occurred while generating related documents.", status=500)

class Export(APIView):
    authentication_classes = [TokenAuth]
    
    # ... (rest of the codes remains unchanged)

These changes aim to make the code cleaner and potentially easier to maintain or update in the future. Let me know if further assistance is needed!

Expand Down
3 changes: 3 additions & 0 deletions apps/locales/en_US/LC_MESSAGES/django.po
Original file line number Diff line number Diff line change
Expand Up @@ -7487,4 +7487,7 @@ msgid "Field: {name} Type: {_type} Value: {value} Unsupported types"
msgstr ""

msgid "Field: {name} No value set"
msgstr ""

msgid "Generate related"
msgstr ""
5 changes: 4 additions & 1 deletion apps/locales/zh_CN/LC_MESSAGES/django.po
Original file line number Diff line number Diff line change
Expand Up @@ -7650,4 +7650,7 @@ msgid "Field: {name} Type: {_type} Value: {value} Unsupported types"
msgstr "字段: {name} 类型: {_type} 值: {value} 不支持的类型"

msgid "Field: {name} No value set"
msgstr "字段: {name} 未设置值"
msgstr "字段: {name} 未设置值"

msgid "Generate related"
msgstr "生成问题"
5 changes: 4 additions & 1 deletion apps/locales/zh_Hant/LC_MESSAGES/django.po
Original file line number Diff line number Diff line change
Expand Up @@ -7660,4 +7660,7 @@ msgid "Field: {name} Type: {_type} Value: {value} Unsupported types"
msgstr "欄位: {name} 類型: {_type} 值: {value} 不支持的類型"

msgid "Field: {name} No value set"
msgstr "欄位: {name} 未設定值"
msgstr "欄位: {name} 未設定值"

msgid "Generate related"
msgstr "生成問題"
17 changes: 16 additions & 1 deletion ui/src/api/dataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,20 @@ const importLarkDocument: (
) => Promise<Result<Array<any>>> = (dataset_id, data, loading) => {
return post(`${prefix}/lark/${dataset_id}/import`, data, null, loading)
}
/**
* 生成关联问题
* @param dataset_id 知识库id
* @param data
* @param loading
* @returns
*/
const generateRelated: (
dataset_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (dataset_id, data, loading) => {
return put(`${prefix}/${dataset_id}/generate_related`, data, null, loading)
}

export default {
getDataset,
Expand All @@ -297,5 +311,6 @@ export default {
postLarkDataset,
getLarkDocumentList,
importLarkDocument,
putLarkDataset
putLarkDataset,
generateRelated
}
17 changes: 14 additions & 3 deletions ui/src/components/generate-related-dialog/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
/>
</el-form-item>
<el-form-item
v-if="apiType === 'document'"
v-if="['document', 'dataset'].includes(apiType)"
:label="$t('components.selectParagraph.title')"
prop="state"
>
Expand Down Expand Up @@ -107,6 +107,7 @@ const stateMap = {
error: ['0', '1', '3', '4', '5', 'n']
}
const FormRef = ref()
const datasetId = ref<string>()
const userId = user.userInfo?.id as string
const form = ref(prompt.get(userId))
const rules = reactive({
Expand All @@ -133,7 +134,8 @@ watch(dialogVisible, (bool) => {
}
})

const open = (ids: string[], type: string) => {
const open = (ids: string[], type: string, _datasetId?: string) => {
datasetId.value = _datasetId
getModel()
idList.value = ids
apiType.value = type
Expand Down Expand Up @@ -169,6 +171,15 @@ const submitHandle = async (formEl: FormInstance) => {
emit('refresh')
dialogVisible.value = false
})
} else if (apiType.value === 'dataset') {
const data = {
...form.value,
state_list: stateMap[state.value]
}
datasetApi.generateRelated(id ? id : datasetId.value, data, loading).then(() => {
MsgSuccess(t('views.document.generateQuestion.successMessage'))
dialogVisible.value = false
})
}
}
})
Expand All @@ -177,7 +188,7 @@ const submitHandle = async (formEl: FormInstance) => {
function getModel() {
loading.value = true
datasetApi
.getDatasetModel(id)
.getDatasetModel(id ? id : datasetId.value)
.then((res: any) => {
modelOptions.value = groupBy(res?.data, 'provider')
loading.value = false
Expand Down
17 changes: 15 additions & 2 deletions ui/src/views/dataset/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,19 @@
v-if="item.type === '1'"
>{{ $t('views.dataset.setting.sync') }}</el-dropdown-item
>

<el-dropdown-item @click="reEmbeddingDataset(item)">
<AppIcon
iconName="app-document-refresh"
style="font-size: 16px"
></AppIcon>
{{ $t('views.dataset.setting.vectorization') }}</el-dropdown-item
>
<el-dropdown-item
icon="Connection"
@click.stop="openGenerateDialog(item)"
>{{ $t('views.document.generateQuestion.title') }}</el-dropdown-item
>
<el-dropdown-item
icon="Setting"
@click.stop="router.push({ path: `/dataset/${item.id}/setting` })"
Expand Down Expand Up @@ -165,10 +171,11 @@
</div>
<SyncWebDialog ref="SyncWebDialogRef" @refresh="refresh" />
<CreateDatasetDialog ref="CreateDatasetDialogRef" />
<GenerateRelatedDialog ref="GenerateRelatedDialogRef" />
</div>
</template>
<script setup lang="ts">
import { ref, onMounted, reactive, computed } from 'vue'
import { ref, onMounted, reactive } from 'vue'
import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue'
import CreateDatasetDialog from './component/CreateDatasetDialog.vue'
import datasetApi from '@/api/dataset'
Expand All @@ -179,7 +186,7 @@ import { ValidType, ValidCount } from '@/enums/common'
import { t } from '@/locales'
import useStore from '@/stores'
import applicationApi from '@/api/application'

import GenerateRelatedDialog from '@/components/generate-related-dialog/index.vue'
const { user, common } = useStore()
const router = useRouter()

Expand All @@ -192,6 +199,12 @@ const paginationConfig = reactive({
page_size: 30,
total: 0
})
const GenerateRelatedDialogRef = ref<InstanceType<typeof GenerateRelatedDialog>>()
function openGenerateDialog(row: any) {
if (GenerateRelatedDialogRef.value) {
GenerateRelatedDialogRef.value.open([], 'dataset', row.id)
}
}

const searchValue = ref('')

Expand Down