Skip to content

chore(tests): add persistency and source manager tests #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ on:
pull_request:
branches:
- '**'


concurrency:
group: ci-${{ github.head_ref || github.ref }}-${{ github.repository }}
cancel-in-progress: true


jobs:
test:
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion rag/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func NewPersistentLocalAICollection(llmClient *openai.Client, apiURL, apiKey, co

// TODO: This does not work as we do not have .Reset().
// The problem is that LocalAI stores are not persistent either and do not allow upserts.
persistentKB.repopulate()
persistentKB.Repopulate()

return persistentKB
}
Expand Down
48 changes: 37 additions & 11 deletions rag/persistency.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/mudler/localrecall/pkg/chunk"
"github.com/mudler/localrecall/pkg/xlog"
"github.com/mudler/localrecall/rag/engine"
"github.com/mudler/localrecall/rag/types"
)

// CollectionState represents the persistent state of a collection
Expand Down Expand Up @@ -91,6 +92,13 @@ func NewPersistentCollectionKB(stateFile, assetDir string, store Engine, maxChun
return db, nil
}

func (db *PersistentKB) Search(s string, similarEntries int) ([]types.Result, error) {
db.Lock()
defer db.Unlock()

return db.Engine.Search(s, similarEntries)
}

func (db *PersistentKB) Reset() error {
db.Lock()
for f := range db.index {
Expand Down Expand Up @@ -120,11 +128,16 @@ func (db *PersistentKB) save() error {
return os.WriteFile(db.path, data, 0644)
}

// repopulate reinitializes the persistent knowledge base with the files that were added to it.
func (db *PersistentKB) repopulate() error {
func (db *PersistentKB) Count() int {
db.Lock()
defer db.Unlock()

return db.Engine.Count()
}

// repopulate reinitializes the persistent knowledge base with the files that were added to it.
func (db *PersistentKB) repopulate() error {

if err := db.Engine.Reset(); err != nil {
return fmt.Errorf("failed to reset engine: %w", err)
}
Expand All @@ -141,6 +154,13 @@ func (db *PersistentKB) repopulate() error {
return nil
}

func (db *PersistentKB) Repopulate() error {
db.Lock()
defer db.Unlock()

return db.repopulate()
}

// Store stores an entry in the persistent knowledge base.
func (db *PersistentKB) ListDocuments() []string {
db.Lock()
Expand Down Expand Up @@ -173,6 +193,10 @@ func (db *PersistentKB) Store(entry string, metadata map[string]string) error {
db.Lock()
defer db.Unlock()

return db.storeFile(entry, metadata)
}

func (db *PersistentKB) storeFile(entry string, metadata map[string]string) error {
fileName := filepath.Base(entry)

// copy file to assetDir (if it's a file)
Expand All @@ -194,18 +218,19 @@ func (db *PersistentKB) Store(entry string, metadata map[string]string) error {

func (db *PersistentKB) StoreOrReplace(entry string, metadata map[string]string) error {
db.Lock()
defer db.Unlock()

fileName := filepath.Base(entry)
_, ok := db.index[fileName]
db.Unlock()
// Check if we have it already in the index
if ok {
xlog.Info("Data already exists for entry", "entry", entry, "index", db.index)
if err := db.RemoveEntry(fileName); err != nil {
if err := db.removeFileEntry(fileName); err != nil {
return fmt.Errorf("failed to remove entry: %w", err)
}
}

return db.Store(entry, metadata)
return db.storeFile(entry, metadata)
}

func (db *PersistentKB) store(metadata map[string]string, files ...string) ([]engine.Result, error) {
Expand All @@ -231,8 +256,14 @@ func (db *PersistentKB) store(metadata map[string]string, files ...string) ([]en
return results, nil
}

// RemoveEntry removes an entry from the persistent knowledge base.
func (db *PersistentKB) RemoveEntry(entry string) error {
db.Lock()
defer db.Unlock()

return db.removeFileEntry(entry)
}

func (db *PersistentKB) removeFileEntry(entry string) error {

xlog.Info("Removing entry", "entry", entry)
if os.Getenv("LOCALRECALL_REPOPULATE_DELETE") != "true" {
Expand Down Expand Up @@ -261,25 +292,20 @@ func (db *PersistentKB) RemoveEntry(entry string) error {
}
}

db.Lock()

xlog.Info("Deleting entry from index", "entry", entry)
delete(db.index, entry)

xlog.Info("Removing entry from disk", "file", e)
os.Remove(e)
db.Unlock()
return db.save()
}

db.Lock()
for e := range db.index {
if e == entry {
os.Remove(filepath.Join(db.assetDir, e))
break
}
}
db.Unlock()

// TODO: this is suboptimal, but currently chromem does not support deleting single entities
return db.repopulate()
Expand Down
32 changes: 24 additions & 8 deletions rag/source_manager.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rag

import (
"context"
"fmt"
"os"
"path/filepath"
Expand All @@ -24,13 +25,18 @@ type SourceManager struct {
sources map[string][]ExternalSource // collection name -> sources
collections map[string]*PersistentKB // collection name -> collection
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}

// NewSourceManager creates a new source manager
func NewSourceManager() *SourceManager {
ctx, cancel := context.WithCancel(context.Background())
return &SourceManager{
sources: make(map[string][]ExternalSource),
collections: make(map[string]*PersistentKB),
ctx: ctx,
cancel: cancel,
}
}

Expand Down Expand Up @@ -182,17 +188,27 @@ func (sm *SourceManager) Start() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()

for range ticker.C {
sm.mu.RLock()
for collectionName, sources := range sm.sources {
collection := sm.collections[collectionName]
for _, source := range sources {
if time.Since(source.LastUpdate) >= source.UpdateInterval {
go sm.updateSource(collectionName, source, collection)
for {
select {
case <-sm.ctx.Done():
return
case <-ticker.C:
sm.mu.RLock()
for collectionName, sources := range sm.sources {
collection := sm.collections[collectionName]
for _, source := range sources {
if time.Since(source.LastUpdate) >= source.UpdateInterval {
go sm.updateSource(collectionName, source, collection)
}
}
}
sm.mu.RUnlock()
}
sm.mu.RUnlock()
}
}()
}

// Stop stops the background service
func (sm *SourceManager) Stop() {
sm.cancel()
}
34 changes: 34 additions & 0 deletions test/e2e/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package e2e_test

import (
"time"

"github.com/sashabaranov/go-openai"
)

const (
// TestCollection is the default collection name used in tests
TestCollection = "foo"

// EmbeddingModel is the model used for embeddings in tests
EmbeddingModel = "granite-embedding-107m-multilingual"

// DefaultChunkSize is the default chunk size used in tests
DefaultChunkSize = 1000

// DefaultUpdateInterval is the default update interval for external sources
DefaultUpdateInterval = time.Hour

// TestTimeout is the default timeout for Eventually blocks
TestTimeout = 1 * time.Minute

// TestPollingInterval is the default polling interval for Eventually blocks
TestPollingInterval = 500 * time.Millisecond
)

// NewTestOpenAIConfig creates a new OpenAI config for testing
func NewTestOpenAIConfig() openai.ClientConfig {
config := openai.DefaultConfig("foo")
config.BaseURL = localAIEndpoint
return config
}
46 changes: 20 additions & 26 deletions test/e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"os"
"path/filepath"
"time"

"github.com/mudler/localrecall/pkg/client"
. "github.com/onsi/ginkgo/v2"
Expand Down Expand Up @@ -64,88 +63,83 @@ var _ = Describe("API", func() {
Skip("Skipping E2E tests")
}

config := openai.DefaultConfig("foo")
config.BaseURL = localAIEndpoint

localAI = openai.NewClientWithConfig(config)
localAI = openai.NewClientWithConfig(NewTestOpenAIConfig())
localRecall = client.NewClient(localRecallEndpoint)

Eventually(func() error {

res, err := localAI.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
Model: "granite-embedding-107m-multilingual",
Model: EmbeddingModel,
Input: "foo",
})
if len(res.Data) == 0 {
return fmt.Errorf("no data")
}
return err
}, 5*time.Minute, time.Second).Should(Succeed())
}, TestTimeout, TestPollingInterval).Should(Succeed())

Eventually(func() error {
_, err := localRecall.ListCollections()

return err
}, 5*time.Minute, time.Second).Should(Succeed())
}, TestTimeout, TestPollingInterval).Should(Succeed())

localRecall.Reset(testCollection)
localRecall.Reset(TestCollection)
})

It("should create collections", func() {
err := localRecall.CreateCollection(testCollection)
err := localRecall.CreateCollection(TestCollection)
Expect(err).To(BeNil())

collections, err := localRecall.ListCollections()
Expect(err).To(BeNil())
Expect(collections).To(ContainElement(testCollection))
Expect(collections).To(ContainElement(TestCollection))
})

It("should search between documents", func() {
err := localRecall.CreateCollection(testCollection)
err := localRecall.CreateCollection(TestCollection)
Expect(err).ToNot(HaveOccurred())

tempContent(story1, localRecall)
tempContent(story2, localRecall)
expectContent("foo", "spiders", "spider", localRecall)
expectContent("foo", "heist", "the Great Pigeon Heist", localRecall)
expectContent(TestCollection, "spiders", "spider", localRecall)
expectContent(TestCollection, "heist", "the Great Pigeon Heist", localRecall)
})

It("should reset collections", func() {
err := localRecall.CreateCollection(testCollection)
err := localRecall.CreateCollection(TestCollection)
Expect(err).To(BeNil())

tempContent(story1, localRecall)
tempContent(story2, localRecall)

err = localRecall.Reset(testCollection)
err = localRecall.Reset(TestCollection)
Expect(err).To(BeNil())

docs, err := localRecall.Search(testCollection, "spiders", 1)
docs, err := localRecall.Search(TestCollection, "spiders", 1)
Expect(err).To(HaveOccurred())
Expect(docs).To(BeNil())
Expect(err.Error()).To(ContainSubstring("failed to search collection"))
})

It("should be able to delete documents", func() {
err := localRecall.CreateCollection(testCollection)
err := localRecall.CreateCollection(TestCollection)
Expect(err).ToNot(HaveOccurred())

tempContent(story1, localRecall)
fileName := tempContent(story2, localRecall)

entries, err := localRecall.ListEntries(testCollection)
entries, err := localRecall.ListEntries(TestCollection)
Expect(err).ToNot(HaveOccurred())
Expect(entries).To(HaveLen(2))

expectContent("foo", "spiders", "spider", localRecall)
expectContent("foo", "heist", "the Great Pigeon Heist", localRecall)
expectContent(TestCollection, "spiders", "spider", localRecall)
expectContent(TestCollection, "heist", "the Great Pigeon Heist", localRecall)

entry := fileName
entries, err = localRecall.DeleteEntry(testCollection, entry)
entries, err = localRecall.DeleteEntry(TestCollection, entry)
Expect(err).ToNot(HaveOccurred())
Expect(entries).To(HaveLen(1))

expectContent("foo", "heist", "the Great Pigeon Heist", localRecall)
expectContent(TestCollection, "heist", "the Great Pigeon Heist", localRecall)
})
})

Expand All @@ -167,7 +161,7 @@ func tempContent(content string, localRecall *client.Client) string {
_, err = ff.WriteString(content)
ExpectWithOffset(1, err).ToNot(HaveOccurred())

err = localRecall.Store(testCollection, ff.Name())
err = localRecall.Store(TestCollection, ff.Name())
ExpectWithOffset(1, err).ToNot(HaveOccurred())

return fileName
Expand Down
Loading