Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import sqlalchemy as sa
from alembic import op
from pgvector.sqlalchemy import Vector
from sqlalchemy import text
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
Expand All @@ -23,8 +24,21 @@
def upgrade() -> None:
"""Upgrade schema - create all tables from scratch."""

# Enable required extensions
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
# Note: pgvector extension is installed globally BEFORE migrations run
# See migrations.py:run_migrations() - this ensures the extension is available
# to all schemas, not just the one being migrated

# We keep this here as a fallback for backwards compatibility
# This may fail if user lacks permissions, which is fine if extension already exists
try:
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
except Exception:
# Extension might already exist or user lacks permissions - verify it exists
conn = op.get_bind()
result = conn.execute(text("SELECT 1 FROM pg_extension WHERE extname = 'vector'")).fetchone()
if not result:
# Extension truly doesn't exist - re-raise the error
raise

# Create banks table
op.create_table(
Expand Down
2 changes: 1 addition & 1 deletion hindsight-api/hindsight_api/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,7 +1410,7 @@ async def lifespan(app: FastAPI):
poll_interval_ms=config.worker_poll_interval_ms,
max_retries=config.worker_max_retries,
schema=schema,
tenant_extension=getattr(memory, "_tenant_extension", None),
tenant_extension=memory._tenant_extension,
max_slots=config.worker_max_slots,
consolidation_max_slots=config.worker_consolidation_max_slots,
)
Expand Down
66 changes: 35 additions & 31 deletions hindsight-api/hindsight_api/engine/memory_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,11 @@ def __init__(
# Store operation validator extension (optional)
self._operation_validator = operation_validator

# Store tenant extension (optional)
# Store tenant extension (always set, use default if none provided)
if tenant_extension is None:
from ..extensions.builtin.tenant import DefaultTenantExtension

tenant_extension = DefaultTenantExtension(config={})
self._tenant_extension = tenant_extension

async def _validate_operation(self, validation_coro) -> None:
Expand Down Expand Up @@ -497,22 +501,18 @@ async def _authenticate_tenant(self, request_context: "RequestContext | None") -
Raises:
AuthenticationError: If authentication fails or request_context is missing when required.
"""
if self._tenant_extension is None:
_current_schema.set("public")
return "public"

from hindsight_api.extensions import AuthenticationError

if request_context is None:
raise AuthenticationError("RequestContext is required when tenant extension is configured")
raise AuthenticationError("RequestContext is required")

# For internal/background operations (e.g., worker tasks), skip extension authentication.
# The task was already authenticated at submission time, and execute_task sets _current_schema
# from the task's _schema field. For public schema tasks, _current_schema keeps its default "public".
# from the task's _schema field.
if request_context.internal:
return _current_schema.get()

# Let AuthenticationError propagate - HTTP layer will convert to 401
# Authenticate through tenant extension (always set, may be default no-auth extension)
tenant_context = await self._tenant_extension.authenticate(request_context)

_current_schema.set(tenant_context.schema_name)
Expand Down Expand Up @@ -939,30 +939,34 @@ async def verify_llm():

if not self.db_url:
raise ValueError("Database URL is required for migrations")
logger.info("Running database migrations...")
# Use configured database schema for migrations (defaults to "public")
run_migrations(self.db_url, schema=get_config().database_schema)

# Migrate all existing tenant schemas (if multi-tenant)
if self._tenant_extension is not None:
try:
tenants = await self._tenant_extension.list_tenants()
if tenants:
logger.info(f"Running migrations on {len(tenants)} tenant schemas...")
for tenant in tenants:
schema = tenant.schema
if schema and schema != "public":
try:
run_migrations(self.db_url, schema=schema)
except Exception as e:
logger.warning(f"Failed to migrate tenant schema {schema}: {e}")
logger.info("Tenant schema migrations completed")
except Exception as e:
logger.warning(f"Failed to run tenant schema migrations: {e}")

# Ensure embedding column dimension matches the model's dimension
# This is done after migrations and after embeddings.initialize()
ensure_embedding_dimension(self.db_url, self.embeddings.dimension, schema=get_config().database_schema)
# Migrate all schemas from the tenant extension
# The tenant extension is the single source of truth for which schemas exist
logger.info("Running database migrations...")
try:
tenants = await self._tenant_extension.list_tenants()
if tenants:
logger.info(f"Running migrations on {len(tenants)} schema(s)...")
for tenant in tenants:
schema = tenant.schema
if schema:
try:
run_migrations(self.db_url, schema=schema)
except Exception as e:
logger.warning(f"Failed to migrate schema {schema}: {e}")
logger.info("Schema migrations completed")

# Ensure embedding column dimension matches the model's dimension
# This is done after migrations and after embeddings.initialize()
for tenant in tenants:
schema = tenant.schema
if schema:
try:
ensure_embedding_dimension(self.db_url, self.embeddings.dimension, schema=schema)
except Exception as e:
logger.warning(f"Failed to ensure embedding dimension for schema {schema}: {e}")
except Exception as e:
logger.warning(f"Failed to run schema migrations: {e}")

logger.info(f"Connecting to PostgreSQL at {self.db_url}")

Expand Down
36 changes: 36 additions & 0 deletions hindsight-api/hindsight_api/extensions/builtin/tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,42 @@
from hindsight_api.models import RequestContext


class DefaultTenantExtension(TenantExtension):
"""
Default single-tenant extension with no authentication.

This is the default extension used when no tenant extension is configured.
It provides single-tenant behavior using the configured schema from
HINDSIGHT_API_DATABASE_SCHEMA (defaults to 'public').

Features:
- No authentication required (passes all requests)
- Uses configured schema from environment
- Perfect for single-tenant deployments without auth

Configuration:
HINDSIGHT_API_DATABASE_SCHEMA=your-schema (optional, defaults to 'public')

This is automatically enabled by default. To use custom authentication,
configure a different tenant extension:
HINDSIGHT_API_TENANT_EXTENSION=hindsight_api.extensions.builtin.tenant:ApiKeyTenantExtension
"""

def __init__(self, config: dict[str, str]):
super().__init__(config)
# Cache the schema at initialization for consistency
# Support explicit schema override via config, otherwise use environment
self._schema = config.get("schema", get_config().database_schema)

async def authenticate(self, context: RequestContext) -> TenantContext:
"""Return configured schema without any authentication."""
return TenantContext(schema_name=self._schema)

async def list_tenants(self) -> list[Tenant]:
"""Return configured schema for single-tenant setup."""
return [Tenant(schema=self._schema)]


class ApiKeyTenantExtension(TenantExtension):
"""
Built-in tenant extension that validates API key against an environment variable.
Expand Down
75 changes: 75 additions & 0 deletions hindsight-api/hindsight_api/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,81 @@ def run_migrations(
logger.debug("Migration advisory lock acquired")

try:
# Ensure pgvector extension is installed globally BEFORE schema migrations
# This is critical: the extension must exist database-wide before any schema
# migrations run, otherwise custom schemas won't have access to vector types
logger.debug("Checking pgvector extension availability...")

# First, check if extension already exists
ext_check = conn.execute(
text(
"SELECT extname, nspname FROM pg_extension e "
"JOIN pg_namespace n ON e.extnamespace = n.oid "
"WHERE extname = 'vector'"
)
).fetchone()

if ext_check:
# Extension exists - check if in correct schema
ext_schema = ext_check[1]
if ext_schema == "public":
logger.info("pgvector extension found in public schema - ready to use")
else:
# Extension in wrong schema - try to fix if we have permissions
logger.warning(
f"pgvector extension found in schema '{ext_schema}' instead of 'public'. "
f"Attempting to relocate..."
)
try:
conn.execute(text("DROP EXTENSION vector CASCADE"))
conn.execute(text("SET search_path TO public"))
conn.execute(text("CREATE EXTENSION vector"))
conn.commit()
logger.info("pgvector extension relocated to public schema")
except Exception as e:
# Failed to relocate - log but don't fail if extension exists somewhere
logger.warning(
f"Could not relocate pgvector extension to public schema: {e}. "
f"Continuing with extension in '{ext_schema}' schema."
)
conn.rollback()
else:
# Extension doesn't exist - try to install
logger.info("pgvector extension not found, attempting to install...")
try:
conn.execute(text("SET search_path TO public"))
conn.execute(text("CREATE EXTENSION vector"))
conn.commit()
logger.info("pgvector extension installed in public schema")
except Exception as e:
# Installation failed - this is only fatal if extension truly doesn't exist
# Check one more time in case another process installed it
conn.rollback()
ext_recheck = conn.execute(
text(
"SELECT nspname FROM pg_extension e "
"JOIN pg_namespace n ON e.extnamespace = n.oid "
"WHERE extname = 'vector'"
)
).fetchone()

if ext_recheck:
logger.warning(
f"Could not install pgvector extension (permission denied?), "
f"but extension exists in '{ext_recheck[0]}' schema. Continuing..."
)
else:
# Extension truly doesn't exist and we can't install it
logger.error(
f"pgvector extension is not installed and cannot be installed: {e}. "
f"Please ensure pgvector is installed by a database administrator. "
f"See: https://github.com/pgvector/pgvector#installation"
)
raise RuntimeError(
"pgvector extension is required but not installed. "
"Please install it with: CREATE EXTENSION vector;"
) from e

# Run migrations while holding the lock
_run_migrations_internal(database_url, script_location, schema=schema)
finally:
Expand Down
36 changes: 27 additions & 9 deletions hindsight-api/hindsight_api/worker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,30 @@ async def run():
# Create the HTTP app for metrics/health
app = create_worker_app(poller, memory)

# Setup signal handlers for graceful shutdown
# Setup signal handlers for graceful shutdown using asyncio
shutdown_requested = asyncio.Event()

def signal_handler(signum, frame):
print(f"\nReceived signal {signum}, initiating graceful shutdown...")
shutdown_requested.set()

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
force_exit = False

loop = asyncio.get_event_loop()

def signal_handler():
nonlocal force_exit
if shutdown_requested.is_set():
# Second signal = force exit
print("\nReceived second signal, forcing immediate exit...")
force_exit = True
# Restore default handler so third signal kills process
loop.remove_signal_handler(signal.SIGINT)
loop.remove_signal_handler(signal.SIGTERM)
sys.exit(1)
else:
print("\nReceived shutdown signal, initiating graceful shutdown...")
print("(Press Ctrl+C again to force immediate exit)")
shutdown_requested.set()

# Use asyncio's signal handlers which work properly with the event loop
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)

# Create uvicorn config and server
uvicorn_config = uvicorn.Config(
Expand All @@ -249,7 +264,10 @@ def signal_handler(signum, frame):
print(f"Worker started. Metrics available at http://{args.http_host}:{args.http_port}/metrics")

# Wait for shutdown signal
await shutdown_requested.wait()
try:
await shutdown_requested.wait()
except KeyboardInterrupt:
print("\nReceived interrupt, initiating graceful shutdown...")

# Graceful shutdown
print("Shutting down HTTP server...")
Expand Down
26 changes: 15 additions & 11 deletions hindsight-api/hindsight_api/worker/poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def __init__(
executor: Async function to execute tasks (typically MemoryEngine.execute_task)
poll_interval_ms: Interval between polls when no tasks found (milliseconds)
max_retries: Maximum retry attempts before marking task as failed
schema: Database schema for single-tenant support (ignored if tenant_extension is set)
tenant_extension: Extension for dynamic multi-tenant discovery. If set, list_tenants()
is called on each poll cycle to discover schemas dynamically.
schema: Database schema for single-tenant support (deprecated, use tenant_extension)
tenant_extension: Extension for dynamic multi-tenant discovery. If None, creates a
DefaultTenantExtension with the configured schema.
max_slots: Maximum concurrent tasks per worker
consolidation_max_slots: Maximum concurrent consolidation tasks per worker
"""
Expand All @@ -84,6 +84,13 @@ def __init__(
self._poll_interval_ms = poll_interval_ms
self._max_retries = max_retries
self._schema = schema
# Always set tenant extension (use DefaultTenantExtension if none provided)
if tenant_extension is None:
from ..extensions.builtin.tenant import DefaultTenantExtension

# Pass schema parameter to DefaultTenantExtension if explicitly provided
config = {"schema": schema} if schema else {}
tenant_extension = DefaultTenantExtension(config=config)
self._tenant_extension = tenant_extension
self._max_slots = max_slots
self._consolidation_max_slots = consolidation_max_slots
Expand All @@ -100,14 +107,11 @@ def __init__(

async def _get_schemas(self) -> list[str | None]:
"""Get list of schemas to poll. Returns [None] for default schema (no prefix)."""
if self._tenant_extension is not None:
from ..config import DEFAULT_DATABASE_SCHEMA

tenants = await self._tenant_extension.list_tenants()
# Convert default schema to None for SQL compatibility (no prefix), keep others as-is
return [t.schema if t.schema != DEFAULT_DATABASE_SCHEMA else None for t in tenants]
# Single schema mode
return [self._schema]
from ..config import DEFAULT_DATABASE_SCHEMA

tenants = await self._tenant_extension.list_tenants()
# Convert default schema to None for SQL compatibility (no prefix), keep others as-is
return [t.schema if t.schema != DEFAULT_DATABASE_SCHEMA else None for t in tenants]

async def _get_available_slots(self) -> tuple[int, int]:
"""
Expand Down
Loading