Skip to content

Break apart session processor and the running of each session into se… #6382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
e51a302
Break apart session processor and the running of each session into se…
brandonrising May 16, 2024
82957bb
Run ruff
brandonrising May 16, 2024
8edc25d
Fix next node calling logic
brandonrising May 16, 2024
efb069d
feat(app): iterate on processor split
psychedelicious May 22, 2024
f7c356d
feat(app): iterate on processor split 2
psychedelicious May 22, 2024
cb8e9e1
feat(app): make things in session runner private
psychedelicious May 22, 2024
cef1585
feat(app): support multiple processor lifecycle callbacks
psychedelicious May 22, 2024
eff3596
tidy(app): rearrange proccessor
psychedelicious May 22, 2024
b1f819a
tidy(app): "outputs" -> "output"
psychedelicious May 22, 2024
d30c1ad
docs(app): explain why errors are handled poorly
psychedelicious May 22, 2024
df54572
feat(app): handle preparation errors as node errors
psychedelicious May 22, 2024
80905ff
fix(app): fix logging of error classes instead of class names
psychedelicious May 22, 2024
23b0534
feat(processor): get user/project from queue item w/ fallback
psychedelicious May 23, 2024
a55b2f0
chore: ruff
psychedelicious May 23, 2024
7652fbc
fix(processor): restore missing update of session
psychedelicious May 23, 2024
0e81e7b
feat(db): add `error_type`, `error_message`, rename `error` -> `error…
psychedelicious May 23, 2024
d6696a7
feat(queue): session queue error handling
psychedelicious May 23, 2024
6a34176
feat(events): add enriched errors to events
psychedelicious May 23, 2024
db0ef8d
feat(processor): update enriched errors & `fail_queue_item()`
psychedelicious May 23, 2024
19227fe
feat(app): update test event callbacks
psychedelicious May 23, 2024
9a4c167
chore(ui): typegen
psychedelicious May 23, 2024
6063487
feat(ui): handle enriched events
psychedelicious May 23, 2024
a98dded
docs(processor): update docstrings, comments
psychedelicious May 24, 2024
7d1844e
chore: ruff
psychedelicious May 24, 2024
c88de18
tidy(queue): delete unused `delete_queue_item` method
psychedelicious May 24, 2024
169b75b
tidy(processor): remove test callbacks
psychedelicious May 24, 2024
350feee
fix(processor): fix race condition related to clearing the queue
psychedelicious May 24, 2024
fb93e68
feat(processor): add debug log stmts to session running callbacks
psychedelicious May 24, 2024
0758e9c
fix(ui): race condition with progress
psychedelicious May 24, 2024
08a42c3
tidy(ui): remove extraneous condition in socketInvocationError
psychedelicious May 24, 2024
dc78a0e
fix(ui): correctly fallback to error message when traceback is empty …
psychedelicious May 24, 2024
65e85d1
tidy: remove unnecessary whitespace changes
psychedelicious May 24, 2024
5edc825
fix(processor): race condition that could result in node errors not g…
psychedelicious May 24, 2024
5ee9ff7
feat(ui): toast on queue item errors, improved error descriptions
psychedelicious May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
Expand Down Expand Up @@ -103,7 +103,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)
Expand Down
10 changes: 7 additions & 3 deletions invokeai/app/services/events/events_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def emit_invocation_error(
node: dict,
source_node_id: str,
error_type: str,
error: str,
error_message: str,
error_traceback: str,
user_id: str | None,
project_id: str | None,
) -> None:
Expand All @@ -136,7 +137,8 @@ def emit_invocation_error(
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error": error,
"error_message": error_message,
"error_traceback": error_traceback,
"user_id": user_id,
"project_id": project_id,
},
Expand Down Expand Up @@ -257,7 +259,9 @@ def emit_queue_item_status_changed(
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error": session_queue_item.error,
"error_type": session_queue_item.error_type,
"error_message": session_queue_item.error_message,
"error_traceback": session_queue_item.error_traceback,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
Expand Down
125 changes: 125 additions & 0 deletions invokeai/app/services/session_processor/session_processor_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,49 @@
from abc import ABC, abstractmethod
from threading import Event
from typing import Optional, Protocol

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.util.profiler import Profiler


class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""

@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
"""Starts the session runner.

Args:
services: The invocation services.
cancel_event: The cancel event.
profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session
stats will be still be recorded and logged when profiling is disabled.
"""
pass

@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs a session.

Args:
queue_item: The session to run.
"""
pass

@abstractmethod
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Run a single node in the graph.

Args:
invocation: The invocation to run.
queue_item: The session queue item.
"""
pass


class SessionProcessorBase(ABC):
Expand All @@ -26,3 +69,85 @@ def pause(self) -> SessionProcessorStatus:
def get_status(self) -> SessionProcessorStatus:
"""Gets the status of the session processor"""
pass


class OnBeforeRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a node.

Args:
invocation: The invocation that will be executed.
queue_item: The session queue item.
"""
...


class OnAfterRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None:
"""Callback to run before executing a node.

Args:
invocation: The invocation that was executed.
queue_item: The session queue item.
"""
...


class OnNodeError(Protocol):
def __call__(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Callback to run when a node has an error.

Args:
invocation: The invocation that errored.
queue_item: The session queue item.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...


class OnBeforeRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a session.

Args:
queue_item: The session queue item.
"""
...


class OnAfterRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run after executing a session.

Args:
queue_item: The session queue item.
"""
...


class OnNonFatalProcessorError(Protocol):
def __call__(
self,
queue_item: Optional[SessionQueueItem],
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Callback to run when a non-fatal error occurs in the processor.

Args:
queue_item: The session queue item, if one was being executed when the error occurred.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...
Loading
Loading