Skip to content

Commit 3b21c27

Browse files
authored
feat: adding cached llm (#66)
1 parent b05b535 commit 3b21c27

File tree

5 files changed

+290
-4
lines changed

5 files changed

+290
-4
lines changed

PERFORMANCE_OPTIMIZATIONS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ New environment variables for tuning performance:
6767

6868
- `AGENT_LLM_CACHE_SIZE` - LRU cache size for LLM responses (default: 1000)
6969
- `AGENT_LLM_CACHE_TTL` - Cache TTL in seconds (default: 300)
70+
- `AGENT_LLM_CACHE_PATH` - Path to cache file for persistence (default: .agent_cache.json)
7071
- `AGENT_CONCURRENT_OPS` - Max concurrent operations (default: 10)
7172
- `AGENT_BATCH_SIZE` - Batch size for batch operations (default: 50)
7273

src/cache/lru_cache.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,45 @@ func HashKey(prompt string) string {
121121
h := sha256.Sum256([]byte(prompt))
122122
return hex.EncodeToString(h[:])
123123
}
124+
125+
// Dump returns a slice of cache entries for persistence
126+
func (c *LRUCache) Dump() map[string]CacheEntry {
127+
c.mu.RLock()
128+
defer c.mu.RUnlock()
129+
130+
dump := make(map[string]CacheEntry, len(c.items))
131+
for k, elem := range c.items {
132+
dump[k] = elem.Value.(*entry).value
133+
}
134+
return dump
135+
}
136+
137+
// Restore populates the cache from a map of entries
138+
func (c *LRUCache) Restore(dump map[string]CacheEntry) {
139+
c.mu.Lock()
140+
defer c.mu.Unlock()
141+
142+
c.lru.Init()
143+
c.items = make(map[string]*list.Element, c.capacity)
144+
145+
for k, v := range dump {
146+
// Check expiry during restore
147+
if time.Now().After(v.ExpiresAt) {
148+
continue
149+
}
150+
151+
// Add to cache
152+
ent := &entry{key: k, value: v}
153+
elem := c.lru.PushFront(ent)
154+
c.items[k] = elem
155+
}
156+
157+
// Enforce capacity
158+
for c.lru.Len() > c.capacity {
159+
oldest := c.lru.Back()
160+
if oldest != nil {
161+
c.lru.Remove(oldest)
162+
delete(c.items, oldest.Value.(*entry).key)
163+
}
164+
}
165+
}

src/models/cached.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package models
2+
3+
import (
4+
"context"
5+
"crypto/sha256"
6+
"encoding/hex"
7+
"encoding/json"
8+
"os"
9+
"strconv"
10+
"time"
11+
12+
"github.com/Protocol-Lattice/go-agent/src/cache"
13+
)
14+
15+
// CachedLLM wraps an Agent and caches Generate calls.
16+
type CachedLLM struct {
17+
Agent Agent
18+
Cache *cache.LRUCache
19+
FilePath string
20+
}
21+
22+
// NewCachedLLM creates a new CachedLLM wrapper.
23+
func NewCachedLLM(agent Agent, size int, ttl time.Duration, filePath string) *CachedLLM {
24+
c := &CachedLLM{
25+
Agent: agent,
26+
Cache: cache.NewLRUCache(size, ttl),
27+
FilePath: filePath,
28+
}
29+
if filePath != "" {
30+
c.load()
31+
}
32+
return c
33+
}
34+
35+
func (c *CachedLLM) load() {
36+
f, err := os.Open(c.FilePath)
37+
if err != nil {
38+
return // ignore errors (file not found, etc)
39+
}
40+
defer f.Close()
41+
42+
var dump map[string]cache.CacheEntry
43+
if err := json.NewDecoder(f).Decode(&dump); err == nil {
44+
c.Cache.Restore(dump)
45+
}
46+
}
47+
48+
func (c *CachedLLM) save() {
49+
if c.FilePath == "" {
50+
return
51+
}
52+
dump := c.Cache.Dump()
53+
54+
// Atomic write: write to temp, then rename
55+
tmp := c.FilePath + ".tmp"
56+
f, err := os.Create(tmp)
57+
if err != nil {
58+
return
59+
}
60+
61+
if err := json.NewEncoder(f).Encode(dump); err != nil {
62+
f.Close()
63+
os.Remove(tmp)
64+
return
65+
}
66+
f.Close()
67+
os.Rename(tmp, c.FilePath)
68+
}
69+
70+
// Generate checks the cache before calling the underlying agent.
71+
func (c *CachedLLM) Generate(ctx context.Context, prompt string) (any, error) {
72+
key := cache.HashKey(prompt)
73+
if val, ok := c.Cache.Get(key); ok {
74+
return val, nil
75+
}
76+
77+
res, err := c.Agent.Generate(ctx, prompt)
78+
if err != nil {
79+
return nil, err
80+
}
81+
82+
c.Cache.Set(key, res)
83+
c.save()
84+
return res, nil
85+
}
86+
87+
// GenerateWithFiles checks the cache (including file hashes) before calling the underlying agent.
88+
func (c *CachedLLM) GenerateWithFiles(ctx context.Context, prompt string, files []File) (any, error) {
89+
// Create a cache key that includes the prompt and all file contents
90+
h := sha256.New()
91+
h.Write([]byte(prompt))
92+
for _, f := range files {
93+
h.Write([]byte(f.Name))
94+
h.Write([]byte(f.MIME))
95+
h.Write(f.Data)
96+
}
97+
key := hex.EncodeToString(h.Sum(nil))
98+
99+
if val, ok := c.Cache.Get(key); ok {
100+
return val, nil
101+
}
102+
103+
res, err := c.Agent.GenerateWithFiles(ctx, prompt, files)
104+
if err != nil {
105+
return nil, err
106+
}
107+
108+
c.Cache.Set(key, res)
109+
c.save()
110+
return res, nil
111+
}
112+
113+
// TryCreateCachedLLM checks env vars and wraps the agent if caching is enabled.
114+
func TryCreateCachedLLM(agent Agent) Agent {
115+
sizeStr := os.Getenv("AGENT_LLM_CACHE_SIZE")
116+
if sizeStr == "" {
117+
return agent
118+
}
119+
120+
size, err := strconv.Atoi(sizeStr)
121+
if err != nil || size <= 0 {
122+
return agent
123+
}
124+
125+
ttlStr := os.Getenv("AGENT_LLM_CACHE_TTL")
126+
ttl := 300 * time.Second // default 5 mins
127+
if ttlStr != "" {
128+
if sec, err := strconv.Atoi(ttlStr); err == nil && sec > 0 {
129+
ttl = time.Duration(sec) * time.Second
130+
}
131+
}
132+
133+
path := os.Getenv("AGENT_LLM_CACHE_PATH")
134+
if path == "" {
135+
// Default to local directory if not specified, but only if size is set
136+
path = ".agent_cache.json"
137+
}
138+
139+
return NewCachedLLM(agent, size, ttl, path)
140+
}

src/models/cached_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package models
2+
3+
import (
4+
"context"
5+
"sync/atomic"
6+
"testing"
7+
"time"
8+
)
9+
10+
type MockAgent struct {
11+
CallCount int32
12+
}
13+
14+
func (m *MockAgent) Generate(ctx context.Context, prompt string) (any, error) {
15+
atomic.AddInt32(&m.CallCount, 1)
16+
return "mock response", nil
17+
}
18+
19+
func (m *MockAgent) GenerateWithFiles(ctx context.Context, prompt string, files []File) (any, error) {
20+
atomic.AddInt32(&m.CallCount, 1)
21+
return "mock response with files", nil
22+
}
23+
24+
func TestCachedLLM_Generate(t *testing.T) {
25+
mock := &MockAgent{}
26+
cached := NewCachedLLM(mock, 10, time.Minute, "")
27+
28+
ctx := context.Background()
29+
prompt := "hello"
30+
31+
// First call - should hit the agent
32+
_, err := cached.Generate(ctx, prompt)
33+
if err != nil {
34+
t.Fatalf("first call failed: %v", err)
35+
}
36+
if count := atomic.LoadInt32(&mock.CallCount); count != 1 {
37+
t.Errorf("expected 1 call, got %d", count)
38+
}
39+
40+
// Second call - should hit the cache
41+
_, err = cached.Generate(ctx, prompt)
42+
if err != nil {
43+
t.Fatalf("second call failed: %v", err)
44+
}
45+
if count := atomic.LoadInt32(&mock.CallCount); count != 1 {
46+
t.Errorf("expected 1 call (cached), got %d", count)
47+
}
48+
49+
// Different prompt - should hit the agent
50+
_, err = cached.Generate(ctx, "world")
51+
if err != nil {
52+
t.Fatalf("third call failed: %v", err)
53+
}
54+
if count := atomic.LoadInt32(&mock.CallCount); count != 2 {
55+
t.Errorf("expected 2 calls, got %d", count)
56+
}
57+
}
58+
59+
func TestCachedLLM_GenerateWithFiles(t *testing.T) {
60+
mock := &MockAgent{}
61+
cached := NewCachedLLM(mock, 10, time.Minute, "")
62+
63+
ctx := context.Background()
64+
prompt := "analyze"
65+
files := []File{{Name: "a.txt", Data: []byte("content")}}
66+
67+
// First call
68+
_, err := cached.GenerateWithFiles(ctx, prompt, files)
69+
if err != nil {
70+
t.Fatalf("first call failed: %v", err)
71+
}
72+
if count := atomic.LoadInt32(&mock.CallCount); count != 1 {
73+
t.Errorf("expected 1 call, got %d", count)
74+
}
75+
76+
// Second call - same files
77+
_, err = cached.GenerateWithFiles(ctx, prompt, files)
78+
if err != nil {
79+
t.Fatalf("second call failed: %v", err)
80+
}
81+
if count := atomic.LoadInt32(&mock.CallCount); count != 1 {
82+
t.Errorf("expected 1 call, got %d", count)
83+
}
84+
85+
// Different file content
86+
files2 := []File{{Name: "a.txt", Data: []byte("different")}}
87+
_, err = cached.GenerateWithFiles(ctx, prompt, files2)
88+
if err != nil {
89+
t.Fatalf("third call failed: %v", err)
90+
}
91+
if count := atomic.LoadInt32(&mock.CallCount); count != 2 {
92+
t.Errorf("expected 2 calls, got %d", count)
93+
}
94+
}

src/models/helper.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,27 @@ var (
4848

4949
// NewLLMProvider returns a concrete Agent.
5050
func NewLLMProvider(ctx context.Context, provider string, model string, promptPrefix string) (Agent, error) {
51+
var agent Agent
52+
var err error
53+
5154
switch provider {
5255
case "openai":
53-
return NewOpenAILLM(model, promptPrefix), nil
56+
agent = NewOpenAILLM(model, promptPrefix)
5457
case "gemini", "google":
55-
return NewGeminiLLM(ctx, model, promptPrefix)
58+
agent, err = NewGeminiLLM(ctx, model, promptPrefix)
5659
case "ollama":
57-
return NewOllamaLLM(model, promptPrefix)
60+
agent, err = NewOllamaLLM(model, promptPrefix)
5861
case "anthropic", "claude":
59-
return NewAnthropicLLM(model, promptPrefix), nil
62+
agent = NewAnthropicLLM(model, promptPrefix)
6063
default:
6164
return nil, fmt.Errorf("unknown provider: %s", provider)
6265
}
66+
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
return TryCreateCachedLLM(agent), nil
6372
}
6473

6574
// sanitizeForGemini coerces edge cases again and filters to what Gemini will accept.

0 commit comments

Comments
 (0)