Skip to content

WIP - Hybrid search #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions rag/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/mudler/localrecall/pkg/xlog"
"github.com/mudler/localrecall/rag/engine"
"github.com/mudler/localrecall/rag/engine/localai"
"github.com/mudler/localrecall/rag/types"
"github.com/sashabaranov/go-openai"
)

Expand All @@ -22,10 +23,17 @@ func NewPersistentChromeCollection(llmClient *openai.Client, collectionName, dbP
os.Exit(1)
}

// Create a hybrid search engine with the ChromemDB engine
hybridEngine, err := engine.NewHybridSearchEngine(chromemDB, types.NewBasicReranker(), dbPath)
if err != nil {
xlog.Error("Failed to create hybrid search engine", err)
os.Exit(1)
}

persistentKB, err := NewPersistentCollectionKB(
filepath.Join(dbPath, fmt.Sprintf("%s%s.json", collectionPrefix, collectionName)),
filePath,
chromemDB,
hybridEngine,
maxChunkSize)
if err != nil {
xlog.Error("Failed to create PersistentKB", err)
Expand All @@ -40,10 +48,17 @@ func NewPersistentLocalAICollection(llmClient *openai.Client, apiURL, apiKey, co
laiStore := localai.NewStoreClient(apiURL, apiKey)
ragDB := engine.NewLocalAIRAGDB(laiStore, llmClient, embeddingModel)

// Create a hybrid search engine with the LocalAI engine
hybridEngine, err := engine.NewHybridSearchEngine(ragDB, types.NewBasicReranker(), dbPath)
if err != nil {
xlog.Error("Failed to create hybrid search engine", err)
os.Exit(1)
}

persistentKB, err := NewPersistentCollectionKB(
filepath.Join(dbPath, fmt.Sprintf("%s%s.json", collectionPrefix, collectionName)),
filePath,
ragDB,
hybridEngine,
maxChunkSize)
if err != nil {
xlog.Error("Failed to create PersistentKB", err)
Expand All @@ -59,18 +74,15 @@ func NewPersistentLocalAICollection(llmClient *openai.Client, apiURL, apiKey, co

// ListAllCollections lists all collections in the database
func ListAllCollections(dbPath string) []string {
collections := []string{}
files, err := os.ReadDir(dbPath)
if err != nil {
xlog.Error("Failed to read directory", err)
return nil
return collections
}

var collections []string
for _, file := range files {
if !file.IsDir() && filepath.Ext(file.Name()) == ".json" && strings.HasPrefix(file.Name(), collectionPrefix) {
collectionName := strings.TrimPrefix(file.Name(), collectionPrefix)
collectionName = strings.TrimSuffix(collectionName, ".json")
collections = append(collections, collectionName)
for _, f := range files {
if strings.HasPrefix(f.Name(), collectionPrefix) {
collections = append(collections, strings.TrimPrefix(strings.TrimSuffix(f.Name(), ".json"), collectionPrefix))
}
}

Expand Down
12 changes: 6 additions & 6 deletions rag/engine.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package rag

import (
"github.com/mudler/localrecall/rag/interfaces"
"github.com/mudler/localrecall/rag/types"
)

type Engine interface {
Store(s string, meta map[string]string) error
Reset() error
Search(s string, similarEntries int) ([]types.Result, error)
Count() int
}
// Engine is an alias for interfaces.Engine
type Engine = interfaces.Engine

// Result is an alias for types.Result
type Result = types.Result
141 changes: 141 additions & 0 deletions rag/engine/fulltext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package engine

import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"

"github.com/mudler/localrecall/rag/types"
)

// FullTextIndex manages the full-text search index
type FullTextIndex struct {
path string
documents map[string]string
mu sync.RWMutex
}

// NewFullTextIndex creates a new full-text index
func NewFullTextIndex(path string) (*FullTextIndex, error) {
index := &FullTextIndex{
path: path,
documents: make(map[string]string),
}

// Load existing index if it exists
if err := index.load(); err != nil {
return nil, fmt.Errorf("failed to load full-text index: %w", err)
}

return index, nil
}

// Store adds a document to the index
func (i *FullTextIndex) Store(id string, content string) error {
i.mu.Lock()
defer i.mu.Unlock()

i.documents[id] = content
return i.save()
}

// Delete removes a document from the index
func (i *FullTextIndex) Delete(id string) error {
i.mu.Lock()
defer i.mu.Unlock()

delete(i.documents, id)
return i.save()
}

// Reset clears the index
func (i *FullTextIndex) Reset() error {
i.mu.Lock()
defer i.mu.Unlock()

i.documents = make(map[string]string)
return i.save()
}

// Search performs full-text search on the index
func (i *FullTextIndex) Search(query string, maxResults int) []types.Result {
i.mu.RLock()
defer i.mu.RUnlock()

queryTerms := strings.Fields(strings.ToLower(query))
scoredResults := make([]types.Result, 0)

// Score all documents
for id, content := range i.documents {
contentLower := strings.ToLower(content)
score := float32(0)

// Simple term frequency scoring
for _, term := range queryTerms {
if strings.Contains(contentLower, term) {
score += 1.0
}
}

// Normalize score
if len(queryTerms) > 0 {
score = score / float32(len(queryTerms))
}

// Only include documents with a score > 0
if score > 0 {
scoredResults = append(scoredResults, types.Result{
ID: id,
Content: content,
FullTextScore: score,
})
}
}

// Sort by full-text score
for i := 0; i < len(scoredResults); i++ {
for j := i + 1; j < len(scoredResults); j++ {
if scoredResults[i].FullTextScore < scoredResults[j].FullTextScore {
scoredResults[i], scoredResults[j] = scoredResults[j], scoredResults[i]
}
}
}

// Return top maxResults results
if len(scoredResults) > maxResults {
scoredResults = scoredResults[:maxResults]
}

return scoredResults
}

// load reads the index from disk
func (i *FullTextIndex) load() error {
data, err := os.ReadFile(i.path)
if err != nil {
if os.IsNotExist(err) {
return nil // File doesn't exist yet, that's okay
}
return err
}

return json.Unmarshal(data, &i.documents)
}

// save writes the index to disk
func (i *FullTextIndex) save() error {
data, err := json.Marshal(i.documents)
if err != nil {
return err
}

// Ensure directory exists
if err := os.MkdirAll(filepath.Dir(i.path), 0755); err != nil {
return err
}

return os.WriteFile(i.path, data, 0644)
}
121 changes: 121 additions & 0 deletions rag/engine/hybrid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package engine

import (
"fmt"
"path/filepath"

"github.com/mudler/localrecall/rag/interfaces"
"github.com/mudler/localrecall/rag/types"
)

// HybridSearchEngine combines semantic and full-text search
type HybridSearchEngine struct {
semanticEngine interfaces.Engine
reranker types.Reranker
fullTextIndex *FullTextIndex
}

// NewHybridSearchEngine creates a new hybrid search engine
func NewHybridSearchEngine(semanticEngine interfaces.Engine, reranker types.Reranker, dbPath string) (*HybridSearchEngine, error) {
// Create full-text index in the same directory as the semantic engine
fullTextIndex, err := NewFullTextIndex(filepath.Join(dbPath, "fulltext.json"))
if err != nil {
return nil, fmt.Errorf("failed to create full-text index: %w", err)
}

return &HybridSearchEngine{
semanticEngine: semanticEngine,
reranker: reranker,
fullTextIndex: fullTextIndex,
}, nil
}

// Store stores a document in both semantic and full-text indexes
func (h *HybridSearchEngine) Store(s string, metadata map[string]string) error {
// Store in semantic engine
if err := h.semanticEngine.Store(s, metadata); err != nil {
return err
}

// Store in full-text index
// Use the content as the ID since we don't have a better identifier
return h.fullTextIndex.Store(s, s)
}

// Reset resets both semantic and full-text indexes
func (h *HybridSearchEngine) Reset() error {
if err := h.semanticEngine.Reset(); err != nil {
return err
}
return h.fullTextIndex.Reset()
}

// Count returns the number of documents in the index
func (h *HybridSearchEngine) Count() int {
return h.semanticEngine.Count()
}

// Search performs hybrid search by combining semantic and full-text search results
func (h *HybridSearchEngine) Search(query string, similarEntries int) ([]types.Result, error) {
// Perform semantic search
semanticResults, err := h.semanticEngine.Search(query, similarEntries)
if err != nil {
return nil, fmt.Errorf("semantic search failed: %w", err)
}

// Perform full-text search on all documents
fullTextResults := h.fullTextIndex.Search(query, similarEntries)

// Combine results from both searches
combinedResults := h.combineResults(semanticResults, fullTextResults)

// Rerank the combined results
rerankedResults, err := h.reranker.Rerank(query, combinedResults)
if err != nil {
return nil, fmt.Errorf("reranking failed: %w", err)
}

return rerankedResults, nil
}

// combineResults combines semantic and full-text search results
func (h *HybridSearchEngine) combineResults(semanticResults, fullTextResults []types.Result) []types.Result {
// Create a map to track unique results by content
resultMap := make(map[string]types.Result)

// Add semantic results
for _, result := range semanticResults {
resultMap[result.Content] = result
}

// Add full-text results, combining scores if the same content exists
for _, result := range fullTextResults {
if existing, exists := resultMap[result.Content]; exists {
// If the content exists in both results, combine the scores
existing.FullTextScore = result.FullTextScore
existing.CombinedScore = (existing.Similarity + result.FullTextScore) / 2
resultMap[result.Content] = existing
} else {
// If it's a new result, just add it
result.CombinedScore = result.FullTextScore
resultMap[result.Content] = result
}
}

// Convert map back to slice
combinedResults := make([]types.Result, 0, len(resultMap))
for _, result := range resultMap {
combinedResults = append(combinedResults, result)
}

// Sort by combined score
for i := 0; i < len(combinedResults); i++ {
for j := i + 1; j < len(combinedResults); j++ {
if combinedResults[i].CombinedScore < combinedResults[j].CombinedScore {
combinedResults[i], combinedResults[j] = combinedResults[j], combinedResults[i]
}
}
}

return combinedResults
}
11 changes: 11 additions & 0 deletions rag/interfaces/engine.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package interfaces

import "github.com/mudler/localrecall/rag/types"

// Engine defines the interface for search engines
type Engine interface {
Store(s string, meta map[string]string) error
Reset() error
Search(s string, similarEntries int) ([]types.Result, error)
Count() int
}
22 changes: 22 additions & 0 deletions rag/types/reranker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package types

// Reranker defines the interface for reranking search results
type Reranker interface {
// Rerank takes a query and a list of results, and returns a reranked list
Rerank(query string, results []Result) ([]Result, error)
}

// BasicReranker implements a simple reranking strategy that combines semantic and full-text scores
type BasicReranker struct{}

// NewBasicReranker creates a new BasicReranker instance
func NewBasicReranker() *BasicReranker {
return &BasicReranker{}
}

// Rerank implements a simple reranking strategy that combines semantic and full-text scores
func (r *BasicReranker) Rerank(query string, results []Result) ([]Result, error) {
// For now, we'll just return the results as is
// In a real implementation, we would combine semantic and full-text scores
return results, nil
}
Loading
Loading