Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions hindsight-api/hindsight_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
ENV_RERANKER_LITELLM_API_KEY = "HINDSIGHT_API_RERANKER_LITELLM_API_KEY"
ENV_RERANKER_LITELLM_MODEL = "HINDSIGHT_API_RERANKER_LITELLM_MODEL"

# LiteLLM SDK configuration (direct API access, no proxy needed)
ENV_EMBEDDINGS_LITELLM_SDK_API_KEY = "HINDSIGHT_API_EMBEDDINGS_LITELLM_SDK_API_KEY"
ENV_EMBEDDINGS_LITELLM_SDK_MODEL = "HINDSIGHT_API_EMBEDDINGS_LITELLM_SDK_MODEL"
ENV_EMBEDDINGS_LITELLM_SDK_API_BASE = "HINDSIGHT_API_EMBEDDINGS_LITELLM_SDK_API_BASE"
ENV_RERANKER_LITELLM_SDK_API_KEY = "HINDSIGHT_API_RERANKER_LITELLM_SDK_API_KEY"
ENV_RERANKER_LITELLM_SDK_MODEL = "HINDSIGHT_API_RERANKER_LITELLM_SDK_MODEL"
ENV_RERANKER_LITELLM_SDK_API_BASE = "HINDSIGHT_API_RERANKER_LITELLM_SDK_API_BASE"

# Deprecated: Legacy shared LiteLLM config (for backward compatibility)
ENV_LITELLM_API_BASE = "HINDSIGHT_API_LITELLM_API_BASE"
ENV_LITELLM_API_KEY = "HINDSIGHT_API_LITELLM_API_KEY"
Expand Down Expand Up @@ -328,6 +336,10 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
DEFAULT_EMBEDDINGS_LITELLM_MODEL = "text-embedding-3-small"
DEFAULT_RERANKER_LITELLM_MODEL = "cohere/rerank-english-v3.0"

# LiteLLM SDK defaults
DEFAULT_EMBEDDINGS_LITELLM_SDK_MODEL = "cohere/embed-english-v3.0"
DEFAULT_RERANKER_LITELLM_SDK_MODEL = "cohere/rerank-english-v3.0"

DEFAULT_HOST = "0.0.0.0"
DEFAULT_PORT = 8888
DEFAULT_BASE_PATH = "" # Empty string = root path
Expand Down Expand Up @@ -521,6 +533,9 @@ class HindsightConfig:
embeddings_litellm_api_base: str
embeddings_litellm_api_key: str | None
embeddings_litellm_model: str
embeddings_litellm_sdk_api_key: str | None
embeddings_litellm_sdk_model: str
embeddings_litellm_sdk_api_base: str | None

# Reranker
reranker_provider: str
Expand All @@ -538,6 +553,9 @@ class HindsightConfig:
reranker_litellm_api_base: str
reranker_litellm_api_key: str | None
reranker_litellm_model: str
reranker_litellm_sdk_api_key: str | None
reranker_litellm_sdk_model: str
reranker_litellm_sdk_api_base: str | None

# Server
host: str
Expand Down Expand Up @@ -820,6 +838,12 @@ def from_env(cls) -> "HindsightConfig":
or os.getenv(ENV_LITELLM_API_BASE, DEFAULT_LITELLM_API_BASE),
embeddings_litellm_api_key=os.getenv(ENV_EMBEDDINGS_LITELLM_API_KEY) or os.getenv(ENV_LITELLM_API_KEY),
embeddings_litellm_model=os.getenv(ENV_EMBEDDINGS_LITELLM_MODEL, DEFAULT_EMBEDDINGS_LITELLM_MODEL),
# LiteLLM SDK embeddings (direct API access)
embeddings_litellm_sdk_api_key=os.getenv(ENV_EMBEDDINGS_LITELLM_SDK_API_KEY),
embeddings_litellm_sdk_model=os.getenv(
ENV_EMBEDDINGS_LITELLM_SDK_MODEL, DEFAULT_EMBEDDINGS_LITELLM_SDK_MODEL
),
embeddings_litellm_sdk_api_base=os.getenv(ENV_EMBEDDINGS_LITELLM_SDK_API_BASE) or None,
# Reranker
reranker_provider=os.getenv(ENV_RERANKER_PROVIDER, DEFAULT_RERANKER_PROVIDER),
reranker_local_model=os.getenv(ENV_RERANKER_LOCAL_MODEL, DEFAULT_RERANKER_LOCAL_MODEL),
Expand Down Expand Up @@ -849,6 +873,10 @@ def from_env(cls) -> "HindsightConfig":
or os.getenv(ENV_LITELLM_API_BASE, DEFAULT_LITELLM_API_BASE),
reranker_litellm_api_key=os.getenv(ENV_RERANKER_LITELLM_API_KEY) or os.getenv(ENV_LITELLM_API_KEY),
reranker_litellm_model=os.getenv(ENV_RERANKER_LITELLM_MODEL, DEFAULT_RERANKER_LITELLM_MODEL),
# LiteLLM SDK reranker (direct API access)
reranker_litellm_sdk_api_key=os.getenv(ENV_RERANKER_LITELLM_SDK_API_KEY),
reranker_litellm_sdk_model=os.getenv(ENV_RERANKER_LITELLM_SDK_MODEL, DEFAULT_RERANKER_LITELLM_SDK_MODEL),
reranker_litellm_sdk_api_base=os.getenv(ENV_RERANKER_LITELLM_SDK_API_BASE) or None,
# Server
host=os.getenv(ENV_HOST, DEFAULT_HOST),
port=int(os.getenv(ENV_PORT, DEFAULT_PORT)),
Expand Down
135 changes: 134 additions & 1 deletion hindsight-api/hindsight_api/engine/cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DEFAULT_RERANKER_FLASHRANK_CACHE_DIR,
DEFAULT_RERANKER_FLASHRANK_MODEL,
DEFAULT_RERANKER_LITELLM_MODEL,
DEFAULT_RERANKER_LITELLM_SDK_MODEL,
DEFAULT_RERANKER_LOCAL_FORCE_CPU,
DEFAULT_RERANKER_LOCAL_MAX_CONCURRENT,
DEFAULT_RERANKER_LOCAL_MODEL,
Expand All @@ -32,6 +33,7 @@
ENV_RERANKER_COHERE_MODEL,
ENV_RERANKER_FLASHRANK_CACHE_DIR,
ENV_RERANKER_FLASHRANK_MODEL,
ENV_RERANKER_LITELLM_SDK_API_KEY,
ENV_RERANKER_LOCAL_FORCE_CPU,
ENV_RERANKER_LOCAL_MAX_CONCURRENT,
ENV_RERANKER_LOCAL_MODEL,
Expand Down Expand Up @@ -828,6 +830,126 @@ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
return all_scores


class LiteLLMSDKCrossEncoder(CrossEncoderModel):
"""
LiteLLM SDK cross-encoder for direct API integration.

Supports reranking via LiteLLM SDK without requiring a proxy server.
Supported providers: Cohere, DeepInfra, Together AI, HuggingFace, Jina AI, Voyage AI, AWS Bedrock.

Example model names:
- cohere/rerank-english-v3.0
- deepinfra/Qwen3-reranker-8B
- together_ai/Salesforce/Llama-Rank-V1
- huggingface/BAAI/bge-reranker-v2-m3
"""

def __init__(
self,
api_key: str,
model: str = DEFAULT_RERANKER_LITELLM_SDK_MODEL,
api_base: str | None = None,
timeout: float = 60.0,
):
"""
Initialize LiteLLM SDK cross-encoder client.

Args:
api_key: API key for the reranking provider
model: Model name with provider prefix (e.g., "deepinfra/Qwen3-reranker-8B")
api_base: Custom base URL for API (optional)
timeout: Request timeout in seconds (default: 60.0)
"""
self.api_key = api_key
self.model = model
self.api_base = api_base
self.timeout = timeout
self._initialized = False
self._litellm = None # Will be set during initialization

@property
def provider_name(self) -> str:
return "litellm-sdk"

async def initialize(self) -> None:
"""Initialize the LiteLLM SDK client."""
if self._initialized:
return

try:
import litellm

self._litellm = litellm # Store reference
except ImportError:
raise ImportError("litellm is required for LiteLLMSDKCrossEncoder. Install it with: pip install litellm")

api_base_msg = f" at {self.api_base}" if self.api_base else ""
logger.info(f"Reranker: initializing LiteLLM SDK provider with model {self.model}{api_base_msg}")

self._initialized = True
logger.info("Reranker: LiteLLM SDK provider initialized")

async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
"""
Score query-document pairs using the LiteLLM SDK.

Args:
pairs: List of (query, document) tuples to score

Returns:
List of relevance scores
"""
if not self._initialized:
raise RuntimeError("Reranker not initialized. Call initialize() first.")

if not pairs:
return []

# Group pairs by query for efficient batching
# LiteLLM rerank expects one query with multiple documents
query_groups: dict[str, list[tuple[int, str]]] = {}
for idx, (query, text) in enumerate(pairs):
if query not in query_groups:
query_groups[query] = []
query_groups[query].append((idx, text))

all_scores = [0.0] * len(pairs)

for query, indexed_texts in query_groups.items():
texts = [text for _, text in indexed_texts]
indices = [idx for idx, _ in indexed_texts]

# Build kwargs for rerank call
rerank_kwargs = {
"model": self.model,
"query": query,
"documents": texts,
"api_key": self.api_key,
}
if self.api_base:
rerank_kwargs["api_base"] = self.api_base

response = await self._litellm.arerank(**rerank_kwargs)

# Map scores back to original positions
# Response format: RerankResponse with results list
# Each result is a TypedDict with "index" and "relevance_score"
if hasattr(response, "results") and response.results:
for result in response.results:
# Results are TypedDicts, use dict-style access
original_idx = result["index"]
score = result.get("relevance_score", result.get("score", 0.0))
all_scores[indices[original_idx]] = score
elif isinstance(response, list):
# Direct list of scores (unlikely but defensive)
for i, score in enumerate(response):
all_scores[indices[i]] = score
else:
logger.warning(f"Unexpected response format from LiteLLM rerank: {type(response)}")

return all_scores


def create_cross_encoder_from_env() -> CrossEncoderModel:
"""
Create a CrossEncoderModel instance based on configuration.
Expand Down Expand Up @@ -877,9 +999,20 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
api_key=config.reranker_litellm_api_key,
model=config.reranker_litellm_model,
)
elif provider == "litellm-sdk":
api_key = config.reranker_litellm_sdk_api_key
if not api_key:
raise ValueError(
f"{ENV_RERANKER_LITELLM_SDK_API_KEY} is required when {ENV_RERANKER_PROVIDER} is 'litellm-sdk'"
)
return LiteLLMSDKCrossEncoder(
api_key=api_key,
model=config.reranker_litellm_sdk_model,
api_base=config.reranker_litellm_sdk_api_base,
)
elif provider == "rrf":
return RRFPassthroughCrossEncoder()
else:
raise ValueError(
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'flashrank', 'litellm', 'rrf'"
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'flashrank', 'litellm', 'litellm-sdk', 'rrf'"
)
Loading