Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions api/routers/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,24 @@ class FileListResponse(BaseModel):
"""Response model for file list"""
files: List[FileInfo]
total: int

class UploadResponse(BaseModel):
"""Response model for file upload"""
job_id: str
file_id: str | None = None
status: str = "accepted"

@router.post(
"",
response_model=str,
response_model=UploadResponse,
status_code=status.HTTP_201_CREATED,
)
def upload_file(
file: UploadFile,
user: Annotated[User | None, Depends(get_current_user)],
callback_url: str | None = Form(default=None),
job_id: str | None = Form(default=None),
callback_secret: str | None = Form(default=None),
):
"""
Upload a file to the knowledge base
Expand All @@ -66,9 +75,17 @@ def upload_file(
)
try:
print(f"Uploading file: {file.filename} for owner_id: {user.id}")
# Convert string UUID to UUID object
doc_id = knowledge_handler.upload_file(file, user.id)
return doc_id
upload_result = knowledge_handler.upload_file(
file,
user.id,
callback_url=callback_url,
job_id=job_id,
callback_secret=callback_secret,
)
return UploadResponse(
job_id=upload_result["job_id"],
file_id=upload_result.get("file_id"),
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -80,7 +97,6 @@ def upload_file(
detail=f"Failed to upload file: {str(e)}",
)


@router.get("/{file_id}/download")
async def download_file(file_id: str, user: Annotated[User | None, Depends(get_current_user)]):
if user is None:
Expand Down
139 changes: 130 additions & 9 deletions application/knowledge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Dict, Any

import httpx

if TYPE_CHECKING:
from config.application.knowledge_config import KnowledgeConfig

Expand All @@ -24,8 +26,19 @@ def __init__(self, config: 'KnowledgeConfig'):
# Semaphore to control concurrent indexing operations
max_concurrent_indexing = config.max_concurrent_indexing
self.indexing_semaphore = asyncio.Semaphore(max_concurrent_indexing)

# Semaphore to control concurrent indexing operations
max_concurrent_indexing = config.max_concurrent_indexing
self.indexing_semaphore = asyncio.Semaphore(max_concurrent_indexing)

def upload_file(self, file: UploadFile, user_id: uuid.UUID) -> str:
def upload_file(
self,
file: UploadFile,
user_id: uuid.UUID,
callback_url: str | None = None,
job_id: str | None = None,
callback_secret: str | None = None,
) -> Dict[str, Any]:
try:
doc_id = self.file_storage.upload_file(
filename=file.filename,
Expand All @@ -34,43 +47,151 @@ def upload_file(self, file: UploadFile, user_id: uuid.UUID) -> str:
content_type=file.content_type
)
# Start indexing in background (fire-and-forget)
self._start_background_indexing(doc_id)
job_identifier = job_id or str(uuid.uuid4())
logger.info(
"Upload accepted for file_id=%s (job_id=%s, callback_url=%s)",
doc_id,
job_identifier,
callback_url,
)
self._start_background_indexing(
doc_id,
callback_url,
job_identifier,
callback_secret,
)
logger.info(f"File {file.filename} uploaded with ID {doc_id}, indexing started in background")
return doc_id
return {
"file_id": doc_id,
"job_id": job_identifier,
}

except Exception as e:
logger.error(e)
raise

def _start_background_indexing(self, doc_id: str):
def _start_background_indexing(
self,
doc_id: str,
callback_url: str | None = None,
job_id: str | None = None,
callback_secret: str | None = None,
):
"""Start background indexing task safely"""
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# If we're in an async context, create the task
loop.create_task(self._index_file_background(doc_id))
loop.create_task(
self._index_file_background(
doc_id,
callback_url,
job_id,
callback_secret,
)
)
except RuntimeError:
# No event loop running, start a new one in a thread
import threading
def run_async():
asyncio.run(self._index_file_background(doc_id))
asyncio.run(
self._index_file_background(
doc_id,
callback_url,
job_id,
callback_secret,
)
)
thread = threading.Thread(target=run_async, daemon=True)
thread.start()

async def _index_file_background(self, doc_id: str):
async def _index_file_background(
self,
doc_id: str,
callback_url: str | None = None,
job_id: str | None = None,
callback_secret: str | None = None,
):
"""Background task for indexing files with semaphore control"""
callback_payload: Dict[str, Any] | None = None
status: str = "processing"
async with self.indexing_semaphore:
try:
logger.info(f"Starting background indexing for file_id: {doc_id} (semaphore acquired)")
result = await self.file_index.index_file(doc_id)
if result.get("success"):
success = bool(result.get("success"))
error_message = result.get("error_message")
status = "succeeded" if success else "failed"
metadata = {
key: value
for key, value in (result or {}).items()
if key not in {"success", "error_message"}
} if isinstance(result, dict) else {}
if success:
logger.info(f"Background indexing completed successfully for file_id: {doc_id}")
else:
logger.error(f"Background indexing failed for file_id: {doc_id}, error: {result.get('error_message')}")
logger.error(f"Background indexing failed for file_id: {doc_id}, error: {error_message}")
callback_payload = {
"file_id": doc_id,
"rag_file_id": doc_id,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference b/t file_id and rag_file_id?

"success": success,
"status": status,
"error": None if success else error_message,
"metadata": metadata or None,
}
except Exception as e:
logger.error(f"Background indexing failed for file_id: {doc_id}, exception: {str(e)}")
status = "failed"
callback_payload = {
"file_id": doc_id,
"rag_file_id": doc_id,
"success": False,
"status": status,
"error": str(e),
"metadata": None,
}
finally:
logger.debug(f"Background indexing semaphore released for file_id: {doc_id}")
if callback_url and callback_payload is not None:
payload_with_job = dict(callback_payload)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this line? To create a hard copy? if so why not directly use copy()

payload_with_job["job_id"] = job_id or str(uuid.uuid4())
await self._send_callback(
callback_url,
payload_with_job,
callback_secret=callback_secret,
)

async def _send_callback(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. if the user is under a NAT, how to make sure this callback_url is available from the server?
  2. For this function my feeling is, its invasion is a little bit strong. My idea is to shield application layer from access layer kind of stuff, for example: http related thing. Instead of setting the function here, is it better to move it to api layer?

self,
callback_url: str,
payload: Dict[str, Any],
callback_secret: str | None = None,
):
"""Send callback notification with indexing results."""
try:
headers = {}
if callback_secret:
headers["X-Callback-Secret"] = callback_secret
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(callback_url, json=payload, headers=headers)
response.raise_for_status()
logger.info(
"Sent callback for job_id=%s file_id=%s to %s (status=%s)",
payload.get("job_id"),
payload.get("file_id"),
callback_url,
response.status_code,
)
except httpx.HTTPStatusError as e:
logger.error(
"Callback endpoint returned error for job_id=%s file_id=%s (status=%s, body=%s)",
payload.get("job_id"),
payload.get("file_id"),
e.response.status_code if e.response else "unknown",
e.response.text if e.response else "no-body",
)
except Exception as e:
logger.error(f"Failed to send callback to {callback_url}: {e}")

def get_file(self, doc_id: str, user_id: uuid.UUID) -> Response:
metadata = self.file_storage.get_file_metadata(doc_id)
Expand Down