Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion hindsight-api/hindsight_api/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,11 +1922,16 @@ async def api_graph(
bank_id: str,
type: str | None = None,
limit: int = 1000,
q: str | None = None,
tags: list[str] | None = Query(None),
tags_match: str = "all_strict",
request_context: RequestContext = Depends(get_request_context),
):
"""Get graph data from database, filtered by bank_id and optionally by type."""
try:
data = await app.state.memory.get_graph_data(bank_id, type, limit=limit, request_context=request_context)
data = await app.state.memory.get_graph_data(
bank_id, type, limit=limit, q=q, tags=tags, tags_match=tags_match, request_context=request_context
)
return data
except (AuthenticationError, HTTPException):
raise
Expand Down
23 changes: 22 additions & 1 deletion hindsight-api/hindsight_api/engine/memory_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3508,6 +3508,9 @@ async def get_graph_data(
fact_type: str | None = None,
*,
limit: int = 1000,
q: str | None = None,
tags: list[str] | None = None,
tags_match: str = "all_strict",
request_context: "RequestContext",
):
"""
Expand All @@ -3517,6 +3520,9 @@ async def get_graph_data(
bank_id: Filter by bank ID
fact_type: Filter by fact type (world, experience, opinion)
limit: Maximum number of items to return (default: 1000)
q: Full-text search query (searches text and context fields)
tags: Filter by tags
tags_match: Tag matching mode (default: all_strict)
request_context: Request context for authentication.

Returns:
Expand All @@ -3540,6 +3546,20 @@ async def get_graph_data(
query_conditions.append(f"fact_type = ${param_count}")
query_params.append(fact_type)

if q:
param_count += 1
query_conditions.append(f"(text ILIKE ${param_count} OR context ILIKE ${param_count})")
query_params.append(f"%{q}%")

if tags:
from .search.tags import build_tags_where_clause_simple

tag_clause = build_tags_where_clause_simple(tags, param_count + 1, match=tags_match)
if tag_clause:
query_conditions.append(tag_clause.removeprefix("AND "))
param_count += 1
query_params.append(tags)

where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""

# Get total count first
Expand Down Expand Up @@ -3855,7 +3875,7 @@ async def list_memory_units(

units = await conn.fetch(
f"""
SELECT id, text, event_date, context, fact_type, mentioned_at, occurred_start, occurred_end, chunk_id, proof_count
SELECT id, text, event_date, context, fact_type, mentioned_at, occurred_start, occurred_end, chunk_id, proof_count, tags
FROM {fq_table("memory_units")}
{where_clause}
ORDER BY mentioned_at DESC NULLS LAST, created_at DESC
Expand Down Expand Up @@ -3908,6 +3928,7 @@ async def list_memory_units(
"entities": ", ".join(entities) if entities else "",
"chunk_id": row["chunk_id"] if row["chunk_id"] else None,
"proof_count": row["proof_count"] if row["proof_count"] is not None else 1,
"tags": list(row["tags"]) if row["tags"] else [],
}
)

Expand Down
171 changes: 171 additions & 0 deletions hindsight-api/tests/test_graph_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
Tests for server-side filtering in the graph API endpoint.

Verifies that q (text search) and tags filters work correctly
when passed as query parameters to GET /v1/default/banks/{bank_id}/graph.
"""
from datetime import datetime

import httpx
import pytest
import pytest_asyncio

from hindsight_api.api import create_app


@pytest_asyncio.fixture
async def api_client(memory):
"""Create an async test client for the FastAPI app."""
app = create_app(memory, initialize_memory=False)
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
yield client


@pytest.fixture
def test_bank_id():
"""Provide a unique bank ID for this test run."""
return f"graph_filter_test_{datetime.now().timestamp()}"


@pytest.mark.asyncio
async def test_graph_no_filter_returns_all(api_client, test_bank_id):
"""Without filters the graph endpoint returns all memories."""
response = await api_client.post(
f"/v1/default/banks/{test_bank_id}/memories",
json={
"items": [
{"content": "Alice loves hiking in the mountains.", "tags": ["user_alice"]},
{"content": "Bob enjoys swimming at the beach.", "tags": ["user_bob"]},
]
},
)
assert response.status_code == 200

response = await api_client.get(f"/v1/default/banks/{test_bank_id}/graph")
assert response.status_code == 200
data = response.json()
assert "table_rows" in data
texts = [row["text"] for row in data["table_rows"]]
assert any("Alice" in t for t in texts)
assert any("Bob" in t for t in texts)


@pytest.mark.asyncio
async def test_graph_q_filter_returns_matching(api_client, test_bank_id):
"""The q parameter filters memories by text content."""
response = await api_client.post(
f"/v1/default/banks/{test_bank_id}/memories",
json={
"items": [
{"content": "Alice loves hiking in the mountains."},
{"content": "Bob enjoys swimming at the beach."},
]
},
)
assert response.status_code == 200

response = await api_client.get(f"/v1/default/banks/{test_bank_id}/graph", params={"q": "Alice"})
assert response.status_code == 200
data = response.json()
texts = [row["text"] for row in data["table_rows"]]
assert all("Alice" in t or "alice" in t.lower() for t in texts), (
f"Expected only Alice memories, got: {texts}"
)
assert not any("Bob" in t for t in texts)


@pytest.mark.asyncio
async def test_graph_q_filter_case_insensitive(api_client, test_bank_id):
"""The q filter is case-insensitive."""
response = await api_client.post(
f"/v1/default/banks/{test_bank_id}/memories",
json={
"items": [
{"content": "Alice loves hiking in the mountains."},
{"content": "Bob enjoys swimming at the beach."},
]
},
)
assert response.status_code == 200

response = await api_client.get(f"/v1/default/banks/{test_bank_id}/graph", params={"q": "alice"})
assert response.status_code == 200
data = response.json()
texts = [row["text"] for row in data["table_rows"]]
assert any("Alice" in t for t in texts)
assert not any("Bob" in t for t in texts)


@pytest.mark.asyncio
async def test_graph_tags_filter_returns_matching(api_client, test_bank_id):
"""The tags parameter filters memories to only those with matching tags."""
response = await api_client.post(
f"/v1/default/banks/{test_bank_id}/memories",
json={
"items": [
{"content": "Alice loves hiking.", "tags": ["user_alice"]},
{"content": "Bob enjoys swimming.", "tags": ["user_bob"]},
]
},
)
assert response.status_code == 200

response = await api_client.get(
f"/v1/default/banks/{test_bank_id}/graph",
params={"tags": "user_alice", "tags_match": "all_strict"},
)
assert response.status_code == 200
data = response.json()
texts = [row["text"] for row in data["table_rows"]]
assert any("Alice" in t for t in texts)
assert not any("Bob" in t for t in texts)


@pytest.mark.asyncio
async def test_graph_q_and_tags_filter_combined(api_client, test_bank_id):
"""Combining q and tags filters applies both server-side."""
response = await api_client.post(
f"/v1/default/banks/{test_bank_id}/memories",
json={
"items": [
{"content": "Alice loves hiking.", "tags": ["user_alice"]},
{"content": "Alice also loves coding.", "tags": ["user_alice"]},
{"content": "Bob enjoys swimming.", "tags": ["user_bob"]},
]
},
)
assert response.status_code == 200

response = await api_client.get(
f"/v1/default/banks/{test_bank_id}/graph",
params={"q": "hiking", "tags": "user_alice", "tags_match": "all_strict"},
)
assert response.status_code == 200
data = response.json()
texts = [row["text"] for row in data["table_rows"]]
assert any("hiking" in t.lower() for t in texts)
assert not any("coding" in t.lower() for t in texts)
assert not any("Bob" in t for t in texts)


@pytest.mark.asyncio
async def test_graph_q_filter_empty_results(api_client, test_bank_id):
"""The q filter returns empty results when no memory matches."""
response = await api_client.post(
f"/v1/default/banks/{test_bank_id}/memories",
json={
"items": [
{"content": "Alice loves hiking."},
]
},
)
assert response.status_code == 200

response = await api_client.get(
f"/v1/default/banks/{test_bank_id}/graph",
params={"q": "zzznomatchzzz"},
)
assert response.status_code == 200
data = response.json()
assert data["table_rows"] == []
37 changes: 37 additions & 0 deletions hindsight-api/tests/test_tags_visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,3 +890,40 @@ async def test_list_tags_ordered_by_count(api_client):
# common (3) should come before medium (2) which should come before rare (1)
assert tags.index("common") < tags.index("medium")
assert tags.index("medium") < tags.index("rare")


@pytest.mark.asyncio
async def test_list_memories_includes_tags(api_client, test_bank_id):
"""Test that list memories endpoint returns tags for each memory unit.

Regression test: tags were previously omitted from the SELECT query in
list_memory_units, causing the memory dialog in the UI to show no tags
even when memories had been stored with tags.
"""
tags = ["user_alice", "session_xyz", "project_alpha", "team_eng", "env_prod", "region_us"]

response = await api_client.post(
f"/v1/default/banks/{test_bank_id}/memories",
json={
"items": [
{
"content": "Alice is a senior engineer on the platform team.",
"tags": tags,
}
]
},
)
assert response.status_code == 200

# List memories and verify all tags are returned
response = await api_client.get(f"/v1/default/banks/{test_bank_id}/memories/list")
assert response.status_code == 200
result = response.json()

assert result["total"] > 0
memory_item = next((item for item in result["items"] if "Alice" in item["text"]), None)
assert memory_item is not None, "Should find the stored memory"
assert "tags" in memory_item, "Memory item must include a 'tags' field"
assert set(memory_item["tags"]) == set(tags), (
f"All {len(tags)} tags should be returned, got: {memory_item['tags']}"
)
2 changes: 1 addition & 1 deletion hindsight-cli/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ impl ApiClient {
_verbose: bool,
) -> Result<types::GraphDataResponse> {
self.runtime.block_on(async {
let response = self.client.get_graph(bank_id, limit, type_filter, None).await?;
let response = self.client.get_graph(bank_id, limit, type_filter, None, None, None, None).await?;
Ok(response.into_inner())
})
}
Expand Down
28 changes: 28 additions & 0 deletions hindsight-clients/go/api/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,34 @@ paths:
title: Limit
type: integer
style: form
- explode: true
in: query
name: q
required: false
schema:
nullable: true
type: string
style: form
- explode: true
in: query
name: tags
required: false
schema:
items:
nullable: true
type: string
nullable: true
type: array
style: form
- explode: true
in: query
name: tags_match
required: false
schema:
default: all_strict
title: Tags Match
type: string
style: form
- explode: false
in: header
name: authorization
Expand Down
39 changes: 39 additions & 0 deletions hindsight-clients/go/api_memory.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading