Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions hindsight-api/hindsight_api/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def FieldWithDefault(default_factory: Callable, **kwargs) -> Any:
from hindsight_api.engine.db_utils import acquire_with_retry
from hindsight_api.engine.memory_engine import Budget, _get_tiktoken_encoding, fq_table
from hindsight_api.engine.reflect.observations import Observation
from hindsight_api.engine.response_models import VALID_RECALL_FACT_TYPES, TokenUsage
from hindsight_api.engine.response_models import VALID_RECALL_FACT_TYPES, MemoryFact, TokenUsage
from hindsight_api.engine.search.tags import TagsMatch
from hindsight_api.extensions import HttpExtension, OperationValidationError, load_extension
from hindsight_api.metrics import create_metrics_collector, get_metrics_collector, initialize_metrics
Expand All @@ -97,6 +97,12 @@ class ChunkIncludeOptions(BaseModel):
max_tokens: int = Field(default=8192, description="Maximum tokens for chunks (chunks may be truncated)")


class SourceFactsIncludeOptions(BaseModel):
"""Options for including source facts for observation-type results."""

max_tokens: int = Field(default=4096, description="Maximum tokens for source facts")


class IncludeOptions(BaseModel):
"""Options for including additional data in recall results."""

Expand All @@ -107,6 +113,10 @@ class IncludeOptions(BaseModel):
chunks: ChunkIncludeOptions | None = Field(
default=None, description="Include raw chunks. Set to {} to enable, null to disable (default: disabled)."
)
source_facts: SourceFactsIncludeOptions | None = Field(
default=None,
description="Include source facts for observation-type results. Set to {} to enable, null to disable (default: disabled).",
)


class RecallRequest(BaseModel):
Expand Down Expand Up @@ -189,6 +199,9 @@ class RecallResult(BaseModel):
metadata: dict[str, str] | None = None # User-defined metadata
chunk_id: str | None = None # Chunk this fact was extracted from
tags: list[str] | None = None # Visibility scope tags
source_fact_ids: list[str] | None = (
None # IDs of source facts (observation type only, when source_facts is enabled)
)


class EntityObservationResponse(BaseModel):
Expand Down Expand Up @@ -340,6 +353,9 @@ class RecallResponse(BaseModel):
default=None, description="Entity states for entities mentioned in results"
)
chunks: dict[str, ChunkData] | None = Field(default=None, description="Chunks for facts, keyed by chunk_id")
source_facts: dict[str, RecallResult] | None = Field(
default=None, description="Source facts for observation-type results, keyed by fact ID"
)


class EntityInput(BaseModel):
Expand Down Expand Up @@ -1959,6 +1975,10 @@ async def api_recall(
include_chunks = request.include.chunks is not None
max_chunk_tokens = request.include.chunks.max_tokens if include_chunks else 8192

# Determine source facts inclusion settings
include_source_facts = request.include.source_facts is not None
max_source_facts_tokens = request.include.source_facts.max_tokens if include_source_facts else 4096

pre_recall = time.time() - handler_start
# Run recall with tracing (record metrics)
with metrics.record_operation(
Expand All @@ -1977,14 +1997,16 @@ async def api_recall(
max_entity_tokens=max_entity_tokens,
include_chunks=include_chunks,
max_chunk_tokens=max_chunk_tokens,
include_source_facts=include_source_facts,
max_source_facts_tokens=max_source_facts_tokens,
request_context=request_context,
tags=request.tags,
tags_match=request.tags_match,
)

# Convert core MemoryFact objects to API RecallResult objects (excluding internal metrics)
recall_results = [
RecallResult(
def _fact_to_result(fact: "MemoryFact") -> RecallResult:
return RecallResult(
id=fact.id,
text=fact.text,
type=fact.fact_type,
Expand All @@ -1996,9 +2018,10 @@ async def api_recall(
document_id=fact.document_id,
chunk_id=fact.chunk_id,
tags=fact.tags,
source_fact_ids=fact.source_fact_ids,
)
for fact in core_result.results
]

recall_results = [_fact_to_result(fact) for fact in core_result.results]

# Convert chunks from engine to HTTP API format
chunks_response = None
Expand Down Expand Up @@ -2026,11 +2049,19 @@ async def api_recall(
],
)

# Convert source facts dict to API format
source_facts_response = None
if core_result.source_facts:
source_facts_response = {
fact_id: _fact_to_result(fact) for fact_id, fact in core_result.source_facts.items()
}

response = RecallResponse(
results=recall_results,
trace=core_result.trace,
entities=entities_response,
chunks=chunks_response,
source_facts=source_facts_response,
)

handler_duration = time.time() - handler_start
Expand Down
83 changes: 82 additions & 1 deletion hindsight-api/hindsight_api/engine/memory_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,8 @@ async def recall_async(
max_entity_tokens: int = 500,
include_chunks: bool = False,
max_chunk_tokens: int = 8192,
include_source_facts: bool = False,
max_source_facts_tokens: int = 4096,
request_context: "RequestContext",
tags: list[str] | None = None,
tags_match: TagsMatch = "any",
Expand Down Expand Up @@ -2159,6 +2161,8 @@ async def recall_async(
tags_match=tags_match,
connection_budget=_connection_budget,
quiet=_quiet,
include_source_facts=include_source_facts,
max_source_facts_tokens=max_source_facts_tokens,
)
break # Success - exit retry loop
except Exception as e:
Expand Down Expand Up @@ -2283,6 +2287,8 @@ async def _search_with_retries(
tags_match: TagsMatch = "any",
connection_budget: int | None = None,
quiet: bool = False,
include_source_facts: bool = False,
max_source_facts_tokens: int = 4096,
) -> RecallResultModel:
"""
Search implementation with modular retrieval and reranking.
Expand Down Expand Up @@ -2879,6 +2885,74 @@ def to_tuple_format(results):
)
top_results_dicts.append(result_dict)

# Fetch source facts for observation-type results (mirrors chunks pattern)
source_fact_ids_by_obs: dict[str, list[str]] = {} # obs_id -> [source_id, ...]
source_facts_dict: dict[str, MemoryFact] | None = None
if include_source_facts:
observation_ids = [uuid.UUID(sr.id) for sr in top_scored if sr.retrieval.fact_type == "observation"]
if observation_ids:
async with acquire_with_retry(pool) as sf_conn:
# Fetch source_memory_ids for all observation results
obs_rows = await sf_conn.fetch(
f"""
SELECT id, source_memory_ids
FROM {fq_table("memory_units")}
WHERE id = ANY($1::uuid[]) AND fact_type = 'observation'
""",
observation_ids,
)

# Collect unique source IDs in order of first appearance
seen_source_ids: set[str] = set()
source_ids_ordered: list[str] = []
for obs_row in obs_rows:
obs_id = str(obs_row["id"])
sids = [str(s) for s in (obs_row["source_memory_ids"] or [])]
source_fact_ids_by_obs[obs_id] = sids
for sid in sids:
if sid not in seen_source_ids:
source_ids_ordered.append(sid)
seen_source_ids.add(sid)

# Fetch source fact content up to token budget
if source_ids_ordered:
import uuid as uuid_module

source_rows = await sf_conn.fetch(
f"""
SELECT id, text, fact_type, context, occurred_start, occurred_end,
mentioned_at, document_id, chunk_id, tags
FROM {fq_table("memory_units")}
WHERE id = ANY($1::uuid[])
""",
[uuid_module.UUID(sid) for sid in source_ids_ordered],
)
source_row_by_id = {str(r["id"]): r for r in source_rows}

encoding = _get_tiktoken_encoding()
source_facts_dict = {}
total_source_tokens = 0
for sid in source_ids_ordered:
if sid not in source_row_by_id:
continue
r = source_row_by_id[sid]
fact_tokens = len(encoding.encode(r["text"]))
if total_source_tokens + fact_tokens > max_source_facts_tokens:
break
source_facts_dict[sid] = MemoryFact(
id=sid,
text=r["text"],
fact_type=r["fact_type"],
context=r["context"],
occurred_start=r["occurred_start"].isoformat() if r["occurred_start"] else None,
occurred_end=r["occurred_end"].isoformat() if r["occurred_end"] else None,
mentioned_at=r["mentioned_at"].isoformat() if r["mentioned_at"] else None,
document_id=r["document_id"],
chunk_id=str(r["chunk_id"]) if r["chunk_id"] else None,
tags=r["tags"] or None,
)
total_source_tokens += fact_tokens

# Get entities for each fact if include_entities is requested
fact_entity_map = {} # unit_id -> list of (entity_id, entity_name)
if include_entities and top_scored:
Expand Down Expand Up @@ -2924,6 +2998,7 @@ def to_tuple_format(results):
document_id=result_dict.get("document_id"),
chunk_id=result_dict.get("chunk_id"),
tags=result_dict.get("tags"),
source_fact_ids=source_fact_ids_by_obs.get(result_id) if include_source_facts else None,
)
)

Expand Down Expand Up @@ -2977,7 +3052,13 @@ def to_tuple_format(results):
if not quiet:
logger.info("\n" + "\n".join(log_buffer))

return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
return RecallResultModel(
results=memory_facts,
trace=trace_dict,
entities=entities_dict,
chunks=chunks_dict,
source_facts=source_facts_dict,
)

except Exception as e:
log_buffer.append(f"[RECALL {recall_id}] ERROR after {time.time() - recall_start:.3f}s: {str(e)}")
Expand Down
7 changes: 7 additions & 0 deletions hindsight-api/hindsight_api/engine/response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ class MemoryFact(BaseModel):
None, description="ID of the chunk this fact was extracted from (format: bank_id_document_id_chunk_index)"
)
tags: list[str] | None = Field(None, description="Visibility scope tags associated with this fact")
source_fact_ids: list[str] | None = Field(
None,
description="IDs of source facts this observation was derived from (observation type only, when source_facts is enabled)",
)


class ChunkInfo(BaseModel):
Expand Down Expand Up @@ -226,6 +230,9 @@ class RecallResult(BaseModel):
chunks: dict[str, ChunkInfo] | None = Field(
None, description="Chunks for facts, keyed by '{document_id}_{chunk_index}'"
)
source_facts: dict[str, MemoryFact] | None = Field(
None, description="Source facts for observation-type results, keyed by fact ID"
)


class ReflectResult(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions hindsight-cli/src/commands/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ pub fn recall(
max_tokens: chunk_max_tokens,
}),
entities: None,
source_facts: None,
})
} else {
None
Expand Down
20 changes: 20 additions & 0 deletions hindsight-clients/go/api/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3226,6 +3226,8 @@ components:
$ref: '#/components/schemas/EntityIncludeOptions'
chunks:
$ref: '#/components/schemas/ChunkIncludeOptions'
source_facts:
$ref: '#/components/schemas/SourceFactsIncludeOptions'
title: IncludeOptions
ListDocumentsResponse:
description: Response model for list documents endpoint.
Expand Down Expand Up @@ -3724,6 +3726,10 @@ components:
additionalProperties:
$ref: '#/components/schemas/ChunkData'
nullable: true
source_facts:
additionalProperties:
$ref: '#/components/schemas/RecallResult'
nullable: true
required:
- results
title: RecallResponse
Expand Down Expand Up @@ -3789,6 +3795,11 @@ components:
type: string
nullable: true
type: array
source_fact_ids:
items:
type: string
nullable: true
type: array
required:
- id
- text
Expand Down Expand Up @@ -4148,6 +4159,15 @@ components:
- items_count
- success
title: RetainResponse
SourceFactsIncludeOptions:
description: Options for including source facts for observation-type results.
properties:
max_tokens:
default: 4096
description: Maximum tokens for source facts
title: Max Tokens
type: integer
title: SourceFactsIncludeOptions
TagItem:
description: Single tag with usage count.
properties:
Expand Down
46 changes: 46 additions & 0 deletions hindsight-clients/go/model_include_options.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading