Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
yqwang96 authored Oct 22, 2024
2 parents e9e8540 + adb0a93 commit d9caaa3
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 73 deletions.
23 changes: 11 additions & 12 deletions agent/component/crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
from crawl4ai import AsyncWebCrawler
from agent.component.base import ComponentBase, ComponentParamBase


class CrawlerParam(ComponentParamBase):
"""
Define the Crawler component parameters.
"""

def __init__(self):
super().__init__()
self.proxy = None
self.extract_type = "markdown"

def check(self):
return True
self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content'])


class Crawler(ComponentBase, ABC):
Expand All @@ -46,7 +49,6 @@ def _run(self, history, **kwargs):
except Exception as e:
return Crawler.be_output(f"An unexpected error occurred: {str(e)}")


async def get_web(self, url):
proxy = self._param.proxy if self._param.proxy else None
async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler:
Expand All @@ -55,16 +57,13 @@ async def get_web(self, url):
bypass_cache=True
)

match self._param.extract_type:
case 'html':
return result.cleaned_html
case 'markdown':
return result.markdown
case 'content':
return result.extracted_content
case _:
return result.markdown
# print(result.markdown)
if self._param.extract_type == 'html':
return result.cleaned_html
elif self._param.extract_type == 'markdown':
return result.markdown
elif self._param.extract_type == 'content':
result.extracted_content
return result.markdown



Expand Down
84 changes: 84 additions & 0 deletions agent/component/invoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from abc import ABC

import requests

from agent.component.base import ComponentBase, ComponentParamBase


class InvokeParam(ComponentParamBase):
"""
Define the Crawler component parameters.
"""

def __init__(self):
super().__init__()
self.proxy = None
self.headers = ""
self.method = "get"
self.variables = []
self.url = ""
self.timeout = 60

def check(self):
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
self.check_empty(self.url, "End point URL")
self.check_positive_integer(self.timeout, "Timeout time in second")


class Invoke(ComponentBase, ABC):
component_name = "Invoke"

def _run(self, history, **kwargs):
args = {}
for para in self._param.variables:
if para.get("component_id"):
cpn = self._canvas.get_component(para["component_id"])["obj"]
_, out = cpn.output(allow_partial=False)
args[para["key"]] = "\n".join(out["content"])
else:
args[para["key"]] = "\n".join(para["value"])

url = self._param.url.strip()
if url.find("http") != 0:
url = "http://" + url

method = self._param.method.lower()
headers = {}
if self._param.headers:
headers = json.loads(self._param.headers)
proxies = None
if self._param.proxy:
proxies = {"http": self._param.proxy, "https": self._param.proxy}

if method == 'get':
response = requests.get(url=url,
params=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
return Invoke.be_output(response.text)

if method == 'put':
response = requests.put(url=url,
data=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)

return Invoke.be_output(response.text)
3 changes: 1 addition & 2 deletions api/apps/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from api.db.db_models import APIToken, Task, File
from api.db.services import duplicate_name
from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.dialog_service import DialogService, chat
from api.db.services.dialog_service import DialogService, chat, keyword_extraction
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
Expand All @@ -38,7 +38,6 @@
generate_confirmation_token

from api.utils.file_utils import filename_type, thumbnail
from rag.nlp import keyword_extraction
from rag.utils.storage_factory import STORAGE_IMPL

from api.db.services.canvas_service import UserCanvasService
Expand Down
3 changes: 2 additions & 1 deletion api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from flask_login import login_required, current_user
from elasticsearch_dsl import Q

from api.db.services.dialog_service import keyword_extraction
from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, rag_tokenizer, keyword_extraction
from rag.nlp import search, rag_tokenizer
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils import rmSpace
from api.db import LLMType, ParserType
Expand Down
5 changes: 2 additions & 3 deletions api/apps/sdk/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
from flask import request

from api.db import StatusEnum
from api.db.db_models import TenantLLM
from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, token_required
from api.utils.api_utils import get_result


@manager.route('/chat', methods=['POST'])
@token_required
def create(tenant_id):
Expand Down
25 changes: 20 additions & 5 deletions api/apps/sdk/dify_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request, jsonify

from db import LLMType, ParserType
from db.services.knowledgebase_service import KnowledgebaseService
from db.services.llm_service import LLMBundle
from settings import retrievaler, kg_retrievaler, RetCode
from utils.api_utils import validate_request, build_error_result, apikey_required
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler, kg_retrievaler, RetCode
from api.utils.api_utils import validate_request, build_error_result, apikey_required


@manager.route('/dify/retrieval', methods=['POST'])
Expand Down
61 changes: 23 additions & 38 deletions api/apps/sdk/doc.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,45 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pathlib
import re
import datetime
import json
import traceback

from botocore.docs.method import document_model_driven_method
from flask import request
from flask_login import login_required, current_user
from elasticsearch_dsl import Q
from pygments import highlight
from sphinx.addnodes import document

from api.db.services.dialog_service import keyword_extraction
from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, rag_tokenizer, keyword_extraction
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils import rmSpace
from rag.nlp import rag_tokenizer
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_error_data_result, validate_request
from api.db.services.document_service import DocumentService
from api.settings import RetCode, retrievaler, kg_retrievaler
from api.utils.api_utils import get_result
from api.settings import kg_retrievaler
import hashlib
import re
from api.utils.api_utils import get_result, token_required, get_error_data_result

from api.db.db_models import Task, File

from api.utils.api_utils import token_required
from api.db.db_models import Task
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.user_service import TenantService, UserTenantService

from api.utils.api_utils import server_error_response, get_error_data_result, validate_request

from api.utils.api_utils import get_result, get_result, get_error_data_result

from functools import partial
from api.utils.api_utils import server_error_response
from api.utils.api_utils import get_result, get_error_data_result
from io import BytesIO

from elasticsearch_dsl import Q
from flask import request, send_file
from flask_login import login_required

from api.db import FileSource, TaskStatus, FileType
from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import RetCode, retrievaler
from api.utils.api_utils import construct_json_result, construct_error_response
from rag.app import book, laws, manual, naive, one, paper, presentation, qa, resume, table, picture, audio, email
from api.utils.api_utils import construct_json_result
from rag.nlp import search
from rag.utils import rmSpace
from rag.utils.es_conn import ELASTICSEARCH
Expand Down Expand Up @@ -365,7 +353,6 @@ def list_chunks(tenant_id,dataset_id,document_id):
return get_result(data=res)



@manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['POST'])
@token_required
def create(tenant_id,dataset_id,document_id):
Expand Down Expand Up @@ -454,7 +441,6 @@ def rm_chunk(tenant_id,dataset_id,document_id):
return get_result()



@manager.route('/dataset/<dataset_id>/document/<document_id>/chunk/<chunk_id>', methods=['PUT'])
@token_required
def update_chunk(tenant_id,dataset_id,document_id,chunk_id):
Expand Down Expand Up @@ -512,7 +498,6 @@ def update_chunk(tenant_id,dataset_id,document_id,chunk_id):
return get_result()



@manager.route('/retrieval', methods=['POST'])
@token_required
def retrieval_test(tenant_id):
Expand Down
54 changes: 53 additions & 1 deletion api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.settings import chat_logger, retrievaler, kg_retrievaler
from rag.app.resume import forbidden_select_fields4resume
from rag.nlp import keyword_extraction
from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory
Expand Down Expand Up @@ -80,6 +79,7 @@ def get_list(cls,dialog_id,page_number, items_per_page, orderby, desc, id , name

return list(sessions.dicts())


def message_fit_in(msg, max_length=4000):
def count():
nonlocal msg
Expand Down Expand Up @@ -456,6 +456,58 @@ def rewrite(tenant_id, llm_id, question):
return ans


def keyword_extraction(chat_mdl, content, topn=3):
prompt = f"""
Role: You're a text analyzer.
Task: extract the most important keywords/phrases of a given piece of text content.
Requirements:
- Summarize the text content, and give top {topn} important keywords/phrases.
- The keywords MUST be in language of the given piece of text content.
- The keywords are delimited by ENGLISH COMMA.
- Keywords ONLY in output.
### Text Content
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): kwd = kwd[0]
if kwd.find("**ERROR**") >=0: return ""
return kwd


def question_proposal(chat_mdl, content, topn=3):
prompt = f"""
Role: You're a text analyzer.
Task: propose {topn} questions about a given piece of text content.
Requirements:
- Understand and summarize the text content, and propose top {topn} important questions.
- The questions SHOULD NOT have overlapping meanings.
- The questions SHOULD cover the main content of the text as much as possible.
- The questions MUST be in language of the given piece of text content.
- One question per line.
- Question ONLY in output.
### Text Content
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): kwd = kwd[0]
if kwd.find("**ERROR**") >= 0: return ""
return kwd


def full_question(tenant_id, llm_id, messages):
if llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
Expand Down
Loading

0 comments on commit d9caaa3

Please sign in to comment.