Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions hindsight-api/hindsight_api/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,13 +1398,18 @@ async def lifespan(app: FastAPI):

# Start worker poller if enabled (standalone mode)
if config.worker_enabled and memory._pool is not None:
from ..config import DEFAULT_DATABASE_SCHEMA

worker_id = config.worker_id or socket.gethostname()
# Convert default schema to None for SQL compatibility (no schema prefix)
schema = None if config.database_schema == DEFAULT_DATABASE_SCHEMA else config.database_schema
poller = WorkerPoller(
pool=memory._pool,
worker_id=worker_id,
executor=memory.execute_task,
poll_interval_ms=config.worker_poll_interval_ms,
max_retries=config.worker_max_retries,
schema=schema,
tenant_extension=getattr(memory, "_tenant_extension", None),
max_slots=config.worker_max_slots,
consolidation_max_slots=config.worker_consolidation_max_slots,
Expand Down
7 changes: 6 additions & 1 deletion hindsight-api/hindsight_api/worker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,20 @@ async def run():
if tenant_extension:
print("Tenant extension loaded - schemas will be discovered dynamically on each poll")
else:
print("No tenant extension configured, using public schema only")
print(f"No tenant extension configured, using schema: {config.database_schema}")

# Create a single poller that handles all schemas dynamically
# Convert default schema to None for SQL compatibility (no schema prefix)
from hindsight_api.config import DEFAULT_DATABASE_SCHEMA

schema = None if config.database_schema == DEFAULT_DATABASE_SCHEMA else config.database_schema
poller = WorkerPoller(
pool=memory._pool,
worker_id=args.worker_id,
executor=memory.execute_task,
poll_interval_ms=args.poll_interval,
max_retries=args.max_retries,
schema=schema,
tenant_extension=tenant_extension,
max_slots=config.worker_max_slots,
consolidation_max_slots=config.worker_consolidation_max_slots,
Expand Down
22 changes: 15 additions & 7 deletions hindsight-api/hindsight_api/worker/poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,13 @@ def __init__(
self._in_flight_by_type: dict[str, int] = {}

async def _get_schemas(self) -> list[str | None]:
"""Get list of schemas to poll. Returns [None] for public schema."""
"""Get list of schemas to poll. Returns [None] for default schema (no prefix)."""
if self._tenant_extension is not None:
from ..config import DEFAULT_DATABASE_SCHEMA

tenants = await self._tenant_extension.list_tenants()
# Convert "public" to None for SQL compatibility, keep others as-is
return [t.schema if t.schema != "public" else None for t in tenants]
# Convert default schema to None for SQL compatibility (no prefix), keep others as-is
return [t.schema if t.schema != DEFAULT_DATABASE_SCHEMA else None for t in tenants]
# Single schema mode
return [self._schema]

Expand Down Expand Up @@ -194,7 +196,9 @@ async def _claim_batch_for_schema(
try:
return await self._claim_batch_for_schema_inner(schema, limit, consolidation_limit)
except Exception as e:
logger.warning(f"Worker {self._worker_id} failed to claim tasks for schema {schema or 'public'}: {e}")
# Format schema for logging: custom schemas in quotes, None as-is
schema_display = f'"{schema}"' if schema else str(schema)
logger.warning(f"Worker {self._worker_id} failed to claim tasks for schema {schema_display}: {e}")
return []

async def _claim_batch_for_schema_inner(
Expand Down Expand Up @@ -418,7 +422,9 @@ async def recover_own_tasks(self) -> int:
count = int(result.split()[-1]) if result else 0
total_count += count
except Exception as e:
logger.warning(f"Worker {self._worker_id} failed to recover tasks for schema {schema or 'public'}: {e}")
# Format schema for logging: custom schemas in quotes, None as-is
schema_display = f'"{schema}"' if schema else str(schema)
logger.warning(f"Worker {self._worker_id} failed to recover tasks for schema {schema_display}: {e}")

if total_count > 0:
logger.info(f"Worker {self._worker_id} recovered {total_count} stale tasks from previous run")
Expand Down Expand Up @@ -457,7 +463,8 @@ async def run(self):
consolidation_count += 1

types_str = ", ".join(f"{k}:{v}" for k, v in task_types.items())
schemas_str = ", ".join(s or "public" for s in schemas_seen)
# Display None as "default" in logs
schemas_str = ", ".join(s if s else "default" for s in schemas_seen)
logger.info(
f"Worker {self._worker_id} claimed {len(tasks)} tasks "
f"({consolidation_count} consolidation): {types_str} (schemas: {schemas_str})"
Expand Down Expand Up @@ -591,7 +598,8 @@ async def _log_progress_if_due(self):
other_workers.append(f"{wid}:{cnt}")
others_str = ", ".join(other_workers) if other_workers else "none"

schemas_str = ", ".join(s or "public" for s in schemas)
# Display None as "default" in logs
schemas_str = ", ".join(s if s else "default" for s in schemas)
logger.info(
f"[WORKER_STATS] worker={self._worker_id} "
f"slots={in_flight}/{self._max_slots} (consolidation={consolidation_count}/{self._consolidation_max_slots}) | "
Expand Down
75 changes: 75 additions & 0 deletions hindsight-api/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,81 @@ async def test_poller_without_tenant_extension_uses_public(self, pool, clean_ope
for task in claimed:
assert task.schema is None

@pytest.mark.asyncio
async def test_poller_with_custom_schema(self, pool):
"""Test that poller uses custom schema when schema parameter is provided."""
from hindsight_api.worker import WorkerPoller

# Create a custom schema for testing
test_schema = "test_custom_schema"

try:
# Create schema and copy table structure
await pool.execute(f'CREATE SCHEMA IF NOT EXISTS "{test_schema}"')
await pool.execute(
f"""
CREATE TABLE "{test_schema}".async_operations (
LIKE public.async_operations INCLUDING ALL
)
"""
)

# Create pending tasks in the custom schema
bank_id = f"test-worker-{uuid.uuid4().hex[:8]}"
task_ids = []
for i in range(3):
op_id = uuid.uuid4()
task_ids.append(str(op_id))
payload = json.dumps({"type": "test_task", "index": i, "bank_id": bank_id})
await pool.execute(
f"""
INSERT INTO "{test_schema}".async_operations (operation_id, bank_id, operation_type, status, task_payload)
VALUES ($1, $2, 'test', 'pending', $3::jsonb)
""",
op_id,
bank_id,
payload,
)

# Create poller with custom schema
poller = WorkerPoller(
pool=pool,
worker_id="test-worker-custom-schema",
executor=lambda x: None,
schema=test_schema,
)

# Claim tasks
claimed = await poller.claim_batch()
assert len(claimed) == 3, f"Expected 3 tasks, got {len(claimed)}"

# All tasks should have schema=test_schema
claimed_ids = []
for task in claimed:
assert task.schema == test_schema, f"Expected schema '{test_schema}', got '{task.schema}'"
claimed_ids.append(task.operation_id)

# Verify claimed tasks match what we inserted
assert set(claimed_ids) == set(task_ids)

# Verify tasks are marked as processing in the custom schema
rows = await pool.fetch(
f"""
SELECT operation_id, status, worker_id
FROM "{test_schema}".async_operations
WHERE operation_id = ANY($1)
""",
[uuid.UUID(tid) for tid in task_ids],
)
assert len(rows) == 3
for row in rows:
assert row["status"] == "processing"
assert row["worker_id"] == "test-worker-custom-schema"

finally:
# Clean up: drop the custom schema
await pool.execute(f'DROP SCHEMA IF EXISTS "{test_schema}" CASCADE')


async def test_worker_fire_and_forget_nonblocking(pool, clean_operations):
"""
Expand Down
Loading