Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions hindsight-api/hindsight_api/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,16 +2285,55 @@ async def api_get_mental_model(
):
"""Get a mental model by ID."""
try:
# Pre-operation validation hook
validator = app.state.memory._operation_validator
if validator:
from hindsight_api.extensions.operation_validator import MentalModelGetContext

ctx = MentalModelGetContext(
bank_id=bank_id,
mental_model_id=mental_model_id,
request_context=request_context,
)
validation = await validator.validate_mental_model_get(ctx)
if not validation.allowed:
raise OperationValidationError(
validation.reason or "Operation not allowed",
status_code=validation.status_code,
)

mental_model = await app.state.memory.get_mental_model(
bank_id=bank_id,
mental_model_id=mental_model_id,
request_context=request_context,
)
if mental_model is None:
raise HTTPException(status_code=404, detail=f"Mental model '{mental_model_id}' not found")

# Post-operation hook
if validator:
from hindsight_api.extensions.operation_validator import MentalModelGetResult

content = mental_model.get("content", "")
output_tokens = len(content) // 4 if content else 0

result_ctx = MentalModelGetResult(
bank_id=bank_id,
mental_model_id=mental_model_id,
request_context=request_context,
output_tokens=output_tokens,
success=True,
)
try:
await validator.on_mental_model_get_complete(result_ctx)
except Exception as hook_err:
logger.warning(f"Post-mental-model-get hook error (non-fatal): {hook_err}")

return MentalModelResponse(**mental_model)
except (AuthenticationError, HTTPException):
raise
except OperationValidationError as e:
raise HTTPException(status_code=e.status_code, detail=e.reason)
except Exception as e:
import traceback

Expand Down
58 changes: 54 additions & 4 deletions hindsight-api/hindsight_api/engine/memory_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,13 @@ async def _handle_refresh_mental_model(self, task_dict: dict[str, Any]):

from hindsight_api.models import RequestContext

internal_context = RequestContext(internal=True)
# Restore tenant_id/api_key_id from task payload so extensions can
# attribute the mental_model_refresh operation to the correct org.
internal_context = RequestContext(
internal=True,
tenant_id=task_dict.get("_tenant_id"),
api_key_id=task_dict.get("_api_key_id"),
)

# Get the current mental model to get source_query
mental_model = await self.get_mental_model(bank_id, mental_model_id, request_context=internal_context)
Expand Down Expand Up @@ -641,6 +647,42 @@ async def _handle_refresh_mental_model(self, task_dict: dict[str, Any]):
request_context=internal_context,
)

# Call post-operation hook if validator is configured
if self._operation_validator:
from hindsight_api.extensions.operation_validator import MentalModelRefreshResult

# Count facts and mental models from based_on
facts_used = 0
mental_models_used = 0
if reflect_result.based_on:
for fact_type, facts in reflect_result.based_on.items():
if facts:
if fact_type == "mental_models":
mental_models_used += len(facts)
else:
facts_used += len(facts)

# Estimate tokens
query_tokens = len(source_query) // 4 if source_query else 0
output_tokens = len(generated_content) // 4 if generated_content else 0
context_tokens = 0 # refresh doesn't use additional context

result_ctx = MentalModelRefreshResult(
bank_id=bank_id,
mental_model_id=mental_model_id,
request_context=internal_context,
query_tokens=query_tokens,
output_tokens=output_tokens,
context_tokens=context_tokens,
facts_used=facts_used,
mental_models_used=mental_models_used,
success=True,
)
try:
await self._operation_validator.on_mental_model_refresh_complete(result_ctx)
except Exception as hook_err:
logger.warning(f"Post-mental-model-refresh hook error (non-fatal): {hook_err}")

logger.info(f"[REFRESH_MENTAL_MODEL_TASK] Completed for bank_id={bank_id}, mental_model_id={mental_model_id}")

async def execute_task(self, task_dict: dict[str, Any]):
Expand Down Expand Up @@ -5482,13 +5524,21 @@ async def submit_async_refresh_mental_model(
if not mental_model:
raise ValueError(f"Mental model {mental_model_id} not found in bank {bank_id}")

# Pass tenant_id and api_key_id through task payload so the worker
# can provide request context to extension hooks.
task_payload: dict[str, Any] = {
"mental_model_id": mental_model_id,
}
if request_context.tenant_id:
task_payload["_tenant_id"] = request_context.tenant_id
if request_context.api_key_id:
task_payload["_api_key_id"] = request_context.api_key_id

return await self._submit_async_operation(
bank_id=bank_id,
operation_type="refresh_mental_model",
task_type="refresh_mental_model",
task_payload={
"mental_model_id": mental_model_id,
},
task_payload=task_payload,
result_metadata={"mental_model_id": mental_model_id, "name": mental_model["name"]},
dedupe_by_bank=False,
)
8 changes: 8 additions & 0 deletions hindsight-api/hindsight_api/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
# Consolidation operation
ConsolidateContext,
ConsolidateResult,
# Mental Model operations
MentalModelGetContext,
MentalModelGetResult,
MentalModelRefreshResult,
# Core operations
OperationValidationError,
OperationValidatorExtension,
Expand Down Expand Up @@ -65,6 +69,10 @@
# Operation Validator - Consolidation
"ConsolidateContext",
"ConsolidateResult",
# Operation Validator - Mental Model
"MentalModelGetContext",
"MentalModelGetResult",
"MentalModelRefreshResult",
# Tenant/Auth
"ApiKeyTenantExtension",
"AuthenticationError",
Expand Down
103 changes: 103 additions & 0 deletions hindsight-api/hindsight_api/extensions/operation_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,48 @@ class ConsolidateResult:
error: str | None = None


# =============================================================================
# Mental Model Contexts
# =============================================================================


@dataclass
class MentalModelGetContext:
"""Context for a mental model GET operation validation (pre-operation)."""

bank_id: str
mental_model_id: str
request_context: "RequestContext"


@dataclass
class MentalModelGetResult:
"""Result context for post-mental-model-GET hook."""

bank_id: str
mental_model_id: str
request_context: "RequestContext"
output_tokens: int # tokens in the returned content
success: bool = True
error: str | None = None


@dataclass
class MentalModelRefreshResult:
"""Result context for post-mental-model-refresh hook."""

bank_id: str
mental_model_id: str
request_context: "RequestContext"
query_tokens: int # tokens in source_query
output_tokens: int # tokens in generated content
context_tokens: int # tokens in context (if any)
facts_used: int # facts referenced in based_on
mental_models_used: int # mental models referenced in based_on
success: bool = True
error: str | None = None


class OperationValidatorExtension(Extension, ABC):
"""
Validates and hooks into retain/recall/reflect/consolidate operations.
Expand Down Expand Up @@ -402,3 +444,64 @@ async def on_consolidate_complete(self, result: ConsolidateResult) -> None:
- error: Error message (if failed)
"""
pass

# =========================================================================
# Mental Model - Pre-operation validation hook (optional - override to implement)
# =========================================================================

async def validate_mental_model_get(self, ctx: MentalModelGetContext) -> ValidationResult:
"""
Validate a mental model GET operation before execution.

Override to implement custom validation logic for mental model retrieval.

Args:
ctx: Context containing:
- bank_id: Bank identifier
- mental_model_id: Mental model identifier
- request_context: Request context with auth info

Returns:
ValidationResult indicating whether the operation is allowed.
"""
return ValidationResult.accept()

# =========================================================================
# Mental Model - Post-operation hooks (optional - override to implement)
# =========================================================================

async def on_mental_model_get_complete(self, result: MentalModelGetResult) -> None:
"""
Called after a mental model GET operation completes (success or failure).

Override to implement post-operation logic such as tracking or audit logging.

Args:
result: Result context containing:
- bank_id: Bank identifier
- mental_model_id: Mental model identifier
- output_tokens: Token count of the returned content
- success: Whether the operation succeeded
- error: Error message (if failed)
"""
pass

async def on_mental_model_refresh_complete(self, result: MentalModelRefreshResult) -> None:
"""
Called after a mental model refresh operation completes (success or failure).

Override to implement post-operation logic such as tracking or audit logging.

Args:
result: Result context containing:
- bank_id: Bank identifier
- mental_model_id: Mental model identifier
- query_tokens: Tokens in source_query
- output_tokens: Tokens in generated content
- context_tokens: Tokens in context
- facts_used: Number of facts referenced
- mental_models_used: Number of mental models referenced
- success: Whether the operation succeeded
- error: Error message (if failed)
"""
pass
Loading
Loading