diff --git a/pkg/api/schemas/chat_stream.go b/pkg/api/schemas/chat_stream.go index 983d2242..ed43b99d 100644 --- a/pkg/api/schemas/chat_stream.go +++ b/pkg/api/schemas/chat_stream.go @@ -24,6 +24,20 @@ var ( UnknownError ErrorCode = "unknown_error" ) +type StreamingCacheEntry struct { + Key string + Query string + ResponseChunks []string + Complete bool +} + +type StreamingCacheEntryChunk struct { + Key string + Index int + Content ChatStreamChunk + Complete bool +} + type StreamRequestID = string // ChatStreamRequest defines a message that requests a new streaming chat diff --git a/pkg/cache/memory_cache.go b/pkg/cache/memory_cache.go index f8d6133b..3d2045ed 100644 --- a/pkg/cache/memory_cache.go +++ b/pkg/cache/memory_cache.go @@ -1,51 +1,27 @@ package cache -import ( - "fmt" - "sync" - "time" - - "github.com/EinStack/glide/pkg/api/schemas" -) - -type CacheEntry struct { - Response schemas.ChatResponse - Timestamp time.Time -} +import "sync" type MemoryCache struct { - cache map[string]CacheEntry - mux sync.Mutex + cache map[string]interface{} + lock sync.RWMutex } func NewMemoryCache() *MemoryCache { return &MemoryCache{ - cache: make(map[string]CacheEntry), - } -} - -func (m *MemoryCache) Get(key string) (schemas.ChatResponse, bool) { - m.mux.Lock() - defer m.mux.Unlock() - entry, exists := m.cache[key] - if !exists { - return schemas.ChatResponse{}, false + cache: make(map[string]interface{}), } - return entry.Response, true } -func (m *MemoryCache) Set(key string, response schemas.ChatResponse) { - m.mux.Lock() - defer m.mux.Unlock() - m.cache[key] = CacheEntry{ - Response: response, - Timestamp: time.Now(), - } +func (m *MemoryCache) Get(key string) (interface{}, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + val, found := m.cache[key] + return val, found } -func (m *MemoryCache) All() { - m.mux.Lock() - defer m.mux.Unlock() - - fmt.Println("%v", m.cache) +func (m *MemoryCache) Set(key string, value interface{}) { + m.lock.Lock() + defer m.lock.Unlock() + m.cache[key] = value } diff --git a/pkg/routers/router.go b/pkg/routers/router.go index a6ef9dff..00ed4f90 100644 --- a/pkg/routers/router.go +++ b/pkg/routers/router.go @@ -3,6 +3,7 @@ package routers import ( "context" "errors" + "fmt" "log" "github.com/EinStack/glide/pkg/cache" @@ -78,7 +79,11 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem cacheKey := req.Message.Content if cachedResponse, found := r.cache.Get(cacheKey); found { log.Println("found cached response and returning: ", cachedResponse) - return &cachedResponse, nil + if response, ok := cachedResponse.(*schemas.ChatResponse); ok { + return response, nil + } else { + log.Println("Failed to cast cached response to ChatResponse") + } } retryIterator := r.retry.Iterator() @@ -112,14 +117,13 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem zap.String("provider", langModel.Provider()), zap.Error(err), ) - continue } resp.RouterID = r.routerID // Store response in cache - r.cache.Set(cacheKey, *resp) + r.cache.Set(cacheKey, resp) return resp, nil } @@ -152,10 +156,43 @@ func (r *LangRouter) ChatStream( req.Metadata, &schemas.ErrorReason, ) - return } + cacheKey := req.Message.Content + if streamingCacheEntry, found := r.cache.Get(cacheKey); found { + if entry, ok := streamingCacheEntry.(*schemas.StreamingCacheEntry); ok { + for _, chunkKey := range entry.ResponseChunks { + if cachedChunk, found := r.cache.Get(chunkKey); found { + if chunk, ok := cachedChunk.(*schemas.ChatStreamChunk); ok { + respC <- schemas.NewChatStreamChunk( + req.ID, + r.routerID, + req.Metadata, + chunk, + ) + } else { + log.Println("Failed to cast cached chunk to ChatStreamChunk") + } + } + } + + if entry.Complete { + return + } + } else { + log.Println("Failed to cast cached entry to StreamingCacheEntry") + } + } else { + streamingCacheEntry := &schemas.StreamingCacheEntry{ + Key: cacheKey, + Query: req.Message.Content, + ResponseChunks: []string{}, + Complete: false, + } + r.cache.Set(cacheKey, streamingCacheEntry) + } + retryIterator := r.retry.Iterator() for retryIterator.HasNext() { @@ -183,6 +220,7 @@ func (r *LangRouter) ChatStream( continue } + buffer := []schemas.ChatStreamChunk{} for chunkResult := range modelRespC { err = chunkResult.Error() if err != nil { @@ -193,9 +231,6 @@ func (r *LangRouter) ChatStream( zap.Error(err), ) - // It's challenging to hide an error in case of streaming chat as consumer apps - // may have already used all chunks we streamed this far (e.g. showed them to their users like OpenAI UI does), - // so we cannot easily restart that process from scratch respC <- schemas.NewChatStreamError( req.ID, r.routerID, @@ -209,25 +244,52 @@ func (r *LangRouter) ChatStream( } chunk := chunkResult.Chunk() - + buffer = append(buffer, *chunk) respC <- schemas.NewChatStreamChunk( req.ID, r.routerID, req.Metadata, chunk, ) + + if len(buffer) >= 1048 { // Define bufferSize as per your requirement + chunkKey := fmt.Sprintf("%s-chunk-%d", cacheKey, len(buffer)) + r.cache.Set(chunkKey, &schemas.StreamingCacheEntryChunk{ + Key: chunkKey, + Index: len(buffer), + Content: *chunk, + }) + streamingCacheEntry := schemas.StreamingCacheEntry{} + streamingCacheEntry.ResponseChunks = append(streamingCacheEntry.ResponseChunks, chunkKey) + buffer = buffer[:0] // Reset buffer + r.cache.Set(cacheKey, streamingCacheEntry) + } } + if len(buffer) > 0 { + chunkKey := fmt.Sprintf("%s-chunk-%d", cacheKey, len(buffer)) + r.cache.Set(chunkKey, &schemas.StreamingCacheEntryChunk{ + Key: chunkKey, + Index: len(buffer), + Content: buffer[0], // Assuming buffer has at least one element + }) + streamingCacheEntry := schemas.StreamingCacheEntry{} + streamingCacheEntry.ResponseChunks = append(streamingCacheEntry.ResponseChunks, chunkKey) + buffer = buffer[:0] // Reset buffer + r.cache.Set(cacheKey, streamingCacheEntry) + } + + streamingCacheEntry := schemas.StreamingCacheEntry{} + streamingCacheEntry.Complete = true + r.cache.Set(cacheKey, streamingCacheEntry) + return } - // no providers were available to handle the request, - // so we have to wait a bit with a hope there is some available next time r.logger.Warn("No healthy model found to serve streaming chat request, wait and retry") err := retryIterator.WaitNext(ctx) if err != nil { - // something has cancelled the context respC <- schemas.NewChatStreamError( req.ID, r.routerID, @@ -241,7 +303,6 @@ func (r *LangRouter) ChatStream( } } - // if we reach this part, then we are in trouble r.logger.Error( "No model was available to handle streaming chat request. " + "Try to configure more fallback models to avoid this",