Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions backend/.env-dev
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,39 @@ ALLOWED_HOSTS='["localhost", "127.0.0.1"]'

# AI Provider Settings
ENABLE_TEST_ROUTES=true # Enable test routes in development

## Temporary Settings

# Chat With Your Documents Plugins ENVS
# LLM provider selection
LLM_PROVIDER=ollama

# Embedding provider selection
EMBEDDING_PROVIDER=ollama

# Contextual Retrieval
ENABLE_CONTEXTUAL_RETRIEVAL=true
OLLAMA_CONTEXTUAL_LLM_BASE_URL=https://ollama-llama3-2-3b-979418853698.us-central1.run.app
OLLAMA_CONTEXTUAL_LLM_MODEL=llama3.2:3b

# Ollama LLM
OLLAMA_LLM_BASE_URL=https://ollama-qwen3-8b-979418853698.us-central1.run.app/
OLLAMA_LLM_MODEL=qwen3:8b

# Ollama Embedding
OLLAMA_EMBEDDING_BASE_URL=https://ollama-mxbai-embed-large-979418853698.us-central1.run.app
OLLAMA_EMBEDDING_MODEL=mxbai-embed-large

# Chroma DB

# BM25 Configuration
BM25_PERSIST_DIR=./data/bm25_index
BM25_INDEX_NAME=documents_bm25

# Database
# DATABASE_URL="postgresql://postgres:BpRzDvEHjCEEQiif@db.ukobjmisuhhcvpkqsstg.supabase.co:5432/postgres"

# Document Processor API Configuration
DOCUMENT_PROCESSOR_API_URL=https://braindrive-document-ai-979418853698.us-central1.run.app/documents/
DOCUMENT_PROCESSOR_TIMEOUT=300
DOCUMENT_PROCESSOR_MAX_RETRIES=3
598 changes: 589 additions & 9 deletions backend/app/api/v1/endpoints/auth.py

Large diffs are not rendered by default.

66 changes: 66 additions & 0 deletions backend/app/api/v1/endpoints/auth_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Authentication Fix Script
This script provides utilities to fix authentication token mismatches
"""

import asyncio
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))

from app.core.database import db_factory
from app.models.user import User
from sqlalchemy import select
import logging

logger = logging.getLogger(__name__)

async def clear_all_refresh_tokens():
"""Clear all refresh tokens from the database to force fresh login"""
async with db_factory.session_factory() as session:
# Get all users
result = await session.execute(select(User))
users = result.scalars().all()

print(f"Found {len(users)} users in database")

for user in users:
if user.refresh_token:
print(f"Clearing refresh token for user {user.email} (ID: {user.id})")
user.refresh_token = None
user.refresh_token_expires = None
await user.save(session)
else:
print(f"User {user.email} (ID: {user.id}) has no refresh token")

print("All refresh tokens cleared. Users will need to login again.")

async def show_user_tokens():
"""Show current refresh tokens for all users"""
async with db_factory.session_factory() as session:
result = await session.execute(select(User))
users = result.scalars().all()

print(f"Current user tokens:")
for user in users:
token_preview = user.refresh_token[:20] + "..." if user.refresh_token else "None"
print(f" User: {user.email} (ID: {user.id})")
print(f" Token: {token_preview}")
print(f" Expires: {user.refresh_token_expires}")
print()

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description='Fix authentication token issues')
parser.add_argument('--clear', action='store_true', help='Clear all refresh tokens')
parser.add_argument('--show', action='store_true', help='Show current tokens')

args = parser.parse_args()

if args.clear:
asyncio.run(clear_all_refresh_tokens())
elif args.show:
asyncio.run(show_user_tokens())
else:
print("Use --clear to clear all tokens or --show to display current tokens")
144 changes: 75 additions & 69 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,91 @@
# app/core/config.py
import json
import os
from typing import List, Optional
from pydantic_settings import BaseSettings
from pydantic import field_validator

class Settings(BaseSettings):
# Application settings
# Application
APP_NAME: str = "BrainDrive"
APP_ENV: str = os.getenv("APP_ENV", "dev") # 🔥 Set from .env (default to 'dev')
APP_ENV: str = "dev"
API_V1_PREFIX: str = "/api/v1"
DEBUG: bool = os.getenv("DEBUG", "true").lower() == "true"

# Server settings
HOST: str = os.getenv("HOST", "0.0.0.0")
PORT: int = int(os.getenv("PORT", 8005))
RELOAD: bool = os.getenv("RELOAD", "true").lower() == "true"
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "info")
PROXY_HEADERS: bool = os.getenv("PROXY_HEADERS", "true").lower() == "true"
FORWARDED_ALLOW_IPS: str = os.getenv("FORWARDED_ALLOW_IPS", "*")
SSL_KEYFILE: Optional[str] = os.getenv("SSL_KEYFILE", None)
SSL_CERTFILE: Optional[str] = os.getenv("SSL_CERTFILE", None)

# Security settings
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-here")
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 30))
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", 30))
ALGORITHM: str = os.getenv("ALGORITHM", "HS256")

# Database settings
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///braindrive.db")
DATABASE_TYPE: str = os.getenv("DATABASE_TYPE", "sqlite")
USE_JSON_STORAGE: bool = os.getenv("USE_JSON_STORAGE", "false").lower() == "true"
JSON_DB_PATH: str = os.getenv("JSON_DB_PATH", "./storage/database.json")
SQL_LOG_LEVEL: str = os.getenv("SQL_LOG_LEVEL", "WARNING")

# Redis settings
USE_REDIS: bool = os.getenv("USE_REDIS", "false").lower() == "true"
REDIS_HOST: str = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT: int = int(os.getenv("REDIS_PORT", 6379))

# CORS settings (Convert JSON strings to lists)
CORS_ORIGINS: str = os.getenv("CORS_ORIGINS", '["http://localhost:3000", "http://10.0.2.149:3000", "https://braindrive.ijustwantthebox.com"]')
CORS_METHODS: str = os.getenv("CORS_METHODS", '["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD"]')
CORS_HEADERS: str = os.getenv("CORS_HEADERS", '["Authorization", "Content-Type", "Accept", "Origin", "X-Requested-With"]')
CORS_EXPOSE_HEADERS: Optional[str] = os.getenv("CORS_EXPOSE_HEADERS", None)
CORS_MAX_AGE: int = int(os.getenv("CORS_MAX_AGE", 3600))
CORS_ALLOW_CREDENTIALS: bool = os.getenv("CORS_ALLOW_CREDENTIALS", "true").lower() == "true"

# Allowed Hosts
ALLOWED_HOSTS: str = os.getenv("ALLOWED_HOSTS", '["localhost", "127.0.0.1", "10.0.2.149", "braindrive.ijustwantthebox.com"]')

@property
def cors_origins_list(self) -> List[str]:
return json.loads(self.CORS_ORIGINS)

@property
def cors_methods_list(self) -> List[str]:
return json.loads(self.CORS_METHODS)

@property
def cors_headers_list(self) -> List[str]:
return json.loads(self.CORS_HEADERS)

@property
def cors_expose_headers_list(self) -> Optional[List[str]]:
return json.loads(self.CORS_EXPOSE_HEADERS) if self.CORS_EXPOSE_HEADERS else None

@property
def allowed_hosts_list(self) -> List[str]:
return json.loads(self.ALLOWED_HOSTS)

@property
def is_production(self) -> bool:
return self.APP_ENV == "prod" # 🔥 This makes it easy to check if we're in production
DEBUG: bool = True

# Server
HOST: str = "0.0.0.0"
PORT: int = 8005
RELOAD: bool = True
LOG_LEVEL: str = "info"
PROXY_HEADERS: bool = True
FORWARDED_ALLOW_IPS: str = "*"
SSL_KEYFILE: Optional[str] = None
SSL_CERTFILE: Optional[str] = None

# Security
SECRET_KEY: str = "your-secret-key-here"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_DAYS: int = 30
ALGORITHM: str = "HS256"

# Database
DATABASE_URL: str = "sqlite:///braindrive.db"
DATABASE_TYPE: str = "sqlite"
USE_JSON_STORAGE: bool = False
JSON_DB_PATH: str = "./storage/database.json"
SQL_LOG_LEVEL: str = "WARNING"

# Redis
USE_REDIS: bool = False
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379

# CORS Configuration - Revised for cross-platform compatibility
# Production origins (explicit list for security)
CORS_ORIGINS: List[str] = [] # Explicit origins for production only
CORS_ALLOW_CREDENTIALS: bool = True
CORS_MAX_AGE: int = 600
CORS_EXPOSE_HEADERS: List[str] = [] # e.g., ["X-Request-Id", "X-Total-Count"]

# Development CORS hosts (for regex generation)
CORS_DEV_HOSTS: List[str] = ["localhost", "127.0.0.1", "[::1]", "10.0.2.149"] # IPv6 support + network IP

# Allowed hosts
ALLOWED_HOSTS: List[str] = ["localhost", "127.0.0.1"]

@field_validator("CORS_ORIGINS", "CORS_EXPOSE_HEADERS", "CORS_DEV_HOSTS", mode="before")
@classmethod
def parse_cors_list(cls, v):
"""Parse CORS-related list fields from string or list"""
if v is None or v == "":
return []
if isinstance(v, str):
s = v.strip()
if not s: # Empty string
return []
if s.startswith("["): # JSON array
try:
return json.loads(s)
except json.JSONDecodeError:
return []
return [p.strip() for p in s.split(",") if p.strip()] # comma-separated
return v or []

@field_validator("ALLOWED_HOSTS", mode="before")
@classmethod
def parse_hosts(cls, v):
if isinstance(v, str):
s = v.strip()
if s.startswith("["):
return json.loads(s)
return [p.strip() for p in s.split(",") if p.strip()]
return v

model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": True,
"extra": "ignore" # Allow extra fields in the environment that aren't defined in the Settings class
"extra": "ignore",
}

settings = Settings()

__all__ = ["settings"]
124 changes: 124 additions & 0 deletions backend/app/core/cors_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import re
from typing import List
from urllib.parse import urlparse
import structlog

logger = structlog.get_logger("cors")

def build_dev_origin_regex(hosts: List[str] = None) -> str:
"""
Build regex pattern for development origins.
Supports IPv4, IPv6, and localhost with any port.

Args:
hosts: List of allowed hosts. Defaults to ["localhost", "127.0.0.1", "[::1]"]

Returns:
Regex pattern string for use with CORSMiddleware allow_origin_regex
"""
if not hosts:
hosts = ["localhost", "127.0.0.1", "[::1]"]

# Escape special regex characters and handle IPv6
escaped_hosts = []
for host in hosts:
if host.startswith("[") and host.endswith("]"):
# IPv6 - already bracketed, escape the brackets
escaped_hosts.append(re.escape(host))
else:
# IPv4 or hostname - escape dots and other special chars
escaped_hosts.append(re.escape(host))

# Create regex pattern: ^https?://(host1|host2|host3)(:\d+)?$
host_pattern = "|".join(escaped_hosts)
regex = rf"^https?://({host_pattern})(:\d+)?$"

logger.info("Development CORS regex created",
pattern=regex,
hosts=hosts)

return regex

def validate_production_origins(origins: List[str]) -> List[str]:
"""
Validate production origins are properly formatted URLs.

Args:
origins: List of origin URLs to validate

Returns:
List of validated origins
"""
validated = []
for origin in origins:
try:
parsed = urlparse(origin)
if not parsed.scheme or not parsed.netloc:
logger.warning("Invalid origin format - missing scheme or netloc",
origin=origin)
continue
if parsed.scheme not in ("http", "https"):
logger.warning("Invalid origin scheme - must be http or https",
origin=origin,
scheme=parsed.scheme)
continue
# Additional production checks
if parsed.scheme == "http" and not origin.startswith("http://localhost"):
logger.warning("HTTP origins not recommended for production",
origin=origin)
validated.append(origin)
except Exception as e:
logger.error("Error parsing origin",
origin=origin,
error=str(e))
continue

logger.info("Production origins validated",
total=len(origins),
valid=len(validated),
origins=validated)

return validated

def log_cors_config(app_env: str, **kwargs):
"""
Log CORS configuration for debugging purposes.

Args:
app_env: Application environment (dev, staging, prod)
**kwargs: Additional configuration parameters to log
"""
logger.info("CORS configuration applied",
environment=app_env,
**kwargs)

def get_cors_debug_info(request_origin: str = None, app_env: str = None) -> dict:
"""
Get debugging information for CORS issues.

Args:
request_origin: The origin from the request
app_env: Application environment

Returns:
Dictionary with debug information
"""
debug_info = {
"environment": app_env,
"request_origin": request_origin,
"timestamp": structlog.get_logger().info.__globals__.get("time", "unknown")
}

if request_origin:
try:
parsed = urlparse(request_origin)
debug_info.update({
"origin_scheme": parsed.scheme,
"origin_hostname": parsed.hostname,
"origin_port": parsed.port,
"origin_netloc": parsed.netloc
})
except Exception as e:
debug_info["origin_parse_error"] = str(e)

return debug_info
Loading