Skip to content

Commit

Permalink
Merge branch 'main' into feat/azure-tts
Browse files Browse the repository at this point in the history
  • Loading branch information
leslie2046 committed Mar 12, 2024
2 parents 38ed8d2 + f734cca commit 0db8d9f
Show file tree
Hide file tree
Showing 104 changed files with 2,162 additions and 393 deletions.
11 changes: 0 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
</p>

<p align="center">
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
</a>
<ul align="center" style="text-decoration: none; list-style: none;">
<li> US EST: 09:00 (9:00 AM)</li>
<li> CET: 15:00 (3:00 PM)</li>
<li> CST: 22:00 (10:00 PM)</li>
</ul>
</p>

<p align="center">
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
Expand Down
6 changes: 3 additions & 3 deletions api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
1. Start the docker-compose stack

The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.

```bash
cd ../docker
docker-compose -f docker-compose.middleware.yaml -p dify up -d
Expand All @@ -15,7 +15,7 @@
3. Generate a `SECRET_KEY` in the `.env` file.

```bash
openssl rand -base64 42
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
```
3.5 If you use annaconda, create a new environment and activate it
```bash
Expand Down Expand Up @@ -46,7 +46,7 @@
```
pip install -r requirements.txt --upgrade --force-reinstall
```

6. Start backend:
```bash
flask run --host 0.0.0.0 --port=5001 --debug
Expand Down
2 changes: 1 addition & 1 deletion api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self):
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.5.8"
self.CURRENT_VERSION = "0.5.9"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
Expand Down
41 changes: 39 additions & 2 deletions api/controllers/console/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from libs.login import login_required
from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService

from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.tool_manager import ToolManager
from core.entities.application_entities import AgentToolEntity

def _get_app(app_id, tenant_id):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
Expand Down Expand Up @@ -236,7 +238,42 @@ class AppApi(Resource):
def get(self, app_id):
"""Get app detail"""
app_id = str(app_id)
app = _get_app(app_id, current_user.current_tenant_id)
app: App = _get_app(app_id, current_user.current_tenant_id)

# get original app model config
model_config: AppModelConfig = app.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)

# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}

# override tool parameters
tool['tool_parameters'] = masked_parameter
except Exception as e:
pass

# override agent mode
model_config.agent_mode = json.dumps(agent_mode)

return app

Expand Down
86 changes: 86 additions & 0 deletions api/controllers/console/app/model_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json

from flask import request
from flask_login import current_user
Expand All @@ -7,6 +8,9 @@
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.entities.application_entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.login import login_required
Expand Down Expand Up @@ -38,6 +42,88 @@ def post(self, app_id):
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)

# get original app model config
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
AppModelConfig.id == app.app_model_config_id
).first()
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
except Exception as e:
continue

# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
parameters = {}
masked_parameter = {}

key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime

# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)

# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
except Exception as e:
continue

manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
manager.delete_tool_parameters_cache()

# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue

if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]

# encrypt parameters
if agent_tool_entity.tool_parameters:
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})

# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)

db.session.add(new_app_model_config)
db.session.flush()

Expand Down
26 changes: 26 additions & 0 deletions api/controllers/console/workspace/tool_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,30 @@ def get(self, provider):
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=minetype)

class ToolModelProviderIconApi(Resource):
@setup_required
def get(self, provider):
icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)

class ToolModelProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')

args = parser.parse_args()

return ToolManageService.list_model_tool_provider_tools(
user_id,
tenant_id,
args['provider'],
)

class ToolApiProviderAddApi(Resource):
@setup_required
Expand Down Expand Up @@ -283,6 +307,8 @@ def post(self):
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
Expand Down
4 changes: 2 additions & 2 deletions api/controllers/service_api/dataset/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def post(self, tenant_id, dataset_id, document_id, segment_id):
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
args = parser.parse_args()

SegmentService.segment_create_args_validate(args['segments'], document)
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
Expand Down
4 changes: 4 additions & 0 deletions api/core/app_runner/assistant_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ def run(self, application_generate_entity: ApplicationGenerateEntity,
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING

db.session.refresh(conversation)
db.session.refresh(message)
db.session.close()

# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner(
Expand Down
2 changes: 2 additions & 0 deletions api/core/app_runner/basic_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity,
model=app_orchestration_config.model_config.model
)

db.session.close()

invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
Expand Down
8 changes: 8 additions & 0 deletions api/core/app_runner/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def process(self, stream: bool) -> Union[dict, Generator]:
Process generate task pipeline.
:return:
"""
db.session.refresh(self._conversation)
db.session.refresh(self._message)
db.session.close()

if stream:
return self._process_stream_response()
else:
Expand Down Expand Up @@ -303,6 +307,7 @@ def _process_stream_response(self) -> Generator:
.first()
)
db.session.refresh(agent_thought)
db.session.close()

if agent_thought:
response = {
Expand Down Expand Up @@ -330,6 +335,8 @@ def _process_stream_response(self) -> Generator:
.filter(MessageFile.id == event.message_file_id)
.first()
)
db.session.close()

# get extension
if '.' in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}'
Expand Down Expand Up @@ -413,6 +420,7 @@ def _save_message(self, llm_result: LLMResult) -> None:
usage = llm_result.usage

self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()

self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
self._message.message_tokens = usage.prompt_tokens
Expand Down
6 changes: 3 additions & 3 deletions api/core/application_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _generate_worker(self, flask_app: Flask,
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()

def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
Expand Down Expand Up @@ -233,8 +233,6 @@ def _handle_response(self, application_generate_entity: ApplicationGenerateEntit
else:
logger.exception(e)
raise e
finally:
db.session.remove()

def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
-> AppOrchestrationConfigEntity:
Expand Down Expand Up @@ -651,6 +649,7 @@ def _init_generate_records(self, application_generate_entity: ApplicationGenerat

db.session.add(conversation)
db.session.commit()
db.session.refresh(conversation)
else:
conversation = (
db.session.query(Conversation)
Expand Down Expand Up @@ -689,6 +688,7 @@ def _init_generate_records(self, application_generate_entity: ApplicationGenerat

db.session.add(message)
db.session.commit()
db.session.refresh(message)

for file in application_generate_entity.files:
message_file = MessageFile(
Expand Down
Loading

0 comments on commit 0db8d9f

Please sign in to comment.