Skip to content

Commit

Permalink
refactor: optimize database usage (langgenius#12071)
Browse files Browse the repository at this point in the history
Signed-off-by: -LAN- <laipz8200@outlook.com>
  • Loading branch information
laipz8200 authored Dec 25, 2024
1 parent b281a80 commit 83ea931
Show file tree
Hide file tree
Showing 12 changed files with 587 additions and 574 deletions.
352 changes: 180 additions & 172 deletions api/core/app/apps/advanced_chat/generate_task_pipeline.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion api/core/app/apps/message_based_app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def _handle_response(
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
)

Expand Down
192 changes: 106 additions & 86 deletions api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from collections.abc import Generator
from typing import Any, Optional, Union

from sqlalchemy.orm import Session

from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager
Expand Down Expand Up @@ -50,6 +52,7 @@
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
from models.model import EndUser
from models.workflow import (
Workflow,
Expand All @@ -68,8 +71,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""

_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
Expand All @@ -83,25 +84,27 @@ def __init__(
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: user
:param stream: is streamed
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
super().__init__(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
)

if isinstance(self._user, EndUser):
user_id = self._user.session_id
if isinstance(user, EndUser):
self._user_id = user.session_id
self._created_by_role = CreatedByRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT
else:
user_id = self._user.id
raise ValueError(f"Invalid user type: {type(user)}")

self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict

self._workflow = workflow
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.USER_ID: self._user_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
Expand All @@ -115,10 +118,6 @@ def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStr
Process generate task pipeline.
:return:
"""
db.session.refresh(self._workflow)
db.session.refresh(self._user)
db.session.close()

generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
Expand Down Expand Up @@ -185,7 +184,7 @@ def _wrapper_process_stream_response(
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
features_dict = self._workflow_features_dict

if (
features_dict.get("text_to_speech")
Expand Down Expand Up @@ -242,18 +241,26 @@ def _process_stream_response(
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event)
err = self._handle_error(event=event)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state

# init workflow run
workflow_run = self._handle_workflow_run_start()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
with Session(db.engine) as session:
# init workflow run
workflow_run = self._handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
start_resp = self._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield start_resp
elif isinstance(
event,
QueueNodeRetryEvent,
Expand Down Expand Up @@ -350,72 +357,87 @@ def _process_stream_response(
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")

workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
)

# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_success(
session=session,
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
)

# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)

workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise ValueError("workflow run not initialized.")

if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")

workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)

# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_partial_success(
session=session,
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)

# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)

workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()

yield workflow_finish_resp
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise ValueError("workflow run not initialized.")

if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)

# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
session=session,
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)

# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)

workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
Expand All @@ -435,7 +457,7 @@ def _process_stream_response(
if tts_publisher:
tts_publisher.publish(None)

def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.
:return:
Expand All @@ -457,12 +479,10 @@ def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
workflow_app_log.created_by = self._user.id
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id

db.session.add(workflow_app_log)
db.session.commit()
db.session.close()
session.add(workflow_app_log)

def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
Expand Down
36 changes: 15 additions & 21 deletions api/core/app/task_pipeline/based_generate_task_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
import time
from typing import Optional, Union
from typing import Optional

from sqlalchemy import select
from sqlalchemy.orm import Session

from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import (
Expand All @@ -17,9 +20,7 @@
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, Message
from models.model import Message

logger = logging.getLogger(__name__)

Expand All @@ -36,7 +37,6 @@ def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Expand All @@ -48,18 +48,11 @@ def __init__(
"""
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
self._user = user
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream

def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None):
"""
Handle error event.
:param event: event
:param message: message
:return:
"""
def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
logger.debug("error: %s", event.error)
e = event.error
err: Exception
Expand All @@ -71,16 +64,17 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non
else:
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))

if message:
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()

if refetch_message:
err_desc = self._error_to_desc(err)
refetch_message.status = "error"
refetch_message.error = err_desc
if not message_id or not session:
return err

db.session.commit()
stmt = select(Message).where(Message.id == message_id)
message = session.scalar(stmt)
if not message:
return err

err_desc = self._error_to_desc(err)
message.status = "error"
message.error = err_desc
return err

def _error_to_desc(self, e: Exception) -> str:
Expand Down
Loading

0 comments on commit 83ea931

Please sign in to comment.