Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 57 additions & 13 deletions backend/app/ai_providers/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,23 @@ def provider_name(self) -> str:
return "ollama"

async def initialize(self, config: Dict[str, Any]) -> bool:
import logging
logger = logging.getLogger(__name__)

# Debug logging for server URL resolution
logger.info(f"🔧 Ollama provider initializing with config: {config}")

self.server_url = config.get("server_url", "http://localhost:11434")
self.api_key = config.get("api_key", "")
self.server_name = config.get("server_name", "Default Ollama Server")

# Log what URL we're actually using
if self.server_url == "http://localhost:11434" and "server_url" not in config:
logger.warning(f"⚠️ Ollama provider defaulting to localhost! Config was: {config}")
else:
logger.info(f"✅ Ollama provider using server_url: {self.server_url}")

logger.info(f"🎯 Ollama provider initialized - server_name: {self.server_name}, server_url: {self.server_url}")
return True

async def get_models(self) -> List[Dict[str, Any]]:
Expand All @@ -41,13 +55,12 @@ async def generate_stream(self, prompt: str, model: str, params: Dict[str, Any])
yield chunk

async def chat_completion(self, messages: List[Dict[str, Any]], model: str, params: Dict[str, Any]) -> Dict[str, Any]:
print(f"Ollama chat_completion received {len(messages)} messages")
print(f"DETAILED MESSAGE INSPECTION:")
for i, msg in enumerate(messages):
print(f" Message {i+1}: role={msg.get('role', 'unknown')}, content={msg.get('content', '')}")

print(f"🤖 OLLAMA CHAT_COMPLETION CALLED")
print(f"📊 Server URL: {self.server_url}")
print(f"📊 Server Name: {self.server_name}")
print(f"📊 Model: {model}")
print(f"📊 Messages: {len(messages)} messages")
prompt = self._format_chat_messages(messages)
print(f"Formatted prompt (full):\n{prompt}")

result = await self._call_ollama_api(prompt, model, params, is_streaming=False)
if "error" not in result:
Expand All @@ -61,13 +74,7 @@ async def chat_completion(self, messages: List[Dict[str, Any]], model: str, para
return result

async def chat_completion_stream(self, messages: List[Dict[str, Any]], model: str, params: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]:
print(f"Ollama chat_completion_stream received {len(messages)} messages")
print(f"DETAILED MESSAGE INSPECTION (STREAMING):")
for i, msg in enumerate(messages):
print(f" Message {i+1}: role={msg.get('role', 'unknown')}, content={msg.get('content', '')}")

prompt = self._format_chat_messages(messages)
print(f"Formatted prompt (full):\n{prompt}")

async for chunk in self._stream_ollama_api(prompt, model, params):
if "error" not in chunk:
Expand All @@ -81,16 +88,23 @@ async def chat_completion_stream(self, messages: List[Dict[str, Any]], model: st
yield chunk

async def _call_ollama_api(self, prompt: str, model: str, params: Dict[str, Any], is_streaming: bool = False) -> Dict[str, Any]:
import logging
logger = logging.getLogger(__name__)

payload_params = params.copy() if params else {}
payload_params["stream"] = False
payload = {"model": model, "prompt": prompt, **payload_params}
headers = {'Content-Type': 'application/json'}
if self.api_key:
headers['Authorization'] = f'Bearer {self.api_key}'

# Log the actual URL being called
api_url = f"{self.server_url}/api/generate"
logger.debug(f"Making Ollama API call to: {api_url}")

try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(f"{self.server_url}/api/generate", json=payload, headers=headers)
response = await client.post(api_url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
return {
Expand All @@ -100,6 +114,36 @@ async def _call_ollama_api(self, prompt: str, model: str, params: Dict[str, Any]
"metadata": result,
"finish_reason": result.get("done") and "stop" or None
}
except httpx.ConnectError as e:
logger.error(f"❌ Cannot connect to Ollama server at {api_url}")
return {
"error": f"Cannot connect to Ollama server at {self.server_url}. "
f"Please check if the server is running and accessible.",
"provider": "ollama",
"model": model,
"server_name": self.server_name,
"server_url": self.server_url
}
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
logger.error(f"❌ Model '{model}' not found on server {self.server_name}")
return {
"error": f"Model '{model}' not found on Ollama server '{self.server_name}'. "
f"Please check if the model is installed or use a different model.",
"provider": "ollama",
"model": model,
"server_name": self.server_name,
"server_url": self.server_url
}
else:
logger.error(f"❌ HTTP error {e.response.status_code} from server {self.server_name}")
return {
"error": f"HTTP {e.response.status_code} error from Ollama server '{self.server_name}': {e.response.text}",
"provider": "ollama",
"model": model,
"server_name": self.server_name,
"server_url": self.server_url
}
except Exception as e:
return self._format_error(e, model)

Expand Down
6 changes: 6 additions & 0 deletions backend/app/ai_providers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def register_provider(self, name: str, provider_class: Type[AIProvider]) -> None

async def get_provider(self, name: str, instance_id: str, config: Dict[str, Any]) -> AIProvider:
"""Get or create a provider instance."""
import logging
logger = logging.getLogger(__name__)
logger.debug(f"Provider registry: get_provider called - {name}, instance: {instance_id}")

if name not in self._providers:
raise ValueError(f"Provider '{name}' not registered")

Expand All @@ -43,9 +47,11 @@ async def get_provider(self, name: str, instance_id: str, config: Dict[str, Any]

# Check if instance exists
if instance_key in self._instances[name]:
logger.debug(f"Using existing provider instance: {instance_key}")
return self._instances[name][instance_key]

# Create new instance
logger.debug(f"Creating new provider instance: {instance_key}")
provider = self._providers[name]()
await provider.initialize(config)
self._instances[name][instance_key] = provider
Expand Down
141 changes: 94 additions & 47 deletions backend/app/api/v1/endpoints/ai_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from app.models.user import User
from app.ai_providers.registry import provider_registry
from app.ai_providers.ollama import OllamaProvider
from app.utils.json_parsing import safe_encrypted_json_parse, validate_ollama_settings_format, create_default_ollama_settings
from app.schemas.ai_providers import (
TextGenerationRequest,
ChatCompletionRequest,
Expand All @@ -41,6 +42,12 @@ async def get_provider_instance_from_request(request, db):
user_id = user_id.replace("-", "")

logger = logging.getLogger(__name__)
print(f"🚀 PROVIDER REQUEST RECEIVED")
print(f"📊 Provider: {request.provider}")
print(f"📊 Settings ID: {request.settings_id}")
print(f"📊 Server ID: {request.server_id}")
print(f"📊 Model: {getattr(request, 'model', 'N/A')}")
print(f"📊 User ID: {user_id}")
logger.info(f"Getting provider instance for: settings_id={request.settings_id}, user_id={user_id}")
logger.info(f"Original user_id from request: {request.user_id}")

Expand Down Expand Up @@ -94,43 +101,63 @@ async def get_provider_instance_from_request(request, db):

# Use the first setting found
setting = settings[0]
logger.info(f"Using setting with ID: {setting['id'] if isinstance(setting, dict) else setting.id}")
logger.debug(f"Using setting with ID: {setting['id'] if isinstance(setting, dict) else setting.id}")

# Extract configuration from settings value
# Parse the JSON string if value is a string
# Extract configuration from settings value using robust parsing
setting_value = setting['value'] if isinstance(setting, dict) else setting.value
setting_id = setting['id'] if isinstance(setting, dict) else setting.id

if isinstance(setting_value, str):
try:
# First parse
logger.info("Parsing settings value as JSON string")
value_dict = json.loads(setting_value)
logger.debug(f"Parsed JSON string into: {value_dict}")

# Check if the result is still a string (double-encoded JSON)
if isinstance(value_dict, str):
logger.info("Value is double-encoded JSON, parsing again")
value_dict = json.loads(value_dict)
logger.debug(f"Parsed double-encoded JSON into: {value_dict}")

# Now value_dict should be a dictionary
if not isinstance(value_dict, dict):
logger.error(f"Error: value_dict is not a dictionary: {type(value_dict)}")
# Use our robust JSON parsing utility that handles encryption issues
try:
value_dict = safe_encrypted_json_parse(
setting_value,
context=f"settings_id={request.settings_id}, user_id={user_id}",
setting_id=setting_id,
definition_id=request.settings_id
)

# Ensure we have a dictionary
if not isinstance(value_dict, dict):
logger.error(f"Parsed value is not a dictionary: {type(value_dict)}")
# For Ollama settings, provide a default structure
if 'ollama' in request.settings_id.lower():
logger.warning("Creating default Ollama settings structure")
value_dict = create_default_ollama_settings()
else:
raise HTTPException(
status_code=500,
detail=f"Error parsing settings value: expected dict, got {type(value_dict)}. "
f"Please check the format of your settings."
detail=f"Settings value must be a dictionary, got {type(value_dict)}. "
f"Setting ID: {setting_id}"
)
except json.JSONDecodeError as e:
logger.error(f"Error parsing JSON string: {e}")

logger.debug(f"Successfully parsed settings value for {request.settings_id}")

except ValueError as e:
logger.error(f"Failed to parse encrypted settings: {e}")
# For Ollama settings, provide helpful error message and fallback
if 'ollama' in request.settings_id.lower():
logger.info("Ollama settings parsing failed, using fallback configuration")
value_dict = create_default_ollama_settings()
else:
raise HTTPException(
status_code=500,
detail=f"Error parsing settings value: {str(e)}. "
f"Please ensure the settings value is valid JSON."
detail=str(e)
)
else:
logger.info("Settings value is already a dictionary")
value_dict = setting_value
except Exception as e:
logger.error(f"Unexpected error parsing settings: {e}")
raise HTTPException(
status_code=500,
detail=f"Unexpected error parsing settings: {str(e)}. Setting ID: {setting_id}"
)

# Add specific validation for Ollama settings
if 'ollama' in request.settings_id.lower():
logger.info("Validating Ollama settings format")
if not validate_ollama_settings_format(value_dict):
logger.warning("Ollama settings format validation failed, using default structure")
value_dict = create_default_ollama_settings()
else:
logger.info("Ollama settings format validation passed")

# Handle different provider configurations
if request.provider == "openai":
Expand Down Expand Up @@ -207,32 +234,45 @@ async def get_provider_instance_from_request(request, db):
logger.info(f"Created Groq config with API key")
else:
# Other providers (like Ollama) use servers array
logger.info("Processing server-based provider configuration")
logger.debug("Processing server-based provider configuration")
servers = value_dict.get("servers", [])
logger.info(f"Found {len(servers)} servers in settings")
logger.debug(f"Found {len(servers)} servers in settings")

logger.debug("Processing server-based provider configuration")
servers = value_dict.get("servers", [])
logger.debug(f"Found {len(servers)} servers in settings")

# Find the specific server by ID
logger.info(f"Looking for server with ID: {request.server_id}")
logger.debug(f"Looking for server with ID: '{request.server_id}'")
server = next((s for s in servers if s.get("id") == request.server_id), None)
if not server and servers:
# If the requested server ID is not found but there are servers available,
# use the first server as a fallback
logger.warning(f"Server with ID {request.server_id} not found, using first available server as fallback")
server = servers[0]
logger.info(f"Using fallback server: {server.get('serverName')} ({server.get('id')})")

if not server:
logger.error(f"No server found with ID: {request.server_id} and no fallback available")
raise HTTPException(
status_code=404,
detail=f"Server not found with ID: {request.server_id}. "
f"Please check your server configuration or use a different server ID."
)
# Provide detailed error message about available servers
if servers:
available_servers = [f"{s.get('serverName', 'Unknown')} (ID: {s.get('id', 'Unknown')})" for s in servers]
available_list = ", ".join(available_servers)
logger.error(f"❌ Server with ID '{request.server_id}' not found")
logger.error(f"📋 Available servers: {available_list}")
raise HTTPException(
status_code=404,
detail=f"Ollama server '{request.server_id}' not found. "
f"Available servers: {available_list}. "
f"Please select a valid server from your Ollama settings."
)
else:
logger.error(f"❌ No Ollama servers configured")
raise HTTPException(
status_code=404,
detail="No Ollama servers are configured. "
"Please add at least one Ollama server in your settings before using this provider."
)

logger.info(f"Found server: {server.get('serverName')}")
logger.debug(f"Found server: {server.get('serverName')} (ID: {server.get('id')})")

# Create provider configuration from server details
server_url = server.get("serverAddress")
logger.debug(f"Server URL from settings: '{server_url}'")

if not server_url:
logger.error(f"Server URL is missing for server: {server.get('id')}")
raise HTTPException(
Expand All @@ -246,11 +286,11 @@ async def get_provider_instance_from_request(request, db):
"api_key": server.get("apiKey", ""),
"server_name": server.get("serverName", "Unknown Server")
}

logger.info(f"Created config with server_url: {config.get('server_url', 'N/A')}")
logger.debug(f"Created server config: {config.get('server_name')} -> {config.get('server_url')}")

# Get provider instance
logger.info(f"Getting provider instance for: {request.provider}, {request.server_id}")
logger.debug(f"Getting provider instance for: {request.provider}, {request.server_id}")
provider_instance = await provider_registry.get_provider(
request.provider,
request.server_id,
Expand Down Expand Up @@ -716,6 +756,13 @@ async def chat_completion(request: ChatCompletionRequest, db: AsyncSession = Dep
"""
logger = logging.getLogger(__name__)
try:
print(f"🎯 CHAT COMPLETION ENDPOINT CALLED")
print(f"📊 Provider: {request.provider}")
print(f"📊 Settings ID: {request.settings_id}")
print(f"📊 Server ID: {request.server_id}")
print(f"📊 Model: {request.model}")
print(f"📊 User ID: {request.user_id}")
print(f"📊 Stream: {request.stream}")
logger.info(f"Production chat endpoint called with: provider={request.provider}, settings_id={request.settings_id}, server_id={request.server_id}, model={request.model}")
logger.debug(f"Messages: {request.messages}")
logger.debug(f"Params: {request.params}")
Expand Down
Loading