Skip to content

Commit

Permalink
utilise buffer and existing structs more
Browse files Browse the repository at this point in the history
  • Loading branch information
daesu committed May 22, 2024
1 parent b5c4c99 commit c070673
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 49 deletions.
14 changes: 14 additions & 0 deletions pkg/api/schemas/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ var (
UnknownError ErrorCode = "unknown_error"
)

type StreamingCacheEntry struct {
Key string
Query string
ResponseChunks []string
Complete bool
}

type StreamingCacheEntryChunk struct {
Key string
Index int
Content ChatStreamChunk
Complete bool
}

type StreamRequestID = string

// ChatStreamRequest defines a message that requests a new streaming chat
Expand Down
50 changes: 13 additions & 37 deletions pkg/cache/memory_cache.go
Original file line number Diff line number Diff line change
@@ -1,51 +1,27 @@
package cache

import (
"fmt"
"sync"
"time"

"github.com/EinStack/glide/pkg/api/schemas"
)

type CacheEntry struct {
Response schemas.ChatResponse
Timestamp time.Time
}
import "sync"

type MemoryCache struct {
cache map[string]CacheEntry
mux sync.Mutex
cache map[string]interface{}
lock sync.RWMutex
}

func NewMemoryCache() *MemoryCache {
return &MemoryCache{
cache: make(map[string]CacheEntry),
}
}

func (m *MemoryCache) Get(key string) (schemas.ChatResponse, bool) {
m.mux.Lock()
defer m.mux.Unlock()
entry, exists := m.cache[key]
if !exists {
return schemas.ChatResponse{}, false
cache: make(map[string]interface{}),
}
return entry.Response, true
}

func (m *MemoryCache) Set(key string, response schemas.ChatResponse) {
m.mux.Lock()
defer m.mux.Unlock()
m.cache[key] = CacheEntry{
Response: response,
Timestamp: time.Now(),
}
func (m *MemoryCache) Get(key string) (interface{}, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
val, found := m.cache[key]
return val, found
}

func (m *MemoryCache) All() {
m.mux.Lock()
defer m.mux.Unlock()

fmt.Println("%v", m.cache)
func (m *MemoryCache) Set(key string, value interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
m.cache[key] = value
}
85 changes: 73 additions & 12 deletions pkg/routers/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package routers
import (
"context"
"errors"
"fmt"
"log"

"github.com/EinStack/glide/pkg/cache"
Expand Down Expand Up @@ -78,7 +79,11 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem
cacheKey := req.Message.Content
if cachedResponse, found := r.cache.Get(cacheKey); found {
log.Println("found cached response and returning: ", cachedResponse)

Check warning on line 81 in pkg/routers/router.go

View workflow job for this annotation

GitHub Actions / Static Checks

indent-error-flow: if block ends with a return statement, so drop this else and outdent its block (move short variable declaration to its own line if necessary) (revive)
return &cachedResponse, nil
if response, ok := cachedResponse.(*schemas.ChatResponse); ok {
return response, nil
} else {
log.Println("Failed to cast cached response to ChatResponse")
}
}

retryIterator := r.retry.Iterator()
Expand Down Expand Up @@ -112,14 +117,13 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem
zap.String("provider", langModel.Provider()),
zap.Error(err),
)

continue
}

resp.RouterID = r.routerID

// Store response in cache
r.cache.Set(cacheKey, *resp)
r.cache.Set(cacheKey, resp)

return resp, nil
}
Expand Down Expand Up @@ -152,10 +156,43 @@ func (r *LangRouter) ChatStream(
req.Metadata,
&schemas.ErrorReason,
)

return
}

cacheKey := req.Message.Content
if streamingCacheEntry, found := r.cache.Get(cacheKey); found {
if entry, ok := streamingCacheEntry.(*schemas.StreamingCacheEntry); ok {
for _, chunkKey := range entry.ResponseChunks {
if cachedChunk, found := r.cache.Get(chunkKey); found {
if chunk, ok := cachedChunk.(*schemas.ChatStreamChunk); ok {
respC <- schemas.NewChatStreamChunk(
req.ID,
r.routerID,
req.Metadata,
chunk,
)
} else {
log.Println("Failed to cast cached chunk to ChatStreamChunk")
}
}
}

if entry.Complete {
return
}
} else {
log.Println("Failed to cast cached entry to StreamingCacheEntry")
}
} else {
streamingCacheEntry := &schemas.StreamingCacheEntry{
Key: cacheKey,
Query: req.Message.Content,
ResponseChunks: []string{},
Complete: false,
}
r.cache.Set(cacheKey, streamingCacheEntry)
}

retryIterator := r.retry.Iterator()

for retryIterator.HasNext() {
Expand Down Expand Up @@ -183,6 +220,7 @@ func (r *LangRouter) ChatStream(
continue
}

buffer := []schemas.ChatStreamChunk{}
for chunkResult := range modelRespC {
err = chunkResult.Error()
if err != nil {
Expand All @@ -193,9 +231,6 @@ func (r *LangRouter) ChatStream(
zap.Error(err),
)

// It's challenging to hide an error in case of streaming chat as consumer apps
// may have already used all chunks we streamed this far (e.g. showed them to their users like OpenAI UI does),
// so we cannot easily restart that process from scratch
respC <- schemas.NewChatStreamError(
req.ID,
r.routerID,
Expand All @@ -209,25 +244,52 @@ func (r *LangRouter) ChatStream(
}

chunk := chunkResult.Chunk()

buffer = append(buffer, *chunk)
respC <- schemas.NewChatStreamChunk(
req.ID,
r.routerID,
req.Metadata,
chunk,
)

if len(buffer) >= 1048 { // Define bufferSize as per your requirement
chunkKey := fmt.Sprintf("%s-chunk-%d", cacheKey, len(buffer))
r.cache.Set(chunkKey, &schemas.StreamingCacheEntryChunk{
Key: chunkKey,
Index: len(buffer),
Content: *chunk,
})
streamingCacheEntry := schemas.StreamingCacheEntry{}
streamingCacheEntry.ResponseChunks = append(streamingCacheEntry.ResponseChunks, chunkKey)
buffer = buffer[:0] // Reset buffer
r.cache.Set(cacheKey, streamingCacheEntry)
}
}

if len(buffer) > 0 {
chunkKey := fmt.Sprintf("%s-chunk-%d", cacheKey, len(buffer))
r.cache.Set(chunkKey, &schemas.StreamingCacheEntryChunk{
Key: chunkKey,
Index: len(buffer),
Content: buffer[0], // Assuming buffer has at least one element
})
streamingCacheEntry := schemas.StreamingCacheEntry{}
streamingCacheEntry.ResponseChunks = append(streamingCacheEntry.ResponseChunks, chunkKey)
buffer = buffer[:0] // Reset buffer
r.cache.Set(cacheKey, streamingCacheEntry)
}

streamingCacheEntry := schemas.StreamingCacheEntry{}
streamingCacheEntry.Complete = true
r.cache.Set(cacheKey, streamingCacheEntry)

return
}

// no providers were available to handle the request,
// so we have to wait a bit with a hope there is some available next time
r.logger.Warn("No healthy model found to serve streaming chat request, wait and retry")

err := retryIterator.WaitNext(ctx)
if err != nil {
// something has cancelled the context
respC <- schemas.NewChatStreamError(
req.ID,
r.routerID,
Expand All @@ -241,7 +303,6 @@ func (r *LangRouter) ChatStream(
}
}

// if we reach this part, then we are in trouble
r.logger.Error(
"No model was available to handle streaming chat request. " +
"Try to configure more fallback models to avoid this",
Expand Down

0 comments on commit c070673

Please sign in to comment.