Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file renamed .coverage → src/code/agent/.coverage
Binary file not shown.
3 changes: 3 additions & 0 deletions src/code/agent/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@
# 日志配置
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")

# 多租户/多用户模式配置
ENABLE_MULTI_USER = os.getenv('ENABLE_COMFYUI_MULTI_USER', '').lower() == 'true'

class ERROR_CODE(Enum):
UNCLASSIFY = "UNCLASSIFY"
INVALID_PARAMS = "INVALID_PARAMS"
Expand Down
23 changes: 22 additions & 1 deletion src/code/agent/routes/gateway_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import traceback

from flask import Blueprint, Flask, jsonify, request
from flask import Blueprint, Flask, jsonify, request, g
from flask_sock import Sock
import websocket

Expand Down Expand Up @@ -58,6 +58,7 @@ def register(self, app: Flask):
def setup_routes(self):
"""设置所有路由"""
self._register_backend_status_middleware()
self._register_user_identity_middleware()
self._register_reboot_handler()

# 只在 CPU 模式下注册这些路由
Expand Down Expand Up @@ -92,6 +93,26 @@ def check_backend_status():
status_code=500
)

def _register_user_identity_middleware(self):
"""注册用户身份识别中间件,在每个请求前识别用户"""
@self.bp.before_request
def identify_user():
"""
识别用户身份

在多租户模式下从请求头提取用户信息,单租户模式使用默认用户
"""
if not constants.ENABLE_MULTI_USER:
from utils.user_identity import DEFAULT_USER_ID
g.user_id = DEFAULT_USER_ID
else:
from flask import abort
from utils.user_identity import extract_user_from_header
uid = extract_user_from_header()
if uid is None:
abort(403, "User context required but not provided")
g.user_id = uid

def _register_websocket(self):
@self.sock.route("/ws")
def comfyui_compatible_ws(ws):
Expand Down
14 changes: 12 additions & 2 deletions src/code/agent/routes/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import threading
import traceback

from flask import Flask, jsonify, request, Response
from flask import Flask, jsonify, request, Response, g
from flask_sock import Sock
import requests

Expand All @@ -13,6 +13,7 @@
from services.management_service import ManagementService, Action, BackendStatus
from utils.logger import log
from utils.error_handler import ErrorResponse
from utils.user_identity import identify_user_or_default
from .management_routes import ManagementRoutes
from .serverless_api_routes import ServerlessApiRoutes
from .gateway_routes import GatewayRoutes
Expand Down Expand Up @@ -202,6 +203,7 @@ def do_save(result_queue, target_snapshot_name):

@self.app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
@self.app.route("/", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
@identify_user_or_default
def proxy(path=""):
backend_status = self.management.service.status
if backend_status not in (BackendStatus.RUNNING, BackendStatus.SAVING):
Expand All @@ -216,10 +218,18 @@ def proxy(path=""):
target_url = f"http://{constants.APP_HOST}{original_uri}"
# print(f"Forwarding http request to path: {target_url}")

# 准备转发的 headers(添加用户标识用于多租户支持)
forward_headers = dict(request.headers)
forward_headers.pop('X-Art-Comfy-User', None)

user_id = g.user_id
if user_id:
forward_headers['X-Art-Comfy-User'] = user_id

resp = requests.request(
method=request.method,
url=target_url,
headers=dict(request.headers),
headers=forward_headers,
params=request.args,
data=request.get_data(),
cookies=request.cookies,
Expand Down
Loading