Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion hindsight-api/hindsight_api/engine/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,17 @@ async def retain_batch_async(
contents: list[dict[str, Any]],
*,
request_context: "RequestContext",
document_tags: list[str] | None = None,
) -> dict[str, Any]:
"""
Retain a batch of memory items.

Args:
bank_id: The memory bank ID.
contents: List of content dicts with 'content', optional 'event_date',
'context', 'metadata', 'document_id'.
'context', 'metadata', 'document_id', and per-item 'tags'.
request_context: Request context for authentication.
document_tags: Optional tags applied to all items in the batch.

Returns:
Dict with processing results.
Expand Down Expand Up @@ -561,6 +563,7 @@ async def submit_async_retain(
contents: list[dict[str, Any]],
*,
request_context: "RequestContext",
document_tags: list[str] | None = None,
) -> dict[str, Any]:
"""
Submit a batch retain operation to run asynchronously.
Expand All @@ -569,6 +572,7 @@ async def submit_async_retain(
bank_id: The memory bank ID.
contents: List of content dicts to retain.
request_context: Request context for authentication.
document_tags: Optional tags applied to all items in the async batch.

Returns:
Dict with operation_id and items_count.
Expand Down
8 changes: 7 additions & 1 deletion hindsight-api/hindsight_api/engine/memory_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ async def _handle_batch_retain(self, task_dict: dict[str, Any]):
if not bank_id:
raise ValueError("bank_id is required for batch retain task")
contents = task_dict.get("contents", [])
document_tags = task_dict.get("document_tags")

logger.info(
f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items"
Expand All @@ -557,7 +558,12 @@ async def _handle_batch_retain(self, task_dict: dict[str, Any]):
tenant_id=task_dict.get("_tenant_id"),
api_key_id=task_dict.get("_api_key_id"),
)
await self.retain_batch_async(bank_id=bank_id, contents=contents, request_context=context)
await self.retain_batch_async(
bank_id=bank_id,
contents=contents,
document_tags=document_tags,
request_context=context,
)

logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")

Expand Down
70 changes: 70 additions & 0 deletions hindsight-api/tests/test_async_retain_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Unit tests for async retain tag propagation."""

from unittest.mock import AsyncMock

import pytest

from hindsight_api.engine.memory_engine import MemoryEngine
from hindsight_api.models import RequestContext


@pytest.mark.asyncio
async def test_submit_async_retain_includes_document_tags_in_task_payload():
"""submit_async_retain should include document_tags in queued task payload."""
engine = MemoryEngine.__new__(MemoryEngine)
engine._authenticate_tenant = AsyncMock()
engine._submit_async_operation = AsyncMock(return_value={"operation_id": "op-1"})

request_context = RequestContext(tenant_id="tenant-a", api_key_id="key-a")
contents = [{"content": "Async retain payload test."}]
document_tags = ["scope:tools", "user:alice"]

result = await MemoryEngine.submit_async_retain(
engine,
bank_id="bank-1",
contents=contents,
document_tags=document_tags,
request_context=request_context,
)

assert result == {"operation_id": "op-1", "items_count": 1}
engine._authenticate_tenant.assert_awaited_once_with(request_context)
engine._submit_async_operation.assert_awaited_once()

kwargs = engine._submit_async_operation.await_args.kwargs
assert kwargs["bank_id"] == "bank-1"
assert kwargs["operation_type"] == "retain"
assert kwargs["task_type"] == "batch_retain"
assert kwargs["task_payload"]["contents"] == contents
assert kwargs["task_payload"]["document_tags"] == document_tags
assert kwargs["task_payload"]["_tenant_id"] == "tenant-a"
assert kwargs["task_payload"]["_api_key_id"] == "key-a"


@pytest.mark.asyncio
async def test_handle_batch_retain_forwards_document_tags_to_retain_batch_async():
"""Worker handler should forward document_tags from task payload."""
engine = MemoryEngine.__new__(MemoryEngine)
engine.retain_batch_async = AsyncMock(return_value={"items_count": 1})

task_dict = {
"bank_id": "bank-1",
"contents": [{"content": "Forward tags test."}],
"document_tags": ["scope:client"],
"_tenant_id": "tenant-a",
"_api_key_id": "key-a",
}

await MemoryEngine._handle_batch_retain(engine, task_dict)

engine.retain_batch_async.assert_awaited_once()
kwargs = engine.retain_batch_async.await_args.kwargs
assert kwargs["bank_id"] == "bank-1"
assert kwargs["contents"] == task_dict["contents"]
assert kwargs["document_tags"] == ["scope:client"]

request_context = kwargs["request_context"]
assert request_context.internal is True
assert request_context.user_initiated is True
assert request_context.tenant_id == "tenant-a"
assert request_context.api_key_id == "key-a"